Update TensorRT-LLM (#1358)

Co-authored-by: Kaiyu <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
石晓伟 2024-03-26 20:47:14 +08:00 committed by GitHub
parent 66ca3378c6
commit 850b6fa1e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
328 changed files with 436793 additions and 6630 deletions

View File

@ -34,7 +34,7 @@
- Optimize AllReduce for parallel attention on Falcon and GPT-J
- Enable split-k for weight-only cutlass kernel when SM>=75
* Documentation
- Add [documentation for new builder workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/new_workflow.md)
- Add [documentation for convert/build workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/checkpoint.md)
## Versions 0.6.0 / 0.6.1

View File

@ -282,6 +282,7 @@ The list of supported models is:
* [InternLM](examples/internlm)
* [LLaMA](examples/llama)
* [LLaMA-v2](examples/llama)
* [Mamba](examples/mamba)
* [mBART](examples/enc_dec)
* [Mistral](examples/llama#mistral-v01)
* [MPT](examples/mpt)
@ -454,7 +455,7 @@ For example: `mpirun -n 1 python3 examples/run.py ...`
- Support FP16 fMHA on NVIDIA V100 GPU
* API
- Add a set of High-level APIs for end-to-end generation tasks (see examples/high-level-api/README.md)
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/new_workflow.md)
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/checkpoint.md)
- **[BREAKING CHANGES]** Deprecate `LayerNorm` and `RMSNorm` plugins and removed corresponding build parameters
- **[BREAKING CHANGES]** Remove optional parameter `maxNumSequences` for GPT manager
* Bug fixes
@ -482,7 +483,7 @@ For example: `mpirun -n 1 python3 examples/run.py ...`
- Batch manager arguments documentation updates
- Add documentation for best practices for tuning the performance of TensorRT-LLM (See docs/source/perf_best_practices.md)
- Add documentation for Falcon AWQ support (See examples/falcon/README.md)
- Update to the `docs/source/new_workflow.md` documentation
- Update to the `docs/source/checkpoint.md` documentation
- Update AWQ INT4 weight only quantization documentation for GPT-J
- Add blog: Speed up inference with SOTA quantization techniques in TRT-LLM
- Refine TensorRT-LLM backend README structure #133

View File

@ -19,8 +19,10 @@ set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
add_custom_target(benchmarks)
set(CXXOPTS_SRC_DIR ${PROJECT_SOURCE_DIR}/../3rdparty/cxxopts)
add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)
if(NOT TARGET cxxopts::cxxopts)
set(CXXOPTS_SRC_DIR ${PROJECT_SOURCE_DIR}/../3rdparty/cxxopts)
add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)
endif()
function(add_benchmark test_name test_src)
add_executable(${test_name} ${test_src})

View File

@ -127,6 +127,7 @@ python prepare_dataset.py \
For `tokenizer`, specifying the path to the local tokenizer that have already been downloaded, or simply the name of the tokenizer from HuggingFace like `meta-llama/Llama-2-7b` will both work. The tokenizer will be downloaded automatically for the latter case.
#### Prepare TensorRT-LLM engines
Please make sure that the engines are built with argument `--use_inflight_batching` and `--remove_input_padding` if you'd like to benchmark inflight batching, for more details, please see the document in TensorRT-LLM examples.
@ -187,3 +188,130 @@ Take GPT-350M as an example for single GPU with static batching
--static_emulated_timeout 100 \
--dataset ../../benchmarks/cpp/tokens-fixed-lengths.json
```
#### Benchmarking LoRA
Using either of the `prepare_dataset.py` methods above, add `--rand-task-id <start-id> <end-id>` to the command. This will add a random `task_id` from `<start-id>` to `<end-id>` inclusive.
You can then use `utils/generate_rand_loras.py` to generate random LoRA weights for benchmarking purposes. `utils/generate_rand_loras.py` takes an example LoRA for the model you are benchmarking.
Then you can run `gptManagerBenchmark` with `--type IFB` and `--lora_dir /path/to/utils/generate_rand_loras/output`
End-to-end LoRA benchmarking script
```
git-lfs clone https://huggingface.co/meta-llama/Llama-2-13b-hf
git-lfs clone https://huggingface.co/hfl/chinese-llama-2-lora-13b
MODEL_CHECKPOINT=Llama-2-13b-hf
CONVERTED_CHECKPOINT=Llama-2-13b-hf-ckpt
TOKENIZER=Llama-2-13b-hf
LORA_ENGINE=Llama-2-13b-hf-engine
DTYPE=float16
TP=2
PP=1
MAX_LEN=1024
MAX_BATCH=32
MAX_LORA_RANK=32
SOURCE_LORA=chinese-llama-2-lora-13b
CPP_LORA=chinese-llama-2-lora-13b-cpp
EG_DIR=/tmp/lora-eg
# Build lora enabled engine
python examples/llama/convert_checkpoint.py --model_dir ${MODEL_CHECKPOINT} \
--output_dir ${CONVERTED_CHECKPOINT} \
--dtype ${DTYPE} \
--tp_size ${TP} \
--pp_size 1 \
--lora_target_modules attn_qkv \
--max_lora_rank ${MAX_LORA_RANK}
${HOME}/.local/bin/trtllm-build \
--checkpoint_dir ${CONVERTED_CHECKPOINT} \
--output_dir ${LORA_ENGINE} \
--max_batch_size ${MAX_BATCH} \
--max_input_len $MAX_LEN \
--max_output_len $MAX_LEN \
--gpt_attention_plugin float16 \
--paged_kv_cache enable \
--remove_input_padding enable \
--gemm_plugin float16 \
--lora_plugin float16 \
--use_paged_context_fmha enable \
--use_custom_all_reduce disable
NUM_LORAS=(8 16 24 32 64 128 256)
NUM_REQUESTS=1024
# Convert LoRA to cpp format
python examples/gpt/nemo_lora_convert.py \
-i $SOURCE_LORA \
--storage-type $DTYPE \
--write-cpp-runtime-tensors \
-o $CPP_LORA
# Prepare datasets
mkdir -p $EG_DIR/data
# Prepare dataset without lora_task_id
python benchmarks/cpp/prepare_dataset.py \
--output "${EG_DIR}/data/token-norm-dist.json" \
--request-rate -1 \
--time-delay-dist constant \
--tokenizer $TOKENIZER \
token-norm-dist \
--num-requests $NUM_REQUESTS \
--input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24
# Prepare dataset with lora_task_ids from 0 - $nloras
for nloras in ${NUM_LORAS[@]}; do
python benchmarks/cpp/prepare_dataset.py \
--output "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
--request-rate -1 \
--time-delay-dist constant \
--rand-task-id 0 $(( $nloras - 1 )) \
--tokenizer $TOKENIZER \
token-norm-dist \
--num-requests $NUM_REQUESTS \
--input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24
done
# Generate random lora weights for 256 adapters
python benchmarks/cpp/utils/generate_rand_loras.py ${CPP_LORA} ${EG_DIR}/loras 256
# perform benchmarking
# First run inference without LoRAs
mkdir -p ${EG_DIR}/log-base-lora
mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \
cpp/build_Debug/benchmarks/gptManagerBenchmark \
--engine_dir $LORA_ENGINE \
--type IFB \
--dataset "${EG_DIR}/data/token-norm-dist.json" \
--lora_host_cache_bytes 8589934592 \
--lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
--kv_cache_free_gpu_mem_fraction 0.80 \
--log_level info \
--eos_id ${EOS_ID}
# Now run inference with various numbers or loras
# The host cache is set large enough to hold all the LoRAs in lora_dir
# GPU cache is set to hold 32 LoRAs
# This benchmark will preload all the LoRAs into the host cache
# We run inference on a range of active LoRAs exercising different cache miss rates.
for nloras in ${NUM_LORAS[@]}; do
mkdir -p ${EG_DIR}/log-lora-${nloras}
mpirun -n ${TP} --output-filename "${EG_DIR}/log-lora-${nloras}" \
cpp/build_Debug/benchmarks/gptManagerBenchmark \
--engine_dir $LORA_ENGINE \
--type IFB \
--dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
--lora_host_cache_bytes 8589934592 \
--lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
--kv_cache_free_gpu_mem_fraction 0.80 \
--log_level info \
--eos_id ${EOS_ID} \
--lora_dir ${EG_DIR}/loras
done
```

View File

@ -34,6 +34,7 @@
#include <cstdint>
#include <cxxopts.hpp>
#include <iostream>
#include <memory>
#include <nlohmann/json.hpp>
#include <string>
#include <thread>
@ -300,9 +301,13 @@ struct BenchInfo
class Recorder
{
using TensorPtr = ITensor::SharedPtr;
public:
explicit Recorder(std::string opCsvFile)
explicit Recorder(std::string opCsvFile, std::string responsesJsonFile = "", bool excludeInputInOutput = false)
: mOpCsvFile(std::move(opCsvFile))
, mRespJsonFile(std::move(responsesJsonFile))
, mOutputHasInput(!excludeInputInOutput)
{
}
@ -328,6 +333,10 @@ public:
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
}
// number of output tokens not calculated from output sequence here, instead set to max_output_len
// - if eos_id == -1 (default behavior), this is correct since output seq will have max permissible length.
// - However, if eos_id != -1, the token size of output sequence may be less than max_output_len, and token
// throughput may be inaccurate
void recordStart(SizeType inputLength, SizeType maxNewTokens, uint64_t requestId,
std::chrono::time_point<std::chrono::steady_clock> const& start)
{
@ -342,20 +351,62 @@ public:
.count();
}
void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors)
{
this->recordEnd(requestId);
if (mRespJsonFile.empty())
return;
int32_t outputSeqLen;
for (auto& tensor : responseTensors)
{
if (tensor.name == inference_request::kOutputIdsTensorName)
{
mResponseTensors[requestId] = tensor.tensor;
}
else if (tensor.name == inference_request::kSequenceLengthTensorName)
{
// Tensor of shape nBeams, and we only need the first one
outputSeqLen = *(bufferCast<int32_t>(*(tensor.tensor)));
if (mOutputHasInput)
{
int inputSeqLen = mRequestBenchInfos[requestId].inputLength;
outputSeqLen -= inputSeqLen;
}
mRequestBenchInfos[requestId].outputLength = outputSeqLen;
}
}
}
float calcPercentile(std::vector<float> const& latencies, int percentile)
{
int const index = static_cast<int>(std::ceil((percentile / 100.0) * latencies.size())) - 1;
return latencies[index];
}
void calculateMetrics()
{
mNumSamples = mRequestBenchInfos.size();
mTotalLatency = std::chrono::duration<float, std::milli>(mEnd - mStart).count();
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
mAvgSeqLatency = 0;
std::vector<float> reqLatencies;
int totalOutputTokens = 0;
for (auto reqInfo : mRequestBenchInfos)
{
mAvgSeqLatency += reqInfo.second.latency;
reqLatencies.push_back(reqInfo.second.latency);
totalOutputTokens += reqInfo.second.outputLength;
}
mAvgSeqLatency /= mNumSamples;
mTokenThroughput = totalOutputTokens / (mTotalLatency / 1000);
std::sort(reqLatencies.begin(), reqLatencies.end());
mP99SeqLatency = calcPercentile(reqLatencies, 99);
mP90SeqLatency = calcPercentile(reqLatencies, 90);
mP50SeqLatency = calcPercentile(reqLatencies, 50);
}
void report()
@ -363,8 +414,11 @@ public:
printf("[BENCHMARK] num_samples %d\n", mNumSamples);
printf("[BENCHMARK] total_latency(ms) %.2f\n", mTotalLatency);
printf("[BENCHMARK] seq_throughput(seq/sec) %.2f\n", mSeqThroughput);
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n", mTokenThroughput);
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
printf("[BENCHMARK] p99_sequence_latency(ms) %.2f\n", mP99SeqLatency);
printf("[BENCHMARK] p90_sequence_latency(ms) %.2f\n", mP90SeqLatency);
printf("[BENCHMARK] p50_sequence_latency(ms) %.2f\n", mP50SeqLatency);
}
void writeOpMetricsToCsv()
@ -372,7 +426,8 @@ public:
if (!mOpCsvFile.empty())
{
std::vector<std::string> headers = {"num_samples", "total_latency(ms)", "seq_throughput(seq/sec)",
"avg_sequence_latency(ms)", "token_throughput(token/sec)"};
"token_throughput(token/sec)", "avg_sequence_latency(ms)", "p99_sequence_latency(ms)",
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)"};
std::ofstream outputFile(mOpCsvFile);
@ -383,8 +438,9 @@ public:
outputFile << header << ",";
}
outputFile << "\n";
outputFile << mNumSamples << "," << mTotalLatency << "," << mSeqThroughput << "," << mAvgSeqLatency
<< "," << mTokenThroughput;
outputFile << mNumSamples << "," << mTotalLatency << "," << mSeqThroughput << "," << mTokenThroughput
<< "," << mAvgSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << ","
<< mP50SeqLatency;
outputFile << "\n";
}
else
@ -394,6 +450,32 @@ public:
}
}
void dumpResponseSeqs()
{
if (mRespJsonFile.empty())
return;
nlohmann::json jsonResponses = nlohmann::json::array();
for (auto const& [respId, respTokensTensor] : mResponseTensors)
{
int inputLength = mRequestBenchInfos[respId].inputLength;
int outputLength = mRequestBenchInfos[respId].outputLength;
std::vector<int32_t> outputTokens(outputLength);
int32_t* outputToksBufferPtr = bufferCast<int32_t>(*respTokensTensor);
if (mOutputHasInput)
outputToksBufferPtr += inputLength;
std::copy(outputToksBufferPtr, outputToksBufferPtr + outputLength, outputTokens.begin());
nlohmann::json currResp;
currResp["response_id"] = respId;
currResp["response_tokens"] = outputTokens;
jsonResponses.push_back(currResp);
}
std::ofstream outFile(mRespJsonFile);
outFile << jsonResponses;
outFile.close();
}
private:
std::unordered_map<uint64_t, BenchInfo> mRequestBenchInfos;
@ -404,7 +486,14 @@ private:
float mSeqThroughput{};
float mAvgSeqLatency{};
float mTokenThroughput{};
float mP99SeqLatency{};
float mP90SeqLatency{};
float mP50SeqLatency{};
std::string mOpCsvFile;
std::string mRespJsonFile;
std::unordered_map<uint64_t, TensorPtr> mResponseTensors;
bool mOutputHasInput;
}; // class Recorder
class ExecutorServer
@ -430,7 +519,7 @@ public:
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
executorConfig.setPeftCacheConfig(peftCacheConfig);
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
mExecutor = std::make_unique<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
if (logIterationData)
{
@ -519,7 +608,7 @@ public:
}
private:
std::shared_ptr<texec::Executor> mExecutor;
std::unique_ptr<texec::Executor> mExecutor;
std::thread mCollectStatsThread;
std::shared_ptr<Recorder> mRecorder;
std::chrono::milliseconds mWaitSleep;
@ -535,7 +624,7 @@ public:
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
std::optional<SizeType> const staticEmulatedBatchSize,
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData)
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData, bool excludeInputInOutput)
: mRecorder(std::move(recorder))
, mTerminateReqId(terminateReqId)
, mWaitSleep(waitSleep)
@ -564,7 +653,7 @@ public:
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
std::string const& errMsg)
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
nullptr, iterationDataCallback, optionalParams, terminateReqId);
nullptr, iterationDataCallback, optionalParams, terminateReqId, std::nullopt, excludeInputInOutput);
}
~GptServer()
@ -729,7 +818,7 @@ public:
if (final_response)
{
mWorkItemsQueue.markFinished(requestId);
mRecorder->recordEnd(requestId);
mRecorder->recordEnd(requestId, response_tensors);
mActiveCount--;
}
}
@ -852,7 +941,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
bool logIterationData)
bool logIterationData, bool excludeInputInOutput, std::string const& responsesJsonFile)
{
auto const worldConfig = WorldConfig::mpi();
@ -885,10 +974,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
auto const numSamples = samples.size();
int const maxBeamWidth = beamWidth;
auto recorder = std::make_shared<Recorder>(opCsvFile);
auto recorder = std::make_shared<Recorder>(opCsvFile, responsesJsonFile, excludeInputInOutput);
uint64_t terminateReqId = numSamples + 1;
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData);
auto gptServer
= std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder,
terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData, excludeInputInOutput);
ITensor::SharedPtr eosIdTensor{
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
@ -962,6 +1052,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
recorder->calculateMetrics();
recorder->report();
recorder->writeOpMetricsToCsv();
recorder->dumpResponseSeqs();
// Send terminateReqId to terminate servers on all ranks
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
@ -1154,6 +1245,14 @@ int main(int argc, char* argv[])
options.add_options()("lora_host_cache_bytes", "LoRA host cache memory in bytes", cxxopts::value<size_t>());
options.add_options()("lora_num_device_mod_layers", "LoRA number 1d cache rows", cxxopts::value<int>());
options.add_options()("exclude_input_in_output_seq",
"When enabled, GptManager will exclude the input sequence from output. (Only works if --api is gptManager)",
cxxopts::value<bool>());
options.add_options()("responses_json_file",
"When specified, dumps the responses to JSON file. (only works if --api is gptManager)",
cxxopts::value<std::string>()->default_value(""));
auto result = options.parse(argc, argv);
if (result.count("help"))
@ -1329,7 +1428,8 @@ int main(int argc, char* argv[])
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
logIterationData);
logIterationData, result["exclude_input_in_output_seq"].as<bool>(),
result["responses_json_file"].as<std::string>());
}
catch (std::exception const& e)
{

View File

@ -29,7 +29,9 @@ from tensorrt_llm.functional import AllReduceStrategy, allreduce
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
def allreduce_benchmark(dtype: str,
test_range: str = "10,10000000,10",
no_header: bool = False):
tllm.logger.set_level('error')
world_size = tllm.mpi_world_size()
rank = tllm.mpi_rank()
@ -48,11 +50,15 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
size = min_size
dtype_size = torch.finfo(torch_dtype).bits // 8
if mapping.rank == 0 and not no_header:
print(
f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<15}, {'duration (ms)':<10}"
)
while size < max_size:
input = torch.ones(size, dtype=torch_dtype, device="cuda")
for strategy in [
AllReduceStrategy.RING, AllReduceStrategy.ONESHOT,
AllReduceStrategy.NCCL, AllReduceStrategy.ONESHOT,
AllReduceStrategy.TWOSHOT
]:
builder = tllm.Builder()
@ -95,33 +101,40 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
session = tllm.runtime.Session.from_engine(build_engine())
_, start = cuda.cuEventCreate(0)
_, stop = cuda.cuEventCreate(0)
runtimes = []
with peer_access(mapping):
MPI.COMM_WORLD.barrier()
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
for _ in range(10):
cuda.cuEventRecord(start, stream.cuda_stream)
session.run(inputs=feed_dict,
outputs={"output": output},
stream=stream.cuda_stream)
cuda.cuEventRecord(stop, stream.cuda_stream)
torch.cuda.synchronize()
_, ms = cuda.cuEventElapsedTime(start, stop)
runtimes.append(ms)
median_ms = sorted(runtimes)[len(runtimes) // 2]
assert torch.allclose(output, (input * world_size)**inner_loop)
if mapping.rank == 0:
print(f"{size=}, {strategy=}, {ms=}")
print(
f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<15}, {median_ms:<10.2f}"
)
size *= ratio
if mapping.rank == 0:
print("")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--dtype", "-t", default="float16")
parser.add_argument("--range",
"-r",
default="256,25600000,10",
help="min_size,max_size,multiplicative_ratio")
parser.add_argument(
"--range",
"-r",
default="256,256000000,10", # 256 to 256M
help="min_size,max_size,multiplicative_ratio")
parser.add_argument("--no-header", action="store_true")
args = parser.parse_args()
allreduce_benchmark(args.dtype, args.range)
allreduce_benchmark(args.dtype, args.range, args.no_header)

View File

@ -293,6 +293,42 @@ _allowed_configs = {
pre_norm=True,
do_layer_norm_before=True,
)),
"starcoder":
ModelConfig(name="starcoder_15.5b",
family="gpt",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=40,
num_heads=48,
num_kv_heads=1,
hidden_size=6144,
vocab_size=49152,
hidden_act='gelu',
n_positions=8192,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"starcoder2_3b":
ModelConfig(name="starcoder2_3b",
family="gpt",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=30,
num_heads=24,
num_kv_heads=2,
hidden_size=3072,
vocab_size=49152,
hidden_act='gelu',
n_positions=16384,
position_embedding_type='rope_gpt_neox',
rotary_pct=1.0,
max_batch_size=256,
max_input_len=512,
max_output_len=200,
builder_opt=None,
)),
"llama_7b":
ModelConfig(name="llama_7b",
family="llama",

View File

@ -37,7 +37,7 @@ from tensorrt_llm.models import PretrainedConfig, quantize_model
from tensorrt_llm.models.modeling_utils import optimize_model
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.quantization import QuantAlgo, QuantMode
def parse_arguments():
@ -224,23 +224,23 @@ def get_quant_mode(quantization):
def get_quant_algo(quantization):
if quantization == "fp8":
return "FP8", "FP8"
return QuantAlgo.FP8, QuantAlgo.FP8
elif quantization == "fp8_gemm":
return "FP8", None
return QuantAlgo.FP8, None
elif quantization == "fp8_kv_cache":
return None, "FP8"
return None, QuantAlgo.FP8
elif quantization == "int8_sq_per_tensor":
return "W8A8_SQ_PER_TENSOR_PLUGIN", None
return QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN, None
elif quantization == "int8_sq_per_token_channel":
return "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN", None
return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, None
elif quantization == "int8_weight_only":
return "W8A16", None
return QuantAlgo.W8A16, None
elif quantization == "int4_weight_only":
return "W4A16", None
return QuantAlgo.W4A16, None
elif quantization == "int4_weight_only_awq":
return "W4A16_AWQ", None
return QuantAlgo.W4A16_AWQ, None
elif quantization == "int4_weight_only_gptq":
return "W4A16_GPTQ", None
return QuantAlgo.W4A16_GPTQ, None
elif quantization is None:
return None, None
@ -764,13 +764,11 @@ def build_gpt(args):
else:
raise Exception(f'Unexpected model: {args.model}')
quant_kwargs = {}
if family not in [
'gpt', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
]:
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
**quant_kwargs)
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode)
if family in ['llama']:
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
@ -788,6 +786,7 @@ def build_gpt(args):
network.plugin_config.enable_remove_input_padding()
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
network.plugin_config.set_moe_plugin(dtype=args.dtype)
network.plugin_config.set_mamba_conv1d_plugin(dtype=args.dtype)
if args.quantization is None or "fp8" not in args.quantization:
network.plugin_config.set_gemm_plugin(dtype=args.dtype)

View File

@ -96,6 +96,7 @@ class GPTBenchmark(BaseBenchmark):
self.use_gpt_attention_plugin = True
self.remove_input_padding = True
self.use_moe_plugin = True
self.use_mamba_conv1d_plugin = True
elif args.mode == 'ootb-except-mha':
self.use_gpt_attention_plugin = True
@ -121,6 +122,8 @@ class GPTBenchmark(BaseBenchmark):
gpt_attention_plugin=self.use_gpt_attention_plugin,
paged_kv_cache=self.paged_kv_cache if hasattr(
self, 'paged_kv_cache') else False,
paged_state=self.paged_state
if hasattr(self, 'paged_state') else False,
dtype=self.dtype,
remove_input_padding=self.remove_input_padding,
quant_mode=self.quant_mode,
@ -148,8 +151,6 @@ class GPTBenchmark(BaseBenchmark):
model_config.mamba_d_state = self.mamba_d_state
model_config.mamba_d_conv = self.mamba_d_conv
model_config.mamba_expand = self.mamba_expand
self.remove_input_padding = False
model_config.remove_input_padding = False
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=0, pad_id=0, top_k=args.top_k, top_p=args.top_p)
self.decoder = tensorrt_llm.runtime.MambaLMHeadModelGenerationSession(

View File

@ -25,6 +25,7 @@
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
//! @brief Encapsulates parameters to configure paged KV cache.
class KvCacheConfig
{
public:

View File

@ -53,7 +53,7 @@ public:
using VecLogProbs = std::vector<float>;
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor;
using LogitsPostProcessor = std::function<TensorPtr(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/medusaModule.h"
#include <optional>
#include <vector>
@ -40,7 +41,8 @@ public:
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt,
bool normalizeLogProbs = true, bool enableChunkedContext = false,
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt,
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{})
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
std::optional<runtime::MedusaModule::MedusaChoices> const& medusaChoices = std::nullopt)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
@ -48,6 +50,7 @@ public:
, enableChunkedContext{enableChunkedContext}
, decodingMode{decodingMode}
, peftCacheManagerConfig(peftCacheManagerConfig)
, medusaChoices(medusaChoices)
{
}
@ -55,7 +58,8 @@ public:
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(), std::nullopt,
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig()))
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
executorConfig.getMedusaChoices())
{
}
@ -74,6 +78,7 @@ public:
bool enableChunkedContext;
std::optional<runtime::DecodingMode> decodingMode;
PeftCacheManagerConfig peftCacheManagerConfig;
std::optional<runtime::MedusaModule::MedusaChoices> medusaChoices;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -64,15 +64,17 @@ std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2))
#define __PRETTY_FUNCTION__ __FUNCSIG__
#endif
auto constexpr kDefaultDelimiter = ", ";
template <typename U, typename TStream, typename T>
inline TStream& arr2outCasted(TStream& out, T* arr, size_t size)
inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
{
out << "(";
if (size > 0)
{
for (size_t i = 0; i < size - 1; ++i)
{
out << static_cast<U>(arr[i]) << ", ";
out << static_cast<U>(arr[i]) << delim;
}
out << static_cast<U>(arr[size - 1]);
}
@ -81,22 +83,22 @@ inline TStream& arr2outCasted(TStream& out, T* arr, size_t size)
}
template <typename TStream, typename T>
inline TStream& arr2out(TStream& out, T* arr, size_t size)
inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
{
return arr2outCasted<T>(out, arr, size);
return arr2outCasted<T>(out, arr, size, delim);
}
template <typename T>
inline std::string arr2str(T* arr, size_t size)
inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter)
{
std::stringstream ss;
return arr2out(ss, arr, size).str();
return arr2out(ss, arr, size, delim).str();
}
template <typename T>
inline std::string vec2str(std::vector<T> vec)
inline std::string vec2str(std::vector<T> vec, char const* delim = kDefaultDelimiter)
{
return arr2str(vec.data(), vec.size());
return arr2str(vec.data(), vec.size(), delim);
}
inline bool strStartsWith(std::string const& str, std::string const& prefix)

View File

@ -27,8 +27,6 @@
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <variant>
#include <vector>
namespace tensorrt_llm::executor
@ -43,18 +41,19 @@ class SamplingConfig
public:
/// @brief Constructor for SamplingConfig
/// See description of parameters below
SamplingConfig(SizeType beamWidth = 1, std::optional<SizeType> topK = std::nullopt,
std::optional<FloatType> topP = std::nullopt, std::optional<FloatType> topPMin = std::nullopt,
std::optional<SizeType> topPResetIds = std::nullopt, std::optional<FloatType> topPDecay = std::nullopt,
std::optional<RandomSeedType> randomSeed = std::nullopt, std::optional<FloatType> temperature = std::nullopt,
std::optional<SizeType> minLength = std::nullopt,
std::optional<FloatType> beamSearchDiversityRate = std::nullopt,
std::optional<FloatType> repetitionPenalty = std::nullopt,
std::optional<FloatType> presencePenalty = std::nullopt,
std::optional<FloatType> frequencyPenalty = std::nullopt, std::optional<FloatType> lengthPenalty = std::nullopt,
std::optional<SizeType> earlyStopping = std::nullopt);
~SamplingConfig();
explicit SamplingConfig(SizeType beamWidth = 1, std::optional<SizeType> const& topK = std::nullopt,
std::optional<FloatType> const& topP = std::nullopt, std::optional<FloatType> const& topPMin = std::nullopt,
std::optional<SizeType> const& topPResetIds = std::nullopt,
std::optional<FloatType> const& topPDecay = std::nullopt,
std::optional<RandomSeedType> const& randomSeed = std::nullopt,
std::optional<FloatType> const& temperature = std::nullopt,
std::optional<SizeType> const& minLength = std::nullopt,
std::optional<FloatType> const& beamSearchDiversityRate = std::nullopt,
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
std::optional<FloatType> const& presencePenalty = std::nullopt,
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
std::optional<FloatType> const& lengthPenalty = std::nullopt,
std::optional<SizeType> const& earlyStopping = std::nullopt);
bool operator==(SamplingConfig const& other) const;
@ -91,23 +90,24 @@ private:
std::optional<FloatType> mTopPDecay;
/// @brief Controls the random seed used by the random number generator in sampling
std::optional<RandomSeedType> mRandomSeed;
/// @brief Controls the modulation of logits when sampling new tokens. Default is 1.0f
/// @brief Controls the modulation of logits when sampling new tokens. It can have values > 0.f. Default is 1.0f
std::optional<FloatType> mTemperature;
/// @brief Lower bound on the number of tokens to generate
/// @brief Lower bound on the number of tokens to generate. Values < 1 have no effect. Default is 1.
std::optional<SizeType> mMinLength;
/// @brief Controls the diversity in beam search.
std::optional<FloatType> mBeamSearchDiversityRate;
/// @brief Used to penalize tokens based on how often they appear in the sequence. Default is 0.f
/// @brief Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f.
/// Values < 1.f encourages repetition, values > 1.f discourages it. Default is 1.f
std::optional<FloatType> mRepetitionPenalty;
/// @brief Used to penalize tokens already present in the sequence (irrespective of the number of appearances).
/// Default is 0.f
/// @brief Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It
/// can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
std::optional<FloatType> mPresencePenalty;
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). Default
/// is 0.f
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
std::optional<FloatType> mFrequencyPenalty;
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
std::optional<FloatType> mLengthPenalty;
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (end with
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with
/// end_token)
std::optional<SizeType> mEarlyStopping;
};
@ -116,12 +116,12 @@ private:
class OutputConfig
{
public:
OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
bool excludeInputFromOutput = false);
explicit OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false,
bool returnGenerationLogits = false, bool excludeInputFromOutput = false);
/// @brief Controls if Result should contain log probabilities. Default is false
/// @brief Controls if Result should contain log probabilities. Default is false.
bool returnLogProbs;
/// @brief Controls if Result should contain the context logits. Default is false
/// @brief Controls if Result should contain the context logits. Default is false.
bool returnContextLogits;
/// @brief Controls if Result should contain the generation logits. Default is false.
bool returnGenerationLogits;
@ -135,9 +135,7 @@ class SpeculativeDecodingConfig
{
public:
explicit SpeculativeDecodingConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
std::optional<FloatType> acceptanceThreshold = std::nullopt);
~SpeculativeDecodingConfig();
std::optional<FloatType> const& acceptanceThreshold = std::nullopt);
[[nodiscard]] VecTokens getTokens() const;
[[nodiscard]] std::optional<Tensor> getLogits() const;
@ -147,9 +145,9 @@ private:
friend class Serialization;
/// @brief The draft tokens
VecTokens mTokens;
/// @brief The draft logits
/// @brief The draft logits. Expected shape: [num_draft_tokens, vocab_size].
std::optional<Tensor> mLogits;
/// @brief The acceptance threshold
/// @brief The acceptance threshold. Must be > 0.f and <= 1.f
std::optional<FloatType> mAcceptanceThreshold;
};
@ -157,14 +155,14 @@ private:
class PromptTuningConfig
{
public:
PromptTuningConfig(Tensor embeddingTable);
~PromptTuningConfig();
explicit PromptTuningConfig(Tensor embeddingTable);
[[nodiscard]] Tensor getEmbeddingTable() const;
private:
friend class Serialization;
/// @brief The prompt embedding table
/// @brief The prompt embedding table. Expected shape: [task vocab_size, hidden_size]. Data type must match model
/// weights.
Tensor mEmbeddingTable;
};
@ -172,9 +170,8 @@ private:
class LoraConfig
{
public:
LoraConfig(
explicit LoraConfig(
IdType taskId, std::optional<Tensor> weights = std::nullopt, std::optional<Tensor> config = std::nullopt);
~LoraConfig();
[[nodiscard]] IdType getTaskId() const;
[[nodiscard]] std::optional<Tensor> getWeights() const;
@ -185,9 +182,9 @@ private:
/// @brief The Lora task id
IdType mTaskId;
/// @brief The Lora weights
/// @brief The Lora weights. See TRT-LLM documentation for expected shapes and types
std::optional<Tensor> mWeights;
/// @brief The Lora configuration
/// @brief The Lora configuration. See TRT-LLM documentation for detailed description of the config tensor
std::optional<Tensor> mConfig;
};
@ -199,7 +196,7 @@ public:
/// @param inputTokenIds The input token ids
/// @param maxNewTokens The maximum number of tokens to generate
/// @param streaming Indicates if the responses should be streamed or not
/// @param streaming Indicates if the responses should be streamed or not. Default is false.
/// @param samplingConfig The sampling configuration
/// @param outputConfig The output configuration
/// @param endId The end token id
@ -213,8 +210,8 @@ public:
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
/// name provided to the ExecutorConfig.
Request(VecTokens inputTokenIds, SizeType maxNewTokens, bool streaming = false,
SamplingConfig samplingConfig = SamplingConfig(), OutputConfig outputConfig = OutputConfig(),
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType> const& endId = std::nullopt, std::optional<SizeType> const& padId = std::nullopt,
std::optional<std::list<VecTokens>> badWords = std::nullopt,
std::optional<std::list<VecTokens>> stopWords = std::nullopt,
std::optional<Tensor> embeddingBias = std::nullopt,
@ -245,16 +242,16 @@ public:
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig config);
void setOutputConfig(OutputConfig outputConfig);
void setSamplingConfig(SamplingConfig const& config);
void setOutputConfig(OutputConfig const& outputConfig);
void setEndId(SizeType endId);
void setPadId(SizeType padId);
void setBadWords(std::list<VecTokens> badWords);
void setStopWords(std::list<VecTokens> stopWords);
void setEmbeddingBias(Tensor);
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig specDecodingConfig);
void setPromptTuningConfig(PromptTuningConfig pTuningConfig);
void setLoraConfig(LoraConfig loraConfig);
void setBadWords(std::list<VecTokens> const& badWords);
void setStopWords(std::list<VecTokens> const& stopWords);
void setEmbeddingBias(Tensor const& embeddingBias);
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig const& specDecodingConfig);
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
void setLoraConfig(LoraConfig const& loraConfig);
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
private:
@ -275,7 +272,7 @@ struct Result
/// @brief The cumulative log probabilities. Size beamSize.
std::optional<VecLogProbs> cumLogProbs;
/// @brief The log probabilities for each generated token. Size [beamSize, seqLen]
/// @brief The log probabilities for each generated token. Size [beamSize, outputLen]
std::optional<std::vector<VecLogProbs>> logProbs;
/// @brief The context logits. Size [promptLen, vocabSizePadded]
@ -299,18 +296,18 @@ public:
Response& operator=(Response&& other) noexcept;
/// @brief Get the id of the request for which this response was generated
IdType getRequestId() const;
[[nodiscard]] IdType getRequestId() const;
/// @brief Indicates if this response has an error or not
bool hasError() const;
[[nodiscard]] bool hasError() const;
/// @brief Get the error msg for this response
/// Will throw an exception if hasError is false
std::string getErrorMsg() const;
[[nodiscard]] std::string getErrorMsg() const;
/// @brief Get the result for this response
/// Will throw an exception if hasResult is true
Result getResult() const;
[[nodiscard]] Result getResult() const;
private:
class Impl;
@ -322,7 +319,6 @@ class SchedulerConfig
{
public:
explicit SchedulerConfig(SchedulerPolicy policy = SchedulerPolicy::kGUARANTEED_NO_EVICT);
~SchedulerConfig();
[[nodiscard]] SchedulerPolicy getPolicy() const;
@ -335,10 +331,10 @@ private:
class KvCacheConfig
{
public:
KvCacheConfig(bool enableBlockReuse = false, std::optional<SizeType> maxTokens = std::nullopt,
std::optional<SizeType> maxAttentionWindow = std::nullopt,
std::optional<SizeType> sinkTokenLength = std::nullopt,
std::optional<FloatType> freeGpuMemoryFraction = std::nullopt);
explicit KvCacheConfig(bool enableBlockReuse = false, std::optional<SizeType> const& maxTokens = std::nullopt,
std::optional<SizeType> const& maxAttentionWindow = std::nullopt,
std::optional<SizeType> const& sinkTokenLength = std::nullopt,
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt);
[[nodiscard]] bool getEnableBlockReuse() const;
[[nodiscard]] std::optional<SizeType> getMaxTokens() const;
@ -383,11 +379,10 @@ public:
/// @param deviceIds The IDs of the GPUs involved in the execution of the model
/// @param participantIds The participant IDs (MPI ranks if commType == kMPI) involved in the execution of the
/// model. The first participant is considered to be the leader.
ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
explicit ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
CommunicationMode commMode = CommunicationMode::kLEADER,
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
~ParallelConfig();
[[nodiscard]] CommunicationType getCommunicationType() const;
[[nodiscard]] CommunicationMode getCommunicationMode() const;
@ -396,8 +391,8 @@ public:
void setCommunicationType(CommunicationType type);
void setCommunicationMode(CommunicationMode mode);
void setDeviceIds(std::vector<SizeType> deviceIds);
void setParticipantIds(std::vector<SizeType> participantIds);
void setDeviceIds(std::vector<SizeType> const& deviceIds);
void setParticipantIds(std::vector<SizeType> const& participantIds);
private:
/// @brief The type of communication protocol used. Default is MPI.
@ -417,10 +412,11 @@ private:
class PeftCacheConfig
{
public:
PeftCacheConfig(SizeType numHostModuleLayer = 0, SizeType numDeviceModuleLayer = 0, SizeType optimalAdapterSize = 8,
SizeType maxAdapterSize = 64, SizeType numPutWorkers = 1, SizeType numEnsureWorkers = 1,
SizeType numCopyStreams = 1, SizeType maxPagesPerBlockHost = 24, SizeType maxPagesPerBlockDevice = 8,
std::optional<float> deviceCachePercent = std::nullopt, std::optional<size_t> hostCacheSize = std::nullopt);
explicit PeftCacheConfig(SizeType numHostModuleLayer = 0, SizeType numDeviceModuleLayer = 0,
SizeType optimalAdapterSize = 8, SizeType maxAdapterSize = 64, SizeType numPutWorkers = 1,
SizeType numEnsureWorkers = 1, SizeType numCopyStreams = 1, SizeType maxPagesPerBlockHost = 24,
SizeType maxPagesPerBlockDevice = 8, std::optional<float> const& deviceCachePercent = std::nullopt,
std::optional<size_t> const& hostCacheSize = std::nullopt);
[[nodiscard]] SizeType getNumHostModuleLayer() const;
[[nodiscard]] SizeType getNumDeviceModuleLayer() const;
@ -462,16 +458,16 @@ private:
/// @brief Configuration class for the model executor
class ExecutorConfig
{
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
public:
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
explicit ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig const& schedulerConfig = SchedulerConfig(),
KvCacheConfig const& kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false,
bool normalizeLogProbs = true, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
SizeType requestStatsMaxIterations = kDefaultRequestStatsMaxIterations,
BatchingType batchingType = BatchingType::kINFLIGHT,
std::optional<ParallelConfig> parallelConfig = std::nullopt,
PeftCacheConfig peftCacheConfig = PeftCacheConfig(), LogitsPostProcessorMap = {});
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
std::optional<MedusaChoices> medusaChoices = std::nullopt);
[[nodiscard]] SizeType getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@ -482,20 +478,22 @@ public:
[[nodiscard]] SizeType getRequestStatsMaxIterations() const;
[[nodiscard]] BatchingType getBatchingType() const;
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
[[nodiscard]] PeftCacheConfig getPeftCacheConfig() const;
[[nodiscard]] LogitsPostProcessorMap getLogitsPostProcessorMap() const;
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
void setMaxBeamWidth(SizeType maxBeamWidth);
void setSchedulerConfig(SchedulerConfig schedulerConfig);
void setKvCacheConfig(KvCacheConfig kvCacheConfig);
void setSchedulerConfig(SchedulerConfig const& schedulerConfig);
void setKvCacheConfig(KvCacheConfig const& kvCacheConfig);
void setEnableChunkedContext(bool enableChunkedContext);
void setNormalizeLogProbs(bool normalizeLogProbs);
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
void setRequestStatsMaxIterations(SizeType requestStatsMaxIterations);
void setBatchingType(BatchingType batchingType);
void setParallelConfig(ParallelConfig parallelConfig);
void setPeftCacheConfig(PeftCacheConfig peftCacheConfig);
void setLogitsPostProcessorMap(LogitsPostProcessorMap logitsPostProcessorMap);
void setParallelConfig(ParallelConfig const& parallelConfig);
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
void setMedusaChoices(MedusaChoices const& medusaChoices);
private:
/// @brief The beam width value of requests that will be sent to the executor
@ -524,14 +522,14 @@ private:
/// @brief The parallel execution configuration.
std::optional<ParallelConfig> mParallelConfig;
PeftCacheConfig mPeftCacheConfig;
LogitsPostProcessorMap mLogitsPostProcessorMap;
std::optional<PeftCacheConfig> mPeftCacheConfig;
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
std::optional<MedusaChoices> mMedusaChoices;
};
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
class Executor
{
using RequestPtr = std::shared_ptr<Request>;
public:
/// @brief
@ -539,38 +537,38 @@ public:
/// @param modelType The type of model
/// @param executorConfig The configuration for the executor
/// @param comm An optional inter-process communicator configuration
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig);
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig const& executorConfig);
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
ExecutorConfig executorConfig);
ExecutorConfig const& executorConfig);
Executor(std::shared_ptr<Model> model, ExecutorConfig executorConfig);
Executor(std::shared_ptr<Model> model, ExecutorConfig const& executorConfig);
~Executor();
/// @brief Enqueue a new request
/// @param request The LLM request which contains input tokens and request parameters
/// @return A unique id that identifies the request
IdType enqueueRequest(Request request);
[[nodiscard]] IdType enqueueRequest(Request const& request);
/// @brief Enqueue a batch of request
std::vector<IdType> enqueueRequests(std::vector<Request> requests);
[[nodiscard]] std::vector<IdType> enqueueRequests(std::vector<Request> const& requests);
/// @brief Await for ready responses
/// @param id An optional request id. If not specified, responses for any request can be returned
/// @param timeout The maximum time to wait for new responses
/// @return A vector of responses
std::vector<Response> awaitResponses(
std::optional<IdType> id = std::nullopt, std::optional<std::chrono::milliseconds> timeout = std::nullopt);
[[nodiscard]] std::vector<Response> awaitResponses(std::optional<IdType> const& requestId = std::nullopt,
std::optional<std::chrono::milliseconds> const& timeout = std::nullopt);
/// @brief Get the number of ready responses
/// @param id The request id
/// @param requestId An optional request id
/// @return The number of ready responses
SizeType getNumResponsesReady(std::optional<IdType> id = std::nullopt);
[[nodiscard]] SizeType getNumResponsesReady(std::optional<IdType> const& requestId = std::nullopt) const;
/// @brief Cancel the request with provided request id
/// @param id The request id for which to cancel the response
void cancelRequest(IdType id);
void cancelRequest(IdType requestId);
/// @brief Signals the server to shutdown
/// This call is blocking. Only returns when all requests have terminated or timeout has been reached
@ -586,6 +584,9 @@ public:
/// @return Request stats grouped by iterations
std::deque<RequestStatsPerIteration> getLatestRequestStats();
/// @brief Indicates if the current process is allowed to enqueueRequests
[[nodiscard]] bool canEnqueueRequests() const;
private:
class Impl;
std::unique_ptr<Impl> mImpl;

View File

@ -52,7 +52,9 @@ using IterationType = std::uint64_t;
using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using LogitsPostProcessor = std::function<Tensor(IdType, Tensor&, BeamTokens const&, StreamPtr&)>;
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr&)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
using MedusaChoices = std::vector<std::vector<SizeType>>;
enum class DataType
{
@ -153,15 +155,29 @@ enum class ModelType
kDECODER_ONLY = 0,
};
/// @brief The batching type
enum class BatchingType
{
/// @brief STATIC refers to the traditional batching scheme with a batch of requests running in lockstep until the
/// full generation for all of them is complete. Requests in a batch are all padded up to the maximum input and
/// output sequence length of any member of the batch.
kSTATIC = 0,
/// @brief INFLIGHT refers to a scheme where newly arrived requests are dynamically incorporated into the batch
/// under execution, and requests are returned as soon as the end condition is met without any padding.
kINFLIGHT = 1,
};
/// @brief The policy used to select the subset of available requests in each iteration of the executor generation loop
enum class SchedulerPolicy
{
/// @brief MAX_UTILIZATION packs as many requests as the underlying TRT engine can support in any iteration of the
/// InflightBatching generation loop. While this is expected to maximize GPU throughput, it might require that some
/// requests be paused and restarted depending on peak KV cache memory availability.
kMAX_UTILIZATION = 0,
/// @brief GUARANTEED_NO_EVICT uses KV cache more conservatively guaranteeing that a request, once started, will run
/// to completion without eviction.
kGUARANTEED_NO_EVICT = 1,
};
@ -228,7 +244,7 @@ struct IterationStats
/// @brief Ending time of this iteration
std::string timestamp;
/// @brief Iteration id
SizeType iter;
IterationType iter;
/// @brief Number of active requests
SizeType numActiveRequests;
/// @brief Number of max active requests

View File

@ -16,27 +16,37 @@
#pragma once
#include <NvInferRuntime.h>
#include <cstdint>
#include <mutex>
// Forward declarations
namespace nvinfer1
{
class ILoggerFinder;
class IPluginCreator;
class ILogger;
} // namespace nvinfer1
namespace tensorrt_llm::plugins::api
{
auto constexpr kDefaultNamespace = "tensorrt_llm";
class LoggerFinder : public nvinfer1::ILoggerFinder
class LoggerManager
{
public:
//! Set the logger finder.
void setLoggerFinder(nvinfer1::ILoggerFinder* finder);
//! Get the logger.
nvinfer1::ILogger* findLogger() override;
[[maybe_unused]] nvinfer1::ILogger* logger();
static LoggerFinder& getInstance() noexcept;
static LoggerManager& getInstance() noexcept;
static nvinfer1::ILogger* defaultLogger() noexcept;
private:
LoggerFinder() = default;
LoggerManager() = default;
nvinfer1::ILoggerFinder* mLoggerFinder{nullptr};
std::mutex mMutex;
@ -47,10 +57,11 @@ private:
extern "C"
{
// This function is used for explicitly registering the TRT-LLM plugins and the default logger.
bool initTrtLlmPlugins(void* logger, char const* libNamespace = tensorrt_llm::plugins::api::kDefaultNamespace);
bool initTrtLlmPlugins(void* logger = tensorrt_llm::plugins::api::LoggerManager::defaultLogger(),
char const* libNamespace = tensorrt_llm::plugins::api::kDefaultNamespace);
// The functions below are used by TensorRT to when loading a shared plugin library with automatic registering.
// see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#generating-plugin-library
TENSORRTAPI [[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder);
TENSORRTAPI [[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(int32_t& nbCreators);
[[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder);
[[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(int32_t& nbCreators);
}

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/iTensor.h"
#include <NvInferRuntime.h>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
@ -107,6 +108,9 @@ public:
return allocate(memoryType, ITensor::makeShape({}), type);
}
//! \brief Set the contents of the given `buffer` to value.
void setMem(IBuffer& buffer, int32_t value) const;
//! \brief Set the contents of the given `buffer` to zero.
void setZero(IBuffer& buffer) const;

View File

@ -76,6 +76,20 @@ public:
// parameters for beam search
TensorPtr cacheIndirection; // [maxBatchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
// Medusa
class MedusaInputs
{
public:
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
std::vector<std::vector<TensorPtr>>
medusaLogits; // [maxBatchSize][maxMedusaHeads][tokensPerStep, vocabSizePadded], on gpu
TensorPtr medusaCurTokensPerStep; // [maxBatchSize], on gpu
TensorPtr medusaTargetTokensPerStep; // [maxBatchSize], on gpu
};
std::optional<MedusaInputs> medusaInputs;
};
} // namespace tensorrt_llm::runtime

View File

@ -19,7 +19,7 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
#include <utility>
namespace tensorrt_llm::runtime
@ -84,6 +84,18 @@ public:
TensorPtr cacheIndirection; // [batchSize, beamWidth, maxSeqLen], k/v indirection for next generation step, on gpu
BeamHypotheses beamHypotheses;
// Medusa
class MedusaOutputs
{
public:
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep], on gpu
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize], on gpu
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1], on gpu
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads], on gpu
};
std::optional<MedusaOutputs> medusaOutputs;
};
} // namespace tensorrt_llm::runtime

View File

@ -26,6 +26,67 @@
namespace tensorrt_llm::runtime
{
//! @details
//! ***Mandatory inputs***
//!
//! * `endId`, is the token ID that marks the end of the input sequence (aka `EOS`
//! or end-of-sequence). It's `50,256` for the GPT2 model which has a vocabulary
//! of `50,257` tokens, for example,
//! * `padId`, is the token ID that is used for padding (i.e. fills in the slots
//! that are at an index greater-or-equal to the input length for padded
//! sequences). It can be set to the same value as `endId`,
//! * `ids`, is the tensor of input IDs. That tensor must be allocated on the GPU.
//! When the input tensor is padded, the shape of `ids` is `[batchSize,
//! maxInputLength]`, where `batchSize` and `maxInputLength` must respect the
//! maximum sizes in `sessionConfig` passed to the `GptSession` constructor.
//! When the input is packed, the shape of `ids` is `[numTokens]`, where
//! `numTokens` is the sum of the lengths of the different sequences in the batch,
//! * `lengths`, is the tensor of input sequence lengths. That tensor must be
//! allocated on the GPU and contain `batchSize` values,
//! * `packed`, indicates if the `ids` tensor is packed or padded. In this
//! release, that flag must match the value passed to the constructor through
//! the instance of the `ModelConfig` class. In a future release, the session
//! may be made more flexible and automatically pad or pack the input,
//!
//! ***Optional inputs***
//!
//! * `embeddingBiasOpt`, is a tensor of floating-point values on the GPU that
//! contains the bias to add to the logits during sampling (after the projection
//! from hidden states to logits as the last step of the model). This tensor
//! must have `vocabSize` elements (as defined in the `modelConfig` argument
//! passed to the constructor),
//! * `badWordsList`, is a tensor of integers on the GPU that encodes the list of
//! words that have to be banned from generated sequences. Its shape is `[2,
//! badWordsLength]`, as explained below, or `[batchSize, 2, badWordsLength]`
//! when there is a different list for each sequence in the batch,
//! * `stopWordsList`, is a tensor of integers on the GPU that encodes the list of
//! words that trigger the end of the generation for a sequence. Its shape is
//! `[2, stopWordsLength]`, as explained below, or `[batchSize, 2,
//! stopWordsLength]` when there is a different list for each sequence in the
//! batch,
//! * `maxNewTokens`, is the maximum number of tokens to generate.
//!
//! The `badWordsList` and `stopWordsList` tensors have the same shape `[2,
//! length]`. Let's consider an example with three words to describe the
//! representation of those lists. The first word contains tokens `[5, 7, 3]`, the
//! second one contains `[9, 2]` and the third one is composed of tokens `[6, 2, 4,
//! 1]`. In total, there are 9 tokens. That's the length. The shape of the tensor
//! is `[2, 9]`. The first row of the tensor must contain the 9 token IDs and the
//! second row must store the
//! [inclusive prefix-sum](https://en.wikipedia.org/wiki/Prefix_sum)
//! of the word lengths as shown on the following diagram:
//!
//! ```
//! 0 3 5 9
//! | | | |
//! V V V V
//! [ 5, 7, 3, 9, 2, 6, 2, 4, 1]
//! [ 3, 5, 9, -1, -1, -1, -1, -1, -1]
//! ```
//!
//! In case all the words are made of a single token, the inner-most dimension of
//! the tensor must be increased by 1 (i.e. the length for 4 words, each made of a
//! single token, must be 5 instead of 4 -- the shape is `[2, 5]`).
template <typename TTensor, typename PromptTuningParams>
class GenericGenerationInput
{

View File

@ -25,6 +25,51 @@
namespace tensorrt_llm::runtime
{
//! @details
//! ***Mandatory outputs***
//!
//! * `ids`, is a tensor that contains the output token IDs. Its shape is
//! `[batchSize, beamWidth, maxSeqLength]` where `maxSeqLength` is the sum of
//! `maxInputLength` and `maxNewTokens`. After generation, it contains, for each
//! sequence, a copy of the input tokens followed by the output tokens. When a
//! sequence is shorter than `maxSeqLength`, padding tokens are added at the end
//! of the sequence.
//!
//! _Note that the shape of that tensor is different in this version of
//! TensorRT-LLM from its shape in previous versions where it was `[maxSeqLength,
//! batchSize, beamWidth]`_.
//!
//! ***Optional outputs***
//!
//! * `logProbs`, is a tensor of floating-point values on the GPU to store the
//! log-prob of the generated tokens. Its shape is `[maxNewTokens, batchSize,
//! beamWidth]`. Its shape will likely change in a future release to match the
//! shape of the output `ids` tensor.
//! * `contextLogits`, is a tensor of values on the GPU (same datatype as the
//! computation type) to store the logits for the context. Its shape is
//! `[batchSize, maxSequenceLength, vocabSizePadded]`. If use `remove_input_padding`, its shape is `[packedSize,
//! vocabSizePadded]`. This buffer will only be filled in if the TensorRT engine was built with the
//! `gather_context_logits` or `gather_all_token_logits` parameter enabled.
//!
//! After inference is complete, you can get the context logits in `GenerationOutput.contextLogits`, these are
//! variables on the GPU. For specific acquisition methods, please refer to the example of
//! [gptSessionBenchmark.cpp](https://github.com/NVIDIA/TensorRT-LLM/blob/main/benchmarks/cpp/gptSessionBenchmark.cpp).
//!
//! It is important to point out
//! that enabling the computation may have an impact on performance (the language modeling head (LM head) has to
//! perform a matrix multiplication on all the context tokens instead of a just the last one).
//! * `generationLogits`, is a tensor of values on the GPU (same datatype as the
//! computation type) to store the logits for the generation. Its shape is
//! `[batchSize, beamWidth, maxOutputLen, vocabSizePadded]`. This buffer will only be
//! filled in if the TensorRT engine was built with the `gather_generation_logits` or
//! `gather_all_token_logits` parameter enabled.
//!
//! Generation logits can also be obtained through `GenerationOutput.generationLogits` after inference is completed.
//! * `onTokenGenerated`, is a callback function invoked in the generation loop to
//! pass newly generated tokens to the caller while the loop continues to
//! execute. An implementation of that callback must accept the output `ids`
//! tensor, the generation `step` and a boolean flag that indicates if the
//! generation is complete.
template <typename TTensor>
class GenericGenerationOutput
{

View File

@ -81,7 +81,8 @@ public:
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream);
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType> maxTokensPerStep = std::nullopt,
std::optional<runtime::SizeType> maxNumMedusaHeads = std::nullopt);
};
template <typename T>
@ -93,7 +94,9 @@ public:
using TensorPtr = std::shared_ptr<ITensor>;
GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream);
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
std::optional<runtime::SizeType> maxTokensPerStep = std::nullopt,
std::optional<runtime::SizeType> maxNumMedusaHeads = std::nullopt);
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
@ -119,20 +122,23 @@ private:
SamplingConfig mSamplingConfig;
cudaDeviceProp mProp; // Avoid dangling pointers in mDynamicDecodeLayer
size_t mMaxBatchSize;
};
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream)
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType> maxTokensPerStep,
std::optional<runtime::SizeType> maxNumMedusaHeads)
{
switch (dtype)
{
case nvinfer1::DataType::kFLOAT:
return std::make_unique<GptDecoder<float>>(
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
return std::make_unique<GptDecoder<float>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
case nvinfer1::DataType::kHALF:
return std::make_unique<GptDecoder<half>>(
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
default: return nullptr;
}
}

View File

@ -43,7 +43,7 @@ public:
//! Setup the decoder before calling `forward()`
void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, bool fusedDecoder,
nvinfer1::DataType dtype) override;
nvinfer1::DataType dtype, GptModelConfig const& modelConfig) override;
void newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
@ -156,13 +156,19 @@ public:
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
[[nodiscard]] TensorPtr getNextDraftTokens() const override
{
return mNextDraftTokens;
return mJointDecodingOutput->medusaOutputs->medusaNextDraftTokens;
}
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
[[nodiscard]] TensorPtr getNextDraftTokenLengths() const override
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
[[nodiscard]] TensorPtr getMedusaAcceptedLengthsCumSum() const override
{
return mNextDraftTokenLengths;
return mJointDecodingOutput->medusaOutputs->medusaAcceptedLengthsCumSum;
}
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
[[nodiscard]] TensorPtr getMedusaAcceptedPackedPaths() const override
{
return mJointDecodingOutput->medusaOutputs->medusaPathsOffsets;
}
private:
@ -172,6 +178,27 @@ private:
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
//! @brief Allocate buffers for medusa decoding.
void allocateMedusaBuffers();
//! @brief Setup buffers for medusa decoding.
void setupMedusa(GptModelConfig const& modelConfig);
//! @brief Setups decoder internal tensors for new speculative decoding request
void newRequestSpeculativeDecoding(
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
//! @brief Setups decoder internal tensors for new Medusa request
void newRequestMedusa(SizeType batchIdx, decoder_batch::Request const& request);
//! @brief Asynchronously calls unfused decoder for whole batch in loop
void forwardAsyncUnfusedDecoder(
SizeType step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
//! @brief Asynchronously calls fused decoder for whole batch
void forwardAsyncFusedDecoder(
SizeType step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
private:
std::size_t const mVocabSize;
std::size_t const mVocabSizePadded;
@ -200,7 +227,7 @@ private:
TensorPtr mFinishedSum;
std::vector<SizeType> mMaxNewTokens;
std::vector<SizeType> mBeamWidths;
std::vector<SizeType> mGeneratedTokensPerStep;
std::vector<SizeType> mGeneratedTokensPerEngineStep;
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
// for each generated token of maxTokensPerStep, on gpu
@ -216,16 +243,17 @@ private:
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
TensorPtr mNextDraftTokens;
TensorPtr mNextDraftTokenLengths;
SizeType mMaxSequenceLength{};
SizeType mMaxAttentionWindow{};
SizeType mSinkTokenLength{};
SizeType mActualBatchSize{};
SizeType mMaxTokensPerStep{};
SizeType mMaxTokensPerEngineStep{};
SizeType mMaxStopWordsLen{};
SizeType mMaxBadWordsLen{};
// How many tokens for one request can be processed per mDecoders call
SizeType mMaxTokensPerDecoderStep{};
bool mFusedDecoder{false};
bool mUseMedusa{false};
};
} // namespace tensorrt_llm::runtime

View File

@ -25,13 +25,21 @@
namespace tensorrt_llm::runtime
{
struct MambaConfig
{
SizeType dState = 0;
SizeType dConv = 0;
SizeType expand = 0;
};
class GptModelConfig
{
public:
enum class ModelVariant : std::int32_t
{
kGpt = 0,
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
kMamba = 2, // https://github.com/state-spaces/mamba
};
explicit GptModelConfig(
@ -44,8 +52,10 @@ public:
, mSizePerHead(mHiddenSize / mNbHeads)
, mDataType(dtype)
, mUseGptAttentionPlugin(false)
, mUseMambaConv1dPlugin(false)
, mInputPacked{false}
, mPagedKvCache{false}
, mPagedState{false}
, mTokensPerBlock{64}
, mQuantMode{common::QuantMode::none()}
, mMaxBatchSize(0)
@ -128,6 +138,16 @@ public:
mUseGptAttentionPlugin = useGptAttentionPlugin;
}
[[nodiscard]] bool constexpr useMambaConv1dPlugin() const noexcept
{
return mUseMambaConv1dPlugin;
}
void constexpr useMambaConv1dPlugin(bool useMambaConv1dPlugin) noexcept
{
mUseMambaConv1dPlugin = useMambaConv1dPlugin;
}
[[nodiscard]] bool constexpr usePackedInput() const noexcept
{
return mInputPacked;
@ -148,6 +168,16 @@ public:
mPagedKvCache = pagedKvCache;
}
[[nodiscard]] bool constexpr usePagedState() const noexcept
{
return mPagedState;
}
void constexpr usePagedState(bool pagedState) noexcept
{
mPagedState = pagedState;
}
[[nodiscard]] SizeType constexpr getTokensPerBlock() const noexcept
{
return mTokensPerBlock;
@ -170,7 +200,8 @@ public:
[[nodiscard]] bool constexpr supportsInflightBatching() const noexcept
{
return mUseGptAttentionPlugin && mInputPacked && mPagedKvCache;
return (isTransformerBased() && mUseGptAttentionPlugin && mInputPacked && mPagedKvCache)
|| (isSsmBased() && mUseMambaConv1dPlugin && mInputPacked && mPagedState);
}
[[nodiscard]] SizeType constexpr getMaxBatchSize() const noexcept
@ -368,6 +399,47 @@ public:
mMedusaModule = medusaModule;
}
[[nodiscard]] nvinfer1::DataType getKvDataType() const noexcept
{
if (getQuantMode().hasFp8KvCache())
{
return nvinfer1::DataType::kFP8;
}
else if (getQuantMode().hasInt8KvCache())
{
return nvinfer1::DataType::kINT8;
}
else
{
return getDataType();
}
}
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
{
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm;
}
[[nodiscard]] bool hasMambaConfig() const noexcept
{
return mMambaConfig.has_value();
}
[[nodiscard]] std::optional<MambaConfig> getMambaConfig() const noexcept
{
return mMambaConfig;
}
void setMambaConfig(MambaConfig const& mambaConfig) noexcept
{
mMambaConfig = mambaConfig;
}
[[nodiscard]] bool constexpr isSsmBased() const noexcept
{
return mModelVariant == ModelVariant::kMamba;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
@ -377,8 +449,10 @@ private:
SizeType mSizePerHead;
nvinfer1::DataType mDataType;
bool mUseGptAttentionPlugin;
bool mUseMambaConv1dPlugin;
bool mInputPacked;
bool mPagedKvCache;
bool mPagedState;
SizeType mTokensPerBlock;
common::QuantMode mQuantMode;
SizeType mMaxBatchSize;
@ -404,5 +478,7 @@ private:
SizeType mMaxLoraRank;
std::optional<MedusaModule> mMedusaModule;
std::optional<MambaConfig> mMambaConfig;
};
} // namespace tensorrt_llm::runtime

View File

@ -83,13 +83,23 @@ public:
{
}
// The maximum number of sequences in a batch
SizeType maxBatchSize;
// The maximum width of the beams in beam-search
SizeType maxBeamWidth;
// The length of the longest input sequence
SizeType maxSequenceLength;
// Whether the session will use a different decoder per request.
// It must be set to `true` when running in-flight batching
bool decoderPerRequest{false};
// Whether the session will use CUDA graphs for the engine execution in generation phase
bool cudaGraphMode{false};
KvCacheConfig kvCacheConfig{};
// The micro batch size to be used in context phase.
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size
std::optional<SizeType> ctxMicroBatchSize = std::nullopt;
// The micro batch size to be used in generation phase.
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size.
std::optional<SizeType> genMicroBatchSize = std::nullopt;
std::optional<DecodingMode> decodingMode = std::nullopt;
bool normalizeLogProbs = true;
@ -134,6 +144,12 @@ public:
CudaEvent end;
};
//! @param sessionConfig Configuration of the session,
//! @param modelConfig Description of the model,
//! @param worldConfig Description of the environment,
//! @param engineBuffer The compiled TensorRT engine (const void*),
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
//! @param logger The optional logger.
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
@ -176,6 +192,31 @@ public:
[[nodiscard]] nvinfer1::DataType getLogitDataType() const;
//! @brief This function performs the generation loop.
//! @details Given input tensors to read from, output tensors to populate, that member function
//! can be produced or each sequence has reached completion (due to the production
//! will run the generation loop until it reaches the maximum number of tokens that
//! of "end-of-sequence" or a word in the list of "stop words"). The pseudo-code of
//! that function looks like (member function names were changed to keep the
//! presentation simple):
//!
//! ```cpp
//! // Have all the sequences in the batch reached completion?
//! bool allFinished = false;
//!
//! // Until all sequences are finished or the number of steps reaches the limit...
//! for (int step = 0; !allFinished && step < maxNewTokens; ++step) {
//!
//! // Trigger the computation of the logits...
//! computeLogits(...);
//!
//! // Run the sampling to produce a token (for each active sequence) from the logits.
//! allFinished = generateTokensFromLogits(...);
//!
//! // Callback to stream the output tokens while the generation loop continues.
//! onTokenGenerated(...);
//! }
//! ```
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig,
std::shared_ptr<GenerationProfiler> const generationProfiler = nullptr);

View File

@ -46,7 +46,7 @@ public:
, endId{endId}
, computeCumLogProbs(false)
, computeLogProbs(false)
, generatedTokensPerStep(1)
, generatedTokensPerEngineStep(1)
{
}
@ -66,7 +66,9 @@ public:
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
SizeType generatedTokensPerStep;
SizeType generatedTokensPerEngineStep;
TensorPtr medusaPaths; // [tokensPerStep, medusaHeads + 1], on gpu
TensorPtr medusaTreeIds; // [tokensPerStep], on gpu
};
class Input
@ -109,6 +111,8 @@ public:
// parameters for beam search
TensorConstPtr cacheIndirection; // [batchSize, maxBeamWidth, maxSeqLen] - indices into KV cache of different rays
// within one beam for beam search, on gpu
std::vector<std::vector<TensorConstPtr>>
medusaLogits; // [maxBatchSize][maxNumHeads][tokensPerStep, vocabSizePadded]
};
using Output = decoder::Output;
@ -183,8 +187,11 @@ public:
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
virtual TensorPtr getNextDraftTokens() const = 0;
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
virtual TensorPtr getNextDraftTokenLengths() const = 0;
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
virtual TensorPtr getMedusaAcceptedLengthsCumSum() const = 0;
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
virtual TensorPtr getMedusaAcceptedPackedPaths() const = 0;
protected:
IGptDecoderBatch() = default;

View File

@ -75,7 +75,7 @@ public:
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype)
bool fusedDecoder, nvinfer1::DataType dtype, GptModelConfig const& modelConfig)
= 0;
//! @brief Initialize the decoder with new batch of inputs.

View File

@ -34,7 +34,6 @@ private:
template <typename T>
using OptVec = std::optional<std::vector<T>>;
private:
template <typename T>
static OptVec<T> fuseValues(
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(SizeType ci)> accessor)
@ -91,6 +90,8 @@ public:
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
draftAcceptanceThreshold
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
topKMedusaHeads = fuseValues<std::vector<runtime::SizeType>>(
configs, [&configs](SizeType ci) { return configs[ci].topKMedusaHeads; });
}
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
@ -152,7 +153,22 @@ public:
// speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
// medusa params
OptVec<std::vector<runtime::SizeType>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
std::optional<bool> normalizeLogProbs;
bool operator==(SamplingConfig const& other) const
{
return beamWidth == other.beamWidth && temperature == other.temperature && minLength == other.minLength
&& repetitionPenalty == other.repetitionPenalty && presencePenalty == other.presencePenalty
&& frequencyPenalty == other.frequencyPenalty && topK == other.topK && topP == other.topP
&& randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
&& normalizeLogProbs == other.normalizeLogProbs;
}
};
} // namespace tensorrt_llm::runtime

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fd8e608359009dffbcc5817cd96531254c3ad13df7030b3b7cdf2d609fea99e1
size 2408892
oid sha256:ba545e1931c9405b75028b019ac3949ec5cec57c304aaa10ea6c854f572225b1
size 2856456

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e59449c78d8682be1f0671fa6d8073c71eb37ae452417b70f70bb7db4a68f48b
size 2434826
oid sha256:8ef69cd446d54a1c876237f812839e6ecd9174c327edc5ff4f6594bb2b203aae
size 2885046

View File

@ -1,3 +1,3 @@
ae7c209c38b4c343b0fc49decff6fed5 libtensorrt_llm_batch_manager_static.a
f2fdaabe328c0eb1e46e8ded7bec4d87 libtensorrt_llm_batch_manager_static.pre_cxx11.a
d2cce02a8 commit
cd113ef1af7d78ac0791d4323a8ef370 libtensorrt_llm_batch_manager_static.a
c3583c5524c71f2cd5ae7a3bba864377 libtensorrt_llm_batch_manager_static.pre_cxx11.a
45a8cb4ea commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:88e519a38b4172b960083acf12db2ce17c880ce355cc1c9361f1ae85d839551d
size 2377646
oid sha256:32ca7c2a6701457ecb537a56d9558fb62d35ec5443905d63f1f1a288d8f48f87
size 2780748

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54199fac4bbe94dc314bed8c889753cbb00d2bad1e672384a350dc2b97e4a0b1
size 2343620
oid sha256:c8fdf3d223bb7e0a5eeffbea0a82e50a8e0ec3815b274cdc95d6fb1c36f2178d
size 2755044

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:57a1c54097341e561ae44f5ae69fa6a7e33061e2d0451d2f42a37f22993a22bb
size 818584
oid sha256:36f02388a9bd2ae3d45f0d6480bd95cd99f8ea30eebf1a315b8d54e742fed479
size 846308

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3d443d55b92501991a6102c523d46ddfdf620fa5ab37abcee3e2d6ee4c4d9e90
size 833262
oid sha256:c707f67abccca217d81e8d85e361b6d214131045763df5f806cb789157ea4f80
size 857730

View File

@ -1,3 +1,3 @@
b92b19f8d7eff851dadb8a8e3010a565 libtensorrt_llm_executor_static.a
a546902e11b24c1b890fd913c3e844c5 libtensorrt_llm_executor_static.pre_cxx11.a
d2cce02a8 commit
a552c727c128f6ca402ddc119d295ab0 libtensorrt_llm_executor_static.a
38ad482b0be0996970bc572622967acf libtensorrt_llm_executor_static.pre_cxx11.a
45a8cb4ea commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9233382570d3c9c5417ed1f279c234d323b4dd465bbdca86612e137fabfb9962
size 866182
oid sha256:cc2d59c4878e74f7e38a65187ed303a77b43a3b71753b3e4dcc99a937ccbcdf8
size 884870

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:03ee314aa8ca65abf013c6e5106b701defb5c1435d5fe8879829952c1d2cab1f
size 812078
oid sha256:dc7e967c9aa7ef50227a791c670fe71a9bdef907ce45d3282955ebd5e2ead88f
size 837988

View File

@ -303,6 +303,25 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_80_sm70_c
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_128_sm70_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin[];
// FP8
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin[];
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin[];
extern uint32_t cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len;
@ -582,6 +601,25 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_80_sm70_cu_cub
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_128_sm70_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len;
// FP8
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len;
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len;
static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{
@ -1511,7 +1549,62 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
{ DATA_TYPE_FP16, 0, 160, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_kernel", 98304, 128, 0, false, true, false, 1, false, false },
{ DATA_TYPE_FP16, 0, 160, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_kernel_nl", 98304, 128, 64, false, true, false, 1, false, false },
{ DATA_TYPE_FP16, 0, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_kernel", 98304, 128, 0, false, true, false, 1, false, false },
{ DATA_TYPE_FP16, 0, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_kernel_nl", 98304, 128, 64, false, true, false, 1, false, false }
{ DATA_TYPE_FP16, 0, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_kernel_nl", 98304, 128, 64, false, true, false, 1, false, false },
// FP8
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, false, false},
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, false, false},
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, false, false},
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, true, false},
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, true, false},
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, true, false},
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, true, false}
};
// clang-format on

View File

@ -86,7 +86,9 @@ public:
{
TLLM_CHECK_WITH_INFO(
(sm == kSM_70 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89 || sm == kSM_90), "Unsupported architecture");
TLLM_CHECK_WITH_INFO((mDataType == DATA_TYPE_FP16 || mDataType == DATA_TYPE_BF16), "Unsupported data type");
TLLM_CHECK_WITH_INFO(
(mDataType == DATA_TYPE_FP16 || mDataType == DATA_TYPE_BF16 || mDataType == DATA_TYPE_E4M3),
"Unsupported data type");
pagedKVXmmaKernel = getPagedKVXMMAKernelsV2(mDataType, sm);
xmmaKernel = getXMMAKernelsV2(mDataType, sm);
@ -117,7 +119,7 @@ public:
float const scale_softmax = 1.f; // Seems to be only required for int8
float const scale_bmm2 = 1.f;
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
Data_type scale_type = mLaunchParams.force_fp32_acc || mDataType == DATA_TYPE_E4M3 ? DATA_TYPE_FP32 : mDataType;
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
if (mLaunchParams.useBase2ExpTrick)
{
@ -130,6 +132,7 @@ public:
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
}
set_alpha(params.scale_softmax, scale_softmax, scale_type);
// Host scale_bmm2 will not be used.
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);
params.b = b;
@ -138,7 +141,7 @@ public:
params.d = mHeadSize;
params.sliding_window_size = sliding_window_size;
params.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
params.o_stride_in_bytes = get_size_in_bytes(mNumHeads * mHeadSize, mDataType);
// Total sequence length needed by TMA descriptor
// it should be actual total seq length if non-padded input is given.
@ -153,8 +156,8 @@ public:
}
// Support packed QKV.
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
bool const scale_alibi, int const tp_size, int const tp_rank)
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
float const* scale_bmm2_d, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
{
// Determine launch parameters.
@ -169,7 +172,14 @@ public:
bool const isSm90 = (sm == kSM_90);
bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
bool const isSm80 = (sm == kSM_80);
if (isSm70)
// Only warp-specialized FMHA kernels support FP8 on Hopper.
if (isSm90 && mDataType == DATA_TYPE_E4M3)
{
mLaunchParams.flash_attention = true;
mLaunchParams.force_unroll = true;
}
else if (isSm70)
{
mLaunchParams.flash_attention = true;
mLaunchParams.force_unroll = true; // need more profile
@ -234,7 +244,10 @@ public:
// Set kernel parameters.
setup_params(mParams, b, s, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
mParams.qkv_stride_in_bytes = (mNumHeads + 2 * mParams.h_kv) * mHeadSize * sizeof(half);
// TODO: move this to setup_params when fp8 paged kv fmha is supported.
// TRT doesn't support host scales. Use device scales instead.
mParams.scale_bmm2_d = reinterpret_cast<uint32_t const*>(scale_bmm2_d);
mParams.qkv_stride_in_bytes = get_size_in_bytes((mNumHeads + 2 * mParams.h_kv) * mHeadSize, mDataType);
}
// Support paged_kv_cache and chunked_attention.
@ -309,6 +322,7 @@ public:
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
}
// TODO: add paged kv FP8 FMHA.
setup_params(
mPagedKVParams, b, s_q, s_kv, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
mPagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
@ -320,7 +334,7 @@ public:
void set_tma_descriptors()
{
// split D into multiple groups in order to match the TMA swizzle mode (128B)
const uint32_t d_in_bytes = mLaunchParams.padded_d * sizeof(uint16_t);
const uint32_t d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mDataType);
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
// separate q, k, and v tma descriptors
@ -351,9 +365,9 @@ public:
// stride size in bytes. Assumes least significant dim is 1 (?)
uint64_t tensor_stride_qkv[3];
tensor_stride_qkv[0] = tensor_size_qkv[0] * sizeof(uint16_t); // d
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
tensor_stride_qkv[0] = get_size_in_bytes(tensor_size_qkv[0], mDataType); // d
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
// traversal stride
uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1};
@ -365,7 +379,7 @@ public:
uint32_t fp32_to_tf32 = 0;
// gmma descriptor mode
const uint32_t d_bytes_per_group = (mLaunchParams.padded_d * sizeof(uint16_t)) / d_groups;
const uint32_t d_bytes_per_group = d_in_bytes / d_groups;
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
? cudaTmaDescSwizzle::SWIZZLE_128B
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
@ -388,21 +402,21 @@ public:
// Q: STEP_Q
box_size[3] = q_step;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
&mParams.tma_desc_q);
// Desc Format (data type).
const cudaTmaDescFormat desc_format
= (get_size_in_bytes(1, mDataType) == 1) ? cudaTmaDescFormat::U8 : cudaTmaDescFormat::F16_RN;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mParams.tma_desc_q);
// K/V: STEP_KV
box_size[3] = kv_step;
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
&mParams.tma_desc_k);
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
&mParams.tma_desc_v);
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mParams.tma_desc_k);
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mParams.tma_desc_v);
}
// Q are contiguous in the shape of [B, S, H, D]
@ -520,8 +534,9 @@ public:
void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
{
// BF16 FMHA only accumulates on FP32
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
// BF16 FMHA only accumulates on FP32.
// E4M3 FMHA only supports fp32 accumulation currently.
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || mDataType == DATA_TYPE_E4M3 || force_fp32_acc;
mLaunchParams.attention_mask_type
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
@ -646,9 +661,9 @@ FusedMHARunnerV2::FusedMHARunnerV2(
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
void FusedMHARunnerV2::setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
float const* scale_bmm2_d, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
{
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
pimpl->setup(b, s, sliding_window_size, total_seqlen, scale_bmm2_d, has_alibi, scale_alibi, tp_size, tp_rank);
}
void FusedMHARunnerV2::setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,

View File

@ -48,7 +48,8 @@ public:
virtual ~MHARunner() = default;
virtual void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
float const* scale_bmm2_d = nullptr, bool const has_alibi = false, bool const scale_alibi = false,
int const tp_size = 1, int const tp_rank = 0)
= 0;
virtual void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
@ -91,8 +92,8 @@ public:
~FusedMHARunnerV2(); // for pimpl
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
int const tp_rank = 0) override;
float const* scale_bmm2_d = nullptr, bool const has_alibi = false, bool const scale_alibi = false,
int const tp_size = 1, int const tp_rank = 0) override;
void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,

View File

@ -122,8 +122,9 @@ struct Fused_multihead_attention_params_v2
int *counters, *max_barriers, *sum_barriers, *locks;
// Scratch buffers to finalize softmax.
float *max_scratch_ptr, *sum_scratch_ptr;
// Scratch buffer to finalize the output (not needed for FP16).
int* o_scratch_ptr;
// Scale bmm2 in the device memory.
uint32_t const* scale_bmm2_d;
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
int h_kv;
@ -199,6 +200,10 @@ struct Fused_multihead_attention_paged_kv_params_v2
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// TODO: add fp8 paged kv fmha later.
// Scale bmm2 in the device memory.
// uint32_t const* scale_bmm2_d;
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
// See https://confluence.nvidia.com/pages/viewpage.action?pageId=302779721 for details.
bool enable_i2f_trick;

View File

@ -17,17 +17,19 @@
#include "customAllReduceKernels.h"
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/dataType.h"
#include <tuple>
#include <type_traits>
namespace tensorrt_llm::kernels
{
using tensorrt_llm::common::datatype_enum;
using tensorrt_llm::common::divUp;
using tensorrt_llm::common::roundUp;
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
@ -39,13 +41,15 @@ static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_add
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr)
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr)
{
uint32_t flag;
#if __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#endif
return flag;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -107,72 +111,155 @@ inline __device__ int4 add128b(T& a, T& b)
return c.packed;
}
__inline__ __device__ void multi_gpu_barrier(
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx)
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx)
{
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
uint32_t volatile* my_signals = signals[rank];
// After this function, at least one block in each GPU has reached the barrier
if (tidx < world_size)
{
// The 1st block notifies the other ranks.
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
if (bidx == 0)
{
signals[tidx][rank] = flag;
signals[tidx][local_rank] = flag;
}
// Busy-wait until all ranks are ready.
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t volatile* my_signals = signals[local_rank];
while (my_signals[tidx] != flag)
{
}
}
// Make sure we can move on...
__syncthreads();
}
__global__ void multiGpuBarrierKernel(AllReduceParams params)
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
size_t const world_size, int const tidx, int const bidx)
{
multi_gpu_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, params.ranks_per_node,
threadIdx.x, blockIdx.x);
// After this function, the block of id == bidx of each GPU has reached the barrier
if (tidx < world_size)
{
// we can think of signals having the shape [world_size, num_blocks, world_size]
// (+ an offset on dim 1 to account for flags used in multi_gpu_barrier)
// Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
uint32_t flag_block_offset = world_size + bidx * world_size;
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
// Blocks check that corresponding blocks on other GPUs have also set the flag
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
while (ld_flag_acquire(peer_barrier_d) != flag)
{
}
}
__syncthreads();
}
template <typename T, int RANKS_PER_NODE>
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true, bool PUSH_MODE = false>
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
{
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
int const bidx = blockIdx.x;
int const tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedStruct = typename PackedOn16Bytes<T>::Type;
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
// The source pointers. Distributed round-robin for the different warps.
T const* src_d[RANKS_PER_NODE];
// Start and end offsets of the thread
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_total);
T* buffers[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
// buffers[0] is always the local buffers. Helps load balancing reads.
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
// The end of the segment computed by that block.
size_t max_offset = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
if constexpr (PUSH_MODE || COPY_INPUT)
{
// Copy from local buffer to shareable buffer
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS)
{
if constexpr (PUSH_MODE)
{
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
*reinterpret_cast<int4*>(&buffers[ii][params.local_rank * params.elts_total + iter_offset])
= *reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
}
}
else
{
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset])
= *reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
}
}
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
}
else
{
// In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed
multi_gpu_barrier(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS)
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS)
{
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][iter_offset]);
if constexpr (PUSH_MODE)
{
vals[ii].packed
= *reinterpret_cast<int4 const*>(&buffers[params.local_rank][ii * params.elts_total + iter_offset]);
}
else
{
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
}
}
// Sum the values from the different ranks.
@ -185,55 +272,130 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
}
// Store to the destination buffer.
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
*reinterpret_cast<int4*>(&local_output_buffer[iter_offset]) = sums.packed;
}
}
template <typename T, int RANKS_PER_NODE>
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true, bool PUSH_MODE = false>
static __global__ void twoShotAllReduceKernel(AllReduceParams params)
{
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
// GPU 0 | B0 | B1 | B0 | B1 |
// GPU 1 | B0 | B1 | B0 | B1 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
// part (the first half of the message, see GPU responsibility row above)
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
// where GPU 1 is responsible: the second half of the message.
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
// to be read.
//
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
// However, it's only responsible for the summation of a single chunk.
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
// 2. block sync so the blocks have been shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
// 4. block barrier (corresponding blocks have finished reduction)
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
// written at index 0 of 2nd dim)
// The block index.
int const bidx = blockIdx.x;
// The thread index with the block.
int const tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T);
// Packed data type for comms
static constexpr int PACKED_ELTS = 16 / sizeof(T);
using PackedType = typename PackedOn16Bytes<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
const size_t block_start = params.rank_offset + block_offset;
// The end of the segment computed by that block.
size_t max_offset = min(block_start + params.elts_per_block, params.rank_offset + params.elts_per_rank);
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
// The source pointers. Distributed round-robin for the different warps.
T* src_d[RANKS_PER_NODE];
// The destination ranks for round-robin gathering
size_t dst_rank[RANKS_PER_NODE];
T* buffers[RANKS_PER_NODE];
int ranks[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
// A mapping of the ranks to scatter reads as much as possible
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
dst_rank[ii] = rank;
ranks[ii] = rank;
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
}
if constexpr (PUSH_MODE || COPY_INPUT)
{
// Copy all blocks from local buffer to shareable buffer
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS)
{
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
size_t offset_rank = ii * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total)
{
continue;
}
if constexpr (PUSH_MODE)
{
*reinterpret_cast<int4*>(&buffers[ii][params.local_rank * params.elts_per_rank + local_offset])
= *reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
}
else
{
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank])
= *reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
}
}
}
block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
}
else
{
// In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed
multi_gpu_barrier(
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = block_start; local_offset < max_offset; local_offset += blockDim.x * NUM_ELTS)
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS)
{
size_t const responsible_block_offset = local_offset + params.rank_offset;
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][local_offset]);
if constexpr (PUSH_MODE)
{
vals[ii].packed
= *reinterpret_cast<int4 const*>(&local_shared_buffer[ii * params.elts_per_rank + local_offset]);
}
else
{
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][responsible_block_offset]);
}
}
// Sum the values from the different ranks.
@ -246,86 +408,77 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
}
// Store to the local buffer.
*reinterpret_cast<int4*>(&src_d[0][local_offset]) = sums.packed;
}
// sync threads to make sure all block threads have the sums
__syncthreads();
// barriers among the blocks with the same idx (release-acquire semantics)
if (tidx < RANKS_PER_NODE)
{
// The all blocks notifies the other ranks.
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
st_flag_release(params.barrier_flag, params.peer_barrier_ptrs_in[tidx] + flag_block_offset + params.local_rank);
// Busy-wait until all ranks are ready.
uint32_t rank_barrier = 0;
uint32_t* peer_barrier_d = params.peer_barrier_ptrs_in[params.local_rank] + flag_block_offset + tidx;
do
if constexpr (PUSH_MODE)
{
ld_flag_acquire(rank_barrier, peer_barrier_d);
} while (rank_barrier != params.barrier_flag);
*reinterpret_cast<int4*>(&local_shared_buffer[local_offset]) = sums.packed;
}
else
{
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
}
}
// sync threads to make sure all other ranks has the final partial results
__syncthreads();
block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
size_t max_block_offset = min(block_offset + params.elts_per_block, params.elts_per_rank);
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = block_offset; local_offset < max_block_offset; local_offset += blockDim.x * NUM_ELTS)
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS)
{
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{
// use round-robin gathering from other ranks
size_t offset_rank = dst_rank[ii] * params.elts_per_rank + local_offset;
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
if (offset_rank >= params.elts_total)
{
continue;
}
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])
= *reinterpret_cast<int4*>(&src_d[ii][offset_rank]);
if constexpr (PUSH_MODE)
{
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank])
= *reinterpret_cast<int4*>(&buffers[ii][local_offset]);
}
else
{
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank])
= *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type)
{
size_t elts_per_thread = 16 / common::getDTypeSize(type);
int const msg_align = (algo == AllReduceStrategyType::TWOSHOT) ? n_ranks * elts_per_thread : elts_per_thread;
bool supported_algo = (algo == AllReduceStrategyType::ONESHOT || algo == AllReduceStrategyType::TWOSHOT);
return supported_algo && (msg_size % msg_align == 0);
}
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& param, size_t elts_per_thread)
{
TLLM_CHECK(param.elts_total % elts_per_thread == 0);
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
const size_t total_threads = param.elts_total / elts_per_thread;
switch (algo)
{
case AllReduceStrategyType::ONESHOT:
{ // one stage all reduce algo
if (total_threads <= DEFAULT_BLOCK_SIZE)
{ // local reduce
threads_per_block = WARP_SIZE * divUp(total_threads, WARP_SIZE);
blocks_per_grid = 1;
}
else
{ // local reduce
threads_per_block = DEFAULT_BLOCK_SIZE;
blocks_per_grid = divUp(total_threads, DEFAULT_BLOCK_SIZE);
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), blocks_per_grid);
}
param.elts_per_rank = param.elts_total;
param.elts_per_block = elts_per_thread * divUp(param.elts_per_rank, elts_per_thread * blocks_per_grid);
{
TLLM_CHECK(param.elts_total % elts_per_thread == 0);
size_t const total_threads = roundUp(param.elts_total / elts_per_thread, WARP_SIZE);
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
param.elts_per_block = roundUp(divUp(param.elts_total, blocks_per_grid), elts_per_thread);
break;
}
case AllReduceStrategyType::TWOSHOT:
{ // two stage all reduce algo
const size_t elts_per_rank = param.elts_total / param.ranks_per_node;
TLLM_CHECK(elts_per_rank % elts_per_thread == 0);
{
TLLM_CHECK(param.elts_total % (elts_per_thread * param.ranks_per_node) == 0);
size_t const total_threads = roundUp(param.elts_total / (elts_per_thread * param.ranks_per_node), WARP_SIZE);
size_t total_threads = elts_per_rank / elts_per_thread;
total_threads = WARP_SIZE * ((total_threads + WARP_SIZE - 1) / WARP_SIZE);
TLLM_CHECK(total_threads % WARP_SIZE == 0);
/*
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
{
@ -345,9 +498,8 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
blocks_per_grid /= iter_factor;
}
param.elts_per_rank = param.elts_total / param.ranks_per_node;
param.elts_per_block = param.elts_per_rank / blocks_per_grid;
param.elts_per_block = elts_per_thread * divUp(param.elts_per_block, elts_per_thread);
param.rank_offset = param.rank * param.elts_per_rank;
param.rank_offset = param.local_rank * param.elts_per_rank;
param.elts_per_block = roundUp(divUp(param.elts_per_rank, blocks_per_grid), elts_per_thread);
break;
}
default: TLLM_THROW("Algorithm not supported here.");
@ -356,44 +508,74 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
return std::make_tuple(blocks_per_grid, threads_per_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, int RANKS_PER_NODE>
void dispatchARKernels(
AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, cudaStream_t stream)
template <typename T, int RANKS_PER_NODE, bool PUSH_MODE = false, bool USE_MEMCPY = false>
void AllReduceDispatchMemcpy(
AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(!(USE_MEMCPY && PUSH_MODE), "Memcpy cannot be used with PUSH_MODE.");
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(algo, param, elts_per_thread);
if (USE_MEMCPY)
{
cudaMemcpyAsync(param.peer_comm_buffer_ptrs[param.local_rank], param.local_input_buffer_ptr,
param.elts_total * sizeof(T), cudaMemcpyDeviceToDevice, stream);
}
if (algo == AllReduceStrategyType::ONESHOT)
{
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
oneShotAllReduceKernel<T, RANKS_PER_NODE, !USE_MEMCPY, PUSH_MODE>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}
else
{
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
twoShotAllReduceKernel<T, RANKS_PER_NODE, !USE_MEMCPY, PUSH_MODE>
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}
}
template <typename T, int RANKS_PER_NODE, bool PUSH_MODE = false>
void AllReduceDispatchPushMode(
AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream)
{
if (static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(config)
& static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(AllReduceStrategyConfig::USE_MEMCPY))
{
AllReduceDispatchMemcpy<T, RANKS_PER_NODE, PUSH_MODE, true>(algo, config, param, stream);
}
else
{
AllReduceDispatchMemcpy<T, RANKS_PER_NODE, PUSH_MODE, false>(algo, config, param, stream);
}
}
template <typename T, int RANKS_PER_NODE> //, bool USE_MEMCPY = false, bool PUSH_MODE = false>
void AllReduceDispatchRanksPerNode(
AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream)
{
if (static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(config)
& static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(AllReduceStrategyConfig::PULL_MODE))
{
AllReduceDispatchPushMode<T, RANKS_PER_NODE, true>(algo, config, param, stream);
}
else
{
AllReduceDispatchPushMode<T, RANKS_PER_NODE, false>(algo, config, param, stream);
}
}
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream)
void AllReduceDispatchType(
AllReduceParams& param, AllReduceStrategyType strat, AllReduceStrategyConfig config, cudaStream_t stream)
{
TLLM_CHECK(strat == AllReduceStrategyType::ONESHOT || strat == AllReduceStrategyType::TWOSHOT);
sync_check_cuda_error();
size_t elts_per_thread = 16 / sizeof(T);
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
switch (param.ranks_per_node)
{
case 2: dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream); break;
case 4: dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream); break;
case 6: dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream); break;
case 8: dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream); break;
default: break;
case 2: AllReduceDispatchRanksPerNode<T, 2>(strat, config, param, stream); break;
case 4: AllReduceDispatchRanksPerNode<T, 4>(strat, config, param, stream); break;
case 6: AllReduceDispatchRanksPerNode<T, 6>(strat, config, param, stream); break;
case 8: AllReduceDispatchRanksPerNode<T, 8>(strat, config, param, stream); break;
default: TLLM_THROW("Custom all reduce only supported on {2, 4, 6, 8} GPUs per node.");
}
sync_check_cuda_error();
}
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream)
{
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
}
AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
@ -425,30 +607,25 @@ AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSiz
return params;
}
void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts, size_t size_per_elem,
datatype_enum dataType, AllReduceStrategyType strat, cudaStream_t stream)
void customAllReduce(kernels::AllReduceParams& params, nvinfer1::DataType dataType, AllReduceStrategyType strat,
AllReduceStrategyConfig config, cudaStream_t stream)
{
params.local_output_buffer_ptr = data;
params.elts_total = elts;
TLLM_CHECK_WITH_INFO(configurationSupported(strat, params.elts_total, params.ranks_per_node, dataType),
"Custom all-reduce configuration unsupported");
if (dataType == datatype_enum::TYPE_FP32)
sync_check_cuda_error();
switch (dataType)
{
kernels::invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
}
else if (dataType == datatype_enum::TYPE_FP16)
{
kernels::invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
}
case nvinfer1::DataType::kFLOAT: AllReduceDispatchType<float>(params, strat, config, stream); break;
case nvinfer1::DataType::kHALF: AllReduceDispatchType<half>(params, strat, config, stream); break;
#ifdef ENABLE_BF16
else if (dataType == datatype_enum::TYPE_BF16)
{
kernels::invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
}
case nvinfer1::DataType::kBF16: AllReduceDispatchType<__nv_bfloat16>(params, strat, config, stream); break;
#endif
else
{
TLLM_THROW("Unsupported dataType for customAllReduce");
default: TLLM_THROW("Unsupported dataType for customAllReduce");
}
sync_check_cuda_error();
}
} // namespace tensorrt_llm::kernels

View File

@ -16,12 +16,10 @@
#pragma once
#include <assert.h>
#include <NvInferRuntime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <iostream>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/tensor.h"
@ -38,12 +36,18 @@ constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
// they must be kept in sync
enum class AllReduceStrategyType : int8_t
{
RING = 0,
NCCL = 0,
ONESHOT = 1,
TWOSHOT = 2,
AUTO = 3,
};
enum class AllReduceStrategyConfig : int8_t
{
USE_MEMCPY = 1 << 0,
PULL_MODE = 1 << 1,
};
struct AllReduceParams
{
size_t elts_total;
@ -56,16 +60,14 @@ struct AllReduceParams
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void* local_output_buffer_ptr;
void const* local_input_buffer_ptr;
static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
};
template <typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream);
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type);
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream);
void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts, size_t size_per_elem,
common::datatype_enum dataType, AllReduceStrategyType strat, cudaStream_t stream);
void customAllReduce(kernels::AllReduceParams& params, nvinfer1::DataType dataType, AllReduceStrategyType strat,
AllReduceStrategyConfig config, cudaStream_t stream);
} // namespace tensorrt_llm::kernels

View File

@ -148,6 +148,7 @@ struct Multihead_attention_params_base
float const* qkv_scale_quant_orig = nullptr;
float const* attention_out_scale_orig_quant = nullptr;
// 8 bits kv cache scales.
float const* kv_scale_orig_quant = nullptr;
float const* kv_scale_quant_orig = nullptr;

View File

@ -1638,13 +1638,13 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
{
if (HANDLE_KV)
{
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
params.rotary_embedding_scale, current_pos_idx);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base,
rotary_embedding_scale, current_pos_idx);
}
else
{
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
params.rotary_embedding_scale, current_pos_idx);
apply_rotary_embedding(
q, tidx, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
}
break;
}
@ -2589,6 +2589,9 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
__syncthreads();
}
// Quantized output only supports fp8 currently, which should be used together with FP8 Context FMHA.
using Quantized_t = __nv_fp8_e4m3;
using Quantized_vec = typename packed_type<__nv_fp8_e4m3, num_elems<V_vec_accum>::value>::type;
auto const bhi = tensorrt_llm::common::flat_index2(batch_beam_idx, hi, num_heads);
auto const bhi_seq_len_tile = bhi * params.seq_len_tile;
// Output the final values.
@ -2596,35 +2599,36 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
{
auto const bhvi = tensorrt_llm::common::flat_index2(bhi, vi, Dh);
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
if (write_attention_quant)
if (!MULTI_BLOCK_FLAG)
{
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_accum>::value>::type;
out = mul<V_vec_accum, float>(*params.attention_out_scale_orig_quant, out);
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhvi])) = cast_to_int8(out);
}
else
{
if (!MULTI_BLOCK_FLAG)
if (write_attention_quant)
{
out = mul<V_vec_accum, float>(*params.attention_out_scale_orig_quant, out);
Quantized_vec final_out;
convert_to_fp8(&final_out, out);
*reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhvi) = final_out;
}
else
{
// This makes sure we have coalesced memory access.
V_vec_k final_out;
convert_from_float(&final_out, out);
*reinterpret_cast<V_vec_k*>(&params.out[bhvi]) = final_out;
}
else
{
// for write partial output to partial_out
int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
// for write partial statistics to partial_max and partial_sum
int partial_stats_offset = bhi_seq_len_tile + c_tile;
}
else
{
// for write partial output to partial_out
int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
// for write partial statistics to partial_max and partial_sum
int partial_stats_offset = bhi_seq_len_tile + c_tile;
// This makes sure we have coalesced memory access.
V_vec_k partial_out;
convert_from_float(&partial_out, out);
*reinterpret_cast<V_vec_k*>(&params.partial_out[partial_out_offset + bhvi]) = partial_out;
convert_from_float(reinterpret_cast<float*>(&params.partial_max[partial_stats_offset]), qk_max);
convert_from_float(reinterpret_cast<float*>(&params.partial_sum[partial_stats_offset]), sum);
}
// This makes sure we have coalesced memory access.
V_vec_k partial_out;
convert_from_float(&partial_out, out);
*reinterpret_cast<V_vec_k*>(&params.partial_out[partial_out_offset + bhvi]) = partial_out;
convert_from_float(reinterpret_cast<float*>(&params.partial_max[partial_stats_offset]), qk_max);
convert_from_float(reinterpret_cast<float*>(&params.partial_sum[partial_stats_offset]), sum);
}
#else // MMHA_USE_FP32_ACCUM_FOR_OUT
*reinterpret_cast<V_vec_accum*>(&params.out[bhvi]) = out;
@ -2768,13 +2772,25 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
if (oo == 0 && (Dh == Dh_MAX || oi < Dh))
{
auto const inv_sum = __fdividef(1.f, final_sum + 1.e-6f);
auto const inv_sum = __fdividef(
write_attention_quant ? *params.attention_out_scale_orig_quant : 1.f, final_sum + 1.e-6f);
Tk inv_sum_compute;
convert_from_float(&inv_sum_compute, inv_sum);
thread_accumulated_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_accumulated_out);
*reinterpret_cast<V_vec_k*>(&params.out[bhi * Dh + oi]) = thread_accumulated_out;
if (write_attention_quant)
{
Quantized_vec final_out;
convert_to_fp8(&final_out, thread_accumulated_out);
*reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhi * Dh + oi)
= final_out;
}
else
{
*reinterpret_cast<V_vec_k*>(&params.out[bhi * Dh + oi]) = thread_accumulated_out;
}
}
// Reset qk_current_smem and block_counter for the next timestep

View File

@ -279,7 +279,7 @@ public:
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
void const* xqa_q_input_ptr = xqaParams.output;
invokeApplyBiasRopeUpdateKVCache<T, KVCacheBuffer, true>(static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
static_cast<T*>(const_cast<void*>(xqaParams.output)), kv_cache_buffer,
(__nv_fp8_e4m3*) nullptr, static_cast<T*>(const_cast<void*>(xqaParams.output)), kv_cache_buffer,
static_cast<T const*>(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr, nullptr,
xqaParams.batch_size, xqaParams.generation_input_length, xqaParams.cyclic_attention_window_size,
xqaParams.sink_token_length, xqaParams.batch_size * beam_width * xqaParams.generation_input_length,
@ -287,7 +287,7 @@ public:
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
xqaParams.medusa_position_offsets, xqaParams.position_shift_enabled, (float*) nullptr, 0, cache_type,
xqaParams.kv_scale_orig_quant, true, beam_width, rotary_kernel_launch_cache, stream);
xqaParams.kv_scale_orig_quant, true, false, beam_width, rotary_kernel_launch_cache, stream);
sync_check_cuda_error();

View File

@ -1838,6 +1838,16 @@ inline __device__ Float8_ mul(float a, uint4 b)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ uint2 mul(float a, uint2 b)
{
uint16_t h = float_to_half(a);
uint2 c = mul<uint2, uint16_t, uint2>(h, b);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ uint4 mul(float a, uint4 b)
{
@ -1877,6 +1887,14 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ __nv_bfloat162 mul(float a, __nv_bfloat162 b)
{
return mul<__nv_bfloat162>(__float2bfloat162_rn(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
@ -1908,6 +1926,14 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ bf16_4_t mul(float a, bf16_4_t b)
{
return mul<bf16_4_t>(__float2bfloat16(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
{

View File

@ -734,16 +734,18 @@ __device__ __forceinline__ int4 reduceMaxInt4(int4 const& a, int4 const& b)
}
template <typename T, SizeType BLOCK_SIZE>
__global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds,
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
SizeType const* paths, TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize,
SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads,
SizeType maxTokensPerStep)
__global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal,
SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, T const** medusaLogits,
T const** logitsPtrs, SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds,
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
SizeType maxNumHeads, SizeType maxTokensPerStep)
{
auto const batchIdx = static_cast<SizeType>(blockIdx.x);
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const inputLength = sequenceLengths[batchSlot];
auto const endId = endIds[batchSlot];
auto const numTokensPerStep = curTokensPerStep[batchSlot];
auto const maxNumDraftTokens = maxNumHeads + 1;
int4 partialMax{-1, -1, 0, 0};
@ -761,7 +763,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
{
continue;
}
auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId;
auto const targetTokenIdx = batchSlot * maxDraftTokens + tokenId;
auto targetToken = targetIds[targetTokenIdx];
auto nextIdx = tokenId;
@ -775,17 +777,16 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
acceptedLength = ti;
break;
}
auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId;
auto const draftTokenIdx = batchSlot * maxDraftSeqLen + inputLength + tokenId;
auto const draftToken = outputIds[draftTokenIdx];
auto const targetTokenIdx = batchSlot * maxDraftTokens + tokenId;
auto const draftTokenIdx = batchSlot * (maxDraftTokens - 1) + tokenId - 1;
// In context phase, no draft tokens are given. Set draft token to -1 to get guaranteed rejection
auto const draftToken = tokenId >= numTokensPerStep ? -1 : draftIds[draftTokenIdx];
// Check if draft tokens are the same as target tokens
bool const accepted = draftToken == targetToken;
hasEnd = targetToken == endId;
if (!accepted || hasEnd)
{
acceptedLength = hasEnd ? ti - 1 : ti;
nextIdx = tokenId;
break;
}
targetToken = targetIds[targetTokenIdx];
@ -816,16 +817,16 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto const acceptedLength = totalShared.x;
auto const bestPathIdx = totalShared.y;
auto const bestNextIdx = totalShared.w;
auto const bestNextIdx = numTokensPerStep == 1 ? 0 : totalShared.w;
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
for (auto ti = static_cast<SizeType>(threadIdx.x); ti < acceptedLength; ti += static_cast<SizeType>(blockDim.x))
{
auto const tokenId = paths[pathOffset + ti];
auto const targetSrcTokenIdx = batchSlot * maxTargetSeqLen + tokenId;
auto const draftDstTokenIdx = batchSlot * maxDraftSeqLen + inputLength + ti;
auto const targetSrcTokenIdx = batchSlot * maxDraftTokens + tokenId;
auto const outputTokenIdx = batchSlot * maxSeqLen + inputLength + ti;
auto const targetToken = targetIds[targetSrcTokenIdx];
// Copy accepted tokens to the sequence with draft tokens (outputIds === outputIds)
outputIds[draftDstTokenIdx] = targetToken;
outputIds[outputTokenIdx] = targetToken;
}
// Leading thread reconstructs winning path and sets new data
@ -840,41 +841,135 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
// Make correction to the sequence length
sequenceLengths[batchSlot] += acceptedLength;
acceptedLengths[batchSlot] = acceptedLength;
// In Medusa decoding step, number of draft tokens is 0 and must be updated for the next steps
if (numTokensPerStep == 1)
{
curTokensPerStep[batchSlot] = targetTokensPerStep[batchSlot];
}
bestPathIds[batchSlot] = bestPathIdx;
}
// Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel
for (auto hi = static_cast<SizeType>(threadIdx.x); hi < maxNumHeads; hi += static_cast<SizeType>(blockDim.x))
{
logitsPtrs[batchIdx * maxNumHeads + hi]
= medusaLogits + flat_index4(hi, batchIdx, bestNextIdx, 0, maxBatchSize, maxTokensPerStep, vocabSize);
= medusaLogits[batchSlot * maxNumHeads + hi] + flat_index2(bestNextIdx, 0, vocabSize);
}
}
template <typename T>
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, SizeType* sequenceLengths,
SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths,
TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize,
SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads,
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
SizeType const* paths, TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds, SizeType batchSize,
SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen, SizeType maxNumHeads,
SizeType maxTokensPerStep, cudaStream_t stream)
{
constexpr SizeType BLOCK_SIZE = 256;
dim3 block(BLOCK_SIZE);
dim3 grid(batchSize);
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, targetIds, sequenceLengths,
acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs, batchSize, vocabSize,
maxBatchSize, maxDraftSeqLen, maxTargetSeqLen, maxNumHeads, maxTokensPerStep);
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, draftIds, targetIds,
sequenceLengths, acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs,
curTokensPerStep, targetTokensPerStep, bestPathIds, batchSize, vocabSize, maxBatchSize, maxDraftTokens,
maxSeqLen, maxNumHeads, maxTokensPerStep);
}
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds,
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
SizeType const* paths, TokenIdType const* endIds, float const* medusaLogits, float const** logitsPtrs,
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen,
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal,
SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, float const** medusaLogits,
float const** logitsPtrs, SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds,
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream);
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds,
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
SizeType const* paths, TokenIdType const* endIds, half const* medusaLogits, half const** logitsPtrs,
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, int32_t maxDraftSeqLen, SizeType maxTargetSeqLen,
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal,
SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, half const** medusaLogits,
half const** logitsPtrs, SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds,
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream);
__global__ void scatterMedusaDraftTokens(TokenIdType* treeDraftIds, TokenIdType const* sourceDraftIds,
SizeType const* treeIds, SizeType const* tokensPerStepData, SizeType const* batchSlots, SizeType maxTokensPerStep)
{
auto const batchIdx = static_cast<SizeType>(blockIdx.x);
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const tokensPerStep = tokensPerStepData[batchSlot];
auto const maxDraftTokens = maxTokensPerStep - 1;
for (auto index = static_cast<SizeType>(threadIdx.x); index < tokensPerStep - 1;
index += static_cast<SizeType>(blockDim.x))
{
auto const indexInTree = treeIds[batchSlot * maxDraftTokens + index];
auto const treeDraftIdx = batchSlot * maxDraftTokens + index;
auto const sourceDraftIdx = batchSlot * maxTokensPerStep + indexInTree;
treeDraftIds[treeDraftIdx] = sourceDraftIds[sourceDraftIdx];
}
}
void scatterMedusaDraftTokens(TokenIdType* treeDraftIds, TokenIdType const* sourceDraftIds, SizeType const* treeIds,
SizeType const* tokensPerStep, SizeType const* batchSlots, SizeType maxDraftTokens, SizeType batchSize,
cudaStream_t stream)
{
constexpr SizeType BLOCK_SIZE = 256;
scatterMedusaDraftTokens<<<batchSize, BLOCK_SIZE, 0, stream>>>(
treeDraftIds, sourceDraftIds, treeIds, tokensPerStep, batchSlots, maxDraftTokens);
}
template <int32_t BLOCK_SIZE>
__global__ void packAcceptedPaths(SizeType* acceptedLengthsCumSum, SizeType* pathsOffsets,
SizeType const* acceptedLengths, SizeType const* bestPathIds, SizeType const* paths, SizeType const* batchSlots,
SizeType batchSize, SizeType maxTokensPerStep, SizeType maxNumDraftTokens)
{
// Specialize BlockScan for a 1D block of 128 threads of type int
typedef cub::BlockScan<SizeType, BLOCK_SIZE> BlockScan;
// Allocate shared memory for BlockScan
__shared__ typename BlockScan::TempStorage tempStorage;
auto const batchSizeRounded = ((batchSize + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
__shared__ SizeType currentCumSum;
if (threadIdx.x == 0)
{
currentCumSum = 0;
}
__syncthreads();
for (SizeType bi = static_cast<SizeType>(threadIdx.x); bi < batchSizeRounded;
bi += static_cast<SizeType>(blockDim.x))
{
auto const valid = bi < batchSize;
auto const batchSlot = valid ? batchSlots[bi] : 0;
auto const acceptedLen = valid ? acceptedLengths[batchSlot] - 1 : 0;
SizeType cumSum;
BlockScan(tempStorage).ExclusiveSum(acceptedLen + currentCumSum, cumSum);
if (threadIdx.x == blockDim.x - 1)
{
currentCumSum = cumSum;
}
__syncthreads();
if (valid)
{
acceptedLengthsCumSum[bi] = cumSum;
auto const bestPathIdx = bestPathIds[batchSlot];
auto const pathIdx = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
for (SizeType ti = 0; ti < acceptedLen; ++ti)
{
pathsOffsets[cumSum + ti] = paths[pathIdx + ti + 1] - 1;
}
}
}
if (threadIdx.x == 0)
{
acceptedLengthsCumSum[batchSize] = currentCumSum;
}
}
void invokePackAcceptedPaths(SizeType* acceptedLengthsCumSum, SizeType* pathsOffsets, SizeType const* acceptedLengths,
SizeType const* bestPathIds, SizeType const* paths, SizeType const* batchSlots, SizeType batchSize,
SizeType maxTokensPerStep, SizeType maxNumDraftTokens, cudaStream_t stream)
{
constexpr SizeType BLOCK_SIZE = 1024;
packAcceptedPaths<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(acceptedLengthsCumSum, pathsOffsets, acceptedLengths,
bestPathIds, paths, batchSlots, batchSize, maxTokensPerStep, maxNumDraftTokens);
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -161,11 +161,9 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
//! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according
//! to the next after the last accepted token.
//!
//! \param outputIds input/output buffer [maxBatchSize, maxDraftSeqLen],
//! input tokens followed by draft tokens to be verified.
//! After accepting tokens, gets overwritten such that input tokens are followed by the accepted tokens
//! and one additional token -- next after the last accepted.
//! \param targetIds input buffer [maxBatchSize, maxTargetSeqLen], tokens predicted from the target medusa head
//! \param outputIds output buffer [maxBatchSize, maxSeqLen], input tokens.
//! \param draftIds input buffer [maxBatchSize, maxDraftTokens], draft tokens
//! \param targetIds input buffer [maxBatchSize, maxDraftTokens], tokens predicted from the target medusa head
//! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens
//! Incrememnted according to the accepted length
//! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens
@ -179,20 +177,65 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
//! to the logits from medusa heads
//! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the
//! respective rows of the medusaLogits for the next after the accepted token
//! \param curTokensPerStep current tokens to compute per step will be updated to
//! targetTokensPerStep if curTokensPerStep == 1
//! \param targetTokensPerStep target values of tokens to compute per step
//! \param bestPathIds output buffer [maxBatchSize], indices of the selected paths
//! \param batchSize current batch size
//! \param maxBatchSize maximum batch size
//! \param vocabSize vocab size
//! \param maxDraftSeqLen maximum sequence length of the sequence containing draft tokens
//! \param maxTargetSeqLen maximum sequence length predicted from target head
//! \param maxDraftTokens maximum sequence length of the sequence containing draft tokens
//! \param maxSeqLen maximum sequence length of output ids
//! \param maxNumHeads maximum number of medusa heads
//! \param maxTokensPerStep maximum number of tokens per step configured in the system
//! \param stream stream
template <typename T>
void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* targetIds,
runtime::SizeType* sequenceLengths, runtime::SizeType* acceptedLengths, FinishedState* finishedFinal,
runtime::SizeType const* batchSlots, runtime::SizeType const* paths, runtime::TokenIdType const* endIds,
T const* medusaLogits, T const** logitsPtrs, runtime::SizeType batchSize, runtime::SizeType maxBatchSize,
runtime::SizeType vocabSize, runtime::SizeType maxDraftSeqLen, runtime::SizeType maxTargetSeqLen,
runtime::SizeType maxNumHeads, runtime::SizeType maxTokensPerStep, cudaStream_t stream);
void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* draftIds,
runtime::TokenIdType const* targetIds, runtime::SizeType* sequenceLengths, runtime::SizeType* acceptedLengths,
FinishedState* finishedFinal, runtime::SizeType const* batchSlots, runtime::SizeType const* paths,
runtime::TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
runtime::SizeType* curTokensPerStep, runtime::SizeType const* targetTokensPerStep, runtime::SizeType* bestPathIds,
runtime::SizeType batchSize, runtime::SizeType maxBatchSize, runtime::SizeType vocabSize,
runtime::SizeType maxDraftTokens, runtime::SizeType maxSeqLen, runtime::SizeType maxNumHeads,
runtime::SizeType maxTokensPerStep, cudaStream_t stream);
//! \brief assembles draft tokens to treeDraftIds from sourceDraftIds using indices of treeIds
//!
//! \param treeDraftIds output buffer [maxBatchSize, maxDraftTokens], output draft tokens
//! scattered from sourceDraftIds according to treeIds111
//! \param sourceDraftIds input buffer [maxBatchSize, maxDraftTokens], draft tokens saved leanearly after
//! sampling from Medusa heads with TopK.
//! \param treeIds input buffer [maxBatchSize, maxDraftTokens], address map from sourceDraftIds to treeDraftIds
//! [0, unqiueDraftTokens] -> [0, maxDraftTokens], where unqiueDraftTokens = sum(MedusaHeadsTopK)
//! unqiueDraftTokens <= maxDraftTokens
//! \param tokensPerStep input buffer [maxBatchSize], number of output draft tokens
//! \param batchSlots input buffer [maxBatchSize], address map from local index
//! to global index [0, batchSize] -> [0, maxBatchSize]
//! \param maxDraftTokens maximum number of tokens per step configured in the system
//! \param batchSize current batch size
//! \param stream cuda stream
void scatterMedusaDraftTokens(runtime::TokenIdType* treeDraftIds, runtime::TokenIdType const* sourceDraftIds,
runtime::SizeType const* treeIds, runtime::SizeType const* tokensPerStep, runtime::SizeType const* batchSlots,
runtime::SizeType maxDraftTokens, runtime::SizeType batchSize, cudaStream_t stream);
//! \brief Linearly packs accepted paths in memory according to the accceptedLengths and bestPathIds
//!
//! \param acceptedLengthsCumSum input buffer [maxBatchSize + 1], exclusive sum of accepted lengths
//! (indexed linearly in memory).
//! \param pathsOffsets input buffer [maxBatchSize * maxDraftLen], slices of accepted paths packed in memory
//! \param acceptedLengths input buffer [maxBatchSize], length of the data accepted tokens
//! \param bestPathIds input buffer [maxBatchSize], indices of the selected paths
//! \param paths input buffer [maxBatchSize, maxTokensPerStep, maxNumHeads+1],
//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path.
//! \param batchSlots input buffer [batchSize], address map from local index
//! to global index [0, batchSize] -> [0, maxBatchSize]
//! \param batchSize current batch size
//! \param maxTokensPerStep maximum number of tokens per step configured in the system
//! \param maxDraftTokens maximum sequence length of the sequence containing draft tokens
//! \param stream stream
void invokePackAcceptedPaths(runtime::SizeType* acceptedLengthsCumSum, runtime::SizeType* pathsOffsets,
runtime::SizeType const* acceptedLengths, runtime::SizeType const* bestPathIds, runtime::SizeType const* paths,
runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType maxTokensPerStep,
runtime::SizeType maxNumDraftTokens, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,64 @@
/*
* Adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
* Copyright (c) 2023, Tri Dao.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Not a contribution
* Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as
* NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm
{
namespace kernels
{
struct MambaConv1dParamsBase
{
int batch, dim, max_seqlen, dconv;
bool remove_padding;
void* __restrict__ in_ptr;
void* state_in_ptr;
void* state_out_ptr;
void* __restrict__ weight_ptr;
void* __restrict__ bias_ptr;
void* __restrict__ out_ptr;
int const* __restrict__ last_token_ids_ptr;
int const* __restrict__ state_slot_mapping_ptr;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename input_t>
void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream);
template <typename input_t>
void invokeMambaConv1dGeneration(MambaConv1dParamsBase& params, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -16,6 +16,7 @@
#pragma once
#include "tensorrt_llm/common/assert.h"
#include <limits.h>
#include <stdint.h>
@ -36,6 +37,25 @@ enum Data_type
DATA_TYPE_E5M2
};
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline size_t get_size_in_bytes(size_t n, Data_type dtype)
{
switch (dtype)
{
case DATA_TYPE_FP32: return n * 4;
case DATA_TYPE_FP16: return n * 2;
case DATA_TYPE_INT32: return n * 4;
case DATA_TYPE_INT8: return n;
case DATA_TYPE_BF16: return n * 2;
case DATA_TYPE_E4M3: return n;
case DATA_TYPE_E5M2: return n;
default: TLLM_CHECK_WITH_INFO(false, "FMHA Data Type is not supported."); return 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr int32_t kSM_70 = 70;
constexpr int32_t kSM_72 = 72;
constexpr int32_t kSM_75 = 75;

View File

@ -41,6 +41,8 @@ static int const SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
#define TOPK_FP16_STORAGE 0
#pragma nv_diag_suppress static_var_with_dynamic_init
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(
int const* __restrict topk_id, T const* __restrict topk_val, BeamHypotheses bh, int const candidate_size)

View File

@ -31,7 +31,7 @@ template <typename KVCacheBuffer, int MaxLayerCount, typename MoveEltType>
__global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheBuffer, MaxLayerCount> kvCacheBuffers,
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
int32_t const* pastKeyValueLengths, int rewindDraftTokenCommonCount, int const* rewindDraftTokenSeparateAdjustments,
int eltCountPerHead)
int const* seqSlotRemapping, int eltCountPerHead)
{
int seqIdx = blockIdx.x;
int headIdx = blockIdx.y;
@ -41,16 +41,17 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
int laneIdx = threadIdx.x & 0x1f;
int seqDraftTokenStart = seqAcceptedDraftTokenOffsets[seqIdx];
int seqDraftTokenEnd = seqAcceptedDraftTokenOffsets[seqIdx + 1];
auto const seqSlot = seqSlotRemapping == nullptr ? seqIdx : seqSlotRemapping[seqIdx];
int seqDraftCount = seqDraftTokenEnd - seqDraftTokenStart;
if (seqDraftCount == 0)
{
return;
}
KVCacheBuffer& kvCacheBuffer = kvCacheBuffers[layerIdx];
int tokenStartIdx = pastKeyValueLengths[seqIdx] - rewindDraftTokenCommonCount;
int tokenStartIdx = pastKeyValueLengths[seqSlot] - rewindDraftTokenCommonCount;
if (rewindDraftTokenSeparateAdjustments != nullptr)
{
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqIdx];
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqSlot];
}
int maxEltCountPerMove = kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount;
int eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
@ -65,7 +66,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqIdx, tokenKVPosition));
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
@ -80,7 +81,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
int tokenPos = tokenIdx;
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqIdx, tokenKVPosition));
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
@ -95,7 +96,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqIdx, tokenKVPosition));
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
@ -110,7 +111,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
int tokenPos = tokenIdx;
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenPos * eltCountCurrentMove;
int tokenKVPosition = tokenStartIdx + tokenPos;
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqIdx, tokenKVPosition));
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
{
int channelIdx = loadChannelIdx + startChannelOffset;
@ -126,7 +127,8 @@ template <typename KVCacheBuffer, int MaxLayerCount>
void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
int32_t const* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, cudaStream_t stream)
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
cudaStream_t stream)
{
// make sure launch buffer is enough
static_assert(MaxLayerCount * sizeof(KVCacheBuffer) <= 3072);
@ -148,8 +150,8 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
{
kvCacheBufferArray[i] = kvCacheBuffers[i];
}
void (*pKernelFunc)(
std::array<KVCacheBuffer, MaxLayerCount>, int const*, IndexType const*, int32_t const*, int, int const*, int)
void (*pKernelFunc)(std::array<KVCacheBuffer, MaxLayerCount>, int const*, IndexType const*, int32_t const*, int,
int const*, int const*, int)
= nullptr;
switch (alignedBytes)
{
@ -182,7 +184,7 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
}
pKernelFunc<<<grid, block, 0, stream>>>(kvCacheBufferArray, seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCommonCount,
rewindDraftTokenSeparateAdjustments, eltCountPerHead);
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, eltCountPerHead);
TLLM_CUDA_CHECK(cudaGetLastError());
}
@ -207,7 +209,7 @@ template <typename KVCacheBuffer>
void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int layerCount, int seqCount,
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
cudaStream_t stream)
int const* seqSlotRemapping, cudaStream_t stream)
{
int startLayer = 0;
static constexpr int kMaxLayersPerIter = 32;
@ -217,7 +219,7 @@ void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int co
updateKVCacheDraftTokenLocationBatched<KVCacheBuffer, kMaxLayersPerIter>(kvCacheBuffers + startLayer,
seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices, pastKeyValueLengths, microBatchLayerCount,
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount,
rewindDraftTokenSeparateAdjustments, stream);
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
startLayer += microBatchLayerCount;
}
}
@ -225,7 +227,8 @@ void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int co
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream)
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
int maxKVCacheLen, cudaStream_t stream)
{
std::vector<KVLinearBuffer> kvLinearBuffers;
kvLinearBuffers.reserve(layerCount);
@ -237,14 +240,14 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
}
updateKVCacheDraftTokenLocation(kvLinearBuffers.data(), seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
}
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
cudaStream_t stream)
int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
int tokensPerBlock, cudaStream_t stream)
{
std::vector<KVBlockArray> kvBlockArrays;
kvBlockArrays.reserve(layerCount);
@ -256,47 +259,47 @@ void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffset
}
updateKVCacheDraftTokenLocation(kvBlockArrays.data(), seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
}
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream)
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
{
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, nullptr, maxKVCacheLen, stream);
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, stream);
}
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
{
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
}
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream)
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
{
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
rewindDraftTokenCounts, maxKVCacheLen, stream);
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, stream);
}
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
{
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
rewindDraftTokenCounts, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
}
} // namespace tensorrt_llm::kernels::parallel_decoding

View File

@ -39,13 +39,17 @@ using IndexType = int;
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCount : Count to rewind
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream);
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using common rewind count.
@ -59,6 +63,10 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDra
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCount : Count to rewind
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
@ -67,7 +75,7 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDra
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
/*!
* Update Linear KV cache using separate rewind count for each sequence.
@ -82,13 +90,17 @@ void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraf
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
* one sequence.
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream);
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using separate rewind count for each sequence.
@ -103,6 +115,10 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedD
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
* one sequence.
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
@ -111,7 +127,7 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedD
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
/*!
* Update Linear KV cache using both common rewind and separate rewind count for each sequence. The common
@ -129,13 +145,18 @@ void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDr
* @param rewindDraftTokenCommonCount : Common token count to rewind
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
* rewind adjustment for one sequence.
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream);
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
int maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using both common rewind and separate rewind count for each sequence. The common
@ -153,6 +174,10 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
* @param rewindDraftTokenCommonCount : Common token count to rewind
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
* rewind adjustment for one sequence.
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
@ -161,7 +186,7 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
cudaStream_t stream);
int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
int tokensPerBlock, cudaStream_t stream);
} // namespace tensorrt_llm::kernels::parallel_decoding

View File

@ -150,7 +150,6 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm
{
return;
}
if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot])
{
return;

View File

@ -65,6 +65,8 @@ __device__ void convertAndStore(__nv_bfloat16* output, float input)
}
#endif
#pragma nv_diag_suppress static_var_with_dynamic_init
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, int STAGES = 12,
int SEQ_UNROLL = 6>
__launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBase params)
@ -74,8 +76,8 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
input_t* x = reinterpret_cast<input_t*>(params.u_ptr);
input_t* dt = reinterpret_cast<input_t*>(params.delta_ptr);
weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr);
input_t* B = reinterpret_cast<input_t*>(params.B_ptr);
input_t* C = reinterpret_cast<input_t*>(params.C_ptr);
input_t* B = reinterpret_cast<input_t*>(params.BC_ptr);
input_t* C = reinterpret_cast<input_t*>(params.BC_ptr);
weight_t* D = reinterpret_cast<weight_t*>(params.D_ptr);
input_t* z = reinterpret_cast<input_t*>(params.z_ptr);
weight_t* dt_bias = reinterpret_cast<weight_t*>(params.delta_bias_ptr);
@ -101,11 +103,24 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
int const channel = blockIdx.x * blockDim.x + threadIdx.x;
int const sample = blockIdx.y; // batch id
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
int const bc_cols = DSTATE * 2 + params.dt_rank;
int const b_offset = params.dt_rank;
int const c_offset = params.dt_rank + DSTATE;
int num_tokens;
int start_token_idx;
start_token_idx = sample * params.seqlen;
num_tokens = params.last_token_ids_ptr[sample];
if (params.remove_padding)
{
start_token_idx = sample == 0 ? 0 : params.last_token_ids_ptr[sample - 1];
int end_token_idx = params.last_token_ids_ptr[sample];
num_tokens = end_token_idx - start_token_idx;
}
else
{
start_token_idx = sample * params.max_seqlen;
num_tokens = params.last_token_ids_ptr[sample];
}
int const seq_loops = (num_tokens + SEQ_UNROLL - 1) / SEQ_UNROLL;
int const input_matrix_row_id = start_token_idx;
@ -132,8 +147,8 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
for (int token_id = si * SEQ_UNROLL; token_id < num_tokens && token_id < (si + 1) * SEQ_UNROLL; token_id++)
{
input_t* my_B = &B[input_matrix_row_id * DSTATE + token_id * DSTATE];
input_t* my_C = &C[input_matrix_row_id * DSTATE + token_id * DSTATE];
input_t* my_B = &B[(input_matrix_row_id + token_id) * bc_cols + b_offset];
input_t* my_C = &C[(input_matrix_row_id + token_id) * bc_cols + c_offset];
int block_channel_per_token = blockIdx.x * blockDim.x;
int block_channel
@ -304,7 +319,7 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
// Write the new state back out to the cache
for (int i = 0; i < DSTATE; i++)
{
input_t* my_state = &state[sample * num_channels * DSTATE];
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
int offset = i * num_channels + channel;
convertAndStore(&my_state[offset], state_reg[i]);
}
@ -351,8 +366,8 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
input_t* x = reinterpret_cast<input_t*>(params.u_ptr);
input_t* dt = reinterpret_cast<input_t*>(params.delta_ptr);
weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr);
input_t* B = reinterpret_cast<input_t*>(params.B_ptr);
input_t* C = reinterpret_cast<input_t*>(params.C_ptr);
input_t* B = reinterpret_cast<input_t*>(params.BC_ptr);
input_t* C = reinterpret_cast<input_t*>(params.BC_ptr);
weight_t* D = reinterpret_cast<weight_t*>(params.D_ptr);
input_t* z = reinterpret_cast<input_t*>(params.z_ptr);
weight_t* dt_bias = reinterpret_cast<weight_t*>(params.delta_bias_ptr);
@ -363,8 +378,12 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
if (channel >= num_channels)
return;
int const sample = blockIdx.y;
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
int const bc_cols = DSTATE * 2 + params.dt_rank;
int const b_offset = params.dt_rank;
int const c_offset = params.dt_rank + DSTATE;
input_t* my_state = &state[sample * num_channels * DSTATE];
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
input_t* my_output = &output[sample * num_channels];
float rA[DSTATE];
@ -377,8 +396,8 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
for (int i = 0; i < DSTATE; i++)
{
rA[i] = toFloat(A[i * num_channels + channel]);
rB[i] = toFloat(B[sample * DSTATE + i]);
rC[i] = toFloat(C[sample * DSTATE + i]);
rB[i] = toFloat(B[sample * bc_cols + b_offset + i]);
rC[i] = toFloat(C[sample * bc_cols + c_offset + i]);
rState[i] = toFloat(my_state[i * num_channels + channel]);
}

View File

@ -40,9 +40,9 @@ namespace kernels
struct SSMParamsBase
{
using index_t = uint32_t;
int batch, dim, seqlen, dstate;
int batch, dim, dstate, dt_rank;
int max_seqlen; // only valid for padded input.
bool remove_padding;
bool is_variable_B;
bool is_variable_C;
@ -50,8 +50,7 @@ struct SSMParamsBase
// Common data pointers.
void* __restrict__ A_ptr;
void* __restrict__ B_ptr;
void* __restrict__ C_ptr;
void* __restrict__ BC_ptr;
void* __restrict__ D_ptr;
void* __restrict__ u_ptr;
void* __restrict__ delta_ptr;
@ -60,6 +59,7 @@ struct SSMParamsBase
void* __restrict__ x_ptr;
void* __restrict__ z_ptr;
int const* __restrict__ last_token_ids_ptr;
int const* __restrict__ slot_mapping_ptr;
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -107,7 +107,7 @@ void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentI
}
__global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth)
int32_t* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth)
{
int32_t threadFinishedCount = 0;
auto const batchIdx = blockIdx.x;
@ -122,6 +122,7 @@ __global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, u
if (sequenceLengths[batchSlotBeamWidthIdx] >= sequenceLimitLength[batchSlot])
{
finishState.setFinishedMaxLength();
sequenceLengths[batchSlotBeamWidthIdx] = sequenceLimitLength[batchSlot];
}
threadFinishedCount += finishState.isFinished() ? 1 : 0;
finished[batchSlotBeamWidthIdx] = finishState;
@ -148,8 +149,7 @@ __global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, u
}
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
cudaStream_t stream)
int32_t* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, cudaStream_t stream)
{
// Check if we have attained the sequence length limit. If so, stop the
// sequence. In addition, check if all sequences are stopped and return the

View File

@ -54,14 +54,13 @@ void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentI
//! \param finishedSum output buffer [1].
//! Total sum of finished requests
//! \param sequenceLimitLength input buffer [maxBatchSize]. Maximum sequence length.
//! \param sequenceLengths input buffer [maxBatchSize, beamWidth].
//! \param sequenceLengths input/output buffer [maxBatchSize, beamWidth].
//! Current sequence lengths of the request tokens.
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param batchSize batch size
//! \param beamWidth beam width
//! \param stream stream
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
cudaStream_t stream);
int32_t* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -108,15 +108,15 @@ void invokeTranspose4dBatchMajor(T const* k_src, T const* v_src, KVCacheBuffer&
// NOTE: this kernel is in-place, QKV will be modified, if other kernels need that, may need copy or use before it.
template <typename T, typename KVCacheBuffer, bool IsGenerate = false>
void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
void invokeApplyBiasRopeUpdateKVCache(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, int const int8_mode,
const KvCacheDataType cache_type, float const* kvScaleOrigQuant, bool const enable_paged_kv_fmha,
int const beam_width, int2& grid_block_cache, cudaStream_t stream);
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream);
template <typename T, typename BT>
void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, int const batch_size,

View File

@ -46,7 +46,9 @@ template <typename T, int Dh_MAX>
struct Rotary_vec_t
{
using Type = T;
using Packed_type = T;
// Quantized output type only supports fp8 currently.
using Packed_type = __nv_fp8_e4m3;
using Quantized_type = void;
static constexpr int size = 1;
};
@ -56,6 +58,7 @@ template <>
struct Rotary_vec_t<float, 32>
{
using Type = float;
using Quantized_type = __nv_fp8_e4m3;
using Packed_type = float;
static constexpr int size = 1;
};
@ -64,6 +67,7 @@ template <>
struct Rotary_vec_t<float, 64>
{
using Type = float2;
using Quantized_type = mmha::fp8_2_t;
using Packed_type = float2;
static constexpr int size = 2;
};
@ -72,6 +76,7 @@ template <>
struct Rotary_vec_t<float, 128>
{
using Type = float4;
using Quantized_type = mmha::fp8_4_t;
using Packed_type = float2;
static constexpr int size = 4;
};
@ -80,6 +85,7 @@ template <>
struct Rotary_vec_t<float, 256>
{
using Type = mmha::Float8_;
using Quantized_type = mmha::fp8_8_t;
using Packed_type = float2;
static constexpr int size = 8;
};
@ -90,6 +96,7 @@ template <>
struct Rotary_vec_t<half, 32>
{
using Type = uint16_t;
using Quantized_type = __nv_fp8_e4m3;
using Packed_type = uint16_t;
static constexpr int size = 1;
};
@ -98,6 +105,7 @@ template <>
struct Rotary_vec_t<half, 64>
{
using Type = uint32_t;
using Quantized_type = mmha::fp8_2_t;
using Packed_type = uint32_t;
static constexpr int size = 2;
};
@ -106,6 +114,7 @@ template <>
struct Rotary_vec_t<half, 128>
{
using Type = uint2;
using Quantized_type = mmha::fp8_4_t;
using Packed_type = uint32_t;
static constexpr int size = 4;
};
@ -114,6 +123,7 @@ template <>
struct Rotary_vec_t<half, 256>
{
using Type = uint4;
using Quantized_type = mmha::fp8_8_t;
using Packed_type = uint32_t;
static constexpr int size = 8;
};
@ -126,6 +136,7 @@ template <>
struct Rotary_vec_t<__nv_bfloat16, 32>
{
using Type = __nv_bfloat16;
using Quantized_type = __nv_fp8_e4m3;
using Packed_type = __nv_bfloat16;
static constexpr int size = 1;
};
@ -134,6 +145,7 @@ template <>
struct Rotary_vec_t<__nv_bfloat16, 64>
{
using Type = __nv_bfloat162;
using Quantized_type = mmha::fp8_2_t;
using Packed_type = __nv_bfloat162;
static constexpr int size = 2;
};
@ -142,6 +154,7 @@ template <>
struct Rotary_vec_t<__nv_bfloat16, 128>
{
using Type = mmha::bf16_4_t;
using Quantized_type = mmha::fp8_4_t;
using Packed_type = __nv_bfloat162;
static constexpr int size = 4;
};
@ -150,6 +163,7 @@ template <>
struct Rotary_vec_t<__nv_bfloat16, 256>
{
using Type = mmha::bf16_8_t;
using Quantized_type = mmha::fp8_8_t;
using Packed_type = __nv_bfloat162;
static constexpr int size = 8;
};
@ -158,13 +172,14 @@ struct Rotary_vec_t<__nv_bfloat16, 256>
template <typename T, typename T_cache, int Dh_MAX, bool ADD_BIAS, bool STORE_QKV, bool POS_SHIFT,
typename KVCacheBuffer, bool IS_GENERATE>
__global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBuffer, T const* __restrict qkv_bias,
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, float const* kvScaleOrigQuant,
int const num_tokens, int const batch_size, int const seq_len, int const cyclic_kv_cache_len,
int const sink_token_len, int const head_num, int const kv_head_num, int const qheads_per_kv_head,
int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base,
__global__ void applyBiasRopeUpdateKVCache(T* QKV, void* O, T* Q, KVCacheBuffer kvCacheBuffer,
T const* __restrict qkv_bias, int const* seq_lens, int const* kv_seq_lens, int const* padding_offset,
float const* kvScaleOrigQuant, int const num_tokens, int const batch_size, int const seq_len,
int const cyclic_kv_cache_len, int const sink_token_len, int const head_num, int const kv_head_num,
int const qheads_per_kv_head, int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base,
RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions,
PositionEmbeddingType const position_embedding_type, int const* medusa_position_offsets, int const beam_width)
PositionEmbeddingType const position_embedding_type, int const* medusa_position_offsets,
bool const quantized_fp8_output, int const beam_width)
{
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head]
// Extract the Q input when using paged KV FMHA.
@ -190,6 +205,9 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
// VEC_SIZE is power of 2.
constexpr int VEC_SIZE = Rotary_vec_t<T, Dh_MAX>::size;
using Vec_type = typename Rotary_vec_t<T, Dh_MAX>::Type;
// Quantized output only supports fp8 currently.
using Quantized_elt_type = __nv_fp8_e4m3;
using Quantized_type = typename Rotary_vec_t<T, Dh_MAX>::Quantized_type;
using Packed_type = typename Rotary_vec_t<T, Dh_MAX>::Packed_type;
bool const has_padding = padding_offset == nullptr;
@ -348,7 +366,16 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
if constexpr (STORE_QKV)
{
*reinterpret_cast<Vec_type*>(&QKV[src_q_idx]) = q;
if (quantized_fp8_output)
{
// use 1.0f scale currently for qkv input of FP8 FMHA.
mmha::convert_to_fp8(
reinterpret_cast<Quantized_type*>(reinterpret_cast<Quantized_elt_type*>(O) + src_q_idx), q);
}
else
{
*reinterpret_cast<Vec_type*>(&QKV[src_q_idx]) = q;
}
}
else
{
@ -358,10 +385,21 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
{
if constexpr (STORE_QKV)
{
*reinterpret_cast<Vec_type*>(&QKV[src_k_idx]) = k;
if constexpr (ADD_BIAS)
if (quantized_fp8_output)
{
*reinterpret_cast<Vec_type*>(&QKV[src_v_idx]) = v;
// use 1.0f scale currently for qkv input of FP8 FMHA.
mmha::convert_to_fp8(
reinterpret_cast<Quantized_type*>(reinterpret_cast<Quantized_elt_type*>(O) + src_k_idx), k);
mmha::convert_to_fp8(
reinterpret_cast<Quantized_type*>(reinterpret_cast<Quantized_elt_type*>(O) + src_v_idx), v);
}
else
{
*reinterpret_cast<Vec_type*>(&QKV[src_k_idx]) = k;
if constexpr (ADD_BIAS)
{
*reinterpret_cast<Vec_type*>(&QKV[src_v_idx]) = v;
}
}
}
@ -405,22 +443,22 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
= std::min((grid_size + head_num - 1) / head_num, (token_num + tokens_per_block - 1) / tokens_per_block); \
dim3 grid(blocks_per_sequence, head_num); \
applyBiasRopeUpdateKVCache<T, T_cache, Dh_MAX, ADD_BIAS, STORE_QKV, POS_SHIFT, KVCacheBuffer, IS_GENERATE> \
<<<grid, block, 0, stream>>>(QKV, Q, kvTable, qkv_bias, seq_lens, kv_seq_lens, padding_offset, \
<<<grid, block, 0, stream>>>(QKV, O, Q, kvTable, qkv_bias, seq_lens, kv_seq_lens, padding_offset, \
kvScaleOrigQuant, token_num, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, head_num, \
kv_head_num, head_num / kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, \
rotary_scale_type, updated_rotary_embedding_scale, rotary_embedding_max_positions, \
position_embedding_type, medusa_position_offsets, beam_width);
position_embedding_type, medusa_position_offsets, quantized_fp8_output, beam_width);
template <int Dh_MAX, typename T, typename T_cache, typename KVCacheBuffer, bool IS_GENERATE>
void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
void kernelDispatchHeadSize(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale,
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha, int const beam_width,
int2& grid_block_cache, cudaStream_t stream)
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha,
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream)
{
bool const add_bias = qkv_bias != nullptr;
bool const store_contiguous_qkv = !enable_paged_kv_fmha;
@ -482,15 +520,15 @@ void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_b
}
template <typename T, typename T_cache, typename KVCacheBuffer, bool IS_GENERATE>
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale,
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha, int const beam_width,
int2& grid_block_cache, cudaStream_t stream)
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha,
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO
if constexpr (!IS_GENERATE)
@ -509,30 +547,30 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTab
// GPTJ Rotary embedding needs at least two elements per thread.
if (size_per_head <= 64)
{
kernelDispatchHeadSize<64, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
kernelDispatchHeadSize<64, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias, seq_lens,
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
grid_block_cache, stream);
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
beam_width, grid_block_cache, stream);
}
else if (size_per_head <= 128)
{
kernelDispatchHeadSize<128, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
kernelDispatchHeadSize<128, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias, seq_lens,
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
grid_block_cache, stream);
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
beam_width, grid_block_cache, stream);
}
else if (size_per_head <= 256)
{
kernelDispatchHeadSize<256, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
kernelDispatchHeadSize<256, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias, seq_lens,
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
grid_block_cache, stream);
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
beam_width, grid_block_cache, stream);
}
else
{
@ -541,15 +579,15 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTab
}
template <typename T, typename KVCacheBuffer, bool IS_GENERATE>
void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
void invokeApplyBiasRopeUpdateKVCache(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, int const int8_mode,
const KvCacheDataType cache_type, float const* kvScaleOrigQuant, bool const enable_paged_kv_fmha,
int const beam_width, int2& grid_block_cache, cudaStream_t stream)
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream)
{
// Block handles both K and V tile.
constexpr int x = (sizeof(T) == 4) ? 4 : 8;
@ -557,38 +595,38 @@ void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T co
if (cache_type == KvCacheDataType::INT8)
{
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias,
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias,
seq_lens, kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num,
head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
grid_block_cache, stream);
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
beam_width, grid_block_cache, stream);
}
#ifdef ENABLE_FP8
else if (cache_type == KvCacheDataType::FP8)
{
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable,
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable,
qkv_bias, seq_lens, kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len,
token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base,
rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type,
medusa_position_offsets, position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha,
beam_width, grid_block_cache, stream);
quantized_fp8_output, beam_width, grid_block_cache, stream);
}
#endif // ENABLE_FP8
else
{
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias,
seq_lens, kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num,
head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
grid_block_cache, stream);
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
beam_width, grid_block_cache, stream);
}
}
#define INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(T, KVCacheBuffer, IS_GENERATE) \
template void invokeApplyBiasRopeUpdateKVCache<T, KVCacheBuffer, IS_GENERATE>(T * QKV, T * Q, \
KVCacheBuffer & kvTable, const T* qkv_bias, const int* seq_lens, const int* kv_seq_lens, \
template void invokeApplyBiasRopeUpdateKVCache<T, KVCacheBuffer, IS_GENERATE>(T * QKV, void* O, T* Q, \
KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, const int* kv_seq_lens, \
const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len, \
const int sink_token_len, const int token_num, const int head_num, const int kv_head_num, \
const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, \
@ -596,7 +634,8 @@ void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T co
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, \
const int* medusa_position_offsets, const bool position_shift_enabled, const float* scale, \
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \
const bool enable_paged_kv_fmha, const int beam_width, int2& grid_block_cache, cudaStream_t stream)
const bool enable_paged_kv_fmha, bool const quantized_fp8_output, const int beam_width, \
int2& grid_block_cache, cudaStream_t stream)
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,172 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/int8SQ.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace smooth_quant
{
template <typename Type, int CtaM, int CtaN, int Threads, bool PerChannel, bool PerToken>
__global__ void int8_sq(int8_t const* act, int8_t const* weight, float const* scale_channels, float const* scale_tokens,
Type* output, int m, int n, int k)
{
using VecType = int4;
static constexpr int kStepK = 128 / (8 * sizeof(int8_t));
static constexpr int CtaK = kStepK * Threads;
int tile_id_m = blockIdx.x * CtaM;
int tile_id_n = blockIdx.y * CtaN;
int tid = threadIdx.x;
int8_t tile_a[kStepK], tile_w[CtaN * kStepK];
int acc[CtaM * CtaN];
#pragma unroll
for (int i = 0; i < CtaM * CtaN; ++i)
{
acc[i] = 0;
}
act += tile_id_m * k;
weight += tile_id_n * k;
output += tile_id_m * n + tile_id_n;
for (int idx_k = tid * kStepK; idx_k < k; idx_k += CtaK)
{
#pragma unroll
for (int i = 0; i < CtaN; ++i)
{
reinterpret_cast<VecType*>(tile_w)[i] = reinterpret_cast<VecType const*>(weight + i * k + idx_k)[0];
}
#pragma unroll
for (int i = 0; i < CtaM; ++i)
{
reinterpret_cast<VecType*>(tile_a)[0] = reinterpret_cast<VecType const*>(act + i * k + idx_k)[0];
#pragma unroll
for (int j = 0; j < CtaN; ++j)
{
#pragma unroll
for (int l = 0; l < kStepK; l += 4)
{
acc[i * CtaN + j] = __dp4a(reinterpret_cast<int*>(tile_a + l)[0],
reinterpret_cast<int*>(tile_w + j * kStepK + l)[0], acc[i * CtaN + j]);
}
}
}
}
static constexpr int kWarpSize = 32;
static constexpr int kWarpNum = Threads / kWarpSize;
__shared__ int shmem[CtaM * CtaN * kWarpNum];
int warp_id = tid / kWarpSize, lane_id = tid % kWarpSize;
#pragma unroll
for (int i = 0; i < CtaM; ++i)
{
#pragma unroll
for (int j = 0; j < CtaN; ++j)
{
int val = acc[i * CtaN + j];
val += __shfl_xor_sync(~0, val, 16);
val += __shfl_xor_sync(~0, val, 8);
val += __shfl_xor_sync(~0, val, 4);
val += __shfl_xor_sync(~0, val, 2);
val += __shfl_xor_sync(~0, val, 1);
if (lane_id == 0)
{
shmem[i * CtaN + j + warp_id * CtaM * CtaN] = val;
}
}
}
__syncthreads();
#pragma unroll
for (int ii = tid; ii < CtaM * CtaN; ii += Threads)
{
int mid = ii / CtaN, nid = ii % CtaN;
float scale_channel, scale_token;
if constexpr (PerChannel)
{
scale_channel = scale_channels[tile_id_n + nid];
}
else
{
scale_channel = scale_channels[0];
}
if constexpr (PerToken)
{
scale_token = scale_tokens[tile_id_m + mid];
}
else
{
scale_token = scale_tokens[0];
}
int val = 0;
#pragma unroll
for (int jj = 0; jj < kWarpNum; ++jj)
{
val += shmem[jj * CtaM * CtaN + ii];
}
output[mid * n + nid] = static_cast<Type>(static_cast<float>(val) * scale_channel * scale_token);
}
}
template <typename Type, int CtaM, int CtaN, int Threads, bool PerChannel, bool PerToken>
void int8_sq_kernel(Params& params, cudaStream_t s)
{
dim3 block(Threads);
dim3 grid(params.m / CtaM, params.n / CtaN);
int8_sq<Type, CtaM, CtaN, Threads, PerChannel, PerToken><<<grid, block, 0, s>>>(params.act, params.weight,
params.scale_channels, params.scale_tokens, reinterpret_cast<Type*>(params.output), params.m, params.n,
params.k);
}
template <typename Type, bool PerChannel, bool PerToken>
void algo_tactic_dispatcher(Params& params, cudaStream_t s)
{
#define DISPATCH(TargetM, CtaM, CtaN, Threads) \
if (params.m == TargetM) \
{ \
int8_sq_kernel<Type, CtaM, CtaN, Threads, PerChannel, PerToken>(params, s); \
return; \
}
DISPATCH(1, 1, 2, 128);
DISPATCH(2, 2, 2, 128);
DISPATCH(3, 3, 2, 128);
DISPATCH(4, 4, 2, 128);
#undef DISPATCH
}
template <typename Type>
void int8_sq_launcher(Params& params, cudaStream_t s)
{
#define DISPATCH(PerChannel, PerToken) \
if (per_channel == PerChannel && per_token == PerToken) \
{ \
algo_tactic_dispatcher<Type, PerChannel, PerToken>(params, s); \
return; \
}
bool per_channel = params.quant_mode.hasPerChannelScaling();
bool per_token = params.quant_mode.hasPerTokenScaling();
DISPATCH(false, false);
DISPATCH(false, true);
DISPATCH(true, false);
DISPATCH(true, true);
#undef DISPATCH
}
template void int8_sq_launcher<float>(Params& params, cudaStream_t s);
template void int8_sq_launcher<half>(Params& params, cudaStream_t s);
template void int8_sq_launcher<int>(Params& params, cudaStream_t s);
} // namespace smooth_quant
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/quantization.h"
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <iostream>
namespace tensorrt_llm
{
namespace kernels
{
namespace smooth_quant
{
struct Params
{
int8_t const* act;
int8_t const* weight;
float const* scale_tokens;
float const* scale_channels;
void* output;
int m, n, k;
tensorrt_llm::common::QuantMode quant_mode;
Params(int8_t const* _act, int8_t const* _weight, float const* _scale_tokens, float const* _scale_channels,
void* _output, int _m, int _n, int _k, tensorrt_llm::common::QuantMode _quant_mode)
: act(_act)
, weight(_weight)
, scale_tokens(_scale_tokens)
, scale_channels(_scale_channels)
, output(_output)
, m(_m)
, n(_n)
, k(_k)
, quant_mode(_quant_mode)
{
}
};
template <typename>
void int8_sq_launcher(Params& params, cudaStream_t s);
} // namespace smooth_quant
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -41,12 +41,12 @@ public:
class SetupParams : public DecodingSetupParams
{
public:
std::optional<std::vector<runtime::SizeType>> runtime_top_k; // [1] or [batchSize] on cpu
std::optional<std::vector<float>> runtime_top_p; // [1] or [batchSize] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
std::optional<std::vector<float>> top_p_decay; // [batchSize], must between [0, 1]
std::optional<std::vector<float>> top_p_min; // [batchSize], must between [0, 1]
std::optional<std::vector<std::int32_t>> top_p_reset_ids; // [batchSize]
std::optional<std::vector<runtime::SizeType>> runtime_top_k; // [1] or [batchSize] on cpu
std::optional<std::vector<float>> runtime_top_p; // [1] or [batchSize] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
std::optional<std::vector<float>> top_p_decay; // [batchSize], must between [0, 1]
std::optional<std::vector<float>> top_p_min; // [batchSize], must between [0, 1]
std::optional<std::vector<runtime::TokenIdType>> top_p_reset_ids; // [batchSize]
std::optional<bool> normalize_log_probs;
};

View File

@ -78,8 +78,10 @@ public:
tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
// Medusa params
std::optional<tc::Tensor> nextDraftTokens; // [batch_size, max_draft_tokens_per_step]
std::optional<tc::Tensor> acceptedLengths; // [batch_size]
std::optional<tc::Tensor> nextDraftTokens; // [batch_size, max_draft_tokens_per_step]
std::optional<tc::Tensor> acceptedLengths; // [batch_size]
std::optional<tc::Tensor> acceptedLengthsCumSum; // [batch_size + 1]
std::optional<tc::Tensor> medusaPathsOffsets; // [batch_size * max_medusa_heads]
};
} // namespace tensorrt_llm::layers

View File

@ -292,7 +292,6 @@ void DynamicDecodeLayer<T>::setupLayers(
typename MedusaDecodingLayer<T>::MedusaSetupParams medusaSetupParams;
medusaSetupParams.runtimeTopK = setupParams.runtime_top_k;
medusaSetupParams.runtimeHeadsTopK = setupParams.topKMedusaHeads;
medusaSetupParams.tokensPerStep = setupParams.tokensPerStep;
medusaSetupParams.randomSeed = setupParams.randomSeed;
mMedusaDecodingLayer->setup(batchSize, batchSlots, medusaSetupParams);
}
@ -555,14 +554,19 @@ void DynamicDecodeLayer<T>::layersForward(Tensor& logits, OutputParams& outputs,
typename MedusaDecodingLayer<T>::MedusaForwardParams medusaInputParams(logits, endIds);
medusaInputParams.finished = outputs.finished.value();
medusaInputParams.batch_slots = params.batch_slots;
medusaInputParams.paths = params.paths.value();
medusaInputParams.medusaLogits = params.medusaLogits.value();
medusaInputParams.paths = params.medusaInputs->medusaPaths;
medusaInputParams.medusaLogits = params.medusaInputs->medusaLogits;
medusaInputParams.medusaCurTokensPerStep = params.medusaInputs->medusaCurTokensPerStep;
medusaInputParams.medusaTargetTokensPerStep = params.medusaInputs->medusaTargetTokensPerStep;
medusaInputParams.treeIds = params.medusaInputs->medusaTreeIds;
DecodingOutputParams medusaOutputParams(outputs.output_ids);
medusaOutputParams.sequence_length = outputs.sequence_length.value();
medusaOutputParams.finished = outputs.finished.value();
medusaOutputParams.nextDraftTokens = outputs.nextDraftTokens.value();
medusaOutputParams.acceptedLengths = outputs.acceptedLengths.value();
medusaOutputParams.nextDraftTokens = outputs.medusaOutputs->nextDraftTokens;
medusaOutputParams.acceptedLengths = outputs.medusaOutputs->acceptedLengths;
medusaOutputParams.acceptedLengthsCumSum = outputs.medusaOutputs->medusaAcceptedLengthsCumSum;
medusaOutputParams.medusaPathsOffsets = outputs.medusaOutputs->medusaPathsOffsets;
mMedusaDecodingLayer->forward(medusaOutputParams, medusaInputParams);
}
@ -614,7 +618,8 @@ void DynamicDecodeLayer<T>::applyPenalties(OutputParams& outputs, ForwardParams
#undef GET_PENALTIES
auto const tokensPerStep = params.tokensPerStep ? params.tokensPerStep->template getPtr<SizeType const>() : nullptr;
auto const tokensPerStep
= params.medusaInputs ? params.medusaInputs->medusaCurTokensPerStep.template getPtr<SizeType const>() : nullptr;
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T const* const*>(logitsPtrsHostData),
mRuntimeLogitsDevice, embeddingBias, mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice, temperatures,
repetitionPenalties, presencePenalties, frequencyPenalties,
@ -788,8 +793,9 @@ void DynamicDecodeLayer<T>::prepareOutputData(OutputParams& outputs, ForwardPara
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
auto idsPtrHost = reinterpret_cast<TokenIdType**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
auto const numNewTokens = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr<SizeType const>()
: static_cast<SizeType*>(nullptr);
auto const numNewTokens = outputs.medusaOutputs
? outputs.medusaOutputs->acceptedLengths.template getPtr<SizeType const>()
: static_cast<SizeType*>(nullptr);
invokeCopyNextStepIds(outputs.newTokens.template getPtr<TokenIdType>(), idsPtrHost,
outputs.sequence_length->template getPtr<SizeType>(), numNewTokens, batchSlots, batchSize, maxBatchSize,
beamWidth, maxSeqLen, maxTokensPerStep, stream);

View File

@ -84,7 +84,6 @@ public:
// Medusa params
std::optional<std::vector<std::vector<runtime::SizeType>>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
std::optional<std::vector<runtime::SizeType>> tokensPerStep; // [batchSize]
};
void setup(runtime::SizeType batch_size, runtime::SizeType beam_width, int const* batch_slots,
@ -139,10 +138,18 @@ public:
std::optional<tc::Tensor> batch_slots; // [batch_size], in pinned memory
// Medusa inputs
std::optional<tc::Tensor> tokensPerStep; // [batch_size], optional, on gpu
std::optional<tc::Tensor> paths; // [batch_size, max_tokens_per_step, max_num_heads + 1], optional, on gpu
std::optional<tc::Tensor>
medusaLogits; // [max_num_heads, batch_size, max_tokens_per_step, vocab_size], optional, on gpu
class MedusaInputs
{
public:
tc::Tensor medusaCurTokensPerStep; // [batch_size], optional, on gpu
tc::Tensor medusaTargetTokensPerStep; // [batch_size], optional, on gpu
tc::Tensor medusaPaths; // [batch_size, max_tokens_per_step, max_num_heads + 1], optional, on gpu
tc::Tensor medusaTreeIds; // [batch_size, max_tokens_per_step], optional, on gpu
std::vector<std::vector<tc::Tensor>>
medusaLogits; // [max_batch_size][max_num_heads][tokens_per_step, vocab_size], optional, on gpu
};
std::optional<MedusaInputs> medusaInputs;
};
class OutputParams
@ -173,9 +180,16 @@ public:
tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
// Medusa outputs
std::optional<tc::Tensor>
nextDraftTokens; // [batch_size, max_tokens_per_step], draft tokens predicted by Medusa heads
std::optional<tc::Tensor> acceptedLengths; // [batch_size], lengths of the accepted draft tokens + 1
class MedusaOutputs
{
public:
tc::Tensor nextDraftTokens; // [batch_size, max_tokens_per_step], draft tokens predicted by Medusa heads
tc::Tensor acceptedLengths; // [batch_size], lengths of the accepted draft tokens + 1
tc::Tensor medusaAcceptedLengthsCumSum; // [batch_size + 1]
tc::Tensor medusaPathsOffsets; // [batch_size * max_medusa_heads]
};
std::optional<MedusaOutputs> medusaOutputs;
};
void forward(OutputParams& outputs, ForwardParams const& params);

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <algorithm>
@ -68,7 +69,8 @@ void MedusaDecodingLayer<T>::allocateBuffer()
// Get sampling workspace size
{
auto samplingSizePrimarySampling = getTopKWorkspaceSize<T>(mMaxBatchSize, 1, TOP_K_MAX, mVocabSizePadded);
auto samplingSizePrimarySampling
= getTopKWorkspaceSize<T>(mMaxBatchSize, mMaxTokensPerStep, TOP_K_MAX, mVocabSizePadded);
auto const maxBatchSizeHeadNums = mMaxBatchSize * mMaxNumHeads;
auto samplingSizeMedusaHeadsSampling
@ -82,35 +84,40 @@ void MedusaDecodingLayer<T>::allocateBuffer()
runtime::TRTDataType<TokenIdType*>::value);
mCummulativeTopK.resize(mMaxBatchSize * mMaxNumHeads);
std::array<size_t, 10> deviceBufferSizes;
std::array<size_t, 11> deviceBufferSizes;
deviceBufferSizes[0] = mMaxBatchSize * sizeof(curandState_t);
deviceBufferSizes[1] = mMaxBatchSize * sizeof(SizeType);
deviceBufferSizes[2] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
deviceBufferSizes[3] = mSamplingWorkspaceSize;
deviceBufferSizes[4] = mMaxBatchSize * sizeof(SizeType);
deviceBufferSizes[5] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType);
deviceBufferSizes[6] = mMaxBatchSize * mMaxNumHeads * sizeof(uint64_t);
deviceBufferSizes[7] = mMaxBatchSize * mMaxNumHeads * sizeof(T*);
deviceBufferSizes[8] = mMaxBatchSize * mMaxNumHeads * sizeof(curandState_t);
deviceBufferSizes[9] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
deviceBufferSizes[1] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
deviceBufferSizes[2] = mSamplingWorkspaceSize;
deviceBufferSizes[3] = mMaxBatchSize * sizeof(SizeType);
deviceBufferSizes[4] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType);
deviceBufferSizes[5] = mMaxBatchSize * mMaxNumHeads * sizeof(uint64_t);
deviceBufferSizes[6] = mMaxBatchSize * mMaxNumHeads * sizeof(T*);
deviceBufferSizes[7] = mMaxBatchSize * mMaxNumHeads * sizeof(curandState_t);
deviceBufferSizes[8] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
deviceBufferSizes[9] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType);
deviceBufferSizes[10] = mMaxBatchSize * sizeof(SizeType);
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
mTokensPerStepDevice = mAllocator->reMalloc(mTokensPerStepDevice, deviceBufferSizes[1], false);
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[2], false);
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[3], false);
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[4], false);
mTargetTokensDevice = mAllocator->reMalloc(mTargetTokensDevice, deviceBufferSizes[5], false);
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[6], false);
mMedusaLogitsPtrsDevice = mAllocator->reMalloc(mMedusaLogitsPtrsDevice, deviceBufferSizes[7], false);
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[1], false);
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[2], false);
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[3], false);
mTargetTokensDevice = mAllocator->reMalloc(mTargetTokensDevice, deviceBufferSizes[4], false);
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[5], false);
mMedusaSelectedLogitsPtrsDevice
= mAllocator->reMalloc(mMedusaSelectedLogitsPtrsDevice, deviceBufferSizes[6], false);
mCurandStatesMedusaLogitsDevice
= mAllocator->reMalloc(mCurandStatesMedusaLogitsDevice, deviceBufferSizes[8], false);
= mAllocator->reMalloc(mCurandStatesMedusaLogitsDevice, deviceBufferSizes[7], false);
mRuntimeTopKPerRequestPerMedusaHeadDevice
= mAllocator->reMalloc(mRuntimeTopKPerRequestPerMedusaHeadDevice, deviceBufferSizes[9], false);
= mAllocator->reMalloc(mRuntimeTopKPerRequestPerMedusaHeadDevice, deviceBufferSizes[8], false);
mNewDraftTokensDevice = mAllocator->reMalloc(mNewDraftTokensDevice, deviceBufferSizes[9], false);
mBestPathIdsDevice = mAllocator->reMalloc(mBestPathIdsDevice, deviceBufferSizes[10], false);
mTiledBatchSlotsSetup = BufferManager::pinnedPool(
ITensor::makeShape({static_cast<SizeType>(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32);
mTiledBatchSlotsForward = BufferManager::pinnedPool(
ITensor::makeShape({static_cast<SizeType>(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32);
mMedusaInputLogitsPtrs = BufferManager::pinnedPool(
ITensor::makeShape({static_cast<SizeType>(mMaxBatchSize * mMaxNumHeads)}), TRTDataType<T*>::value);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -121,15 +128,16 @@ void MedusaDecodingLayer<T>::freeBuffer()
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mAllocator->free((void**) (&mCurandStatesDevice));
mAllocator->free((void**) (&mTokensPerStepDevice));
mAllocator->free((void**) (&mSetupWorkspaceDevice));
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
mAllocator->free((void**) (&mRuntimeTopKDevice));
mAllocator->free((void**) (&mTargetTokensDevice));
mAllocator->free((void**) (&mRandomSeedsDevice));
mAllocator->free((void**) (&mMedusaLogitsPtrsDevice));
mAllocator->free((void**) (&mMedusaSelectedLogitsPtrsDevice));
mAllocator->free((void**) (&mCurandStatesMedusaLogitsDevice));
mAllocator->free((void**) (&mRuntimeTopKPerRequestPerMedusaHeadDevice));
mAllocator->free((void**) (&mNewDraftTokensDevice));
mAllocator->free((void**) (&mBestPathIdsDevice));
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -191,17 +199,6 @@ void MedusaDecodingLayer<T>::setup(SizeType batchSize, SizeType const* batchSlot
}
initCurandStates({tiledRandomSeed}, batchSizeMaxNumHeads, tiledBatchSlots, mCurandStatesMedusaLogitsDevice);
// Prepare tokens per step
{
auto tokensPerStep = setupParams.tokensPerStep.value_or(std::vector<SizeType>{batchSize, mMaxTokensPerStep});
TLLM_CHECK_WITH_INFO(tokensPerStep.size() == batchSize,
fmtstr("tokensPerStep.size() (%lu) == batchSize (%d) is not satisfied!", tokensPerStep.size(), batchSize));
cudaAutoCpy(reinterpret_cast<SizeType*>(mSetupWorkspaceDevice), tokensPerStep.data(), batchSize, mStream);
invokeScatterDecodingParams(
reinterpret_cast<SizeType*>(mSetupWorkspaceDevice), mTokensPerStepDevice, batchSlots, batchSize, mStream);
}
// Prepare runtime top K
auto prepareRuntimeTopK = [this](std::vector<SizeType> const& runtimeTopK, SizeType batchSize,
SizeType const* batchSlots, SizeType* runtimeTopKDevice)
@ -221,7 +218,7 @@ void MedusaDecodingLayer<T>::setup(SizeType batchSize, SizeType const* batchSlot
auto constexpr defaultTopK = 1u;
{
auto runtimeTopK = setupParams.runtimeTopK.value_or(std::vector<SizeType>{batchSize, defaultTopK});
auto runtimeTopK = setupParams.runtimeTopK.value_or(std::vector<SizeType>(batchSize, defaultTopK));
auto const curMaxTopK = prepareRuntimeTopK(runtimeTopK, batchSize, batchSlots, mRuntimeTopKDevice);
mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK);
}
@ -279,6 +276,10 @@ void MedusaDecodingLayer<T>::forward(DecodingOutputParams& outputs, MedusaForwar
sampleNewDraftTokens(outputs, inputs);
scatterNewDraftTokens(outputs, inputs);
packAcceptedPaths(outputs, inputs);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -290,10 +291,9 @@ void MedusaDecodingLayer<T>::samplePrimeHeadTokens(DecodingOutputParams& outputs
auto const batchSize = inputs.logits.shape[0];
auto logits = inputs.logits.template getPtr<T>();
auto batchSlots
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>()
: static_cast<SizeType*>(nullptr);
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : nullptr;
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>() : nullptr;
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType>();
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
@ -302,11 +302,11 @@ void MedusaDecodingLayer<T>::samplePrimeHeadTokens(DecodingOutputParams& outputs
// Sequence length is not modified, endIds is not checked, outputLogProbs are not supported.
// Finished state is not set.
invokeBatchTopKSampling(mSamplingWorkspaceDevice, logits, /* logProbsPtrs */ static_cast<T const* const*>(nullptr),
/* outputIdsPtrs */ nullptr, mTargetTokensDevice, sequenceLengths,
/* outputIdsPtrs */ nullptr, mTargetTokensDevice, /* sequenceLengths */ nullptr,
/* finishedInput */ nullptr, /* finishedOutput */ nullptr,
/* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesDevice, mRuntimeMaxTopK,
mRuntimeTopKDevice, 1.0f, /* runtimeTopPDevice */ nullptr, mVocabSizePadded, /* endIds */ nullptr, batchSlots,
mStream, batchSize, mMaxBatchSize, mTokensPerStepDevice, mMaxTokensPerStep, mMaxTokensPerStep,
mStream, batchSize, mMaxBatchSize, tokensPerStepDevice, mMaxTokensPerStep, mMaxTokensPerStep,
/* skipDecode */ nullptr, /* normalizeLogProbs */ false,
/* probsComputed */ false, /* return all Top-K*/ false);
@ -324,7 +324,6 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(DecodingOutputParams& outputs, Me
auto outputIds = outputs.output_ids.template getPtr<TokenIdType>();
auto endIds = inputs.end_ids.template getPtr<TokenIdType const>();
auto paths = inputs.paths.template getPtr<SizeType const>();
auto medusaLogits = inputs.medusaLogits.template getPtr<T const>();
auto batchSlots
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
@ -332,20 +331,42 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(DecodingOutputParams& outputs, Me
: static_cast<SizeType*>(nullptr);
auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr<SizeType>()
: static_cast<SizeType*>(nullptr);
auto curTokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType>();
auto targetTokensPerStepDevice = inputs.medusaTargetTokensPerStep.template getPtr<SizeType>();
auto medusaInputLogitsPtrs = BufferRange<T*>(*mMedusaInputLogitsPtrs);
for (SizeType bi = 0; bi < batchSize; ++bi)
{
auto const slot = batchSlots[bi];
for (SizeType hi = 0; hi < mMaxNumHeads; ++hi)
{
medusaInputLogitsPtrs[slot * mMaxNumHeads + hi] = inputs.medusaLogits[slot][hi].template getPtr<T>();
}
}
auto draftIds = outputs.nextDraftTokens ? outputs.nextDraftTokens->template getPtr<TokenIdType>() : nullptr;
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(
curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(
targetTokensPerStepDevice != nullptr, "Target tokens per step must be provided for MedusaDecoding");
auto finishedStates
= reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>());
// Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths.
// Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly.
// Fill mMedusaLogitsPtrsDevice with respective Medusa logits
acceptDraftTokensByIdsWithPaths(outputIds, mTargetTokensDevice, sequenceLengths, acceptedLengths, finishedStates,
batchSlots, paths, endIds, medusaLogits, const_cast<T const**>(mMedusaLogitsPtrsDevice), batchSize, mVocabSize,
mMaxBatchSize, maxSeqLen, mMaxTokensPerStep, mMaxNumHeads, mMaxTokensPerStep, mStream);
// Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, acceptedLengths,
finishedStates, batchSlots, paths, endIds,
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
const_cast<T const**>(mMedusaSelectedLogitsPtrsDevice), curTokensPerStepDevice, targetTokensPerStepDevice,
mBestPathIdsDevice, batchSize, mVocabSize, mMaxBatchSize, mMaxTokensPerStep, maxSeqLen, mMaxNumHeads,
mMaxTokensPerStep, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -358,8 +379,7 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
auto const batchSize = inputs.logits.shape[0];
auto batchSlots
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>()
: static_cast<SizeType*>(nullptr);
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>() : nullptr;
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
@ -378,9 +398,6 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
}
auto draftIdsPtrs = reinterpret_cast<TokenIdType**>(bufferCast<int64_t>(*mDraftIdsPtrHost));
auto draftIds = (outputs.nextDraftTokens) ? outputs.nextDraftTokens->template getPtr<TokenIdType>()
: static_cast<TokenIdType*>(nullptr);
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
for (SizeType bi = 0; bi < batchSize; ++bi)
{
@ -388,12 +405,13 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
for (SizeType hi = 0; hi < mMaxNumHeads; ++hi)
{
draftIdsPtrs[slot * mMaxNumHeads + hi]
= draftIds + slot * mMaxTokensPerStep + mCummulativeTopK[slot * mMaxNumHeads + hi];
= mNewDraftTokensDevice + slot * mMaxTokensPerStep + mCummulativeTopK[slot * mMaxNumHeads + hi];
}
}
invokeBatchTopKSampling(mSamplingWorkspaceDevice,
/* logits */ static_cast<T const*>(nullptr), const_cast<T const* const*>(mMedusaLogitsPtrsDevice), draftIdsPtrs,
/* logits */ static_cast<T const*>(nullptr), const_cast<T const* const*>(mMedusaSelectedLogitsPtrsDevice),
draftIdsPtrs,
/* outputIds */ nullptr, /* sequenceLength */ nullptr,
/* finishedInput */ nullptr, /* finishedOutput */ nullptr,
/* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesMedusaLogitsDevice,
@ -408,6 +426,54 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void MedusaDecodingLayer<T>::scatterNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto batchSlots
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
auto draftIds = outputs.nextDraftTokens ? outputs.nextDraftTokens->template getPtr<TokenIdType>() : nullptr;
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType>();
auto treeIds = inputs.treeIds.template getPtr<SizeType>();
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(treeIds != nullptr, "Tree ids must be provided for MedusaDecoding");
scatterMedusaDraftTokens(draftIds, mNewDraftTokensDevice, treeIds, tokensPerStepDevice, batchSlots,
mMaxTokensPerStep, batchSize, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void MedusaDecodingLayer<T>::packAcceptedPaths(DecodingOutputParams& outputs, MedusaForwardParams& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto paths = inputs.paths.template getPtr<SizeType const>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : nullptr;
auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr<SizeType>() : nullptr;
auto acceptedLengthsCumSum
= outputs.acceptedLengthsCumSum ? outputs.acceptedLengthsCumSum->template getPtr<SizeType>() : nullptr;
auto medusaPathsOffsets
= outputs.medusaPathsOffsets ? outputs.medusaPathsOffsets->template getPtr<SizeType>() : nullptr;
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for MedusaDecoding");
TLLM_CHECK_WITH_INFO(medusaPathsOffsets != nullptr, "medusaPathsOffsets must be provided for MedusaDecoding");
invokePackAcceptedPaths(acceptedLengthsCumSum, medusaPathsOffsets, acceptedLengths, mBestPathIdsDevice, paths,
batchSlots, batchSize, mMaxTokensPerStep, mMaxNumHeads + 1, mStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class MedusaDecodingLayer<float>;
template class MedusaDecodingLayer<half>;

View File

@ -44,11 +44,10 @@ public:
class MedusaSetupParams : public DecodingSetupParams
{
public:
std::optional<std::vector<runtime::SizeType>> runtimeTopK; // [1] or [batchSize] on cpu
std::optional<std::vector<runtime::SizeType>> runtimeTopK; // [1] or [batchSize] on cpu
std::optional<std::vector<std::vector<runtime::SizeType>>>
runtimeHeadsTopK; // [batchSize, maxMedusaHeads] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
std::optional<std::vector<runtime::SizeType>> tokensPerStep; // [1] or [batchSize] on cpu
runtimeHeadsTopK; // [batchSize, maxMedusaHeads] on cpu
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
};
class MedusaForwardParams : public DecodingParams
@ -59,8 +58,12 @@ public:
{
}
tc::Tensor paths; // [maxBatchSize, maxTokensPerStep, maxNumHeads + 1] on gpu
tc::Tensor medusaLogits; // [maxNumHeads, maxBatchSize, maxTokensPerStep, vocabSize] on gpu
tc::Tensor paths; // [maxBatchSize, maxTokensPerStep, maxNumHeads + 1] on gpu
std::vector<std::vector<tc::Tensor>>
medusaLogits; // [maxBatchSize][maxNumHeads][tokensPerStep, vocabSize] on gpu
tc::Tensor medusaCurTokensPerStep; // [maxBatchSize] on gpu
tc::Tensor medusaTargetTokensPerStep; // [maxBatchSize] on gpu
tc::Tensor treeIds; // [maxBatchSize, maxTokensPerStep] on gpu
};
MedusaDecodingLayer(runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, runtime::SizeType vocabSizePadded,
@ -80,6 +83,8 @@ private:
void samplePrimeHeadTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
void acceptDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
void sampleNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
void scatterNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
void packAcceptedPaths(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
private:
using Base::mStream;
@ -97,19 +102,21 @@ private:
runtime::SizeType mRuntimeMaxTopKPerRequestPerMedusaHead{0};
curandState_t* mCurandStatesDevice{nullptr};
runtime::SizeType* mTokensPerStepDevice{nullptr};
void* mSetupWorkspaceDevice{nullptr};
void* mSamplingWorkspaceDevice{nullptr};
runtime::SizeType* mRuntimeTopKDevice{nullptr};
runtime::TokenIdType* mTargetTokensDevice{nullptr};
uint64_t* mRandomSeedsDevice{nullptr};
T** mMedusaLogitsPtrsDevice{nullptr};
T** mMedusaSelectedLogitsPtrsDevice{nullptr};
curandState_t* mCurandStatesMedusaLogitsDevice{nullptr};
runtime::SizeType* mRuntimeTopKPerRequestPerMedusaHeadDevice{nullptr};
runtime::TokenIdType* mNewDraftTokensDevice{nullptr};
runtime::SizeType* mBestPathIdsDevice{nullptr};
runtime::ITensor::UniquePtr mTiledBatchSlotsSetup;
runtime::ITensor::UniquePtr mTiledBatchSlotsForward;
runtime::ITensor::UniquePtr mDraftIdsPtrHost;
runtime::ITensor::UniquePtr mMedusaInputLogitsPtrs;
std::vector<runtime::SizeType> mCummulativeTopK;
};

View File

@ -66,7 +66,7 @@ void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, So
bh.length_penalties = length_penalties_buf_;
bh.early_stoppings = early_stoppings_buf_;
bh.batch_size = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[0]);
bh.batch_size = static_cast<std::int32_t>(params.end_ids.shape[0]);
bh.beam_width = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[1]);
bh.ite = params.ite;
bh.local_batch_size = params.logits.shape[0];

View File

@ -45,7 +45,8 @@ set(PLUGIN_LISTS
lookupPlugin
loraPlugin
mixtureOfExperts
selectiveScanPlugin)
selectiveScanPlugin
mambaConv1dPlugin)
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
include_directories(${PLUGIN_ITER})

View File

@ -26,6 +26,7 @@
#include "tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h"
#include "tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h"
#include "tensorrt_llm/plugins/loraPlugin/loraPlugin.h"
#include "tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.h"
#include "tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h"
#if ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h"
@ -73,7 +74,7 @@ public:
GlobalLoggerFinder gGlobalLoggerFinder{};
#if !defined(_MSC_VER)
__attribute__((constructor))
[[maybe_unused]] __attribute__((constructor))
#endif
void initOnLoad()
{
@ -89,6 +90,40 @@ bool pluginsInitialized = false;
} // namespace
namespace tensorrt_llm::plugins::api
{
LoggerManager& tensorrt_llm::plugins::api::LoggerManager::getInstance() noexcept
{
static LoggerManager instance;
return instance;
}
void LoggerManager::setLoggerFinder(nvinfer1::ILoggerFinder* finder)
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder == nullptr && finder != nullptr)
{
mLoggerFinder = finder;
}
}
[[maybe_unused]] nvinfer1::ILogger* LoggerManager::logger()
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder != nullptr)
{
return mLoggerFinder->findLogger();
}
return nullptr;
}
nvinfer1::ILogger* LoggerManager::defaultLogger() noexcept
{
return gLogger;
}
} // namespace tensorrt_llm::plugins::api
// New Plugin APIs
extern "C"
@ -127,7 +162,7 @@ extern "C"
[[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder)
{
tensorrt_llm::plugins::api::LoggerFinder::getInstance().setLoggerFinder(finder);
tensorrt_llm::plugins::api::LoggerManager::getInstance().setLoggerFinder(finder);
}
[[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(std::int32_t& nbCreators)
@ -155,6 +190,7 @@ extern "C"
static tensorrt_llm::plugins::LookupPluginCreator lookupPluginCreator;
static tensorrt_llm::plugins::LoraPluginCreator loraPluginCreator;
static tensorrt_llm::plugins::SelectiveScanPluginCreator selectiveScanPluginCreator;
static tensorrt_llm::plugins::MambaConv1dPluginCreator mambaConv1DPluginCreator;
static std::array pluginCreators
= { creatorPtr(identityPluginCreator),
@ -179,37 +215,10 @@ extern "C"
creatorPtr(lookupPluginCreator),
creatorPtr(loraPluginCreator),
creatorPtr(selectiveScanPluginCreator),
creatorPtr(mambaConv1DPluginCreator),
};
nbCreators = pluginCreators.size();
return pluginCreators.data();
}
} // extern "C"
namespace tensorrt_llm::plugins::api
{
LoggerFinder& tensorrt_llm::plugins::api::LoggerFinder::getInstance() noexcept
{
static LoggerFinder instance;
return instance;
}
void LoggerFinder::setLoggerFinder(nvinfer1::ILoggerFinder* finder)
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder == nullptr && finder != nullptr)
{
mLoggerFinder = finder;
}
}
nvinfer1::ILogger* LoggerFinder::findLogger()
{
std::lock_guard<std::mutex> lk(mMutex);
if (mLoggerFinder != nullptr)
{
return mLoggerFinder->findLogger();
}
return nullptr;
}
} // namespace tensorrt_llm::plugins::api

View File

@ -33,21 +33,22 @@ PluginFieldCollection GemmPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> GemmPluginCreator::mPluginAttributes;
void getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& m, int& n, int& k, int& lda, int& ldb,
int& ldc, bool transA, bool transB, int M, int N, int K)
int& ldc, bool transA, bool transB, int M, int N, int K, int padLda, int padLdb)
{
transa = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
transb = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
m = N;
n = M;
k = K;
lda = transB ? K : N;
ldb = transA ? M : K;
lda = transB ? K + padLdb : N + padLdb;
ldb = transA ? M + padLda : K + padLda;
ldc = N;
}
void runGemm(int const M, int const N, int const K, bool const transA, bool const transB, const nvinfer1::DataType type,
CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act, void const* weight, void* output,
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic, void* workspace, cudaStream_t stream)
void runGemm(int const M, int const N, int const K, bool const transA, bool const transB, int const padLda,
int const padLdb, const nvinfer1::DataType type, CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act,
void const* weight, void* output, std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic, void* workspace,
cudaStream_t stream)
{
cublasWrapperPtr->setStream(stream);
cublasWrapperPtr->setWorkspace(workspace);
@ -55,7 +56,7 @@ void runGemm(int const M, int const N, int const K, bool const transA, bool cons
cublasOperation_t transa, transb;
int m, n, k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K, padLda, padLdb);
cublasWrapperPtr->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
cublasWrapperPtr->Gemm(transa, transb, m, n, k, weight, lda, act, ldb, output, ldc, heuristic);
@ -78,7 +79,8 @@ void CublasLtGemmPluginProfiler::runTactic(
nextWorkspacePtrWithAlignment(reinterpret_cast<int8_t*>(weightPtr), n * k * dataSize, ALIGNMENT));
char* workspacePtr = reinterpret_cast<char*>(
nextWorkspacePtrWithAlignment(reinterpret_cast<int8_t*>(outputPtr), m * n * dataSize, ALIGNMENT));
runGemm(m, n, k, mTransA, mTransB, mType, mRunner, actPtr, weightPtr, outputPtr, {tactic}, workspacePtr, stream);
runGemm(m, n, k, mTransA, mTransB, mPadLda, mPadLdb, mType, mRunner, actPtr, weightPtr, outputPtr, {tactic},
workspacePtr, stream);
}
bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, Config const& tactic) const
@ -86,7 +88,7 @@ bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, Config const&
cublasOperation_t transa, transb;
int M = m, N = n, K = k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K, mPadLda, mPadLdb);
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
@ -117,7 +119,7 @@ std::vector<CublasLtGemmPluginProfiler::Config> CublasLtGemmPluginProfiler::getT
cublasOperation_t transa, transb;
int m, n, k;
int lda, ldb, ldc;
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K, mPadLda, mPadLdb);
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
auto const heruistics = mRunner->getTactics(transa, transb, m, n, k, lda, ldb, ldc);
@ -126,10 +128,12 @@ std::vector<CublasLtGemmPluginProfiler::Config> CublasLtGemmPluginProfiler::getT
return heruistics;
}
GemmPlugin::GemmPlugin(
int transA, int transB, nvinfer1::DataType type, bool useFp8, GemmPlugin::PluginProfilerPtr const& pluginProfiler)
GemmPlugin::GemmPlugin(int transA, int transB, int padLda, int padLdb, nvinfer1::DataType type, bool useFp8,
GemmPlugin::PluginProfilerPtr const& pluginProfiler)
: mTransA(transA)
, mTransB(transB)
, mPadLda(padLda)
, mPadLdb(padLdb)
, mType(type)
, mUseFp8(useFp8)
, mPluginProfiler(pluginProfiler)
@ -145,6 +149,8 @@ GemmPlugin::GemmPlugin(void const* data, size_t length, GemmPlugin::PluginProfil
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mTransA);
read(d, mTransB);
read(d, mPadLda);
read(d, mPadLdb);
read(d, mType);
read(d, mUseFp8);
read(d, mDims);
@ -169,6 +175,7 @@ void GemmPlugin::init()
mPluginProfiler->setTranspose(mTransA, mTransB);
mPluginProfiler->setOutputType(mOutputType);
mPluginProfiler->setPadLd(mPadLda, mPadLdb);
mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB, mOutputType);
}
@ -363,13 +370,16 @@ int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
int const nbDimsA = inputDesc[0].dims.nbDims;
int const nbDimsB = inputDesc[1].dims.nbDims;
auto const M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d);
auto const N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d);
int const K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1];
int const padM = mTransA ? mPadLda : 0;
int const padN = mTransB ? 0 : mPadLdb;
int const padK = mTransA ? 0 : mPadLda;
auto const M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d) - padM;
auto const N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d) - padN;
int const K = mTransA ? inputDesc[0].dims.d[0] - padK : inputDesc[0].dims.d[nbDimsA - 1] - padK;
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], bestTactic, workspace,
stream);
runGemm(M, N, K, mTransA, mTransB, mPadLda, mPadLdb, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0],
bestTactic, workspace, stream);
return 0;
}
@ -411,8 +421,9 @@ void GemmPlugin::destroy() noexcept
size_t GemmPlugin::getSerializationSize() const noexcept
{
return sizeof(mTransA) + sizeof(mTransB) + sizeof(mType) + sizeof(mDims) + sizeof(mUseFp8)
+ mPluginProfiler->getSerializationSize(mGemmId) + sizeof(mOutputType); // selected tactics container size
return sizeof(mTransA) + sizeof(mTransB) + sizeof(mPadLda) + sizeof(mPadLdb) + sizeof(mType) + sizeof(mDims)
+ sizeof(mUseFp8) + mPluginProfiler->getSerializationSize(mGemmId)
+ sizeof(mOutputType); // selected tactics container size
}
void GemmPlugin::serialize(void* buffer) const noexcept
@ -420,6 +431,8 @@ void GemmPlugin::serialize(void* buffer) const noexcept
char *d = static_cast<char*>(buffer), *a = d;
write(d, mTransA);
write(d, mTransB);
write(d, mPadLda);
write(d, mPadLdb);
write(d, mType);
write(d, mUseFp8);
write(d, mDims);
@ -439,6 +452,8 @@ GemmPluginCreator::GemmPluginCreator()
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("transA", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("transB", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("padLda", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("padLdb", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("use_fp8", nullptr, PluginFieldType::kINT32, 0));
mFC.nbFields = mPluginAttributes.size();
@ -463,7 +478,7 @@ PluginFieldCollection const* GemmPluginCreator::getFieldNames() noexcept
IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int transA, transB;
int transA, transB, padLda, padLdb;
nvinfer1::DataType type;
int useFp8;
// Read configurations from each fields
@ -480,6 +495,16 @@ IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollecti
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
transB = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "pad_lda"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
padLda = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "pad_ldb"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
padLdb = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "type_id"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
@ -497,7 +522,7 @@ IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollecti
// Create plugin profiler with shared tactics map
// FIXME enable tactic profiler
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
auto* obj = new GemmPlugin(transA, transB, type, useFp8, pluginProfiler);
auto* obj = new GemmPlugin(transA, transB, padLda, padLdb, type, useFp8, pluginProfiler);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

View File

@ -42,6 +42,12 @@ public:
mTransB = transposeB;
}
void setPadLd(int padLda, int padLdb)
{
mPadLda = padLda;
mPadLdb = padLdb;
}
void setOutputType(nvinfer1::DataType type)
{
mOutputType = type;
@ -59,6 +65,8 @@ protected:
private:
bool mTransA;
bool mTransB;
int mPadLda;
int mPadLdb;
nvinfer1::DataType mOutputType;
static constexpr size_t ALIGNMENT = 256;
@ -71,7 +79,8 @@ public:
GemmPlugin() = delete;
GemmPlugin(int transA, int transB, nvinfer1::DataType type, bool useFp8, PluginProfilerPtr const& profiler);
GemmPlugin(int transA, int transB, int padLda, int padLdb, nvinfer1::DataType type, bool useFp8,
PluginProfilerPtr const& profiler);
GemmPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler);
@ -114,6 +123,8 @@ private:
int mTransA;
int mTransB;
int mPadLda;
int mPadLdb;
nvinfer1::DataType mType;
nvinfer1::DataType mOutputType;

View File

@ -98,6 +98,7 @@ struct FusedQKVMaskedAttentionDispatchParams
T const* ia3_key_weights;
T const* ia3_value_weights;
float const* qkv_scale_out;
bool fp8_context_fmha;
float const* attention_out_scale;
bool mUnfuseQkvGemm;
tc::QuantMode quant_option;
@ -320,9 +321,12 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
params.ia3_key_weights = reinterpret_cast<DataType const*>(input_params.ia3_key_weights);
params.ia3_value_weights = reinterpret_cast<DataType const*>(input_params.ia3_value_weights);
if (input_params.quant_option.hasStaticActivationScaling())
if (input_params.quant_option.hasStaticActivationScaling() || input_params.fp8_context_fmha)
{
// qkv_scale_out is nullptr currently (no scale).
params.qkv_scale_quant_orig = input_params.qkv_scale_out;
TLLM_CHECK_WITH_INFO(!input_params.fp8_context_fmha || input_params.attention_out_scale != nullptr,
"attention output scale should be provided.");
params.attention_out_scale_orig_quant = input_params.attention_out_scale;
}
@ -374,7 +378,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_medusa_enabled)
: mLayerIdx(layer_idx)
, mNumHeads(num_heads)
, mNumKVHeads(num_kv_heads)
@ -408,6 +412,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
, mPosShiftEnabled(pos_shift_enabled)
, mDenseContextFMHA(dense_context_fmha)
, mPagedContextFMHA(use_paged_context_fmha)
, mFP8ContextFMHA(use_fp8_context_fmha)
, mUseKVCache(use_cache)
, mIsMedusaEnabled(is_medusa_enabled)
{
@ -442,6 +447,14 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
}
}
// Pre-Check of FP8 Context FMHA.
if (mFP8ContextFMHA)
{
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "FP8 FMHA cannot be enabled because Context FMHA is not supported.");
TLLM_CHECK_WITH_INFO(mSM == 90, "FP8 FMHA cannot be enabled on Pre-Hopper Arch.");
TLLM_CHECK_WITH_INFO(!mPagedContextFMHA, "FP8 Context Paged KV FMHA hasn't been implemented yet.");
}
TLLM_CHECK(isRoPE() == (rotary_embedding_dim != 0));
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
@ -511,6 +524,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
read(d, mPosShiftEnabled);
read(d, mDenseContextFMHA);
read(d, mPagedContextFMHA);
read(d, mFP8ContextFMHA);
read(d, mUseKVCache);
read(d, mIsMedusaEnabled);
@ -558,13 +572,16 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
: sizeof(float) * batch_size * mNumHeads * input_seq_length
* (isCrossAttention() ? cross_qkv_length : input_seq_length);
const size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA
? batch_size * input_seq_length * (local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
const size_t padding_offset_size = sizeof(int) * batch_size * input_seq_length;
// It is assumed that the number of tokens per paged kv block should be >= 128.
const size_t paged_kv_tma_desc_size = mPagedKVCache && mPagedContextFMHA
? batch_size * 2 * TMA_DESC_SIZE_IN_BYTE * tc::divUp(max_attention_window, mTokensPerBlock)
: 0;
int const NUM_BUFFERS = 12;
int const NUM_BUFFERS = 13;
size_t workspaces[NUM_BUFFERS];
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
workspaces[1] = attention_mask_size;
@ -576,8 +593,9 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t
workspaces[7] = qk_buf_size;
workspaces[8] = qkv_buf_2_size;
workspaces[9] = qk_buf_float_size;
workspaces[10] = padding_offset_size;
workspaces[11] = paged_kv_tma_desc_size;
workspaces[10] = fp8_qkv_buffer_size;
workspaces[11] = padding_offset_size;
workspaces[12] = paged_kv_tma_desc_size;
context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
return context_workspace_size;
}
@ -725,6 +743,9 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
: sizeof(float) * params.batch_size * mNumHeads
* params.input_seq_length * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
const size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA
? params.batch_size * params.input_seq_length * (local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
const size_t padding_offset_size
= sizeof(int) * params.batch_size * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
const size_t paged_kv_tma_desc_size = mPagedKVCache && mPagedContextFMHA
@ -746,6 +767,8 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
T* qk_buf_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_size));
T* qkv_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, qkv_buf_2_size));
float* qk_buf_float_ = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_float_size));
__nv_fp8_e4m3* fp8_qkv_buffer
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_qkv_buffer_size));
int* padding_offset = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
void* paged_kv_tma_desc
= reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, paged_kv_tma_desc_size));
@ -826,13 +849,15 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
// Paged Context FMHA doesn't work with fp8/int8 kv cache currently.
TLLM_CHECK_WITH_INFO(cache_type == KvCacheDataType::BASE || !enablePagedKVContextFMHA,
"Paged Context FMHA doesn't work with fp8/int8 kv cache currently.");
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), q_buf_2_, kv_cache_buffer,
const_cast<T*>(params.qkv_bias), params.q_seq_lengths, params.kv_seq_lengths,
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), fp8_qkv_buffer, q_buf_2_,
kv_cache_buffer, const_cast<T*>(params.qkv_bias), params.q_seq_lengths, params.kv_seq_lengths,
mRemovePadding ? padding_offset : nullptr, params.batch_size, params.input_seq_length,
params.cyclic_attention_window_size, params.sink_token_length, params.num_tokens, mNumHeads, mNumKVHeads,
getHeadSize(), mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
mRotaryEmbeddingMaxPositions, position_embedding_type, (int*) nullptr, mPosShiftEnabled, (float*) nullptr,
0, cache_type, params.kv_scale_orig_quant, enablePagedKVContextFMHA, 1, mLaunchGridBlockCache, stream);
mRotaryEmbeddingMaxPositions, position_embedding_type, (int*) nullptr, mPosShiftEnabled, nullptr, 0,
cache_type, params.kv_scale_orig_quant, enablePagedKVContextFMHA, mFP8ContextFMHA, 1, mLaunchGridBlockCache,
stream);
sync_check_cuda_error();
// It is not needed with packed QKV input.
@ -855,6 +880,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
{
TLLM_LOG_ERROR("Cannot support StreamingLLM now when enabling paged KV context FMHA.");
}
// TODO: add support for fp8 paged kv fmha later.
mFMHARunner->setup_paged_kv(params.batch_size, params.input_seq_length, params.max_past_kv_len,
params.max_blocks_per_sequence, mTokensPerBlock, params.cyclic_attention_window_size, params.num_tokens,
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
@ -869,8 +895,10 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
int const attention_window_size
= mDenseContextFMHA ? params.num_tokens : params.cyclic_attention_window_size;
mFMHARunner->setup(params.batch_size, params.input_seq_length, attention_window_size, params.num_tokens,
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_q_seqlens, params.context_buf, stream);
params.attention_output_orig_quant, isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
void const* fmha_input_tensor = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(params.attention_input);
mFMHARunner->run(fmha_input_tensor, cu_q_seqlens, params.context_buf, stream);
}
sync_check_cuda_error();
}
@ -1177,7 +1205,6 @@ int GPTAttentionPluginCommon::enqueueGeneration(
auto const quant_option = tc::QuantMode::fromDescription();
float const* qkv_scale_out = nullptr;
float const* attention_out_scale = nullptr;
int const* ia3_tasks = nullptr;
T const* ia3_key_weights = nullptr;
@ -1323,7 +1350,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
dispatch_params.ia3_key_weights = ia3_key_weights;
dispatch_params.ia3_value_weights = ia3_value_weights;
dispatch_params.qkv_scale_out = qkv_scale_out;
dispatch_params.attention_out_scale = attention_out_scale;
dispatch_params.fp8_context_fmha = mFP8ContextFMHA;
dispatch_params.attention_out_scale = params.attention_output_orig_quant;
dispatch_params.quant_option = quant_option;
dispatch_params.multi_block_mode = enable_multi_block;
dispatch_params.max_seq_len_tile = max_num_seq_len_tiles;
@ -1447,6 +1475,12 @@ int GPTAttentionPluginCommon::initialize() noexcept
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
}
// FP8 FMHA should be used with fp8 workflow together.
if (mFP8ContextFMHA)
{
data_type = DATA_TYPE_E4M3;
}
// Load kernels for contiguous cache and paged kv cache at the same time.
mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling));
// Set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads.
@ -1504,8 +1538,8 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
+ sizeof(unsigned int) // mKVCacheQuantMode
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance)
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mUseKVCache)
+ sizeof(mUnfuseQkvGemm) + sizeof(mIsMedusaEnabled);
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA)
+ sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + sizeof(mIsMedusaEnabled);
}
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
@ -1543,6 +1577,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
write(d, mPosShiftEnabled);
write(d, mDenseContextFMHA);
write(d, mPagedContextFMHA);
write(d, mFP8ContextFMHA);
write(d, mUseKVCache);
write(d, mIsMedusaEnabled);
assert(d == a + getCommonSerializationSize());
@ -1589,6 +1624,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
mPluginAttributes.emplace_back(PluginField("pos_shift_enabled", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("dense_context_fmha", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("use_paged_context_fmha", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("use_fp8_context_fmha", nullptr, PluginFieldType::kINT8, 0));
mPluginAttributes.emplace_back(PluginField("use_cache", nullptr, PluginFieldType::kINT32, 0));
mPluginAttributes.emplace_back(PluginField("is_medusa_enabled", nullptr, PluginFieldType::kINT8, 0));
mFC.nbFields = mPluginAttributes.size();

View File

@ -46,8 +46,8 @@ public:
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false,
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_cache = true,
bool is_medusa_enabled = false);
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false,
bool use_cache = true, bool is_medusa_enabled = false);
GPTAttentionPluginCommon(void const* data, size_t length);
@ -102,6 +102,7 @@ protected:
int32_t const* kv_seq_lengths;
float const* kv_scale_orig_quant;
float const* kv_scale_quant_orig;
float const* attention_output_orig_quant;
T const* alibi_slopes;
T* context_buf;
void* key_value_cache;
@ -137,6 +138,7 @@ protected:
int32_t const* context_lengths;
float const* kv_scale_orig_quant;
float const* kv_scale_quant_orig;
float const* attention_output_orig_quant;
T const* alibi_slopes;
T* context_buf;
void* key_value_cache;
@ -241,6 +243,7 @@ protected:
int mMaxDistance = 0;
bool mPosShiftEnabled = false;
bool mPagedContextFMHA = false;
bool mFP8ContextFMHA = false;
bool mDenseContextFMHA = false;
bool mIsMedusaEnabled = false;

View File

@ -47,13 +47,13 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_medusa_enabled)
: GPTAttentionPluginCommon(layer_idx, num_heads, num_kv_heads, head_size, unidirectional, q_scaling,
position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type,
rotary_embedding_scale, rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type,
multi_block_mode, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache,
tokens_per_block, type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
dense_context_fmha, use_paged_context_fmha, use_cache, is_medusa_enabled)
dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_medusa_enabled)
{
initEntryIdx();
}
@ -83,6 +83,7 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
case IdxEntry::PAST_KEY_VALUE: return useKVCache() && !mPagedKVCache;
case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
case IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE: return mFP8ContextFMHA && mKVCacheQuantMode.hasFp8Qdq();
case IdxEntry::ALIBI_SLOPES: return isALiBi();
case IdxEntry::RELATIVE_ATTENTION_BIAS: return isRelativePosition();
case IdxEntry::CROSS_QKV: return isCrossAttention();
@ -186,6 +187,10 @@ bool GPTAttentionPlugin::supportsFormatCombination(
// kv_scale for mType->int8/fp8 and int8/fp8->mType conversion
return inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (mFP8ContextFMHA && pos == getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE))
{
return inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (mPagedKVCache
&& (pos == getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS) || pos == getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_POINTERS)))
{
@ -213,6 +218,11 @@ bool GPTAttentionPlugin::supportsFormatCombination(
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (pos == nbInputs && mFP8ContextFMHA)
{
// Output tensor now supports fp8 data type.
return (inOut[pos].type == nvinfer1::DataType::kFP8) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else
{
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
@ -247,6 +257,7 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
/*context_lengths=*/nullptr,
/*kv_scale_orig_quant=*/nullptr,
/*kv_scale_quant_orig=*/nullptr,
/*attention_out_orig_quant=*/nullptr,
/*alibi_slopes=*/nullptr,
/*context_buf_=*/nullptr,
/*key_value_cache=*/nullptr,
@ -498,6 +509,14 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
kv_scale_quant_orig = reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)]);
}
float const* attention_output_orig_quant = nullptr;
if (mFP8ContextFMHA)
{
assert(inputDesc[getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT);
attention_output_orig_quant
= reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE)]);
}
int max_blocks_per_sequence = 0;
void* block_pointers = nullptr;
void* host_block_pointers = nullptr;
@ -584,9 +603,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
EnqueueContextParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, max_context_q_len,
max_context_kv_len, max_attention_window_size, cyclic_attention_window_size, sink_token_length,
context_q_lengths, sequence_kv_length, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_,
key_value_cache, block_pointers, host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence,
workspace};
context_q_lengths, sequence_kv_length, kv_scale_orig_quant, kv_scale_quant_orig,
attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache, block_pointers,
host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
if (isRelativePosition())
{
enqueue_params.relative_attention_bias
@ -627,9 +646,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
TLLM_CHECK_WITH_INFO(input_seq_length == num_medusa_tokens + 1, "The generation input length is not expected.");
EnqueueGenerationParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, input_seq_length,
sequence_kv_length, max_context_kv_len, beamWidth, context_q_lengths, kv_scale_orig_quant,
kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache, block_pointers, max_attention_window_size,
cyclic_attention_window_size, sink_token_length, num_requests, max_blocks_per_sequence, cache_indir,
workspace, max_context_kv_len_list};
kv_scale_quant_orig, attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache,
block_pointers, max_attention_window_size, cyclic_attention_window_size, sink_token_length, num_requests,
max_blocks_per_sequence, cache_indir, workspace, max_context_kv_len_list};
enqueue_params.host_context_lengths = host_context_lengths;
if (isRelativePosition())
{
@ -699,7 +718,8 @@ nvinfer1::DataType GPTAttentionPlugin::getOutputDataType(
TLLM_CHECK(index == 0 || (!mPagedKVCache && index == 1));
if (index == 0)
{
return inputTypes[getIdx(IdxEntry::QKV_TENSOR)];
return mFP8ContextFMHA && mEnableContextFMHA ? nvinfer1::DataType::kFP8
: inputTypes[getIdx(IdxEntry::QKV_TENSOR)];
}
else
{
@ -794,6 +814,7 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
static_cast<bool>(p.getScalar<int8_t>("pos_shift_enabled").value()),
static_cast<bool>(p.getScalar<int8_t>("dense_context_fmha").value()),
static_cast<bool>(p.getScalar<int8_t>("use_paged_context_fmha").value()),
static_cast<bool>(p.getScalar<int8_t>("use_fp8_context_fmha").value()),
static_cast<bool>(p.getScalar<int32_t>("use_cache").value()),
static_cast<bool>(p.getScalar<int8_t>("is_medusa_enabled").value()));
obj->setPluginNamespace(mNamespace.c_str());

View File

@ -80,8 +80,8 @@ public:
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false,
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_cache = true,
bool is_medusa_enabled = false);
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false,
bool use_cache = true, bool is_medusa_enabled = false);
GPTAttentionPlugin(void const* data, size_t length);
@ -164,6 +164,7 @@ private:
PAST_KEY_VALUE,
KV_CACHE_QUANTIZATION_SCALE,
KV_CACHE_DEQUANTIZATION_SCALE,
ATTENTION_OUTPUT_QUANTIZATION_SCALE,
ALIBI_SLOPES,
RELATIVE_ATTENTION_BIAS,
CROSS_QKV,

View File

@ -0,0 +1,21 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES
${PLUGIN_SOURCES}
PARENT_SCOPE)

Some files were not shown because too many files have changed in this diff Show More