mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1554)
This commit is contained in:
parent
06c0e9b1ec
commit
89ba1b1a67
2
.gitattributes
vendored
2
.gitattributes
vendored
@ -1,2 +1,4 @@
|
||||
*.a filter=lfs diff=lfs merge=lfs -text
|
||||
*.lib filter=lfs diff=lfs merge=lfs -text
|
||||
*.so filter=lfs diff=lfs merge=lfs -text
|
||||
*.dll filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -17,7 +17,7 @@ body:
|
||||
- Libraries
|
||||
- TensorRT-LLM branch or tag (e.g., main, v0.7.1)
|
||||
- TensorRT-LLM commit (if known)
|
||||
- Versions of TensorRT, AMMO, CUDA, cuBLAS, etc. used
|
||||
- Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used
|
||||
- Container used (if running TensorRT-LLM in a container)
|
||||
- NVIDIA driver version
|
||||
- OS (Ubuntu 22.04, CentOS 7, Windows 10)
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,7 +6,6 @@ __pycache__/
|
||||
*.nsys-rep
|
||||
.VSCodeCounter
|
||||
build*/
|
||||
*.so
|
||||
*.egg-info/
|
||||
.coverage
|
||||
*.csv
|
||||
@ -34,6 +33,7 @@ tensorrt_llm/bindings.pyi
|
||||
tensorrt_llm/bindings/*.pyi
|
||||
*docs/cpp_docs*
|
||||
*docs/source/_cpp_gen*
|
||||
*.swp
|
||||
|
||||
# Testing
|
||||
.coverage.*
|
||||
|
||||
@ -6,9 +6,9 @@ TensorRT-LLM
|
||||
|
||||
[](https://nvidia.github.io/TensorRT-LLM/)
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./setup.py)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./setup.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
|
||||
@ -170,8 +170,6 @@ Given a `static_emulated_batch_size` of `n` the server will wait for `n` request
|
||||
```
|
||||
python prepare_dataset.py \
|
||||
--output tokens-fixed-lengths.json \
|
||||
--request-rate -1 \
|
||||
--time-delay-dist constant \
|
||||
--tokenizer <path/to/tokenizer> \
|
||||
token-norm-dist \
|
||||
--num-requests 128 \
|
||||
@ -184,6 +182,7 @@ Take GPT-350M as an example for single GPU with static batching
|
||||
./benchmarks/gptManagerBenchmark \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||
--type IFB \
|
||||
--request-rate -1 \
|
||||
--static_emulated_batch_size 32 \
|
||||
--static_emulated_timeout 100 \
|
||||
--dataset ../../benchmarks/cpp/tokens-fixed-lengths.json
|
||||
@ -212,6 +211,7 @@ PP=1
|
||||
MAX_LEN=1024
|
||||
MAX_BATCH=32
|
||||
MAX_LORA_RANK=32
|
||||
NUM_LORA_MODS=7
|
||||
|
||||
SOURCE_LORA=chinese-llama-2-lora-13b
|
||||
CPP_LORA=chinese-llama-2-lora-13b-cpp
|
||||
@ -241,10 +241,9 @@ NUM_LORAS=(8 16 24 32 64 128 256)
|
||||
NUM_REQUESTS=1024
|
||||
|
||||
# Convert LoRA to cpp format
|
||||
python examples/gpt/nemo_lora_convert.py \
|
||||
python examples/hf_lora_convert.py \
|
||||
-i $SOURCE_LORA \
|
||||
--storage-type $DTYPE \
|
||||
--write-cpp-runtime-tensors \
|
||||
-o $CPP_LORA
|
||||
|
||||
# Prepare datasets
|
||||
|
||||
@ -151,6 +151,7 @@ struct BenchmarkParams
|
||||
bool enableExpDelays{false};
|
||||
std::optional<float> requestRate{std::nullopt};
|
||||
int randomSeed = 430;
|
||||
std::optional<int> maxAttentionWindow{std::nullopt};
|
||||
|
||||
// lora / peft params
|
||||
std::optional<std::string> loraDir{std::nullopt};
|
||||
@ -746,8 +747,8 @@ public:
|
||||
|
||||
texec::SchedulerConfig schedulerConfig(batch_scheduler::batchManagerToExecSchedPolicy(schedulerPolicy));
|
||||
texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache,
|
||||
std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize,
|
||||
benchmarkParams.kvOnboardBlocks);
|
||||
benchmarkParams.maxAttentionWindow, std::nullopt, benchmarkParams.freeGpuMemoryFraction,
|
||||
benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
|
||||
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
|
||||
std::nullopt, benchmarkParams.loraHostCacheSize);
|
||||
texec::ExecutorConfig executorConfig(
|
||||
@ -909,6 +910,16 @@ public:
|
||||
mWorkItemsQueue.clear();
|
||||
}
|
||||
|
||||
std::string getLayerProfileInfo()
|
||||
{
|
||||
return mBatchManager->getLayerProfileInfo();
|
||||
}
|
||||
|
||||
void setLayerProfiler()
|
||||
{
|
||||
return mBatchManager->setLayerProfiler();
|
||||
}
|
||||
|
||||
void enqueue(std::shared_ptr<InferenceRequest> const& request)
|
||||
{
|
||||
TLLM_CHECK(request != nullptr);
|
||||
@ -1267,7 +1278,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
|
||||
bool logIterationData, bool excludeInputInOutput, std::string const& responsesJsonFile,
|
||||
std::optional<SizeType> const maxPromptLen)
|
||||
std::optional<SizeType> const maxPromptLen, bool dumpProfile)
|
||||
{
|
||||
TrtGptModelOptionalParams optionalParams;
|
||||
|
||||
@ -1279,6 +1290,10 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
{
|
||||
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
|
||||
}
|
||||
if (benchmarkParams.maxAttentionWindow)
|
||||
{
|
||||
optionalParams.kvCacheConfig.maxAttentionWindow = benchmarkParams.maxAttentionWindow;
|
||||
}
|
||||
optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse;
|
||||
optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext;
|
||||
optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap;
|
||||
@ -1391,6 +1406,23 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
recorder->report();
|
||||
recorder->writeOpMetricsToCsv();
|
||||
recorder->dumpResponseSeqs();
|
||||
if (dumpProfile)
|
||||
{
|
||||
// Do per-layer profiling after normal benchmarking to avoid introducing perf overhead.
|
||||
gptServer->resetBatchDeadline();
|
||||
gptServer->setLayerProfiler();
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
|
||||
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
printf("[BENCHMARK] Per layer performance profile\n%s\n", gptServer->getLayerProfileInfo().c_str());
|
||||
}
|
||||
}
|
||||
// Send terminateReqId to terminate servers on all ranks
|
||||
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
|
||||
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
|
||||
@ -1554,6 +1586,7 @@ int main(int argc, char* argv[])
|
||||
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<TokenIdType>()->default_value("-1"));
|
||||
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<TokenIdType>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
options.add_options()("max_attention_window", "Max KV cache length per sequence", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
"random_seed", "integer random seed for exponential time delays.", cxxopts::value<int>()->default_value("420"));
|
||||
options.add_options()(
|
||||
@ -1614,6 +1647,8 @@ int main(int argc, char* argv[])
|
||||
options.add_options()(
|
||||
"max_prompt_len", "Truncate all prompts from dataset to the length specified.", cxxopts::value<SizeType>());
|
||||
|
||||
options.add_options()("dump_profile", "Print profile information per layer.", cxxopts::value<bool>());
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
if (result.count("help"))
|
||||
@ -1674,6 +1709,12 @@ int main(int argc, char* argv[])
|
||||
benchmarkParams.maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as<int>();
|
||||
}
|
||||
|
||||
// Argument: Max KV cache length
|
||||
if (result.count("max_attention_window"))
|
||||
{
|
||||
benchmarkParams.maxAttentionWindow = result["max_attention_window"].as<int>();
|
||||
}
|
||||
|
||||
if (result.count("random_seed"))
|
||||
{
|
||||
benchmarkParams.randomSeed = result["random_seed"].as<int>();
|
||||
@ -1811,6 +1852,9 @@ int main(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: dump profile
|
||||
bool dumpProfile = result["dump_profile"].as<bool>();
|
||||
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
if (api == "gptManager")
|
||||
@ -1821,7 +1865,7 @@ int main(int argc, char* argv[])
|
||||
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
|
||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
|
||||
logIterationData, result["exclude_input_in_output_seq"].as<bool>(),
|
||||
result["responses_json_file"].as<std::string>(), maxPromptLen);
|
||||
result["responses_json_file"].as<std::string>(), maxPromptLen, dumpProfile);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
|
||||
@ -68,7 +68,7 @@ size_t monitorMemory(std::atomic_bool& done)
|
||||
void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int> const& batchSizes, int beamWidth,
|
||||
std::vector<std::vector<int>> const& inOutLen, std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp,
|
||||
int numRuns, int duration, GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits,
|
||||
bool disableForceMaxTokens, bool dumpLayerInfo)
|
||||
bool disableForceMaxTokens, bool dumpLayerInfo, bool dumpProfile)
|
||||
{
|
||||
std::filesystem::path jsonFileName = dataPath / "config.json";
|
||||
auto const json = GptJsonConfig::parse(jsonFileName);
|
||||
@ -298,6 +298,46 @@ void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int>
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
// Do per-layer profiling after normal benchmarking to avoid introducing perf overhead.
|
||||
if (dumpProfile)
|
||||
{
|
||||
session.setLayerProfiler();
|
||||
iterIdx = 0;
|
||||
|
||||
while (iterIdx < numRuns)
|
||||
{
|
||||
auto const start = std::chrono::steady_clock::now();
|
||||
SizeType numSteps = 0;
|
||||
generationOutput.onTokenGenerated
|
||||
= [&numSteps, maxNewTokens](GenerationOutput::TensorPtr const& outputIds, SizeType step,
|
||||
bool finished) { ++numSteps; };
|
||||
session.generate(generationOutput, generationInput, samplingConfig, generationProfiler);
|
||||
bufferManager.getStream().synchronize();
|
||||
auto const end = std::chrono::steady_clock::now();
|
||||
|
||||
iterIdx += 1;
|
||||
float latency = std::chrono::duration<float, std::milli>(end - start).count();
|
||||
curDuration += latency;
|
||||
latencies.emplace_back(latency);
|
||||
generationTimes.emplace_back(generationProfiler->getElapsedTimeMs());
|
||||
|
||||
bool durationLimitReached{curDuration / 1000 >= duration};
|
||||
if (worldConfig.getSize() > 1)
|
||||
{
|
||||
bool result{false};
|
||||
comm.allreduce(&durationLimitReached, &result, 1, tmpi::MpiType::kBOOL, tmpi::MpiOp::LOR);
|
||||
durationLimitReached = result;
|
||||
}
|
||||
if (durationLimitReached)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
printf("%s\n", session.getLayerProfileInfo().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::runtime_error& e)
|
||||
{
|
||||
@ -377,6 +417,7 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("print_all_logits", "Print all context and generation logits.");
|
||||
options.add_options()("disable_force_max_tokens", "Disable force the engine generating new max_tokens.");
|
||||
options.add_options()("dump_layer_info", "Print layer information of the engine to console.");
|
||||
options.add_options()("dump_profile", "Print profile information per layer.");
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
@ -487,6 +528,7 @@ int main(int argc, char* argv[])
|
||||
auto printAllLogits = result.count("print_all_logits") > 0;
|
||||
auto disableForceMaxTokens = result.count("disable_force_max_tokens") > 0;
|
||||
auto dumpLayerInfo = result.count("dump_layer_info") > 0;
|
||||
auto dumpProfile = result.count("dump_profile") > 0;
|
||||
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
@ -494,7 +536,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
benchmarkGptSession(result["engine_dir"].as<std::string>(), batchSizes, beamWidth, inOutLen, logger,
|
||||
result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(), sessionConfig,
|
||||
enableCudaGraph, printAllLogits, disableForceMaxTokens, dumpLayerInfo);
|
||||
enableCudaGraph, printAllLogits, disableForceMaxTokens, dumpLayerInfo, dumpProfile);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
|
||||
@ -48,3 +48,11 @@ mpirun -n 8 python benchmark.py \
|
||||
--batch_size "1;8;64" \
|
||||
--input_output_len "60,20;128,20"
|
||||
```
|
||||
|
||||
Note: Building multi-GPU engines in parallel could be a heavy workload for the CPU system. Tuning `mpirun --map-by <XXX>` option on your system may achieve significant boost in build time, for example:
|
||||
```
|
||||
mpirun --map-by socket -n 8 python build.py \
|
||||
--model gpt_175b \
|
||||
--mode ootb \
|
||||
--quantization fp8
|
||||
```
|
||||
|
||||
@ -67,6 +67,8 @@ class BuildConfig:
|
||||
layer_types: List[str] = field(default_factory=list)
|
||||
rnn_hidden_size: int = 0
|
||||
logits_soft_cap: float = 0.0
|
||||
opt_batch_size: Optional[int] = None
|
||||
opt_num_tokens: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -268,6 +268,25 @@ def parse_arguments():
|
||||
help=
|
||||
"Print layer information of the engine to console (default = disabled)")
|
||||
|
||||
parser.add_argument(
|
||||
'--opt_batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"If opt_batch_size option is specified, it will override the opt batch size."
|
||||
"This flag only takes effect when `--mode=ootb` is added. For other modes, please use --opt_num_tokens to replace it."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--opt_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help="It equals to max_batch_size*max_beam_width by default, set this "
|
||||
"value as close as possible to the actual number of tokens on your workload. "
|
||||
"Note that this argument might be removed in the future."
|
||||
"This flag only takes effect when `--mode` is not `ootb`. For ootb mode, please use --opt_batch_size to replace it."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -334,9 +353,6 @@ def main(args):
|
||||
if args.build_only:
|
||||
return
|
||||
|
||||
if args.dump_profile and benchmark_profiler is not None:
|
||||
benchmark_profiler.set_recording_perf_profile(True)
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
benchmarker.print_report_header(args.csv,
|
||||
@ -432,6 +448,39 @@ def main(args):
|
||||
csv=args.csv,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
|
||||
# Rerun for dumping profile per layer.
|
||||
if args.dump_profile and benchmark_profiler is not None:
|
||||
benchmark_profiler.set_recording_perf_profile(True)
|
||||
logger.info(f'Dump profile information per layer')
|
||||
iter_idx = 0
|
||||
try:
|
||||
# Warm up
|
||||
for _ in range(args.warm_up):
|
||||
benchmarker.run(inputs, config)
|
||||
if benchmark_profiler is not None:
|
||||
benchmark_profiler.clean()
|
||||
benchmark_profiler.start()
|
||||
cur_duration = 0
|
||||
start_time = time()
|
||||
while iter_idx < args.num_runs or cur_duration < args.duration:
|
||||
start.record()
|
||||
benchmarker.run(inputs,
|
||||
config,
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
latencies.append(start.elapsed_time(end))
|
||||
iter_idx += 1
|
||||
cur_duration = round(time() - start_time, 3)
|
||||
benchmarker.report_profiler(
|
||||
benchmark_profiler=benchmark_profiler)
|
||||
except Exception as e:
|
||||
logger.error("Found exception during benchmarking",
|
||||
e.with_traceback())
|
||||
if not disable_mem_monitor:
|
||||
memory_monitor.kill()
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mp.set_start_method('spawn')
|
||||
|
||||
@ -168,6 +168,24 @@ def parse_arguments():
|
||||
help=
|
||||
"The number of gpus to be used for inference, only used when --serial_build is specified"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opt_batch_size',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"If opt_batch_size option is specified, it will override the opt batch size."
|
||||
"This flag only takes effect when `--mode=ootb` is added. For other modes, please use --opt_num_tokens to replace it."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--opt_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help="It equals to max_batch_size*max_beam_width by default, set this "
|
||||
"value as close as possible to the actual number of tokens on your workload. "
|
||||
"Note that this argument might be removed in the future."
|
||||
"This flag only takes effect when `--mode` is not `ootb`. For ootb mode, please use --opt_batch_size to replace it."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -229,6 +247,21 @@ def build_gpt(args):
|
||||
max_beam_width = build_config['max_beam_width'] \
|
||||
if args.max_beam_width is None else args.max_beam_width
|
||||
|
||||
opt_batch_size = build_config[
|
||||
'opt_batch_size'] if args.opt_batch_size is None else args.opt_batch_size
|
||||
|
||||
opt_num_tokens = build_config[
|
||||
'opt_num_tokens'] if args.opt_num_tokens is None else args.opt_num_tokens
|
||||
|
||||
if args.mode != "ootb" and opt_batch_size is not None:
|
||||
raise Exception(
|
||||
f'--opt_batch_size only used when mode is ootb. Please using --opt_num_tokens instead it.'
|
||||
)
|
||||
if args.mode == "ootb" and opt_num_tokens is not None:
|
||||
raise Exception(
|
||||
f'--opt_num_tokens does not support ootb mode. Please using --opt_batch_size instead it.'
|
||||
)
|
||||
|
||||
quant_config = get_quant_config(args.quantization)
|
||||
quant_algo = quant_config.quant_algo
|
||||
kv_cache_quant_algo = quant_config.kv_cache_quant_algo
|
||||
@ -873,9 +906,11 @@ def build_gpt(args):
|
||||
# Inflight batching
|
||||
if args.mode == 'plugin-ifb':
|
||||
network.plugin_config.enable_paged_kv_cache()
|
||||
network.plugin_config.enable_paged_state()
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
network.plugin_config.set_gpt_attention_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
|
||||
if world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(
|
||||
@ -895,7 +930,9 @@ def build_gpt(args):
|
||||
max_input_len=max_input_len,
|
||||
max_seq_len=max_input_len + max_output_len,
|
||||
use_cache=True,
|
||||
max_beam_width=max_beam_width)
|
||||
max_beam_width=max_beam_width,
|
||||
opt_batch_size=opt_batch_size,
|
||||
opt_num_tokens=opt_num_tokens)
|
||||
|
||||
tensorrt_llm_model(**inputs)
|
||||
|
||||
|
||||
163
benchmarks/python/check_accuracy_mlperf.py
Normal file
163
benchmarks/python/check_accuracy_mlperf.py
Normal file
@ -0,0 +1,163 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import evaluate
|
||||
import nltk
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from transformers import AutoTokenizer, LlamaTokenizerFast
|
||||
|
||||
nltk.download("punkt", quiet=False)
|
||||
import argparse
|
||||
|
||||
|
||||
class Model(Enum):
|
||||
Llama_v2_70B = 1
|
||||
GPT_J = 2
|
||||
|
||||
|
||||
ACCURACY_TARGETS = {
|
||||
Model.Llama_v2_70B: {
|
||||
"rouge1": 44.4312 * 0.999,
|
||||
"rouge2": 22.0352 * 0.999,
|
||||
"rougeL": 28.6162 * 0.999,
|
||||
"tokens_per_sample": 294.45 * 0.9
|
||||
},
|
||||
Model.GPT_J: {
|
||||
"rouge1": 42.9435135,
|
||||
"rouge2": 20.1033765,
|
||||
"rougeL": 29.9581119,
|
||||
# "tokens_per_sample": ??
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_reference_df(processed_dataset_file):
|
||||
data = pd.read_pickle(processed_dataset_file)
|
||||
return data["output"].tolist()
|
||||
|
||||
|
||||
def get_reference_json(cnn_dailymail_valset):
|
||||
# Load from CNN dailymail
|
||||
with open(cnn_dailymail_valset, 'r') as fh:
|
||||
list_data_dict = json.load(fh)
|
||||
|
||||
targets = [f"{example['output']}" for example in list_data_dict]
|
||||
|
||||
print(f"Loaded {len(targets)} samples from {cnn_dailymail_valset}")
|
||||
return targets
|
||||
|
||||
|
||||
def get_responses_json(response_file):
|
||||
f = open(response_file)
|
||||
responses = json.load(f)
|
||||
ordered_responses = sorted(responses, key=lambda x: int(x['response_id']))
|
||||
return ordered_responses
|
||||
|
||||
|
||||
def postprocess_text(preds, targets):
|
||||
# Post-process output texts for ROUGE evaluation
|
||||
preds = [pred.strip() for pred in preds]
|
||||
targets = [target.strip() for target in targets]
|
||||
|
||||
# rougeLSum expects newline after each sentence
|
||||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
||||
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
|
||||
|
||||
return preds, targets
|
||||
|
||||
|
||||
def strip_eos(pred_toks, eos_id):
|
||||
while len(pred_toks) > 0 and pred_toks[-1] == eos_id:
|
||||
pred_toks.pop()
|
||||
if len(pred_toks) == 0:
|
||||
raise RuntimeError("Empty output sequence detected with EOS")
|
||||
return pred_toks
|
||||
|
||||
|
||||
def calculate_toks_per_sample(preds, eos_id):
|
||||
preds = [strip_eos(pred, eos_id) for pred in preds]
|
||||
avg_len = sum(len(pred) for pred in preds)
|
||||
num_samples = len(preds)
|
||||
return avg_len / num_samples
|
||||
|
||||
|
||||
def calculate_rouge_score(preds, targets):
|
||||
print("Calculating ROUGE scores...")
|
||||
metric = evaluate.load("rouge")
|
||||
preds, targets = postprocess_text(preds, targets[0:len(preds)])
|
||||
result = metric.compute(predictions=preds,
|
||||
references=targets,
|
||||
use_stemmer=True,
|
||||
use_aggregator=False)
|
||||
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
|
||||
prediction_lens = [len(pred) for pred in preds]
|
||||
result["gen_len"] = np.sum(prediction_lens)
|
||||
result["gen_num"] = len(preds)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help=
|
||||
"Path to the reference dataset against which the responses are evaluated for accuracy. MLPerf uses open-orca (pkl) and cnn-dailymail (np) for Llama2-70B and GPT-J respectively."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--responses",
|
||||
type=str,
|
||||
help="Path to the json file holding the responses from our benchmark run"
|
||||
)
|
||||
parser.add_argument("--base_model",
|
||||
type=str,
|
||||
help="Location of the model used (to create tokenizer)")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
if args.dataset.lower().endswith(".pkl"):
|
||||
target_texts = get_reference_df(args.dataset)
|
||||
model = Model.Llama_v2_70B
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(args.base_model)
|
||||
relaxing_factor = 1.0
|
||||
elif args.dataset.lower().endswith(".json"):
|
||||
target_texts = get_reference_json(args.dataset)
|
||||
model = Model.GPT_J
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model,
|
||||
model_max_length=2047,
|
||||
padding_side="left",
|
||||
use_fast=False)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
relaxing_factor = 0.93
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Dataset expected to be pkl (open-orca) or json (cnn-dailymail)")
|
||||
|
||||
pred_out = get_responses_json(args.responses)
|
||||
pred_toks = [x['response_tokens'] for x in pred_out]
|
||||
|
||||
tps_score = calculate_toks_per_sample(pred_toks, tokenizer.eos_token)
|
||||
|
||||
pred_texts = tokenizer.batch_decode(pred_toks, skip_special_tokens=True)
|
||||
achieved_scores = calculate_rouge_score(pred_texts, target_texts)
|
||||
|
||||
achieved_scores['tokens_per_sample'] = tps_score
|
||||
targets = ACCURACY_TARGETS[model]
|
||||
|
||||
print("Achieved rouge scores: ", achieved_scores)
|
||||
print("Tokens per sample: ", tps_score)
|
||||
print("Targets: ", targets)
|
||||
|
||||
for k, _ in targets.items():
|
||||
assert targets[k] * relaxing_factor <= achieved_scores[k]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -381,6 +381,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
layer_idx, trt.LayerInformationFormat.ONELINE)
|
||||
print(layer_info)
|
||||
|
||||
def report_profiler(self, benchmark_profiler=None):
|
||||
if benchmark_profiler is not None and benchmark_profiler.is_recording_perf_profile:
|
||||
perf_profile_data = self.decoder.profiler.results
|
||||
if not perf_profile_data:
|
||||
@ -418,8 +419,9 @@ class GPTBenchmark(BaseBenchmark):
|
||||
def dump_kernel_profile_table(name: str, profile_data: list,
|
||||
iter_cnt: int):
|
||||
table = pd.DataFrame(
|
||||
[[k, '{:0.3f}'.format(v)] for k, v in profile_data.items()],
|
||||
columns=['{} Phase LayerName'.format(name), 'times (ms)'])
|
||||
[['{:0.3f}'.format(v), k]
|
||||
for k, v in profile_data.items() if v != 0.0],
|
||||
columns=['times (ms)', '{} Phase LayerName'.format(name)])
|
||||
|
||||
def ljust(s):
|
||||
s = s.astype(str).str.strip()
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from multiprocessing import Event, Process, Queue
|
||||
from queue import Empty
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
|
||||
@ -52,11 +53,16 @@ class MemoryMonitor:
|
||||
def stop(self):
|
||||
self.signal_event.set()
|
||||
logger.debug("Sent signal to stop memory monitor subprocess.")
|
||||
peak_mem_use = self.peak_mem_queue.get(timeout=20)
|
||||
|
||||
self._peak_host_memory = max(self._peak_host_memory, peak_mem_use[0])
|
||||
self._peak_device_memory = max(self._peak_device_memory,
|
||||
peak_mem_use[1])
|
||||
try:
|
||||
peak_mem_use = self.peak_mem_queue.get(timeout=20)
|
||||
except Empty:
|
||||
logger.warning("peak_mem_queue was empty.")
|
||||
else:
|
||||
self._peak_host_memory = max(self._peak_host_memory,
|
||||
peak_mem_use[0])
|
||||
self._peak_device_memory = max(self._peak_device_memory,
|
||||
peak_mem_use[1])
|
||||
|
||||
self.mem_monitor_process.join(timeout=20)
|
||||
self.mem_monitor_process = None
|
||||
|
||||
@ -37,6 +37,7 @@ option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
||||
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
|
||||
option(FAST_MATH "Compiling in fast math mode" OFF)
|
||||
option(INDEX_RANGE_CHECK "Compiling with index range checks" OFF)
|
||||
option(USE_SHARED_NVRTC "Use shared NVRTC library instead of static" OFF)
|
||||
|
||||
if(NVTX_DISABLE)
|
||||
add_compile_definitions("NVTX_DISABLE")
|
||||
@ -75,6 +76,23 @@ else()
|
||||
message(STATUS "Importing executor")
|
||||
endif()
|
||||
|
||||
if(EXISTS
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/CMakeLists.txt"
|
||||
)
|
||||
set(BUILD_NVRTC_WRAPPER_DEFAULT ON)
|
||||
else()
|
||||
set(BUILD_NVRTC_WRAPPER_DEFAULT OFF)
|
||||
endif()
|
||||
|
||||
option(BUILD_NVRTC_WRAPPER "Build nvrtc wrapper from source"
|
||||
${BUILD_NVRTC_WRAPPER_DEFAULT})
|
||||
|
||||
if(BUILD_NVRTC_WRAPPER)
|
||||
message(STATUS "Building nvrtc wrapper")
|
||||
else()
|
||||
message(STATUS "Importing nvrtc wrapper")
|
||||
endif()
|
||||
|
||||
if(BUILD_PYT)
|
||||
message(STATUS "Building PyTorch")
|
||||
else()
|
||||
@ -172,6 +190,41 @@ message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||
# pick up on the includes
|
||||
set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0)
|
||||
|
||||
if(USE_SHARED_NVRTC)
|
||||
if(WIN32)
|
||||
message(FATAL_ERROR "Cannot use NVRTC shared library on Windows.")
|
||||
else()
|
||||
find_library(
|
||||
NVRTC_LIB nvrtc
|
||||
HINTS ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
find_library(
|
||||
NVRTC_BUILTINS_LIB nvrtc-builtins
|
||||
HINTS ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
endif()
|
||||
else()
|
||||
if(WIN32)
|
||||
find_library(
|
||||
NVRTC_LIB nvrtc
|
||||
HINTS ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
else()
|
||||
find_library(
|
||||
NVRTC_LIB nvrtc_static
|
||||
HINTS ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
find_library(
|
||||
NVRTC_BUILTINS_LIB nvrtc-builtins_static
|
||||
HINTS ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
find_library(
|
||||
NVPTXCOMPILER_LIB nvptxcompiler_static
|
||||
HINTS ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CUBLAS_LIB CUDA::cublas)
|
||||
set(CUBLASLT_LIB CUDA::cublasLt)
|
||||
set(CUDA_DRV_LIB CUDA::cuda_driver)
|
||||
@ -204,7 +257,15 @@ include_directories(
|
||||
set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})
|
||||
set_ifndef(TRT_INCLUDE_DIR /usr/include/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu)
|
||||
set(TRT_LIB nvinfer)
|
||||
find_library_create_target(${TRT_LIB} nvinfer SHARED ${TRT_LIB_DIR})
|
||||
|
||||
# On Windows major version is appended to nvinfer libs.
|
||||
if(WIN32)
|
||||
set(TRT_LIB_NAME nvinfer_10)
|
||||
else()
|
||||
set(TRT_LIB_NAME nvinfer)
|
||||
endif()
|
||||
|
||||
find_library_create_target(${TRT_LIB} ${TRT_LIB_NAME} SHARED ${TRT_LIB_DIR})
|
||||
|
||||
if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "11")
|
||||
add_definitions("-DENABLE_BF16")
|
||||
|
||||
@ -78,6 +78,10 @@ public:
|
||||
|
||||
virtual ~GptManager();
|
||||
|
||||
void setLayerProfiler();
|
||||
|
||||
[[nodiscard]] std::string getLayerProfileInfo() const;
|
||||
|
||||
protected:
|
||||
/* Synchronizes the decoder */
|
||||
virtual BatchManagerErrorCode_t forwardSync();
|
||||
@ -91,6 +95,7 @@ private:
|
||||
[[nodiscard]] SizeType getMaxInputLen() const;
|
||||
[[nodiscard]] SizeType getMaxSequenceLen() const;
|
||||
[[nodiscard]] SizeType getMaxNumSequences() const;
|
||||
[[nodiscard]] SizeType getMaxDraftLen() const;
|
||||
|
||||
void validateLlmRequest(
|
||||
LlmRequest& newReq, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) const;
|
||||
|
||||
@ -115,12 +115,22 @@ public:
|
||||
, mPadId(req.getPadId())
|
||||
, mOrigPromptLen(mPromptLen)
|
||||
, mMaxSentTokenPos(mPromptLen - 1)
|
||||
, mEmbeddingBias(std::nullopt)
|
||||
, mBadWordsList(std::nullopt)
|
||||
, mStopWordsList(std::nullopt)
|
||||
, mPromptEmbeddingTable(std::nullopt)
|
||||
, mPromptVocabSize(std::nullopt)
|
||||
, mLoraTaskId(std::nullopt)
|
||||
, mLoraWeights(std::nullopt)
|
||||
, mLoraConfig(std::nullopt)
|
||||
, mReturnLogProbs(req.getOutputConfig().returnLogProbs)
|
||||
, mContextChunkSize(std::nullopt)
|
||||
, mContextCurrentPosition(0)
|
||||
, mLogProbs(mSamplingConfig.beamWidth)
|
||||
, mCumLogProbs(mSamplingConfig.beamWidth)
|
||||
, mDraftTokens(std::make_shared<VecTokens>())
|
||||
, mDraftLogits(std::nullopt)
|
||||
, mNumTokensPerIteration(1)
|
||||
, mReturnContextLogits(req.getOutputConfig().returnContextLogits)
|
||||
, mReturnGenerationLogits(req.getOutputConfig().returnGenerationLogits)
|
||||
, mExcludeInputFromOutput(req.getOutputConfig().excludeInputFromOutput)
|
||||
@ -183,20 +193,42 @@ public:
|
||||
initialize(req.getInputTokenIds());
|
||||
}
|
||||
|
||||
void validate(SizeType maxInputLen, SizeType maxSequenceLen)
|
||||
void validate(SizeType maxInputLen, SizeType maxSequenceLen, SizeType maxDraftLen)
|
||||
{
|
||||
if (mPromptLen > maxInputLen)
|
||||
{
|
||||
TLLM_THROW("Prompt length (%d) exceeds maximum input length (%d).", mPromptLen, maxInputLen);
|
||||
}
|
||||
|
||||
if (mPromptLen + mMaxNewTokens > maxSequenceLen)
|
||||
// Maximum number of draft tokens per request we pass to the engine for single runtime iteration.
|
||||
// It depends on the speculative decoding mode.
|
||||
auto draftLenPerEngineStep = maxDraftLen;
|
||||
auto const& draftTokens = getDraftTokens();
|
||||
if (draftTokens && !draftTokens->empty())
|
||||
{
|
||||
auto const maxNewTokens = maxSequenceLen - mPromptLen;
|
||||
auto const inputDraftTokensLen = static_cast<SizeType>(draftTokens->size());
|
||||
if (inputDraftTokensLen > maxDraftLen)
|
||||
{
|
||||
TLLM_THROW("Draft tokens length (%d) exceeds maximum draft tokens length (%d).", inputDraftTokensLen,
|
||||
maxDraftLen);
|
||||
}
|
||||
draftLenPerEngineStep = inputDraftTokensLen;
|
||||
|
||||
if (mPromptLen + draftLenPerEngineStep > maxInputLen)
|
||||
{
|
||||
TLLM_THROW("Prompt length + number of draft tokens (%d + %d) exceeds maximum input length (%d).",
|
||||
mPromptLen, draftLenPerEngineStep, maxInputLen);
|
||||
}
|
||||
}
|
||||
|
||||
if (mPromptLen + mMaxNewTokens + draftLenPerEngineStep > maxSequenceLen)
|
||||
{
|
||||
auto const maxNewTokens = maxSequenceLen - mPromptLen - draftLenPerEngineStep;
|
||||
TLLM_LOG_WARNING(
|
||||
"Prompt length + number of requested output tokens (%d + %d) exceeds maximum sequence length (%d). "
|
||||
"Prompt length + number of requested output tokens + draft tokens per step (%d + %d + %d) exceeds "
|
||||
"maximum sequence length (%d). "
|
||||
"Number of requested output tokens is changed to (%d).",
|
||||
mPromptLen, mMaxNewTokens, maxSequenceLen, maxNewTokens);
|
||||
mPromptLen, mMaxNewTokens, draftLenPerEngineStep, maxSequenceLen, maxNewTokens);
|
||||
mMaxNewTokens = maxNewTokens;
|
||||
}
|
||||
|
||||
@ -537,9 +569,16 @@ public:
|
||||
mReturnGenerationLogits = returnGenerationLogits;
|
||||
}
|
||||
|
||||
// Return all generation logits for model w/o draft token
|
||||
[[nodiscard]] bool getReturnGenerationLogits() const
|
||||
{
|
||||
return mReturnGenerationLogits;
|
||||
return mReturnGenerationLogits && (getNumDraftTokens() == 0);
|
||||
}
|
||||
|
||||
// Return accepted tokens logits for target model
|
||||
[[nodiscard]] bool getReturnTargetModelAcceptedLogits() const
|
||||
{
|
||||
return mReturnGenerationLogits && (getNumDraftTokens() > 0);
|
||||
}
|
||||
|
||||
[[nodiscard]] TensorPtr const& getContextLogitsHost() const
|
||||
@ -701,7 +740,8 @@ public:
|
||||
auto maxNbTokens = getMaxBeamNumTokens();
|
||||
// FIXME(nkorobov): For streaming we do not allow beam search and
|
||||
// streaming index calculation here applies only for sampling
|
||||
int nbTokensOut = mIsStreaming ? 1 : maxNbTokens;
|
||||
// getNumTokensPerIteration takes accepted draft tokens into account
|
||||
int nbTokensOut = mIsStreaming ? std::max(getNumTokensPerIteration(), 1) : maxNbTokens;
|
||||
if (mExcludeInputFromOutput && !mIsStreaming)
|
||||
{
|
||||
nbTokensOut -= getOrigPromptLen();
|
||||
@ -722,6 +762,11 @@ public:
|
||||
{
|
||||
auto tokens = getTokens(beam);
|
||||
auto nbTokens = mIsStreaming ? (tokenPos - getMaxSentTokenPos()) : tokens.size();
|
||||
|
||||
// Take accepted draft tokens into account when streaming
|
||||
auto const numAcceptedTokens = std::max(0, getNumTokensPerIteration() - 1);
|
||||
nbTokens += mIsStreaming ? numAcceptedTokens : 0;
|
||||
|
||||
if (mExcludeInputFromOutput && !mIsStreaming)
|
||||
{
|
||||
nbTokens -= getOrigPromptLen();
|
||||
@ -731,6 +776,8 @@ public:
|
||||
result.outputTokenIds.at(beam).assign(
|
||||
tokens.data() + tokenPos, tokens.data() + tokenPos + nbTokens);
|
||||
}
|
||||
// Correct next token position by accepted draft tokens
|
||||
tokenPos += numAcceptedTokens;
|
||||
}
|
||||
|
||||
if (returnLogProbs())
|
||||
|
||||
@ -19,7 +19,9 @@
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/workerPool.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
@ -61,7 +61,7 @@ namespace utils
|
||||
std::vector<uint8_t> loadEngine(std::string const& enginePath);
|
||||
}
|
||||
|
||||
class IpcMemory;
|
||||
class AllReduceBuffers;
|
||||
class IStatefulGptDecoder;
|
||||
class NcclCommunicator;
|
||||
class RuntimeBuffers;
|
||||
@ -229,6 +229,12 @@ public:
|
||||
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig,
|
||||
std::shared_ptr<GenerationProfiler> const generationProfiler = nullptr);
|
||||
|
||||
//! @brief Set LayerProfiler to collect performance per layer.
|
||||
void setLayerProfiler();
|
||||
|
||||
//! @brief Print profile information per layer.
|
||||
[[nodiscard]] std::string getLayerProfileInfo() const;
|
||||
|
||||
private:
|
||||
[[nodiscard]] bool useCudaGraphs()
|
||||
{
|
||||
@ -349,9 +355,7 @@ private:
|
||||
std::shared_ptr<CudaStream> mCommStream;
|
||||
CudaEvent mCommEvent{};
|
||||
|
||||
// tensor parallelism with custom allreduce plugin
|
||||
ITensor::SharedPtr mCommPtrs;
|
||||
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
|
||||
std::shared_ptr<AllReduceBuffers> mAllReduceBuffers;
|
||||
|
||||
SizeType mDecoderMaxSequenceLength{};
|
||||
SizeType mDecoderMaxAttentionWindow{};
|
||||
|
||||
@ -17,39 +17,56 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/kernels/customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
void setPeerAccess(WorldConfig const& worldConfig, bool enable = true);
|
||||
|
||||
class IpcMemory
|
||||
{
|
||||
public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using BufferPtr = IBuffer::SharedPtr;
|
||||
|
||||
// MAX_ALL_REDUCE_BLOCKS for block_barrier, 1 for multi_gpu_barrier
|
||||
size_t static constexpr FLAGS_SIZE = (kernels::MAX_ALL_REDUCE_BLOCKS + 1) * sizeof(uint32_t);
|
||||
|
||||
IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize);
|
||||
IpcMemory(std::size_t bufferSize, BufferManager const& manager, WorldConfig const& worldConfig);
|
||||
~IpcMemory();
|
||||
|
||||
[[nodiscard]] std::vector<void*> const& getCommPtrsTensor() const
|
||||
IpcMemory(IpcMemory const&) = delete;
|
||||
IpcMemory& operator=(IpcMemory const&) = delete;
|
||||
|
||||
IpcMemory(IpcMemory&&) = default;
|
||||
IpcMemory& operator=(IpcMemory&&) = default;
|
||||
|
||||
[[nodiscard]] std::vector<void*> const& getCommPtrs() const
|
||||
{
|
||||
return mCommPtrs;
|
||||
}
|
||||
|
||||
private:
|
||||
void allocateIpcMemory();
|
||||
void allocateIpcMemory(std::size_t bufferSize, BufferManager const& manager, WorldConfig const& worldConfig);
|
||||
void destroyIpcMemory();
|
||||
|
||||
WorldConfig mWorldConfig;
|
||||
SizeType mTpRank;
|
||||
std::vector<void*> mCommPtrs;
|
||||
std::size_t mBufferSize;
|
||||
void* mBufferPtr{nullptr};
|
||||
BufferPtr mBuffer;
|
||||
};
|
||||
|
||||
class AllReduceBuffers
|
||||
{
|
||||
public:
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
AllReduceBuffers(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength, SizeType hiddenSize,
|
||||
BufferManager const& manager, WorldConfig const& worldConfig);
|
||||
|
||||
TensorPtr mAllReduceCommPtrs;
|
||||
std::vector<runtime::IpcMemory> mIpcMemoryHandles;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -23,7 +23,9 @@
|
||||
#include "tensorrt_llm/runtime/loraModule.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <deque>
|
||||
#include <list>
|
||||
#include <map>
|
||||
|
||||
@ -19,7 +19,9 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
|
||||
@ -25,21 +25,34 @@
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
struct MambaConfig
|
||||
{
|
||||
SizeType dState = 0;
|
||||
SizeType dConv = 0;
|
||||
SizeType expand = 0;
|
||||
};
|
||||
|
||||
class ModelConfig
|
||||
{
|
||||
public:
|
||||
enum class ModelVariant : std::int32_t
|
||||
{
|
||||
kGpt = 0,
|
||||
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
|
||||
kMamba = 2, // https://github.com/state-spaces/mamba
|
||||
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
|
||||
kMamba = 2, // https://github.com/state-spaces/mamba
|
||||
kRecurrentGemma = 3, // https://github.com/google-deepmind/recurrentgemma
|
||||
};
|
||||
|
||||
struct MambaConfig
|
||||
{
|
||||
SizeType dState = 0;
|
||||
SizeType dConv = 0;
|
||||
SizeType expand = 0;
|
||||
};
|
||||
|
||||
struct RnnConfig
|
||||
{
|
||||
SizeType dConv = 0;
|
||||
SizeType hiddenSize = 0;
|
||||
};
|
||||
|
||||
enum class LayerType : std::int32_t
|
||||
{
|
||||
kATTENTION,
|
||||
kRECURRENT,
|
||||
};
|
||||
|
||||
explicit ModelConfig(SizeType vocabSize, SizeType nbAttentionLayers, SizeType nbSsmLayers, SizeType nbHeads,
|
||||
@ -478,7 +491,8 @@ public:
|
||||
|
||||
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
|
||||
{
|
||||
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm;
|
||||
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm
|
||||
|| mModelVariant == ModelVariant::kRecurrentGemma;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasMambaConfig() const noexcept
|
||||
@ -498,7 +512,32 @@ public:
|
||||
|
||||
[[nodiscard]] bool constexpr isSsmBased() const noexcept
|
||||
{
|
||||
return mModelVariant == ModelVariant::kMamba;
|
||||
return mModelVariant == ModelVariant::kMamba || mModelVariant == ModelVariant::kRecurrentGemma;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasRnnConfig() const noexcept
|
||||
{
|
||||
return mRnnConfig.has_value();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<RnnConfig> getRnnConfig() const noexcept
|
||||
{
|
||||
return mRnnConfig;
|
||||
}
|
||||
|
||||
void setRnnConfig(RnnConfig const& rnnConfig) noexcept
|
||||
{
|
||||
mRnnConfig = rnnConfig;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<LayerType> const& getLayerTypes() const noexcept
|
||||
{
|
||||
return mLayerTypes;
|
||||
}
|
||||
|
||||
void setLayerTypes(std::vector<LayerType> const& layerTypes) noexcept
|
||||
{
|
||||
mLayerTypes = layerTypes;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -548,6 +587,10 @@ private:
|
||||
bool mUsePositionEmbedding;
|
||||
bool mUseTokenTypeEmbedding;
|
||||
SizeType mFfnHiddenSize; // indicates encoder output hidden size
|
||||
|
||||
std::optional<RnnConfig> mRnnConfig;
|
||||
|
||||
std::vector<LayerType> mLayerTypes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
141
cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
Normal file
141
cpp/include/tensorrt_llm/runtime/speculativeDecodingMode.h
Normal file
@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
|
||||
class SpeculativeDecodingMode
|
||||
{
|
||||
// [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/models/modeling_utils.py
|
||||
public:
|
||||
static auto constexpr None()
|
||||
{
|
||||
return SpeculativeDecodingMode{kNone};
|
||||
}
|
||||
|
||||
static auto constexpr DraftModel()
|
||||
{
|
||||
return SpeculativeDecodingMode{kDraftModel};
|
||||
}
|
||||
|
||||
static auto constexpr Medusa()
|
||||
{
|
||||
return SpeculativeDecodingMode{kMedusa};
|
||||
}
|
||||
|
||||
static auto constexpr LookaheadDecoding()
|
||||
{
|
||||
return SpeculativeDecodingMode{kLookaheadDecoding};
|
||||
}
|
||||
|
||||
bool constexpr isNone() const
|
||||
{
|
||||
return anyBitSet(kNone);
|
||||
}
|
||||
|
||||
bool constexpr isDraftModel() const
|
||||
{
|
||||
return anyBitSet(kDraftModel);
|
||||
}
|
||||
|
||||
bool constexpr isMedusa() const
|
||||
{
|
||||
return anyBitSet(kMedusa);
|
||||
}
|
||||
|
||||
bool constexpr isLookaheadDecoding() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding);
|
||||
}
|
||||
|
||||
bool constexpr requiresAttentionMask() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa);
|
||||
}
|
||||
|
||||
bool constexpr predictsDraftTokens() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa);
|
||||
}
|
||||
|
||||
bool constexpr needsKVCacheRewind() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa);
|
||||
}
|
||||
|
||||
bool constexpr hasDraftLogits() const
|
||||
{
|
||||
return anyBitSet(kMedusa);
|
||||
}
|
||||
|
||||
using UnderlyingType = uint8_t;
|
||||
|
||||
bool operator==(SpeculativeDecodingMode const& other) const
|
||||
{
|
||||
return mState == other.mState;
|
||||
}
|
||||
|
||||
constexpr SpeculativeDecodingMode(UnderlyingType state)
|
||||
: mState(state)
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
// No speculative decoding is used.
|
||||
static UnderlyingType constexpr kNone{1u << 0};
|
||||
static UnderlyingType constexpr kDraftModel{1u << 1};
|
||||
static UnderlyingType constexpr kMedusa{1u << 2};
|
||||
static UnderlyingType constexpr kLookaheadDecoding{1u << 3};
|
||||
|
||||
bool constexpr anyBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) != 0;
|
||||
}
|
||||
|
||||
bool constexpr allBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) == bits;
|
||||
}
|
||||
|
||||
UnderlyingType mState{kNone};
|
||||
};
|
||||
|
||||
static_assert(SpeculativeDecodingMode::None().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::None().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::None().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::None().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::DraftModel().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::DraftModel().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::DraftModel().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::DraftModel().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::Medusa().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::Medusa().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::LookaheadDecoding().isLookaheadDecoding());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isDraftModel());
|
||||
static_assert(!SpeculativeDecodingMode::LookaheadDecoding().isMedusa());
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tensorrt_llm
|
||||
@ -35,11 +35,7 @@ add_subdirectory(layers)
|
||||
add_subdirectory(runtime)
|
||||
add_subdirectory(executor_worker)
|
||||
|
||||
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
|
||||
set(BATCH_MANAGER_TARGET_ARCH "unknown")
|
||||
|
||||
set(EXECUTOR_TARGET tensorrt_llm_executor_static)
|
||||
set(EXECUTOR_TARGET_ARCH "unknown")
|
||||
set(TARGET_ARCH "unknown")
|
||||
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||
if(NOT WIN32) # Linux
|
||||
@ -58,11 +54,9 @@ if(NOT WIN32) # Linux
|
||||
message(STATUS "Operating System: ${OS_ID}, ${OS_VERSION_ID}")
|
||||
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
set(BATCH_MANAGER_TARGET_ARCH "x86_64-linux-gnu")
|
||||
set(EXECUTOR_TARGET_ARCH "x86_64-linux-gnu")
|
||||
set(TARGET_ARCH "x86_64-linux-gnu")
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||
set(BATCH_MANAGER_TARGET_ARCH "aarch64-linux-gnu")
|
||||
set(EXECUTOR_TARGET_ARCH "aarch64-linux-gnu")
|
||||
set(TARGET_ARCH "aarch64-linux-gnu")
|
||||
if(NOT ${OS_ID} MATCHES "ubuntu" OR ${OS_VERSION_ID} VERSION_LESS 22.04)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
@ -76,8 +70,7 @@ if(NOT WIN32) # Linux
|
||||
else() # Windows
|
||||
# AMD64, IA64, ARM64, EM64T, X86
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "AMD64")
|
||||
set(BATCH_MANAGER_TARGET_ARCH "x86_64-windows-msvc")
|
||||
set(EXECUTOR_TARGET_ARCH "x86_64-windows-msvc")
|
||||
set(TARGET_ARCH "x86_64-windows-msvc")
|
||||
else()
|
||||
message(
|
||||
FATAL_ERROR
|
||||
@ -85,6 +78,9 @@ else() # Windows
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(BATCH_MANAGER_TARGET tensorrt_llm_batch_manager_static)
|
||||
set(BATCH_MANAGER_TARGET_ARCH ${TARGET_ARCH})
|
||||
|
||||
if(BUILD_BATCH_MANAGER)
|
||||
add_subdirectory(batch_manager)
|
||||
else()
|
||||
@ -115,6 +111,9 @@ else()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(EXECUTOR_TARGET tensorrt_llm_executor_static)
|
||||
set(EXECUTOR_TARGET_ARCH ${TARGET_ARCH})
|
||||
|
||||
if(BUILD_EXECUTOR)
|
||||
add_subdirectory(executor)
|
||||
else()
|
||||
@ -189,6 +188,45 @@ else()
|
||||
add_custom_target(check_symbol_executor)
|
||||
endif()
|
||||
|
||||
set(NVRTC_WRAPPER_TARGET tensorrt_llm_nvrtc_wrapper)
|
||||
set(NVRTC_WRAPPER_TARGET_ARCH ${TARGET_ARCH})
|
||||
|
||||
if(BUILD_NVRTC_WRAPPER)
|
||||
add_subdirectory(
|
||||
kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper)
|
||||
else()
|
||||
add_library(${NVRTC_WRAPPER_TARGET} SHARED IMPORTED)
|
||||
if(NOT WIN32) # Linux
|
||||
set(NVRTC_WRAPPER_LIB_SOURCE_REL_LOC
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${NVRTC_WRAPPER_TARGET_ARCH}/libtensorrt_llm_nvrtc_wrapper.so"
|
||||
)
|
||||
set(NVRTC_WRAPPER_LIB_BINARY_REL_LOC
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.so"
|
||||
)
|
||||
else()
|
||||
set(NVRTC_WRAPPER_LIB_SOURCE_REL_LOC
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${NVRTC_WRAPPER_TARGET_ARCH}/libtensorrt_llm_nvrtc_wrapper.dll"
|
||||
)
|
||||
set(NVRTC_WRAPPER_LIB_BINARY_REL_LOC
|
||||
"kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/libtensorrt_llm_nvrtc_wrapper.dll"
|
||||
)
|
||||
endif()
|
||||
set(NVRTC_WRAPPER_LIB_LOC
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${NVRTC_WRAPPER_LIB_SOURCE_REL_LOC}")
|
||||
# Copy the .so to build directory, which is needed in build_wheel.py.
|
||||
configure_file(${NVRTC_WRAPPER_LIB_SOURCE_REL_LOC}
|
||||
${NVRTC_WRAPPER_LIB_BINARY_REL_LOC} COPYONLY)
|
||||
set_property(TARGET ${NVRTC_WRAPPER_TARGET} PROPERTY IMPORTED_LOCATION
|
||||
${NVRTC_WRAPPER_LIB_LOC})
|
||||
file(SIZE ${NVRTC_WRAPPER_LIB_LOC} NVRTC_WRAPPER_LIB_SIZE)
|
||||
if(NVRTC_WRAPPER_LIB_SIZE LESS 1024)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"The nvrtc wrapper library is truncated or incomplete. This is usually caused by using Git LFS (Large File Storage) incorrectly. Please try running command `git lfs install && git lfs pull`."
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(TRTLLM_LINK_LIBS
|
||||
${CUBLAS_LIB}
|
||||
${CUBLASLT_LIB}
|
||||
@ -247,6 +285,13 @@ target_link_libraries(${BATCH_MANAGER_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
# Cyclic dependency of executor on TRT-LLM
|
||||
target_link_libraries(${EXECUTOR_TARGET} INTERFACE ${SHARED_TARGET})
|
||||
|
||||
if(NOT WIN32)
|
||||
set_target_properties(${SHARED_TARGET} PROPERTIES LINK_FLAGS
|
||||
"-Wl,-rpath='$ORIGIN'")
|
||||
endif()
|
||||
|
||||
target_link_libraries(${SHARED_TARGET} PUBLIC ${NVRTC_WRAPPER_TARGET})
|
||||
|
||||
add_dependencies(${SHARED_TARGET} check_symbol)
|
||||
add_dependencies(${SHARED_TARGET} check_symbol_executor)
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3a3c08bd9777149ddf546c2bd02fa78ec0d8a10e7e51fb05f29e63f089caffa9
|
||||
size 3215202
|
||||
oid sha256:97866290105b98bc63d2d38c7176b8e2d79969c99f9c456b04428fef81bd8780
|
||||
size 3309008
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:57b677069d5673dfba53aa2ff89240320f72f21707865f73fe29ce74a36f9a57
|
||||
size 3257948
|
||||
oid sha256:891a0a6f2053b011ba2c58101b279ab583442ff3585f01c919e25a26e75e51d1
|
||||
size 3353702
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
e33ec506a35e58225744944654645de5 libtensorrt_llm_batch_manager_static.a
|
||||
e0e0525dc521f70ba9b2f19638d82187 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
0ff5eb9f3ac62b2672bef68a7117bdef779926e7 commit
|
||||
ba4b89ea4ddf64403656d3626559ceae libtensorrt_llm_batch_manager_static.a
|
||||
decbd28e89ac740f9755b2b2537fa71b libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
942b83732d029cc3eaef9f5a849218d75161ec12 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:860ce68e8062b45dd15160834a5f223da1f3ae205caca5e8a99ce0037a55c900
|
||||
size 3117888
|
||||
oid sha256:04326319261c7b196048535872990497461eed46ed4b989a31527c2ef9ef8c92
|
||||
size 3205910
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3eacf70f4b6b0f959c7b5b29a2f17d2d0f40283334e2decc6ea8ac67eb3523b7
|
||||
size 3097564
|
||||
oid sha256:8dffe215e14b2f67af2e8a77ecb8281db3fe54cc5184e635904f752a7ef84a0c
|
||||
size 3185774
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fc4351557104103d44a1bc38b967e34337777e3c45b441c0057d4a16d68dc458
|
||||
size 19620324
|
||||
oid sha256:5889c4e0dd2109a30c49a554780f43415528a710bf438bf57e0b34ec5c49a695
|
||||
size 19782918
|
||||
|
||||
@ -38,6 +38,28 @@ namespace tensorrt_llm
|
||||
namespace common
|
||||
{
|
||||
|
||||
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
|
||||
{
|
||||
static std::mutex mutex;
|
||||
static std::weak_ptr<CUDADriverWrapper> instance;
|
||||
std::shared_ptr<CUDADriverWrapper> result = instance.lock();
|
||||
if (result)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
result = instance.lock();
|
||||
if (!result)
|
||||
{
|
||||
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
|
||||
instance = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
CUDADriverWrapper::CUDADriverWrapper()
|
||||
{
|
||||
handle = dllOpen(CUDA_LIB_NAME);
|
||||
@ -63,6 +85,7 @@ CUDADriverWrapper::CUDADriverWrapper()
|
||||
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*(void**) (&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*(void**) (&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
*(void**) (&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper()
|
||||
@ -153,5 +176,10 @@ CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUten
|
||||
boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const
|
||||
{
|
||||
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -19,10 +19,12 @@
|
||||
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#define cuErrCheck(stat, wrap) \
|
||||
{ \
|
||||
cuErrCheck_((stat), wrap, __FILE__, __LINE__); \
|
||||
cuErrCheck_((stat), wrap.get(), __FILE__, __LINE__); \
|
||||
}
|
||||
|
||||
namespace tensorrt_llm
|
||||
@ -32,9 +34,12 @@ namespace common
|
||||
|
||||
class CUDADriverWrapper
|
||||
{
|
||||
public:
|
||||
// Use getInstance() instead.
|
||||
CUDADriverWrapper();
|
||||
|
||||
public:
|
||||
static std::shared_ptr<CUDADriverWrapper> getInstance();
|
||||
|
||||
~CUDADriverWrapper();
|
||||
|
||||
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||
@ -75,6 +80,8 @@ public:
|
||||
cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const;
|
||||
|
||||
CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const;
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
@ -98,14 +105,15 @@ private:
|
||||
cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides,
|
||||
cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill);
|
||||
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
};
|
||||
|
||||
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line)
|
||||
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const* wrap, char const* file, int line)
|
||||
{
|
||||
if (stat != CUDA_SUCCESS)
|
||||
{
|
||||
char const* msg = nullptr;
|
||||
wrap.cuGetErrorName(stat, &msg);
|
||||
wrap->cuGetErrorName(stat, &msg);
|
||||
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,11 +23,8 @@ namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
|
||||
constexpr size_t NUM_POINTERS_PER_RANK = 4;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
|
||||
size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
{
|
||||
if (worldSize <= 2)
|
||||
{
|
||||
@ -35,6 +32,5 @@ size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
}
|
||||
return 8 * 1000 * 1000;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
|
||||
@ -55,6 +55,25 @@ std::optional<int32_t> envXqaNbCtaPerKVHead()
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool getEnvDisableXQAJIT()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool disableXQAJIT = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* disable_xqa_jit_var = std::getenv("TRTLLM_DISABLE_XQA_JIT");
|
||||
if (disable_xqa_jit_var)
|
||||
{
|
||||
if (disable_xqa_jit_var[0] == '1' && disable_xqa_jit_var[1] == '\0')
|
||||
{
|
||||
disableXQAJIT = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return disableXQAJIT;
|
||||
}
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug()
|
||||
{
|
||||
|
||||
@ -33,6 +33,9 @@ int32_t xqaMaxNbCtaPerKVHeadFactor();
|
||||
|
||||
std::optional<int32_t> envXqaNbCtaPerKVHead();
|
||||
|
||||
// Whether XQA JIT is disabled.
|
||||
bool getEnvDisableXQAJIT();
|
||||
|
||||
// Tune the number of blocks per sequence for accuracy/performance purpose.
|
||||
bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eeda6a94352bd7bff125b1645ccd7d1e049acf4d316057f7a3adc71f38de54b0
|
||||
size 1228412
|
||||
oid sha256:418820fec34c660cf94828f74159b0856517faf21b877d0a29b6a5e7dc71ece2
|
||||
size 1235256
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3d2c63df67c83b0970032e549d477bccc6e07883bb82562df6fbaa3a7f22dbd5
|
||||
size 1247068
|
||||
oid sha256:65e3acc4d6e33b30775f3fce8c6b171c22b1842eb5a4e04fb2a109b5f56082c7
|
||||
size 1253184
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
91b15059b1b7ea4662db71c7af0abe2b libtensorrt_llm_executor_static.a
|
||||
fe5af71bf010a17fdf34c253ab187c28 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
0ff5eb9f3ac62b2672bef68a7117bdef779926e7 commit
|
||||
0d429aff4a27797c9a4b3078d59bb3d3 libtensorrt_llm_executor_static.a
|
||||
e5012c4a7e70b6d2e9d80563c26d2c83 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
942b83732d029cc3eaef9f5a849218d75161ec12 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:812ce5bd5effd252b642d31ec261e8de1e93bc71017dff91fdb84f833a66029a
|
||||
size 1249594
|
||||
oid sha256:a46eec8c1209e4499478d656fce44bce280e74fb846b669c6601c3d6ea87a21a
|
||||
size 1255814
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:01dc32257ebafd712a527d62cca4c4880a636e65ede63857b7bea62bf21b975e
|
||||
size 1204654
|
||||
oid sha256:7aa3c2841123c7db28fd0e81197e3fec59709d15d6dee8436c138c597bcec4bd
|
||||
size 1210336
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7da73ddfa6393c8f040e92c32206b96c5dab936742fdcdca8e91992c26f80146
|
||||
size 11870092
|
||||
oid sha256:740bd924898b2cdd1181d9dfe60bdba710166ab0e4e11eef424ebed3c6de8ab6
|
||||
size 11912588
|
||||
|
||||
@ -74,7 +74,7 @@ int main(int argc, char* argv[])
|
||||
// In orchestrator mode, the spawned threads will wait for termination signal from orchestrator
|
||||
auto executor = tle::Executor(modelPath, modelType, executorConfig);
|
||||
|
||||
TLLM_LOG_INFO("Executor worker exiting");
|
||||
TLLM_LOG_INFO("Executor instance created by worker");
|
||||
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
|
||||
@ -72,7 +72,8 @@ public:
|
||||
|
||||
TFusedMultiHeadAttentionXMMAKernel(
|
||||
TKernelMeta const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||
: mDataType(type)
|
||||
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
, mDataType(type)
|
||||
, mKernelMeta(pMetaStart)
|
||||
, mKernelMetaCount(nMetaCount)
|
||||
, mSM(sm)
|
||||
@ -99,16 +100,17 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
cuErrCheck(mDriver.cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
|
||||
mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
|
||||
}
|
||||
|
||||
FusedMultiHeadAttentionKernelInfo funcInfo;
|
||||
funcInfo.mMetaInfoIndex = i;
|
||||
cuErrCheck(mDriver.cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
|
||||
cuErrCheck(
|
||||
mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
|
||||
if (kernelMeta.mSharedMemBytes >= 48 * 1024)
|
||||
{
|
||||
cuErrCheck(mDriver.cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
cuErrCheck(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes),
|
||||
mDriver);
|
||||
}
|
||||
@ -133,7 +135,7 @@ public:
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
@ -143,7 +145,7 @@ public:
|
||||
virtual ~TFusedMultiHeadAttentionXMMAKernel() = default;
|
||||
|
||||
protected:
|
||||
tensorrt_llm::common::CUDADriverWrapper mDriver;
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
|
||||
Data_type mDataType;
|
||||
TKernelMeta const* mKernelMeta;
|
||||
@ -306,7 +308,7 @@ public:
|
||||
|
||||
if (!forceUnroll)
|
||||
{
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
} // forceunroll = true for flash attention kernels
|
||||
@ -357,8 +359,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, block_size.x, block_size.y, block_size.z, kernelMeta.mThreadsPerCTA,
|
||||
1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, block_size.x, block_size.y, block_size.z,
|
||||
kernelMeta.mThreadsPerCTA, 1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
else
|
||||
@ -374,13 +376,13 @@ public:
|
||||
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
|
||||
if (mSM == kSM_90 && !launch_params.flash_attention)
|
||||
{
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
} // on Ampere/Ada/Volta flash attention, we launch blocks (steps, h, b)
|
||||
else
|
||||
{
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
|
||||
74
cpp/tensorrt_llm/kernels/cumsumLastDim.cu
Normal file
74
cpp/tensorrt_llm/kernels/cumsumLastDim.cu
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <cub/device/device_scan.cuh>
|
||||
|
||||
#include "cumsumLastDim.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
template <typename input_t>
|
||||
size_t invokeComputeCumsumLastDimWorkspaceSize(int input_length)
|
||||
{
|
||||
input_t* iodata = nullptr;
|
||||
size_t temp_storage_bytes;
|
||||
cub::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, iodata, iodata, input_length);
|
||||
return temp_storage_bytes;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(input_t) \
|
||||
template size_t invokeComputeCumsumLastDimWorkspaceSize<input_t>(int input_length)
|
||||
|
||||
INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(int);
|
||||
INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(float);
|
||||
INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(half);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE(__nv_bfloat16);
|
||||
#endif
|
||||
#undef INSTANTIATE_COMPUTE_CUMSUM_LastDim_WORKSPACE_SIZE_DATA_TYPE
|
||||
|
||||
///////////////
|
||||
|
||||
template <typename input_t>
|
||||
void invokeCumsumLastDim(int batch_size, int input_length, void const* __restrict__ input, void* __restrict__ output,
|
||||
void* d_temp_storage, size_t temp_storage_bytes, cudaStream_t stream)
|
||||
{
|
||||
for (int i = 0; i < batch_size; i++)
|
||||
{
|
||||
input_t const* input_ptr = reinterpret_cast<input_t const*>(input) + i * input_length;
|
||||
input_t* output_ptr = reinterpret_cast<input_t*>(output) + i * input_length;
|
||||
cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, input_ptr, output_ptr, input_length, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(input_t) \
|
||||
template void invokeCumsumLastDim<input_t>(int batch_size, int input_length, const void* __restrict__ input, \
|
||||
void* __restrict__ output, void* workspace, size_t temp_storage_bytes, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(int);
|
||||
INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(float);
|
||||
INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(half);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_CUMSUM_LastDim_DATA_TYPE(__nv_bfloat16);
|
||||
#endif
|
||||
#undef INSTANTIATE_CUMSUM_LastDim_DATA_TYPE
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
35
cpp/tensorrt_llm/kernels/cumsumLastDim.h
Normal file
35
cpp/tensorrt_llm/kernels/cumsumLastDim.h
Normal file
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename input_t>
|
||||
size_t invokeComputeCumsumLastDimWorkspaceSize(int input_length);
|
||||
|
||||
template <typename input_t>
|
||||
void invokeCumsumLastDim(int batch_size, int input_length, void const* __restrict__ input, void* __restrict__ output,
|
||||
void* workspace, size_t temp_storage_bytes, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -525,7 +525,8 @@ size_t MoeGemmRunner<T, WeightType>::calcMaxWorkspaceSize(int num_experts) const
|
||||
return max_size;
|
||||
}
|
||||
|
||||
assert(false); // Unreachable
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported MoE GEMM configuration"); // Unreachable
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
|
||||
@ -123,7 +123,11 @@ struct Multihead_attention_params_base
|
||||
float rotary_embedding_base = 0.0f;
|
||||
RotaryScalingType rotary_embedding_scale_type = RotaryScalingType::kNONE;
|
||||
float rotary_embedding_scale = 0.0f;
|
||||
float rotary_embedding_m_scale = 0.0f;
|
||||
float const* rotary_embedding_scaling_factors = nullptr;
|
||||
int rotary_embedding_max_positions = 0;
|
||||
int rotary_cogvlm_vision_start = -1;
|
||||
int rotary_cogvlm_vision_length = -1;
|
||||
// Position shift for streamingllm
|
||||
bool position_shift_enabled = false;
|
||||
// The current timestep. TODO Check that do we only this param in cross attention?
|
||||
|
||||
@ -18,6 +18,9 @@
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
# Exclude files in nvrtcWrapper folder.
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX ".*nvrtcWrapper/src.*")
|
||||
|
||||
# skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
|
||||
if(FAST_BUILD)
|
||||
list(FILTER SRC_CU EXCLUDE REGEX
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -72,7 +72,8 @@ inline size_t smem_size_in_bytes(Multihead_attention_params<T, DO_CROSS_ATTENTIO
|
||||
size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2;
|
||||
|
||||
size_t transpose_rotary_size = 0;
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX)
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE)
|
||||
{
|
||||
assert(params.rotary_embedding_dim > 0);
|
||||
transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk);
|
||||
@ -365,7 +366,8 @@ void mmha_launch_kernel(KernelParamsType const& params, KVCacheBuffer const& kv_
|
||||
{
|
||||
assert((params.rotary_embedding_dim != 0)
|
||||
== (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ));
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE));
|
||||
if (params.beam_width == 1)
|
||||
{
|
||||
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false, IMPLICIT_REL_ATTN_BIAS>(
|
||||
|
||||
@ -1639,15 +1639,16 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
if (HANDLE_KV)
|
||||
{
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale, current_pos_idx);
|
||||
rotary_embedding_scale, 0, nullptr, current_pos_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_rotary_embedding(
|
||||
q, tidx, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale,
|
||||
0, nullptr, current_pos_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kLONG_ROPE:
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
bool const do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
||||
@ -1683,14 +1684,18 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
mmha::vec_from_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
|
||||
rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_m_scale,
|
||||
params.rotary_embedding_scaling_factors, current_pos_idx, params.rotary_cogvlm_vision_start,
|
||||
params.rotary_cogvlm_vision_length);
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem_, transpose_idx, smem_pitch);
|
||||
}
|
||||
else
|
||||
{
|
||||
mmha::apply_rotary_embedding(q, transpose_idx / tidx_factor, params.rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
|
||||
rotary_embedding_base, rotary_embedding_scale, params.rotary_embedding_m_scale,
|
||||
params.rotary_embedding_scaling_factors, current_pos_idx, params.rotary_cogvlm_vision_start,
|
||||
params.rotary_cogvlm_vision_length);
|
||||
}
|
||||
mmha::write_smem_transpose(q, q_smem_, transpose_idx, smem_pitch);
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
*/
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h"
|
||||
|
||||
#include <cassert>
|
||||
@ -44,12 +45,10 @@ std::unique_ptr<DecoderXQAImpl> DecoderXQAImpl::create(DecoderXQARunner* runner,
|
||||
switch (implType)
|
||||
{
|
||||
case ImplType::kPrecompiled: return std::unique_ptr<DecoderXQAImpl>(new DecoderXQAImplPrecompiled(runner));
|
||||
// TODO(minwei): JIT impl.
|
||||
case ImplType::kJIT: return nullptr;
|
||||
case ImplType::kJIT: return std::unique_ptr<DecoderXQAImpl>(new DecoderXQAImplJIT(runner));
|
||||
}
|
||||
// Shouldn't reach here.
|
||||
assert(false);
|
||||
return nullptr;
|
||||
TLLM_THROW("Unknown DecoderXQAImpl::ImplType");
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -33,15 +33,15 @@ class DecoderXQARunner;
|
||||
* We need this layer of abstraction for abstracting out implementation details. Two possible implementations:
|
||||
* 1. Precompiled, i.e. kernels are compiled and saved as cubins in advance.
|
||||
* 2. JIT, i.e. kernels are compiled on the fly via NVRTC.
|
||||
*
|
||||
* This class is written as Composition over Inheritance, primarily because C++ does not support virtual template
|
||||
* functions.
|
||||
*/
|
||||
class DecoderXQAImpl
|
||||
{
|
||||
public:
|
||||
// TODO(minwei): shouldUse()/prepare() should be templated with KVCacheBuffer.
|
||||
// Whether it is beneficial to use this XQA codepath.
|
||||
virtual bool shouldUse(XQAParams const& xqaParams) = 0;
|
||||
//
|
||||
// forConfigurePlugin: whether this method is called in configure plugin phase.
|
||||
virtual bool shouldUse(XQAParams const& xqaParams, bool forConfigurePlugin) = 0;
|
||||
// Prepares for the kernel running. Must be called before calling run.
|
||||
virtual void prepare(XQAParams const& xqa_params) = 0;
|
||||
// Run XQA kernel with KVCacheBuffer.
|
||||
|
||||
@ -0,0 +1,311 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* Common utils to be shared between Precompiled and JIT implementation.
|
||||
*/
|
||||
#pragma once
|
||||
#include "decoderXQAConstants.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/multiHeadAttentionCommon.h"
|
||||
#include "xqaParams.h"
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
struct XQAKernelLoadHashKey
|
||||
{
|
||||
Data_type data_type;
|
||||
unsigned int sm;
|
||||
|
||||
bool operator==(XQAKernelLoadHashKey const& other) const
|
||||
{
|
||||
return data_type == other.data_type && sm == other.sm;
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelLoadHasher
|
||||
{
|
||||
size_t operator()(XQAKernelLoadHashKey const& s) const
|
||||
{
|
||||
size_t key = s.data_type;
|
||||
key <<= 16;
|
||||
key ^= s.sm;
|
||||
return key;
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelRuntimeHashKey
|
||||
{
|
||||
Data_type kv_data_type;
|
||||
unsigned int head_size;
|
||||
unsigned int beam_size;
|
||||
unsigned int num_q_heads_per_kv;
|
||||
unsigned int m_tilesize;
|
||||
unsigned int tokens_per_page;
|
||||
bool paged_kv_cache;
|
||||
bool multi_query_tokens;
|
||||
|
||||
bool operator==(XQAKernelRuntimeHashKey const& other) const
|
||||
{
|
||||
return kv_data_type == other.kv_data_type && head_size == other.head_size
|
||||
&& num_q_heads_per_kv == other.num_q_heads_per_kv && beam_size == other.beam_size
|
||||
&& multi_query_tokens == other.multi_query_tokens && m_tilesize == other.m_tilesize
|
||||
&& tokens_per_page == other.tokens_per_page && paged_kv_cache == other.paged_kv_cache;
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelRuntimeHasher
|
||||
{
|
||||
size_t operator()(XQAKernelRuntimeHashKey const& s) const
|
||||
{
|
||||
size_t key = s.kv_data_type;
|
||||
key <<= 16;
|
||||
key ^= s.head_size;
|
||||
key <<= 8;
|
||||
key ^= s.num_q_heads_per_kv;
|
||||
key <<= 8;
|
||||
key ^= s.beam_size;
|
||||
key <<= 6;
|
||||
key ^= s.m_tilesize;
|
||||
key <<= 10;
|
||||
key ^= s.tokens_per_page;
|
||||
key <<= 1;
|
||||
key ^= s.paged_kv_cache;
|
||||
key <<= 1;
|
||||
key ^= s.multi_query_tokens;
|
||||
return key;
|
||||
}
|
||||
};
|
||||
|
||||
// XQA kernel can be uniquely identified by (LoadHashKey, RuntimeHashKey).
|
||||
struct XQAKernelFullHashKey
|
||||
{
|
||||
XQAKernelLoadHashKey load_key;
|
||||
XQAKernelRuntimeHashKey runtime_key;
|
||||
|
||||
XQAKernelFullHashKey() = default;
|
||||
|
||||
XQAKernelFullHashKey(XQAKernelLoadHashKey const& load_key, XQAKernelRuntimeHashKey const& runtime_key)
|
||||
: load_key(load_key)
|
||||
, runtime_key(runtime_key)
|
||||
{
|
||||
}
|
||||
|
||||
XQAKernelFullHashKey(void const* buffer, size_t buffer_size)
|
||||
{
|
||||
TLLM_CHECK(sizeof(*this) <= buffer_size);
|
||||
memcpy(this, buffer, sizeof(*this));
|
||||
}
|
||||
|
||||
bool operator==(XQAKernelFullHashKey const& other) const
|
||||
{
|
||||
return load_key == other.load_key && runtime_key == other.runtime_key;
|
||||
}
|
||||
|
||||
size_t getSerializationSize() const
|
||||
{
|
||||
return sizeof(*this);
|
||||
}
|
||||
|
||||
void serialize(void* buffer, size_t buffer_size) const
|
||||
{
|
||||
TLLM_CHECK(sizeof(*this) <= buffer_size);
|
||||
memcpy(buffer, this, sizeof(*this));
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelFullHasher
|
||||
{
|
||||
size_t operator()(XQAKernelFullHashKey const& s) const
|
||||
{
|
||||
return XQAKernelLoadHasher()(s.load_key) ^ XQAKernelRuntimeHasher()(s.runtime_key);
|
||||
}
|
||||
};
|
||||
|
||||
// NOTE: we use int32_t sequence lengths as gpt attention plugins use int32_t for that.
|
||||
// XQA kernels assume all length should use uint32_t.
|
||||
// NOTE: Linear KV cache and paged KV cache uses the same structure.
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
struct KVCache
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KVCache<KVBlockArray>
|
||||
{
|
||||
// Start address of the paged kv block pool.
|
||||
void* poolPtr = nullptr;
|
||||
// Block indices in the memory pool.
|
||||
int32_t const* blockIndices = nullptr;
|
||||
int32_t const* sequence_lengths = nullptr;
|
||||
// NOTE: max_num_blocks_per_sequence for paged kv cache.
|
||||
uint32_t capacity = 0;
|
||||
|
||||
KVCache(KVBlockArray& kv_cache_buffer)
|
||||
{
|
||||
poolPtr = kv_cache_buffer.mPrimaryPoolPtr;
|
||||
blockIndices = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(kv_cache_buffer.data);
|
||||
}
|
||||
|
||||
KVCache() = default;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KVCache<KVLinearBuffer>
|
||||
{
|
||||
// Buffer address.
|
||||
void* data = nullptr;
|
||||
int32_t const* sequence_lengths = nullptr;
|
||||
// NOTE: max_sequence_length for linear kv cache.
|
||||
uint32_t capacity = 0;
|
||||
|
||||
KVCache(KVLinearBuffer& kv_cache_buffer)
|
||||
{
|
||||
data = kv_cache_buffer.data;
|
||||
}
|
||||
|
||||
KVCache() = default;
|
||||
};
|
||||
|
||||
struct BeamSearchParams
|
||||
{
|
||||
int32_t const* indices; // cacheIndir with shape: [batchSize][beamWidth][capacity]
|
||||
int32_t capacity;
|
||||
int32_t const* ctxLenList; // shape: [batchSize][beamWidth]. Should be [batchSize] but we have to match trt-llm API.
|
||||
};
|
||||
|
||||
// XQA kernels assume all integer values should use uint32_t.
|
||||
template <typename KVCacheBuffer>
|
||||
struct XQALaunchParam
|
||||
{
|
||||
uint32_t num_k_heads;
|
||||
void* output;
|
||||
void const* qkv;
|
||||
KVCache<KVCacheBuffer> kvCacheParams;
|
||||
std::optional<BeamSearchParams> beamSearchParams;
|
||||
uint32_t batch_size;
|
||||
float const* kv_scale_quant_orig = nullptr;
|
||||
int* cu_seq_lens = nullptr;
|
||||
float* rotary_inv_freq_buf = nullptr;
|
||||
void* scratch = nullptr;
|
||||
};
|
||||
|
||||
// Setup launch params.
|
||||
template <typename KVCacheBuffer>
|
||||
void buildXQALaunchParams(
|
||||
XQALaunchParam<KVCacheBuffer>& launchParams, XQAParams const& params, KVCacheBuffer kv_cache_buffer)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
params.data_type == DATA_TYPE_FP16 || params.data_type == DATA_TYPE_BF16, "Only fp16 or bf16 supported now.");
|
||||
memset(&launchParams, 0, sizeof(XQALaunchParam<KVCacheBuffer>));
|
||||
launchParams.num_k_heads = params.num_kv_heads;
|
||||
launchParams.output = static_cast<uint8_t*>(params.output);
|
||||
launchParams.qkv = static_cast<uint8_t const*>(params.qkv);
|
||||
launchParams.batch_size = params.batch_size;
|
||||
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
||||
|
||||
// Workspace.
|
||||
size_t offset = 0;
|
||||
int8_t* workspace = reinterpret_cast<int8_t*>(params.workspaces);
|
||||
unsigned int batch_beam_size = params.batch_size * params.beam_width;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_beam_size + 1);
|
||||
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
|
||||
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
|
||||
launchParams.rotary_inv_freq_buf = reinterpret_cast<float*>(workspace);
|
||||
auto const multi_block_workspace_alignment = tensorrt_llm::common::roundUp(
|
||||
sizeof(half) * params.head_size * (params.num_q_heads / params.num_kv_heads) * params.beam_width, 128);
|
||||
workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(
|
||||
workspace, rotary_inv_freq_size, multi_block_workspace_alignment);
|
||||
launchParams.scratch = reinterpret_cast<void*>(workspace);
|
||||
|
||||
launchParams.kvCacheParams = KVCache<KVCacheBuffer>(kv_cache_buffer);
|
||||
launchParams.kvCacheParams.sequence_lengths = params.sequence_lengths;
|
||||
launchParams.kvCacheParams.capacity
|
||||
= params.paged_kv_cache ? params.max_blocks_per_sequence : params.max_attention_window_size;
|
||||
// TODO: beam searching has not been implemented yet.
|
||||
if (params.beam_width > 1)
|
||||
{
|
||||
launchParams.beamSearchParams
|
||||
= BeamSearchParams{params.cache_indir, params.max_attention_window_size, params.context_lengths};
|
||||
}
|
||||
else
|
||||
{
|
||||
launchParams.beamSearchParams = std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::optional<T> getGlobalVar(std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver, CUmodule hmod,
|
||||
char const* const name, bool required = false)
|
||||
{
|
||||
T* pVar = nullptr;
|
||||
size_t size = 0;
|
||||
auto const error = driver->cuModuleGetGlobal(reinterpret_cast<CUdeviceptr*>(&pVar), &size, hmod, name);
|
||||
T ret;
|
||||
switch (error)
|
||||
{
|
||||
case CUDA_SUCCESS:
|
||||
TLLM_CHECK(size == sizeof(T));
|
||||
tensorrt_llm::common::check_cuda_error(cudaMemcpy(&ret, pVar, size, cudaMemcpyDeviceToHost));
|
||||
break;
|
||||
case CUDA_ERROR_NOT_FOUND:
|
||||
if (!required)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
[[fallthrough]];
|
||||
default: cuErrCheck(("Failed to retrieve global variable from cubin.", error), driver);
|
||||
}
|
||||
return std::optional<T>{std::move(ret)};
|
||||
}
|
||||
|
||||
inline int computeMultiBlockCount(XQAParams const& xqaParams, int batch_size, int multiprocessor_count)
|
||||
{
|
||||
if (tensorrt_llm::common::envXqaNbCtaPerKVHead().has_value())
|
||||
{
|
||||
return tensorrt_llm::common::envXqaNbCtaPerKVHead().value();
|
||||
}
|
||||
int multi_block_count = 1;
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
int history_length = xqaParams.timestep;
|
||||
|
||||
multi_block_count = history_length / kMinHistoryTokensPerBlock;
|
||||
multi_block_count = std::max(multi_block_count, 1);
|
||||
// adjust to kTargetWaveFactor, as already initialized using kMinHistoryTokensPerBlock, only need to decrease.
|
||||
double wave_count = (double) batch_size * num_kv_heads * multi_block_count / (double) multiprocessor_count;
|
||||
double adj_factor = wave_count / (double) kTargetWaveFactor;
|
||||
if (adj_factor > 1.0)
|
||||
{
|
||||
multi_block_count = floor(multi_block_count / adj_factor);
|
||||
}
|
||||
multi_block_count = std::max(multi_block_count, 1);
|
||||
|
||||
// add limitation on upper bound.
|
||||
multi_block_count = std::min(tensorrt_llm::common::xqaMaxNbCtaPerKVHeadFactor(), multi_block_count);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(multi_block_count >= 1, "MultiBlock count should be larger than 1");
|
||||
return multi_block_count;
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "compileEngine.h"
|
||||
|
||||
#include "cubinObj.h"
|
||||
#include "nvrtcWrapper/include/nvrtcWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
void CHECK_TLLM_XQA_JIT_ERROR_(tllmXqaJitStatus result, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result != TLLM_XQA_JIT_SUCCESS)
|
||||
{
|
||||
std::vector<char> log(tllmXqaJitGetLastErrorStringSize());
|
||||
tllmXqaJitGetLastErrorString(log.data());
|
||||
throw tensorrt_llm::common::TllmException(file, line,
|
||||
tensorrt_llm::common::fmtstr("[TensorRT-LLM][ERROR] TllmXqaJit runtime error in %s: %s", func, log.data()));
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_TLLM_XQA_JIT_ERROR(val) CHECK_TLLM_XQA_JIT_ERROR_((val), #val, __FILE__, __LINE__)
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
{
|
||||
|
||||
CubinObj CompileEngine::compile() const
|
||||
{
|
||||
tllmXqaJitProgram program;
|
||||
tllmXqaJitContext context{/*sm=*/mSM,
|
||||
/*head_size=*/static_cast<uint32_t>(mXqaParams.head_size),
|
||||
/*num_q_heads=*/static_cast<uint32_t>(mXqaParams.num_q_heads),
|
||||
/*num_kv_heads=*/static_cast<uint32_t>(mXqaParams.num_kv_heads),
|
||||
/*beam_width=*/static_cast<uint32_t>(mXqaParams.beam_width),
|
||||
/*tokens_per_block=*/static_cast<uint32_t>(mXqaParams.tokens_per_block),
|
||||
/*multi_query_tokens=*/mXqaParams.multi_query_tokens,
|
||||
/*paged_kv_cache=*/mXqaParams.paged_kv_cache,
|
||||
/*data_type=*/static_cast<int>(mXqaParams.data_type),
|
||||
/*kv_cache_data_type=*/static_cast<int>(mXqaParams.kv_cache_data_type)};
|
||||
|
||||
CHECK_TLLM_XQA_JIT_ERROR(tllmXqaJitCreateAndCompileProgram(&program, &context));
|
||||
|
||||
size_t cubinSize;
|
||||
CHECK_TLLM_XQA_JIT_ERROR(tllmXqaJitGetCUBINSize(program, &cubinSize));
|
||||
std::string cubinContent(cubinSize, ' ');
|
||||
CHECK_TLLM_XQA_JIT_ERROR(tllmXqaJitGetCUBIN(program, const_cast<char*>(cubinContent.c_str())));
|
||||
|
||||
CHECK_TLLM_XQA_JIT_ERROR(tllmXqaJitDestroyProgram(&program));
|
||||
|
||||
return CubinObj(cubinContent);
|
||||
}
|
||||
|
||||
CompileEngine::CompileEngine(int SM, XQAParams const& xqaParams)
|
||||
: mSM(SM)
|
||||
, mXqaParams(xqaParams)
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "cubinObj.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h"
|
||||
#include <nvrtc.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
{
|
||||
|
||||
// A thin wrapper around NVRTC for compiling CUDA programs.
|
||||
class CompileEngine
|
||||
{
|
||||
public:
|
||||
CompileEngine(int SM, XQAParams const& xqaParams);
|
||||
|
||||
CubinObj compile() const;
|
||||
|
||||
~CompileEngine() = default;
|
||||
|
||||
private:
|
||||
int mSM;
|
||||
XQAParams const& mXqaParams;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cubinObj.h"
|
||||
|
||||
#include "serializationUtils.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
{
|
||||
|
||||
CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
|
||||
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
{
|
||||
uint8_t const* buffer = static_cast<uint8_t const*>(buffer_);
|
||||
size_t remaining_buffer_size = buffer_size;
|
||||
uint32_t len = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
|
||||
mContent.resize(len);
|
||||
TLLM_CHECK(len <= remaining_buffer_size);
|
||||
memcpy(mContent.data(), buffer, len);
|
||||
|
||||
initialize(mContent.c_str(), "kernel_mha");
|
||||
}
|
||||
|
||||
size_t CubinObj::getSerializationSize() const noexcept
|
||||
{
|
||||
size_t result = sizeof(uint32_t) + mContent.size();
|
||||
// Make result multiples of 4.
|
||||
result = (result + 3) & ~3;
|
||||
return result;
|
||||
}
|
||||
|
||||
void CubinObj::serialize(void* buffer_, size_t buffer_size) const noexcept
|
||||
{
|
||||
size_t remaining_buffer_size = buffer_size;
|
||||
uint8_t* buffer = static_cast<uint8_t*>(buffer_);
|
||||
uint32_t len = mContent.size();
|
||||
writeToBuffer<uint32_t>(len, buffer, remaining_buffer_size);
|
||||
TLLM_CHECK(len <= remaining_buffer_size);
|
||||
memcpy(buffer, mContent.c_str(), len);
|
||||
}
|
||||
|
||||
CubinObj::CubinObj(std::string const& content)
|
||||
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
, mContent(content)
|
||||
, mModule(nullptr)
|
||||
, mFunction(nullptr)
|
||||
, mSharedMemBytes(0)
|
||||
{
|
||||
initialize(mContent.c_str(), "kernel_mha");
|
||||
}
|
||||
|
||||
void CubinObj::launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams)
|
||||
{
|
||||
cuErrCheck(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y, blockDim.z,
|
||||
mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr),
|
||||
mDriver);
|
||||
}
|
||||
|
||||
void CubinObj::initialize(char const* content, char const* funcName)
|
||||
{
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&mModule, content), mDriver);
|
||||
TLLM_CHECK(mModule != nullptr);
|
||||
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, funcName), mDriver);
|
||||
TLLM_CHECK(mFunction != nullptr);
|
||||
|
||||
// Populate mSharedMemBytes.
|
||||
CUdeviceptr shmem_dev_ptr = 0;
|
||||
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, "smemSize"), mDriver);
|
||||
TLLM_CHECK(shmem_dev_ptr != 0);
|
||||
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
|
||||
|
||||
TLLM_CHECK(mSharedMemBytes > 0);
|
||||
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
|
||||
if (mSharedMemBytes >= 46 * 1024)
|
||||
{
|
||||
cuErrCheck(
|
||||
mDriver->cuFuncSetAttribute(mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
|
||||
mDriver);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <string>
|
||||
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
{
|
||||
|
||||
class CubinObj
|
||||
{
|
||||
public:
|
||||
// Default constructor constructs an empty unusable CubinObj instance.
|
||||
CubinObj() = default;
|
||||
CubinObj(std::string const& content);
|
||||
CubinObj(void const* buffer, size_t buffer_size);
|
||||
void launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams);
|
||||
|
||||
size_t getSerializationSize() const noexcept;
|
||||
void serialize(void* buffer, size_t buffer_size) const noexcept;
|
||||
|
||||
private:
|
||||
void initialize(char const* content, char const* funcName);
|
||||
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
|
||||
std::string mContent;
|
||||
|
||||
CUmodule mModule;
|
||||
CUfunction mFunction;
|
||||
unsigned int mSharedMemBytes;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,144 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "cubinObj.h"
|
||||
|
||||
#include "compileEngine.h"
|
||||
#include "serializationUtils.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
{
|
||||
|
||||
// A collection of CubinObjs, with caching functionality.
|
||||
template <typename Key, class Hash = std::hash<Key>>
|
||||
class CubinObjRegistryTemplate
|
||||
{
|
||||
public:
|
||||
CubinObjRegistryTemplate() = default;
|
||||
|
||||
CubinObjRegistryTemplate(void const* buffer_, size_t buffer_size)
|
||||
{
|
||||
size_t remaining_buffer_size = buffer_size;
|
||||
uint8_t const* buffer = static_cast<uint8_t const*>(buffer_);
|
||||
// First 4 bytes: num of elements.
|
||||
uint32_t n = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
|
||||
|
||||
for (uint32_t i = 0; i < n; ++i)
|
||||
{
|
||||
uint32_t key_size = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
|
||||
TLLM_CHECK(key_size <= remaining_buffer_size);
|
||||
Key key(buffer, key_size);
|
||||
buffer += key_size;
|
||||
remaining_buffer_size -= key_size;
|
||||
|
||||
uint32_t obj_size = readFromBuffer<uint32_t>(buffer, remaining_buffer_size);
|
||||
TLLM_CHECK(obj_size <= remaining_buffer_size);
|
||||
CubinObj obj(buffer, obj_size);
|
||||
buffer += obj_size;
|
||||
remaining_buffer_size -= obj_size;
|
||||
|
||||
mMap.insert({key, std::move(obj)});
|
||||
}
|
||||
TLLM_CHECK(remaining_buffer_size == 0);
|
||||
}
|
||||
|
||||
std::unique_ptr<CubinObjRegistryTemplate<Key, Hash>> clone() const noexcept
|
||||
{
|
||||
auto result = std::make_unique<CubinObjRegistryTemplate<Key, Hash>>();
|
||||
for (auto const& p : mMap)
|
||||
{
|
||||
result->mMap.insert(p);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t getSerializationSize() const noexcept
|
||||
{
|
||||
size_t result = sizeof(uint32_t);
|
||||
for (auto&& p : mMap)
|
||||
{
|
||||
result += 2 * sizeof(uint32_t);
|
||||
result += p.first.getSerializationSize() + p.second.getSerializationSize();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void serialize(void* buffer_, size_t buffer_size) const noexcept
|
||||
{
|
||||
size_t remaining_buffer_size = buffer_size;
|
||||
uint8_t* buffer = static_cast<uint8_t*>(buffer_);
|
||||
uint32_t n = mMap.size();
|
||||
writeToBuffer<uint32_t>(n, buffer, remaining_buffer_size);
|
||||
for (auto&& p : mMap)
|
||||
{
|
||||
uint32_t key_size = p.first.getSerializationSize();
|
||||
TLLM_CHECK(key_size <= remaining_buffer_size);
|
||||
writeToBuffer<uint32_t>(key_size, buffer, remaining_buffer_size);
|
||||
p.first.serialize(buffer, key_size);
|
||||
buffer += key_size;
|
||||
remaining_buffer_size -= key_size;
|
||||
|
||||
uint32_t obj_size = p.second.getSerializationSize();
|
||||
TLLM_CHECK(obj_size <= remaining_buffer_size);
|
||||
writeToBuffer<uint32_t>(obj_size, buffer, remaining_buffer_size);
|
||||
p.second.serialize(buffer, obj_size);
|
||||
buffer += obj_size;
|
||||
remaining_buffer_size -= obj_size;
|
||||
}
|
||||
TLLM_CHECK(remaining_buffer_size == 0);
|
||||
}
|
||||
|
||||
// Returns directly if the Cubin already exists in the registry, otherwise call compileEngine to compile it.
|
||||
//
|
||||
// compileEngine may be nullptr.
|
||||
CubinObj* getCubin(Key const& key, CompileEngine* compileEngine)
|
||||
{
|
||||
auto iter = mMap.find(key);
|
||||
if (iter != mMap.end())
|
||||
{
|
||||
return &(iter->second);
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(compileEngine != nullptr, "Key not found; compileEngine shouldn't be nullptr.");
|
||||
|
||||
CubinObj obj = compileEngine->compile();
|
||||
auto insertResultIter = mMap.insert({key, std::move(obj)}).first;
|
||||
return &(insertResultIter->second);
|
||||
}
|
||||
|
||||
void clear()
|
||||
{
|
||||
mMap.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<Key, CubinObj, Hash> mMap;
|
||||
};
|
||||
|
||||
using CubinObjKey = XQAKernelFullHashKey;
|
||||
using CubinObjHasher = XQAKernelFullHasher;
|
||||
using CubinObjRegistry = CubinObjRegistryTemplate<CubinObjKey, CubinObjHasher>;
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,305 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.h"
|
||||
|
||||
#include "compileEngine.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAConstants.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
using ::tensorrt_llm::kernels::XQAKernelRuntimeHashKey;
|
||||
using ::tensorrt_llm::kernels::XQAParams;
|
||||
using ::tensorrt_llm::kernels::XQAKernelMetaInfo;
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromKernelMeta(XQAKernelMetaInfo const& kernelMeta)
|
||||
{
|
||||
return {kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth, kernelMeta.mNumQHeadsOverKV,
|
||||
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens};
|
||||
}
|
||||
|
||||
XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParams)
|
||||
{
|
||||
unsigned int head_size = xqaParams.head_size;
|
||||
int num_q_heads = xqaParams.num_q_heads;
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
TLLM_CHECK_WITH_INFO(num_q_heads % num_kv_heads == 0, "numQHeads should be multiple of numKVHeads.");
|
||||
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
|
||||
unsigned int beam_width = xqaParams.beam_width;
|
||||
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
|
||||
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
|
||||
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
|
||||
unsigned int m_tilesize = xqaParams.multi_query_tokens ? 16 : num_q_heads_over_kv;
|
||||
|
||||
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
|
||||
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
|
||||
xqaParams.multi_query_tokens};
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
DecoderXQAImplJIT::DecoderXQAImplJIT(DecoderXQARunner* runner)
|
||||
: DecoderXQAImpl(runner)
|
||||
, mForceXQA(tensorrt_llm::common::forceXQAKernels())
|
||||
, mSM(tensorrt_llm::common::getSMVersion())
|
||||
, mCubinObjRegistry(runner->mResource->getCubinObjRegistry())
|
||||
{
|
||||
initSupportedConfigs();
|
||||
}
|
||||
|
||||
void DecoderXQAImplJIT::initSupportedConfigs()
|
||||
{
|
||||
mSupportedConfigs.clear();
|
||||
|
||||
size_t nbConfigs = sizeof(sXqaKernelMetaInfo) / sizeof(sXqaKernelMetaInfo[0]);
|
||||
for (size_t i = 0; i < nbConfigs; ++i)
|
||||
{
|
||||
XQAKernelMetaInfo const& kernelMeta = sXqaKernelMetaInfo[i];
|
||||
if (!kernelMeta.mMultiQueryTokens)
|
||||
{
|
||||
// Exclude medusa kernels from JIT because they are compiled from a different CUDA source file.
|
||||
mSupportedConfigs.insert(getRuntimeHashKeyFromKernelMeta(kernelMeta));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool DecoderXQAImplJIT::supportConfig(XQAParams const& xqaParams) const
|
||||
{
|
||||
return mSupportedConfigs.find(getRuntimeHashKeyFromXQAParams(xqaParams)) != mSupportedConfigs.end();
|
||||
}
|
||||
|
||||
bool DecoderXQAImplJIT::mayHavePerfGain(XQAParams const& xqaParams) const
|
||||
{
|
||||
// NOTE: only XQA supports multi_query_tokens (Medusa mode).
|
||||
if (mForceXQA || xqaParams.multi_query_tokens)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
int batch_size = static_cast<int>(xqaParams.batch_size);
|
||||
int multi_block_count = 1;
|
||||
if (xqaParams.multi_block_mode)
|
||||
{
|
||||
int history_length = xqaParams.timestep;
|
||||
multi_block_count = history_length / kMinHistoryTokensPerBlock;
|
||||
}
|
||||
int block_count = num_kv_heads * batch_size * multi_block_count;
|
||||
return static_cast<float>(block_count) * kEnableMinBlockFactor >= static_cast<float>(mRunner->mMultiProcessorCount);
|
||||
}
|
||||
|
||||
bool DecoderXQAImplJIT::shouldUse(XQAParams const& xqaParams, bool forConfigurePlugin)
|
||||
{
|
||||
bool is_config_supported = supportConfig(xqaParams);
|
||||
if (forConfigurePlugin)
|
||||
{
|
||||
return is_config_supported;
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_config_supported && mayHavePerfGain(xqaParams);
|
||||
}
|
||||
}
|
||||
|
||||
jit::CubinObjKey DecoderXQAImplJIT::getCubinObjKeyFromXQAParams(XQAParams const& xqaParams) const
|
||||
{
|
||||
XQAKernelLoadHashKey loadKey;
|
||||
loadKey.data_type = xqaParams.data_type;
|
||||
loadKey.sm = mSM;
|
||||
|
||||
XQAKernelRuntimeHashKey runtimeKey = getRuntimeHashKeyFromXQAParams(xqaParams);
|
||||
return {loadKey, runtimeKey};
|
||||
}
|
||||
|
||||
void DecoderXQAImplJIT::prepare(XQAParams const& xqaParams)
|
||||
{
|
||||
jit::CubinObjKey key = getCubinObjKeyFromXQAParams(xqaParams);
|
||||
|
||||
jit::CompileEngine compileEngine(mSM, xqaParams);
|
||||
|
||||
// Discard getCubin() result.
|
||||
mCubinObjRegistry->getCubin(key, &compileEngine);
|
||||
}
|
||||
|
||||
void DecoderXQAImplJIT::runWithKVLinearBuffer(
|
||||
XQAParams const& xqaParams, KVLinearBuffer const& kv_linear_buffer, cudaStream_t const& stream)
|
||||
{
|
||||
runDispatchKVCacheBuffer<KVLinearBuffer>(xqaParams, kv_linear_buffer, stream);
|
||||
}
|
||||
|
||||
void DecoderXQAImplJIT::runWithKVBlockArray(
|
||||
XQAParams const& xqaParams, KVBlockArray const& kv_block_array, cudaStream_t const& stream)
|
||||
{
|
||||
runDispatchKVCacheBuffer<KVBlockArray>(xqaParams, kv_block_array, stream);
|
||||
}
|
||||
|
||||
#define XQA_KERNEL_RUN(DATA_TYPE) \
|
||||
runImpl<DATA_TYPE, KVCacheBuffer>(xqa_params, kv_cache_buffer, mRunner->mMultiProcessorCount, stream)
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
void DecoderXQAImplJIT::runDispatchKVCacheBuffer(
|
||||
XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream)
|
||||
{
|
||||
if (mRunner->mDataType == DATA_TYPE_FP16)
|
||||
{
|
||||
XQA_KERNEL_RUN(__half);
|
||||
}
|
||||
else
|
||||
{
|
||||
XQA_KERNEL_RUN(__nv_bfloat16);
|
||||
}
|
||||
}
|
||||
|
||||
#undef XQA_KERNEL_RUN
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const& kv_cache_buffer,
|
||||
int multiprocessor_count, cudaStream_t const& stream)
|
||||
{
|
||||
unsigned int head_size = xqaParams.head_size;
|
||||
int num_q_heads = xqaParams.num_q_heads;
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
TLLM_CHECK_WITH_INFO(num_q_heads % num_kv_heads == 0, "numQHeads should be multiple of numKVHeads.");
|
||||
unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
|
||||
unsigned int beam_width = xqaParams.beam_width;
|
||||
unsigned int batch_beam_size = xqaParams.batch_size * beam_width;
|
||||
|
||||
const KvCacheDataType cache_type = xqaParams.kv_cache_quant_mode.hasInt8KvCache()
|
||||
? KvCacheDataType::INT8
|
||||
: (xqaParams.kv_cache_quant_mode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
|
||||
|
||||
XQALaunchParam<KVCacheBuffer> launchParams;
|
||||
buildXQALaunchParams(launchParams, xqaParams, kv_cache_buffer);
|
||||
|
||||
// Build cu_seqlens, padding_offset, and rotary inv freq tensors
|
||||
BuildDecoderInfoParams<T> decoder_params;
|
||||
memset(&decoder_params, 0, sizeof(decoder_params));
|
||||
decoder_params.seqQOffsets = launchParams.cu_seq_lens;
|
||||
decoder_params.seqKVLengths = xqaParams.sequence_lengths;
|
||||
decoder_params.batchSize = int(batch_beam_size);
|
||||
decoder_params.maxQSeqLength = xqaParams.generation_input_length;
|
||||
// Rotary embedding inv_freq buffer.
|
||||
decoder_params.rotaryEmbeddingScale = xqaParams.rotary_embedding_scale;
|
||||
decoder_params.rotaryEmbeddingBase = xqaParams.rotary_embedding_base;
|
||||
decoder_params.rotaryEmbeddingDim = xqaParams.rotary_embedding_dim;
|
||||
decoder_params.rotaryScalingType = xqaParams.rotary_embedding_scale_type;
|
||||
decoder_params.rotaryEmbeddingInvFreq = launchParams.rotary_inv_freq_buf;
|
||||
decoder_params.rotaryEmbeddingMaxPositions = xqaParams.rotary_embedding_max_positions;
|
||||
|
||||
invokeBuildDecoderInfo(decoder_params, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// IDEA: Store rotary_processed Q buffer to output buffer.
|
||||
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
|
||||
void const* xqa_q_input_ptr = xqaParams.output;
|
||||
QKVPreprocessingParams<T, KVCacheBuffer> preprocessingParms{static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
|
||||
nullptr, static_cast<T*>(const_cast<void*>(xqaParams.output)), kv_cache_buffer,
|
||||
static_cast<T const*>(xqaParams.qkv_bias), nullptr, xqaParams.sequence_lengths, nullptr,
|
||||
launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, xqaParams.kv_scale_orig_quant,
|
||||
xqaParams.spec_decoding_position_offsets, int(batch_beam_size), xqaParams.generation_input_length,
|
||||
xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads,
|
||||
xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size,
|
||||
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
|
||||
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
|
||||
xqaParams.position_shift_enabled, cache_type, true, false, multiprocessor_count};
|
||||
|
||||
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// Use mTileSize = 16 kernels when qSeqLen <= 16.
|
||||
unsigned int qSeqLen = static_cast<unsigned int>(xqaParams.generation_input_length);
|
||||
unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
|
||||
// MultiQueryToken kernels can support any num_q_heads_over_kv that is power of 2.
|
||||
unsigned int kernel_num_q_heads_over_kv = xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
|
||||
// MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
|
||||
unsigned int kernel_m_tilesize = xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
|
||||
|
||||
jit::CubinObjKey key = getCubinObjKeyFromXQAParams(xqaParams);
|
||||
jit::CubinObj* cubinObj = mCubinObjRegistry->getCubin(key, /*compileEngine=*/nullptr);
|
||||
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
// MultiQueryTokens (generation_input_length > 1) need extra parameters (like qSeqLen, log2HeadGrpSize, and
|
||||
// mask). Input parameters for MultiQueryTokens kernels.
|
||||
unsigned int log2HeadGrpSize = log2(num_q_heads_over_kv);
|
||||
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen << log2HeadGrpSize, mTileSize);
|
||||
int const* maskPtr = xqaParams.spec_decoding_packed_mask;
|
||||
// TODO: add fp8/int8 kv cache kernels.
|
||||
float kvCacheQuantOrig = 1.0f;
|
||||
// TODO: merge SingleQueryToken params and MultiQueryTokens params into one kernelParams.
|
||||
void* kernelParams[]
|
||||
= {&qSeqLen, &launchParams.num_k_heads, &log2HeadGrpSize, &launchParams.output, &xqa_q_input_ptr, &maskPtr,
|
||||
&launchParams.kvCacheParams, &launchParams.batch_size, &kvCacheQuantOrig, &launchParams.scratch};
|
||||
int multi_block = 1;
|
||||
if (xqaParams.multi_block_mode)
|
||||
{
|
||||
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
|
||||
cudaMemsetAsync(
|
||||
xqaParams.workspaces, 0, sizeof(int) * xqaParams.batch_size * xqaParams.num_kv_heads, stream);
|
||||
}
|
||||
dim3 gridDim(multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp, xqaParams.batch_size);
|
||||
dim3 blockDim(128, 1, 2);
|
||||
cubinObj->launch(gridDim, blockDim, stream, kernelParams);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 9;
|
||||
uint32_t idxNextParam = 0;
|
||||
void* kernelParams[kMAX_NB_KERNEL_PARAMS];
|
||||
auto appendParam = [&](auto* p) mutable
|
||||
{
|
||||
TLLM_CHECK(idxNextParam < kMAX_NB_KERNEL_PARAMS);
|
||||
kernelParams[idxNextParam++] = p;
|
||||
};
|
||||
appendParam(&launchParams.num_k_heads);
|
||||
appendParam(&launchParams.output);
|
||||
appendParam(&xqa_q_input_ptr);
|
||||
appendParam(&launchParams.kvCacheParams);
|
||||
if (xqaParams.beam_width > 1)
|
||||
{
|
||||
appendParam(&launchParams.beamSearchParams.value());
|
||||
}
|
||||
appendParam(&launchParams.batch_size);
|
||||
appendParam(&launchParams.kv_scale_quant_orig);
|
||||
appendParam(&launchParams.scratch);
|
||||
kernelParams[idxNextParam] = nullptr; // one extra nullptr at end as guard.
|
||||
int multi_block = 1;
|
||||
if (xqaParams.multi_block_mode)
|
||||
{
|
||||
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
|
||||
cudaMemsetAsync(
|
||||
xqaParams.workspaces, 0, sizeof(int) * xqaParams.batch_size * xqaParams.num_kv_heads, stream);
|
||||
}
|
||||
|
||||
dim3 gridDim(multi_block, xqaParams.num_kv_heads, xqaParams.batch_size);
|
||||
dim3 blockDim(128, 1, 2);
|
||||
cubinObj->launch(gridDim, blockDim, stream, kernelParams);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h"
|
||||
|
||||
#include "compileEngine.h"
|
||||
#include "cubinObjRegistry.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
|
||||
#include <unordered_set>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
class DecoderXQAImplJIT : public DecoderXQAImpl
|
||||
{
|
||||
public:
|
||||
DecoderXQAImplJIT(DecoderXQARunner* runner);
|
||||
|
||||
bool shouldUse(XQAParams const& xqaParams, bool forConfigurePlugin) override;
|
||||
void prepare(XQAParams const& xqaParams) override;
|
||||
|
||||
protected:
|
||||
void runWithKVLinearBuffer(
|
||||
XQAParams const& xqaParams, KVLinearBuffer const& kv_linear_buffer, cudaStream_t const& stream) override;
|
||||
void runWithKVBlockArray(
|
||||
XQAParams const& xqaParams, KVBlockArray const& kv_block_array, cudaStream_t const& stream) override;
|
||||
|
||||
private:
|
||||
void initSupportedConfigs();
|
||||
//! Whether DecoderXQAImplJIT supports xqaParams.
|
||||
bool supportConfig(XQAParams const& xqaParams) const;
|
||||
//! Whether DecoderXQAImplJIT has perf gain over the default (non-XQA-optimized) implementation.
|
||||
bool mayHavePerfGain(XQAParams const& xqaParams) const;
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
void runImpl(XQAParams const& xqaParams, KVCacheBuffer const& kv_cache_buffer, int multiprocessor_count,
|
||||
cudaStream_t const& stream);
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
void runDispatchKVCacheBuffer(
|
||||
XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream);
|
||||
|
||||
bool mForceXQA;
|
||||
int mSM;
|
||||
|
||||
jit::CubinObjRegistry* mCubinObjRegistry;
|
||||
jit::CubinObjKey getCubinObjKeyFromXQAParams(XQAParams const& xqaParams) const;
|
||||
|
||||
//! The first prototype just takes whatever available from the Precompiled cubins.
|
||||
std::unordered_set<XQAKernelRuntimeHashKey, XQAKernelRuntimeHasher> mSupportedConfigs;
|
||||
};
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d7ab77d5678faa3cf90712f9919d71cd9b4d68f5e334f87c6047593963e861bf
|
||||
size 80328568
|
||||
@ -0,0 +1,2 @@
|
||||
24849f03d35877abb0e0f393d32e5000 libtensorrt_llm_nvrtc_wrapper.so
|
||||
942b83732d029cc3eaef9f5a849218d75161ec12 commit
|
||||
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* This file is NOT thread safe.
|
||||
*/
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
#if COMPILING_DLL
|
||||
#define DLLEXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define DLLEXPORT __declspec(dllimport)
|
||||
#endif
|
||||
|
||||
#else // _WIN32
|
||||
#define DLLEXPORT // Nothing.
|
||||
#endif
|
||||
|
||||
#if __cplusplus
|
||||
extern "C"
|
||||
{
|
||||
#endif
|
||||
|
||||
typedef struct
|
||||
{
|
||||
// Compute capability, e.g. 89.
|
||||
int sm;
|
||||
|
||||
unsigned int head_size;
|
||||
unsigned int num_q_heads;
|
||||
unsigned int num_kv_heads;
|
||||
unsigned int beam_width;
|
||||
unsigned int tokens_per_block;
|
||||
bool multi_query_tokens;
|
||||
bool paged_kv_cache;
|
||||
|
||||
// Actual type: tensorrt_llm::kernels::Data_type
|
||||
int data_type;
|
||||
int kv_cache_data_type;
|
||||
} tllmXqaJitContext;
|
||||
|
||||
// tllmXqaJitProgram is an opaque handle for a program.
|
||||
typedef struct _tllmXqaJitProgram* tllmXqaJitProgram;
|
||||
|
||||
typedef enum
|
||||
{
|
||||
TLLM_XQA_JIT_SUCCESS = 0,
|
||||
TLLM_XQA_JIT_INVALID_INPUT = 1,
|
||||
TLLM_XQA_JIT_INTERNAL_ERROR = 2,
|
||||
} tllmXqaJitStatus;
|
||||
|
||||
// context must outlive prog.
|
||||
DLLEXPORT tllmXqaJitStatus tllmXqaJitCreateAndCompileProgram(
|
||||
tllmXqaJitProgram* prog, tllmXqaJitContext const* context);
|
||||
DLLEXPORT tllmXqaJitStatus tllmXqaJitGetCUBINSize(tllmXqaJitProgram prog, size_t* cubinSizeRet);
|
||||
DLLEXPORT tllmXqaJitStatus tllmXqaJitGetCUBIN(tllmXqaJitProgram prog, char* cubin);
|
||||
DLLEXPORT tllmXqaJitStatus tllmXqaJitDestroyProgram(tllmXqaJitProgram* prog);
|
||||
|
||||
// Returns the size of the error string associated with the last non-success tllmXqaJit function call (including the
|
||||
// trailing \0). Returns 0 if there is no such non-success function call.
|
||||
DLLEXPORT size_t tllmXqaJitGetLastErrorStringSize();
|
||||
// Returns the error string.
|
||||
// Output can be nullptr if the returned value of tllmGetLastErrorStringSize() is 0.
|
||||
DLLEXPORT void tllmXqaJitGetLastErrorString(char* output);
|
||||
|
||||
#if __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cbf7c5d33ad7d0533569e1be71e6e13f04c7a001cab15ed55eba81c9f8bb6ad3
|
||||
size 83431088
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:45dcdb034ff53e4f862cc035545973f1b9efae4a8aa3e83555fd77f8b55311db
|
||||
size 966144
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4ca7c531980130dfd37c59132bbce8e90b821ecc31fa20d86726eec153bb016e
|
||||
size 3488
|
||||
@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
T readFromBuffer(uint8_t const*& buffer, size_t& remaining_buffer_size)
|
||||
{
|
||||
TLLM_CHECK(sizeof(T) <= remaining_buffer_size);
|
||||
|
||||
T result = *reinterpret_cast<T const*>(buffer);
|
||||
buffer += sizeof(T);
|
||||
remaining_buffer_size -= sizeof(T);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void writeToBuffer(T output, uint8_t*& buffer, size_t& remaining_buffer_size)
|
||||
{
|
||||
TLLM_CHECK(sizeof(T) <= remaining_buffer_size);
|
||||
|
||||
*reinterpret_cast<T*>(buffer) = output;
|
||||
buffer += sizeof(T);
|
||||
remaining_buffer_size -= sizeof(T);
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -20,6 +20,7 @@
|
||||
#include "tensorrt_llm/common/workspace.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAConstants.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheUtils.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
@ -36,219 +37,14 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
struct XQAKernelLoadHashKey
|
||||
{
|
||||
Data_type data_type;
|
||||
unsigned int sm;
|
||||
|
||||
bool operator==(const XQAKernelLoadHashKey other) const
|
||||
{
|
||||
return data_type == other.data_type && sm == other.sm;
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelLoadHasher
|
||||
{
|
||||
size_t operator()(XQAKernelLoadHashKey const& s) const
|
||||
{
|
||||
size_t key = s.data_type;
|
||||
key <<= 16;
|
||||
key ^= s.sm;
|
||||
return key;
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelRuntimeHashKey
|
||||
{
|
||||
Data_type kv_data_type;
|
||||
unsigned int head_size;
|
||||
unsigned int beam_size;
|
||||
unsigned int num_q_heads_per_kv;
|
||||
unsigned int m_tilesize;
|
||||
unsigned int tokens_per_page;
|
||||
bool paged_kv_cache;
|
||||
bool multi_query_tokens;
|
||||
|
||||
bool operator==(const XQAKernelRuntimeHashKey other) const
|
||||
{
|
||||
return kv_data_type == other.kv_data_type && head_size == other.head_size
|
||||
&& num_q_heads_per_kv == other.num_q_heads_per_kv && beam_size == other.beam_size
|
||||
&& multi_query_tokens == other.multi_query_tokens && m_tilesize == other.m_tilesize
|
||||
&& tokens_per_page == other.tokens_per_page && paged_kv_cache == other.paged_kv_cache;
|
||||
}
|
||||
};
|
||||
|
||||
struct XQAKernelRuntimeHasher
|
||||
{
|
||||
size_t operator()(XQAKernelRuntimeHashKey const& s) const
|
||||
{
|
||||
size_t key = s.kv_data_type;
|
||||
key <<= 16;
|
||||
key ^= s.head_size;
|
||||
key <<= 8;
|
||||
key ^= s.num_q_heads_per_kv;
|
||||
key <<= 8;
|
||||
key ^= s.beam_size;
|
||||
key <<= 6;
|
||||
key ^= s.m_tilesize;
|
||||
key <<= 10;
|
||||
key ^= s.tokens_per_page;
|
||||
key <<= 1;
|
||||
key ^= s.paged_kv_cache;
|
||||
key <<= 1;
|
||||
key ^= s.multi_query_tokens;
|
||||
return key;
|
||||
}
|
||||
};
|
||||
|
||||
// NOTE: we use int32_t sequence lengths as gpt attention plugins use int32_t for that.
|
||||
// XQA kernels assume all length should use uint32_t.
|
||||
// NOTE: Linear KV cache and paged KV cache uses the same structure.
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
struct KVCache
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KVCache<KVBlockArray>
|
||||
{
|
||||
// Start address of the paged kv block pool.
|
||||
void* poolPtr = nullptr;
|
||||
// Block indices in the memory pool.
|
||||
int32_t const* blockIndices = nullptr;
|
||||
int32_t const* sequence_lengths = nullptr;
|
||||
// NOTE: max_num_blocks_per_sequence for paged kv cache.
|
||||
uint32_t capacity = 0;
|
||||
|
||||
KVCache(KVBlockArray& kv_cache_buffer)
|
||||
{
|
||||
poolPtr = kv_cache_buffer.mPrimaryPoolPtr;
|
||||
blockIndices = reinterpret_cast<KVCacheIndex::UnderlyingType const*>(kv_cache_buffer.data);
|
||||
}
|
||||
|
||||
KVCache() = default;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KVCache<KVLinearBuffer>
|
||||
{
|
||||
// Buffer address.
|
||||
void* data = nullptr;
|
||||
int32_t const* sequence_lengths = nullptr;
|
||||
// NOTE: max_sequence_length for linear kv cache.
|
||||
uint32_t capacity = 0;
|
||||
|
||||
KVCache(KVLinearBuffer& kv_cache_buffer)
|
||||
{
|
||||
data = kv_cache_buffer.data;
|
||||
}
|
||||
|
||||
KVCache() = default;
|
||||
};
|
||||
|
||||
struct BeamSearchParams
|
||||
{
|
||||
int32_t const* indices; // cacheIndir with shape: [batchSize][beamWidth][capacity]
|
||||
int32_t capacity;
|
||||
int32_t const* ctxLenList; // shape: [batchSize][beamWidth]. Should be [batchSize] but we have to match trt-llm API.
|
||||
};
|
||||
|
||||
// XQA kernels assume all integer values should use uint32_t.
|
||||
template <typename KVCacheBuffer>
|
||||
struct XQALaunchParam
|
||||
{
|
||||
uint32_t num_k_heads;
|
||||
void* output;
|
||||
void const* qkv;
|
||||
KVCache<KVCacheBuffer> kvCacheParams;
|
||||
std::optional<BeamSearchParams> beamSearchParams;
|
||||
uint32_t batch_size;
|
||||
float const* kv_scale_quant_orig = nullptr;
|
||||
int* cu_seq_lens = nullptr;
|
||||
float* rotary_inv_freq_buf = nullptr;
|
||||
void* scratch = nullptr;
|
||||
};
|
||||
|
||||
// Setup launch params.
|
||||
template <typename KVCacheBuffer>
|
||||
void buildXQALaunchParams(
|
||||
XQALaunchParam<KVCacheBuffer>& launchParams, XQAParams const& params, KVCacheBuffer kv_cache_buffer)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
params.data_type == DATA_TYPE_FP16 || params.data_type == DATA_TYPE_BF16, "Only fp16 or bf16 supported now.");
|
||||
memset(&launchParams, 0, sizeof(XQALaunchParam<KVCacheBuffer>));
|
||||
launchParams.num_k_heads = params.num_kv_heads;
|
||||
launchParams.output = static_cast<uint8_t*>(params.output);
|
||||
launchParams.qkv = static_cast<uint8_t const*>(params.qkv);
|
||||
launchParams.batch_size = params.batch_size;
|
||||
launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
|
||||
|
||||
// Workspace.
|
||||
size_t offset = 0;
|
||||
int8_t* workspace = reinterpret_cast<int8_t*>(params.workspaces);
|
||||
unsigned int batch_beam_size = params.batch_size * params.beam_width;
|
||||
const size_t cu_seqlens_size = sizeof(int) * (batch_beam_size + 1);
|
||||
const size_t rotary_inv_freq_size = sizeof(float) * batch_beam_size * params.rotary_embedding_dim / 2;
|
||||
launchParams.cu_seq_lens = reinterpret_cast<int*>(workspace);
|
||||
workspace = nextWorkspacePtrWithAlignment(workspace, cu_seqlens_size);
|
||||
launchParams.rotary_inv_freq_buf = reinterpret_cast<float*>(workspace);
|
||||
auto const multi_block_workspace_alignment = roundUp(
|
||||
sizeof(half) * params.head_size * (params.num_q_heads / params.num_kv_heads) * params.beam_width, 128);
|
||||
workspace = nextWorkspacePtrWithAlignment(workspace, rotary_inv_freq_size, multi_block_workspace_alignment);
|
||||
launchParams.scratch = reinterpret_cast<void*>(workspace);
|
||||
|
||||
launchParams.kvCacheParams = KVCache<KVCacheBuffer>(kv_cache_buffer);
|
||||
launchParams.kvCacheParams.sequence_lengths = params.sequence_lengths;
|
||||
launchParams.kvCacheParams.capacity
|
||||
= params.paged_kv_cache ? params.max_blocks_per_sequence : params.max_attention_window_size;
|
||||
// TODO: beam searching has not been implemented yet.
|
||||
if (params.beam_width > 1)
|
||||
{
|
||||
launchParams.beamSearchParams
|
||||
= BeamSearchParams{params.cache_indir, params.max_attention_window_size, params.context_lengths};
|
||||
}
|
||||
else
|
||||
{
|
||||
launchParams.beamSearchParams = std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
std::optional<T> getGlobalVar(
|
||||
tensorrt_llm::common::CUDADriverWrapper const& driver, CUmodule hmod, char const* const name, bool required = false)
|
||||
{
|
||||
T* pVar = nullptr;
|
||||
size_t size = 0;
|
||||
auto const error = driver.cuModuleGetGlobal(reinterpret_cast<CUdeviceptr*>(&pVar), &size, hmod, name);
|
||||
T ret;
|
||||
switch (error)
|
||||
{
|
||||
case CUDA_SUCCESS:
|
||||
TLLM_CHECK(size == sizeof(T));
|
||||
check_cuda_error(cudaMemcpy(&ret, pVar, size, cudaMemcpyDeviceToHost));
|
||||
break;
|
||||
case CUDA_ERROR_NOT_FOUND:
|
||||
if (!required)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
[[fallthrough]];
|
||||
default: cuErrCheck(("Failed to retrieve global variable from cubin.", error), driver);
|
||||
}
|
||||
return std::optional<T>{std::move(ret)};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class XQAKernelList
|
||||
{
|
||||
public:
|
||||
using TKernelMeta = XQAKernelMetaInfo;
|
||||
|
||||
XQAKernelList(Data_type type, unsigned int sm)
|
||||
: mDataType(type)
|
||||
: mDriver(tensorrt_llm::common::CUDADriverWrapper::getInstance())
|
||||
, mDataType(type)
|
||||
, mKernelMetaCount(sizeof(sXqaKernelMetaInfo) / sizeof(sXqaKernelMetaInfo[0]))
|
||||
, mKernelMeta(&sXqaKernelMetaInfo[0])
|
||||
, mSM(sm)
|
||||
@ -268,6 +64,10 @@ public:
|
||||
if (kernelMeta.mSM != mSM || kernelMeta.mDataType != mDataType)
|
||||
continue;
|
||||
|
||||
// Cubins for kernels that would take the JIT path are removed from kernelMeta.
|
||||
if (kernelMeta.mCubin == nullptr)
|
||||
continue;
|
||||
|
||||
CUmodule hmod{0};
|
||||
auto findModuleIter = mModules.find(kernelMeta.mCubin);
|
||||
if (findModuleIter != mModules.end())
|
||||
@ -276,13 +76,13 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
cuErrCheck(mDriver.cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
|
||||
mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
|
||||
}
|
||||
|
||||
XQAKernelFuncInfo funcInfo{};
|
||||
funcInfo.mMetaInfoIndex = i;
|
||||
cuErrCheck(mDriver.cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
|
||||
cuErrCheck(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
|
||||
funcInfo.mSharedMemBytes = getGlobalVar<uint32_t>(mDriver, hmod, "smemSize", true).value();
|
||||
funcInfo.mKernelType = getGlobalVar<XQAKernelType>(mDriver, hmod, "kernelType", false)
|
||||
.value_or(XQAKernelType::kAMPERE_WARP_SPECIALIZED);
|
||||
@ -290,7 +90,7 @@ public:
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
|
||||
if (funcInfo.mSharedMemBytes >= 46 * 1024)
|
||||
{
|
||||
cuErrCheck(mDriver.cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
cuErrCheck(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, funcInfo.mSharedMemBytes),
|
||||
mDriver);
|
||||
}
|
||||
@ -386,7 +186,7 @@ public:
|
||||
nullptr, static_cast<T*>(const_cast<void*>(xqaParams.output)), kv_cache_buffer,
|
||||
static_cast<T const*>(xqaParams.qkv_bias), nullptr, xqaParams.sequence_lengths, nullptr,
|
||||
launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, xqaParams.kv_scale_orig_quant,
|
||||
xqaParams.medusa_position_offsets, int(batch_beam_size), xqaParams.generation_input_length,
|
||||
xqaParams.spec_decoding_position_offsets, int(batch_beam_size), xqaParams.generation_input_length,
|
||||
xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads,
|
||||
xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size,
|
||||
@ -424,7 +224,7 @@ public:
|
||||
// mask). Input parameters for MultiQueryTokens kernels.
|
||||
unsigned int log2HeadGrpSize = log2(num_q_heads_over_kv);
|
||||
unsigned int nbTokenBlocksPerGrp = divUp(qSeqLen << log2HeadGrpSize, mTileSize);
|
||||
int const* maskPtr = xqaParams.medusa_packed_mask;
|
||||
int const* maskPtr = xqaParams.spec_decoding_packed_mask;
|
||||
// TODO: add fp8/int8 kv cache kernels.
|
||||
float kvCacheQuantOrig = 1.0f;
|
||||
// TODO: merge SingleQueryToken params and MultiQueryTokens params into one kernelParams.
|
||||
@ -439,7 +239,7 @@ public:
|
||||
launchParams.scratch, 0, sizeof(int) * xqaParams.batch_size * xqaParams.num_kv_heads, stream));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
|
||||
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
@ -484,7 +284,7 @@ public:
|
||||
launchParams.scratch, 0, sizeof(int) * xqaParams.batch_size * xqaParams.num_kv_heads, stream));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads, xqaParams.batch_size, 128, 1,
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads, xqaParams.batch_size, 128, 1,
|
||||
isGmmaKernel ? 3 : 2, shared_mem_bytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
@ -492,34 +292,6 @@ public:
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
static int computeMultiBlockCount(XQAParams const& xqaParams, int batch_size, int multiprocessor_count)
|
||||
{
|
||||
if (envXqaNbCtaPerKVHead().has_value())
|
||||
{
|
||||
return envXqaNbCtaPerKVHead().value();
|
||||
}
|
||||
int multi_block_count = 1;
|
||||
int num_kv_heads = xqaParams.num_kv_heads;
|
||||
int history_length = xqaParams.timestep;
|
||||
|
||||
multi_block_count = history_length / kMinHistoryTokensPerBlock;
|
||||
multi_block_count = std::max(multi_block_count, 1);
|
||||
// adjust to kTargetWaveFactor, as already initialized using kMinHistoryTokensPerBlock, only need to decrease.
|
||||
double wave_count = (double) batch_size * num_kv_heads * multi_block_count / (double) multiprocessor_count;
|
||||
double adj_factor = wave_count / (double) kTargetWaveFactor;
|
||||
if (adj_factor > 1.0)
|
||||
{
|
||||
multi_block_count = floor(multi_block_count / adj_factor);
|
||||
}
|
||||
multi_block_count = std::max(multi_block_count, 1);
|
||||
|
||||
// add limitation on upper bound.
|
||||
multi_block_count = std::min(xqaMaxNbCtaPerKVHeadFactor(), multi_block_count);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(multi_block_count >= 1, "MultiBlock count should be larger than 1");
|
||||
return multi_block_count;
|
||||
}
|
||||
|
||||
private:
|
||||
static uint32_t getElemBytes(CUtensorMapDataType_enum dataType)
|
||||
{
|
||||
@ -565,7 +337,7 @@ private:
|
||||
}
|
||||
}();
|
||||
|
||||
cuErrCheck(mDriver.cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
cuErrCheck(mDriver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
|
||||
mDriver);
|
||||
@ -594,7 +366,7 @@ private:
|
||||
}
|
||||
}();
|
||||
|
||||
cuErrCheck(mDriver.cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
cuErrCheck(mDriver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
|
||||
mDriver);
|
||||
@ -619,7 +391,7 @@ private:
|
||||
}
|
||||
|
||||
protected:
|
||||
tensorrt_llm::common::CUDADriverWrapper mDriver;
|
||||
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> mDriver;
|
||||
|
||||
Data_type mDataType;
|
||||
TKernelMeta const* mKernelMeta;
|
||||
@ -706,7 +478,7 @@ void DecoderXQAImplPrecompiled::runDispatchBuffer(
|
||||
|
||||
#undef XQA_KERNEL_RUN
|
||||
|
||||
bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams)
|
||||
bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool /*forConfigurePlugin*/)
|
||||
{
|
||||
XQAKernelList const* xqa_kernel = getXQAKernels(mRunner->mDataType, tensorrt_llm::common::getSMVersion());
|
||||
return xqa_kernel->supportConfig(xqaParams)
|
||||
|
||||
@ -29,7 +29,7 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
bool shouldUse(XQAParams const& xqaParams) override;
|
||||
bool shouldUse(XQAParams const& xqaParams, bool forConfigurePlugin) override;
|
||||
void prepare(XQAParams const& xqa_params) override;
|
||||
|
||||
protected:
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/cubin/xqa_kernel_cubin.h"
|
||||
@ -36,9 +35,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
DecoderXQARunner::DecoderXQARunner(
|
||||
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode)
|
||||
: mPrepareCalled(false)
|
||||
DecoderXQARunner::DecoderXQARunner(Resource* resource, const XQADataType data_type, int num_heads, int num_kv_heads,
|
||||
int head_size, bool multi_block_mode)
|
||||
: mResource(resource)
|
||||
, mDataType(data_type)
|
||||
, mNumHeads(num_heads)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
@ -46,9 +45,12 @@ DecoderXQARunner::DecoderXQARunner(
|
||||
, mMultiBlockMode(multi_block_mode)
|
||||
{
|
||||
mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
// The initialization of mImpl must be the last line because *this needs to be fully initialized before calling
|
||||
// DecoderXQAImpl::create().
|
||||
mImpl = DecoderXQAImpl::create(this, DecoderXQAImpl::ImplType::kPrecompiled);
|
||||
|
||||
// TODO(minwei): needs both impls because medusa kernels haven't been migrated to JIT yet (which should be).
|
||||
// mJITImpl/mPrecompiledImpl assignments must be the last lines of this constructor. DecoderXQAImpl::create() relies
|
||||
// on *this being fully initialized.
|
||||
mJITImpl = DecoderXQAImpl::create(this, DecoderXQAImpl::ImplType::kJIT);
|
||||
mPrecompiledImpl = DecoderXQAImpl::create(this, DecoderXQAImpl::ImplType::kPrecompiled);
|
||||
}
|
||||
|
||||
DecoderXQARunner::~DecoderXQARunner() = default;
|
||||
@ -96,21 +98,40 @@ size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size)
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
bool DecoderXQARunner::shouldUseImpl(XQAParams const& xqaParams)
|
||||
DecoderXQAImpl* DecoderXQARunner::getImplFromXQAParams(XQAParams const& xqaParams)
|
||||
{
|
||||
return mImpl->shouldUse(xqaParams);
|
||||
if (tensorrt_llm::common::getEnvDisableXQAJIT())
|
||||
{
|
||||
// Always use Precompiled impl if TRTLLM_DISABLE_XQA_JIT is ON.
|
||||
return mPrecompiledImpl.get();
|
||||
}
|
||||
if (xqaParams.multi_query_tokens)
|
||||
{
|
||||
// Use precompiled cubin for medusa, because medusa cubins are generated from a different CUDA source file than
|
||||
// non-medusa.
|
||||
return mPrecompiledImpl.get();
|
||||
}
|
||||
else
|
||||
{
|
||||
return mJITImpl.get();
|
||||
}
|
||||
}
|
||||
|
||||
bool DecoderXQARunner::shouldUseImpl(XQAParams const& xqa_params, bool for_configure_plugin)
|
||||
{
|
||||
return getImplFromXQAParams(xqa_params)->shouldUse(xqa_params, for_configure_plugin);
|
||||
}
|
||||
|
||||
void DecoderXQARunner::prepareForRun(XQAParams const& xqa_params)
|
||||
{
|
||||
return mImpl->prepare(xqa_params);
|
||||
return getImplFromXQAParams(xqa_params)->prepare(xqa_params);
|
||||
}
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
void DecoderXQARunner::run(
|
||||
XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream)
|
||||
{
|
||||
return mImpl->run(xqa_params, kv_cache_buffer, stream);
|
||||
return getImplFromXQAParams(xqa_params)->run(xqa_params, kv_cache_buffer, stream);
|
||||
}
|
||||
|
||||
template void DecoderXQARunner::run(
|
||||
@ -118,6 +139,42 @@ template void DecoderXQARunner::run(
|
||||
template void DecoderXQARunner::run(
|
||||
XQAParams const& xqa_params, KVBlockArray const& kv_block_array, cudaStream_t const& stream);
|
||||
|
||||
//// DecoderXQARunner::Resource
|
||||
DecoderXQARunner::Resource::Resource()
|
||||
: mCubinObjRegistry(std::make_unique<jit::CubinObjRegistry>())
|
||||
{
|
||||
}
|
||||
|
||||
DecoderXQARunner::Resource::Resource(DecoderXQARunner::Resource const& other)
|
||||
: mCubinObjRegistry(other.mCubinObjRegistry->clone())
|
||||
{
|
||||
}
|
||||
|
||||
DecoderXQARunner::Resource& DecoderXQARunner::Resource::operator=(DecoderXQARunner::Resource const& other)
|
||||
{
|
||||
if (this == &other)
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
mCubinObjRegistry = other.mCubinObjRegistry->clone();
|
||||
return *this;
|
||||
}
|
||||
|
||||
DecoderXQARunner::Resource::Resource(void const* buffer, size_t buffer_size)
|
||||
: mCubinObjRegistry(std::make_unique<jit::CubinObjRegistry>(buffer, buffer_size))
|
||||
{
|
||||
}
|
||||
|
||||
size_t DecoderXQARunner::Resource::getSerializationSize() const noexcept
|
||||
{
|
||||
return mCubinObjRegistry->getSerializationSize();
|
||||
}
|
||||
|
||||
void DecoderXQARunner::Resource::serialize(void* buffer, size_t buffer_size) const noexcept
|
||||
{
|
||||
mCubinObjRegistry->serialize(buffer, buffer_size);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -22,6 +22,8 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/cubinObjRegistry.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
@ -75,8 +77,10 @@ struct XQADispatchHelper<__nv_bfloat16, KVBlockArray>
|
||||
class DecoderXQARunner
|
||||
{
|
||||
public:
|
||||
DecoderXQARunner(
|
||||
const XQADataType data_type, int num_heads, int num_kv_heads, int head_size, bool multi_block_mode);
|
||||
// Resources for constructing a DecoderXQARunner object.
|
||||
class Resource;
|
||||
DecoderXQARunner(Resource* resource, const XQADataType data_type, int num_heads, int num_kv_heads, int head_size,
|
||||
bool multi_block_mode);
|
||||
~DecoderXQARunner();
|
||||
|
||||
/**
|
||||
@ -155,41 +159,25 @@ public:
|
||||
SUPPORT_RETURN_FALSE("nbHeads");
|
||||
}
|
||||
}
|
||||
return shouldUseImpl(xqaParams);
|
||||
return shouldUseImpl(xqaParams, forConfigurePlugin);
|
||||
}
|
||||
|
||||
size_t getWorkspaceSize(int max_batch_beam_size);
|
||||
|
||||
void prepare(XQAParams const& xqa_params)
|
||||
{
|
||||
if (!mPrepareCalled)
|
||||
{
|
||||
this->prepareForRun(xqa_params);
|
||||
mPrepareCalled = true;
|
||||
}
|
||||
this->prepareForRun(xqa_params);
|
||||
}
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
void dispatch(XQAParams const& xqa_params, KVCacheBuffer const& kv_cache_buffer, cudaStream_t const& stream)
|
||||
{
|
||||
/*
|
||||
TODO(minwei): re-enabling mPreparCalled checked once we figure out the root cause.
|
||||
|
||||
See https://github.com/NVIDIA/TensorRT-LLM/issues/1256.
|
||||
It is safe to remove the check for now, because this->prepareForRun() is effectively a no-op. It calls into
|
||||
DecoderXQAImplPrecompiled::prepare(), which does nothing in its body.
|
||||
|
||||
if (!mPrepareCalled)
|
||||
{
|
||||
TLLM_THROW("DecoderXQARunner::prepare() hasn't been called before DecoderXQARunner::dispatch().");
|
||||
}
|
||||
*/
|
||||
sync_check_cuda_error();
|
||||
this->run(xqa_params, kv_cache_buffer, stream);
|
||||
}
|
||||
|
||||
private:
|
||||
bool shouldUseImpl(XQAParams const& xqa_params);
|
||||
bool shouldUseImpl(XQAParams const& xqa_params, bool for_configure_plugin);
|
||||
void prepareForRun(XQAParams const& xqa_params);
|
||||
|
||||
template <typename KVCacheBuffer>
|
||||
@ -197,7 +185,7 @@ private:
|
||||
|
||||
static constexpr int kMaxBeamWidth = 4;
|
||||
|
||||
bool mPrepareCalled;
|
||||
Resource* mResource;
|
||||
|
||||
XQADataType mDataType;
|
||||
int mNumHeads;
|
||||
@ -206,9 +194,35 @@ private:
|
||||
bool mMultiBlockMode;
|
||||
int mMultiProcessorCount;
|
||||
|
||||
std::unique_ptr<DecoderXQAImpl> mImpl;
|
||||
std::unique_ptr<DecoderXQAImpl> mJITImpl, mPrecompiledImpl;
|
||||
DecoderXQAImpl* getImplFromXQAParams(XQAParams const& params);
|
||||
|
||||
friend DecoderXQAImplPrecompiled;
|
||||
friend DecoderXQAImplJIT;
|
||||
};
|
||||
|
||||
class DecoderXQARunner::Resource
|
||||
{
|
||||
public:
|
||||
Resource();
|
||||
Resource(Resource const& other);
|
||||
Resource& operator=(Resource const& other);
|
||||
Resource(Resource&& other) = default;
|
||||
Resource& operator=(Resource&& other) = default;
|
||||
// Construct from a serialized buffer.
|
||||
Resource(void const* buffer, size_t buffer_size);
|
||||
~Resource() = default;
|
||||
|
||||
jit::CubinObjRegistry* getCubinObjRegistry()
|
||||
{
|
||||
return mCubinObjRegistry.get();
|
||||
}
|
||||
|
||||
size_t getSerializationSize() const noexcept;
|
||||
void serialize(void* buffer, size_t buffer_size) const noexcept;
|
||||
|
||||
private:
|
||||
std::unique_ptr<jit::CubinObjRegistry> mCubinObjRegistry;
|
||||
};
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -44,11 +44,11 @@ struct XQAParams
|
||||
int32_t sink_token_length = 0;
|
||||
int timestep = 0;
|
||||
void const* qkv_bias;
|
||||
int32_t const* sequence_lengths; //
|
||||
int32_t const* context_lengths; // maybe not used now
|
||||
void const* alibi_slopes; // maybe not used now
|
||||
int32_t const* medusa_packed_mask;
|
||||
int const* medusa_position_offsets; // rotary embedding.
|
||||
int32_t const* sequence_lengths; //
|
||||
int32_t const* context_lengths; // maybe not used now
|
||||
void const* alibi_slopes; // maybe not used now
|
||||
int32_t const* spec_decoding_packed_mask;
|
||||
int const* spec_decoding_position_offsets; // rotary embedding.
|
||||
|
||||
// almost copy from GPTAttentionPluginCommon.
|
||||
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.
|
||||
|
||||
@ -2609,10 +2609,28 @@ inline __device__ void update_rotary_base_n_scale(float& base, float& scale, Rot
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ float2 rotary_embedding_coefficient(
|
||||
int const zid, int const rot_embed_dim, float const base, float const scale, float const t_step)
|
||||
inline __device__ float2 rotary_embedding_coefficient(int const zid, int const rot_embed_dim, float const base,
|
||||
float const scale, float const t_step, int const vision_start = -1, int const vision_length = -1)
|
||||
{
|
||||
float const inv_freq = float(t_step * scale) / powf(base, zid / (float) rot_embed_dim);
|
||||
float real_step = t_step;
|
||||
if (vision_start != -1 && vision_length != -1)
|
||||
{
|
||||
int t_step_int = (int) t_step;
|
||||
if (t_step_int <= vision_start)
|
||||
{
|
||||
real_step = t_step_int;
|
||||
}
|
||||
else if (t_step_int > vision_start && t_step_int <= (vision_length + vision_start))
|
||||
{
|
||||
real_step = vision_start + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
real_step = t_step_int - (vision_length - 1);
|
||||
}
|
||||
}
|
||||
|
||||
float const inv_freq = (real_step * scale) / powf(base, zid / (float) rot_embed_dim);
|
||||
return {cosf(inv_freq), sinf(inv_freq)};
|
||||
}
|
||||
|
||||
@ -2640,42 +2658,50 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162
|
||||
}
|
||||
#endif
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
float& q, float& k, int zid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
float2& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef
|
||||
= rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
float2& q, float2& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef
|
||||
= rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
k = rotary_embedding_transform(k, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
float4& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim)
|
||||
{
|
||||
@ -2683,14 +2709,17 @@ inline __device__ void apply_rotary_embedding(
|
||||
}
|
||||
|
||||
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
||||
auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.x = rotary_embedding_transform(q_.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.y = rotary_embedding_transform(q_.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
float4& q, float4& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim)
|
||||
{
|
||||
@ -2699,16 +2728,19 @@ inline __device__ void apply_rotary_embedding(
|
||||
|
||||
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
||||
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
|
||||
auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.x = rotary_embedding_transform(q_.x, coef0);
|
||||
k_.x = rotary_embedding_transform(k_.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.y = rotary_embedding_transform(q_.y, coef1);
|
||||
k_.y = rotary_embedding_transform(k_.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
Float8_& q, Float8_& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(Float8_& q, Float8_& k, int tid, int rot_embed_dim, float base,
|
||||
float scale, float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim)
|
||||
{
|
||||
@ -2717,205 +2749,289 @@ inline __device__ void apply_rotary_embedding(
|
||||
|
||||
Float8_& q_ = *reinterpret_cast<Float8_*>(&q);
|
||||
Float8_& k_ = *reinterpret_cast<Float8_*>(&k);
|
||||
auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.x = rotary_embedding_transform(q_.x, coef0);
|
||||
k_.x = rotary_embedding_transform(k_.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.y = rotary_embedding_transform(q_.y, coef1);
|
||||
k_.y = rotary_embedding_transform(k_.y, coef1);
|
||||
auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef2
|
||||
= rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.z = rotary_embedding_transform(q_.z, coef2);
|
||||
k_.z = rotary_embedding_transform(k_.z, coef2);
|
||||
auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef3
|
||||
= rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q_.w = rotary_embedding_transform(q_.w, coef3);
|
||||
k_.w = rotary_embedding_transform(k_.w, coef3);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
uint32_t& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef
|
||||
= rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, float base,
|
||||
float scale, float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef
|
||||
= rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
k = rotary_embedding_transform(k, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(half2& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(half2& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
return apply_rotary_embedding(*reinterpret_cast<uint32_t*>(&q), tid, rot_embed_dim, base, scale, t_step);
|
||||
return apply_rotary_embedding(*reinterpret_cast<uint32_t*>(&q), tid, rot_embed_dim, base, scale, mscale,
|
||||
rotary_embedding_scaling_factors, t_step, vision_start, vision_length);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
half2& q, half2& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(half2& q, half2& k, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
return apply_rotary_embedding(
|
||||
*reinterpret_cast<uint32_t*>(&q), *reinterpret_cast<uint32_t*>(&k), tid, rot_embed_dim, base, scale, t_step);
|
||||
return apply_rotary_embedding(*reinterpret_cast<uint32_t*>(&q), *reinterpret_cast<uint32_t*>(&k), tid,
|
||||
rot_embed_dim, base, scale, mscale, rotary_embedding_scaling_factors, t_step, vision_start, vision_length);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
uint2& q, uint2& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ float2 rotary_embedding_coefficient_long_rope(
|
||||
int const zid, int const rot_embed_dim, float const base, float const scale, float const mscale, float const t_step)
|
||||
{
|
||||
float const inv_freq = float(t_step * scale) / powf(base, zid / (float) rot_embed_dim);
|
||||
return {cosf(inv_freq) * mscale, sinf(inv_freq) * mscale};
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step);
|
||||
float2 coef0, coef1;
|
||||
if (rotary_embedding_scaling_factors != nullptr)
|
||||
{
|
||||
float fscale = *(rotary_embedding_scaling_factors + (2 * tid));
|
||||
fscale = 1.0 / fscale;
|
||||
coef0 = rotary_embedding_coefficient_long_rope(4 * tid, rot_embed_dim, base, fscale, mscale, t_step);
|
||||
|
||||
fscale = *(rotary_embedding_scaling_factors + (2 * tid) + 1);
|
||||
fscale = 1.0 / fscale;
|
||||
coef1 = rotary_embedding_coefficient_long_rope(4 * tid + 2, rot_embed_dim, base, fscale, mscale, t_step);
|
||||
}
|
||||
else
|
||||
{
|
||||
coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
coef1 = rotary_embedding_coefficient(
|
||||
4 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
}
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef2
|
||||
= rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef3
|
||||
= rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
uint4& q, uint4& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef2
|
||||
= rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
k.z = rotary_embedding_transform(k.z, coef2);
|
||||
auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef3
|
||||
= rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
k.w = rotary_embedding_transform(k.w, coef3);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
__nv_bfloat162& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef
|
||||
= rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim,
|
||||
float base, float scale, float mscale, float const* rotary_embedding_scaling_factors, int t_step,
|
||||
int vision_start = -1, int vision_length = -1)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef
|
||||
= rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
k = rotary_embedding_transform(k, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
bf16_4_t& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, float base,
|
||||
float scale, float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step);
|
||||
|
||||
float2 coef0, coef1;
|
||||
if (rotary_embedding_scaling_factors != nullptr)
|
||||
{
|
||||
float fscale = *(rotary_embedding_scaling_factors + (2 * tid));
|
||||
fscale = 1.0 / fscale;
|
||||
coef0 = rotary_embedding_coefficient_long_rope(4 * tid, rot_embed_dim, base, fscale, mscale, t_step);
|
||||
|
||||
fscale = *(rotary_embedding_scaling_factors + (2 * tid) + 1);
|
||||
fscale = 1.0 / fscale;
|
||||
coef1 = rotary_embedding_coefficient_long_rope(4 * tid + 2, rot_embed_dim, base, fscale, mscale, t_step);
|
||||
}
|
||||
else
|
||||
{
|
||||
coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
coef1 = rotary_embedding_coefficient(
|
||||
4 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
}
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
bf16_8_t& q, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, float base, float scale,
|
||||
float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef2
|
||||
= rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef3
|
||||
= rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, float base, float scale, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, float base,
|
||||
float scale, float mscale, float const* rotary_embedding_scaling_factors, int t_step, int vision_start = -1,
|
||||
int vision_length = -1)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim)
|
||||
{
|
||||
return;
|
||||
}
|
||||
auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef0
|
||||
= rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef1
|
||||
= rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef2
|
||||
= rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
k.z = rotary_embedding_transform(k.z, coef2);
|
||||
auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step);
|
||||
auto const coef3
|
||||
= rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step, vision_start, vision_length);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
k.w = rotary_embedding_transform(k.w, coef3);
|
||||
}
|
||||
|
||||
@ -508,12 +508,12 @@ __global__ void copyNextStepIds(TokenIdType* nextStepIds, TokenIdType const* con
|
||||
auto const newTokens = numNewTokens == nullptr ? 1 : numNewTokens[batchSlot];
|
||||
auto const batchBeamIdx = batchSlot * beamWidth + beamIdx;
|
||||
auto const tokenBatchBeamIdx = tokenIdx * maxBatchSize * beamWidth + batchSlot * beamWidth + beamIdx;
|
||||
if (tokenIdx >= newTokens)
|
||||
auto const index_src = beamIdx * maxSeqLen + sequenceLengths[batchBeamIdx] - newTokens + tokenIdx;
|
||||
if (tokenIdx >= newTokens || index_src < 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
nextStepIds[tokenBatchBeamIdx]
|
||||
= outputIdsPtr[batchSlot][beamIdx * maxSeqLen + sequenceLengths[batchBeamIdx] - newTokens + tokenIdx];
|
||||
nextStepIds[tokenBatchBeamIdx] = outputIdsPtr[batchSlot][index_src];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -42,11 +42,12 @@ enum class PositionEmbeddingType : int8_t
|
||||
kLEARNED_ABSOLUTE = 0,
|
||||
kROPE_GPTJ = 1,
|
||||
kROPE_GPT_NEOX = 2,
|
||||
kLONG_ROPE = 3,
|
||||
// Workflow: (bmm1_output * scale_bmm1 + alibi).
|
||||
kALIBI = 3,
|
||||
kALIBI = 4,
|
||||
// Workflow: (bmm1_output + alibi) * scale_bmm1.
|
||||
kALIBI_WITH_SCALE = 4,
|
||||
kRELATIVE = 5
|
||||
kALIBI_WITH_SCALE = 5,
|
||||
kRELATIVE = 6,
|
||||
};
|
||||
|
||||
enum class RotaryScalingType : int8_t
|
||||
|
||||
@ -200,8 +200,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(float const* input, bool const* finished, float* output, int const num_rows, int* indices,
|
||||
int* source_rows, int const k, int const start_expert, int const end_expert)
|
||||
void topkGatingSoftmax(float const* input, bool const* finished, float* output, int64_t const num_rows,
|
||||
int* indices, int* source_rows, int const k, int const start_expert, int const end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
@ -403,7 +403,7 @@ template <int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0);
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
@ -413,7 +413,8 @@ struct TopkConstants
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(float const* input, bool const* finished, float* output, int* indices,
|
||||
int* source_row, int const num_rows, int const k, int const start_expert, int const end_expert, cudaStream_t stream)
|
||||
int* source_row, int64_t const num_rows, int const k, int const start_expert, int const end_expert,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
@ -430,8 +431,8 @@ void topkGatingSoftmaxLauncherHelper(float const* input, bool const* finished, f
|
||||
}
|
||||
|
||||
void topkGatingSoftmaxKernelLauncher(float const* input, bool const* finished, float* output,
|
||||
float* softmax_temp_output, int* indices, int* source_row, int const num_rows, int const num_experts, int const k,
|
||||
int const start_expert, int const end_expert, cudaStream_t stream)
|
||||
float* softmax_temp_output, int* indices, int* source_row, int64_t const num_rows, int const num_experts,
|
||||
int const k, int const start_expert, int const end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
@ -523,11 +524,11 @@ void CubKeyValueSorter::updateNumExperts(int const num_experts)
|
||||
|
||||
size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, int const num_experts)
|
||||
{
|
||||
size_t num_bits = (int) log2(num_experts) + 1;
|
||||
int num_bits = static_cast<int>(log2(num_experts)) + 1;
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
cub::DeviceRadixSort::SortPairs(
|
||||
NULL, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits);
|
||||
nullptr, required_storage, null_int, null_int, null_int, null_int, num_key_value_pairs, 0, num_bits);
|
||||
return required_storage;
|
||||
}
|
||||
|
||||
@ -546,7 +547,7 @@ void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, int co
|
||||
// ============================== Infer GEMM sizes =================================
|
||||
// TODO Could linear search be better for small # experts
|
||||
template <class T>
|
||||
__device__ inline int findTotalEltsLeqTarget(T const* sorted_indices, int const arr_length, const T target)
|
||||
__device__ inline int64_t findTotalEltsLeqTarget(T const* sorted_indices, int const arr_length, const T target)
|
||||
{
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high)
|
||||
@ -613,7 +614,7 @@ CUTLASS_HOST_DEVICE cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> make_cu
|
||||
} // namespace detail
|
||||
|
||||
__device__ void computeHopperInputStrides(
|
||||
HopperGroupedGemmInput layout_info, int gemm_m, int gemm_n, int gemm_k, int out_idx)
|
||||
HopperGroupedGemmInput layout_info, int gemm_m, int gemm_n, int gemm_k, int64_t out_idx)
|
||||
{
|
||||
layout_info.stride_a[out_idx] = detail::make_cute_packed_stride(
|
||||
HopperGroupedGemmInput::StrideA{}, cute::make_shape(gemm_m, gemm_k, cute::Int<1>{}));
|
||||
@ -677,6 +678,9 @@ __global__ void computeStridesHopperKernel(int64_t const* total_rows_before_expe
|
||||
layout_info.alpha_scale_ptr_array[expert] = fp8_dequant + expert;
|
||||
}
|
||||
|
||||
assert(gemm_m <= INT32_MAX);
|
||||
assert(gemm_n <= INT32_MAX);
|
||||
assert(gemm_k <= INT32_MAX);
|
||||
computeHopperInputStrides(layout_info, gemm_m, gemm_n, gemm_k, expert);
|
||||
|
||||
computeHopperInputPointers(
|
||||
@ -699,24 +703,25 @@ __global__ void computeStridesHopperKernel(int64_t const* total_rows_before_expe
|
||||
template <typename T, bool CHECK_SKIPPED>
|
||||
__global__ void expandInputRowsKernel(T const* unpermuted_input, T* permuted_output,
|
||||
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
|
||||
int const num_rows, int64_t const* num_dest_rows, int const cols)
|
||||
int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols)
|
||||
{
|
||||
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the
|
||||
// reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
int const expanded_dest_row = blockIdx.x;
|
||||
int const expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int64_t const expanded_dest_row = blockIdx.x;
|
||||
int64_t const expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row;
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] = static_cast<int>(expanded_dest_row);
|
||||
}
|
||||
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows)
|
||||
{
|
||||
// Duplicate and permute rows
|
||||
int const source_row = expanded_source_row % num_rows;
|
||||
int64_t const source_row = expanded_source_row % num_rows;
|
||||
|
||||
T const* source_row_ptr = unpermuted_input + source_row * cols;
|
||||
T* dest_row_ptr = permuted_output + expanded_dest_row * cols;
|
||||
@ -731,10 +736,10 @@ __global__ void expandInputRowsKernel(T const* unpermuted_input, T* permuted_out
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(T const* unpermuted_input, T* permuted_output,
|
||||
int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row,
|
||||
int const num_rows, int64_t const* num_valid_tokens_ptr, int const cols, int const k, cudaStream_t stream)
|
||||
int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, cudaStream_t stream)
|
||||
{
|
||||
int const blocks = num_rows * k;
|
||||
int const threads = std::min(cols, 1024);
|
||||
int64_t const blocks = num_rows * k;
|
||||
int64_t const threads = std::min(cols, int64_t{1024});
|
||||
auto func = (num_valid_tokens_ptr != nullptr) ? expandInputRowsKernel<T, true> : expandInputRowsKernel<T, false>;
|
||||
func<<<blocks, threads, 0, stream>>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, cols);
|
||||
@ -752,8 +757,8 @@ enum class ScaleMode : int
|
||||
template <typename T, typename OutputType, class GemmOutputType, ScaleMode SCALE_MODE, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted_rows,
|
||||
OutputType* reduced_unpermuted_output, T const* bias, float const* scales,
|
||||
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int const cols, int const k,
|
||||
int64_t const* num_valid_ptr)
|
||||
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const cols,
|
||||
int64_t const k, int64_t const* num_valid_ptr)
|
||||
{
|
||||
int const original_row = blockIdx.x;
|
||||
int const num_rows = gridDim.x;
|
||||
@ -766,10 +771,10 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
|
||||
float row_rescale{0.f};
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
int const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
int64_t const expanded_original_row = original_row + k_idx * num_rows;
|
||||
int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
const int64_t k_offset = original_row * k + k_idx;
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset];
|
||||
if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE)
|
||||
{
|
||||
@ -784,7 +789,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
|
||||
|
||||
auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols;
|
||||
|
||||
int const expert_idx = expert_for_source_row[k_offset];
|
||||
int64_t const expert_idx = expert_for_source_row[k_offset];
|
||||
|
||||
T const* bias_ptr = bias + expert_idx * cols;
|
||||
float const bias_value = bias ? static_cast<float>(bias_ptr[tid]) : 0.f;
|
||||
@ -807,12 +812,12 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
|
||||
template <class T, class OutputType, class GemmOutputType = T>
|
||||
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
|
||||
OutputType* reduced_unpermuted_output, T const* bias, float const* scales,
|
||||
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int const num_rows,
|
||||
int const cols, int const k, int64_t const* num_valid_ptr, MOEParallelismConfig parallelism_config,
|
||||
int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int64_t const num_rows,
|
||||
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
{
|
||||
int const blocks = num_rows;
|
||||
int const threads = std::min(cols, 1024);
|
||||
int64_t const blocks = num_rows;
|
||||
int64_t const threads = std::min(cols, int64_t{1024});
|
||||
|
||||
// Only add bias on rank 0 for tensor parallelism
|
||||
bool const is_rank_0 = parallelism_config.tp_rank == 0;
|
||||
@ -848,10 +853,10 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
|
||||
|
||||
template <class T, class ActFn>
|
||||
__global__ void doGatedActivationKernel(
|
||||
T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, size_t inter_size)
|
||||
T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, int64_t inter_size)
|
||||
{
|
||||
int const tid = threadIdx.x;
|
||||
int const token = blockIdx.x;
|
||||
int64_t const tid = threadIdx.x;
|
||||
int64_t const token = blockIdx.x;
|
||||
if (num_valid_tokens_ptr && token >= *num_valid_tokens_ptr)
|
||||
{
|
||||
return;
|
||||
@ -860,7 +865,7 @@ __global__ void doGatedActivationKernel(
|
||||
ActFn fn{};
|
||||
output = output + token * inter_size;
|
||||
gemm_result = gemm_result + token * inter_size * 2;
|
||||
for (int i = tid; i < inter_size; i += blockDim.x)
|
||||
for (int64_t i = tid; i < inter_size; i += blockDim.x)
|
||||
{
|
||||
auto fc1_value = static_cast<float>(gemm_result[i]);
|
||||
// BF16 isn't supported, use FP32 for activation function
|
||||
@ -871,11 +876,11 @@ __global__ void doGatedActivationKernel(
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void doGatedActivation(T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, int inter_size,
|
||||
int num_tokens, ActivationType activation_type, cudaStream_t stream)
|
||||
void doGatedActivation(T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, int64_t inter_size,
|
||||
int64_t num_tokens, ActivationType activation_type, cudaStream_t stream)
|
||||
{
|
||||
int const blocks = num_tokens;
|
||||
int const threads = std::min(inter_size, 1024);
|
||||
int64_t const blocks = num_tokens;
|
||||
int64_t const threads = std::min(inter_size, int64_t{1024});
|
||||
|
||||
// TODO Instead of T use a vectored type if performance would benefit
|
||||
// TODO For some reason Volta fails on GELU_taylor here with Warp Illegal Instruction.
|
||||
@ -890,10 +895,10 @@ void doGatedActivation(T* output, T const* gemm_result, int64_t const* num_valid
|
||||
template <class T, class ActFn>
|
||||
__global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputTypeAdaptor_t<T> const* gemm_result,
|
||||
float const* fp8_quant, T const* bias_ptr, int64_t const* total_rows_before_expert_, int num_experts,
|
||||
size_t inter_size, bool gated)
|
||||
int64_t inter_size, bool gated)
|
||||
{
|
||||
int const tid = threadIdx.x;
|
||||
int const token = blockIdx.x;
|
||||
int64_t const tid = threadIdx.x;
|
||||
int64_t const token = blockIdx.x;
|
||||
if (token >= total_rows_before_expert_[num_experts - 1])
|
||||
{
|
||||
return;
|
||||
@ -906,7 +911,7 @@ __global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputType
|
||||
gemm_result = gemm_result + token * inter_size * gated_mul;
|
||||
output = output + token * inter_size; // Aliases gemm_result for non-gated, non-fp8 cases
|
||||
|
||||
int expert = 0;
|
||||
int64_t expert = 0;
|
||||
if (bias_ptr)
|
||||
{
|
||||
// TODO this is almost certainly faster as a linear scan
|
||||
@ -919,7 +924,7 @@ __global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputType
|
||||
{
|
||||
bias_ptr = bias_ptr + expert * inter_size * gated_mul;
|
||||
}
|
||||
for (int i = tid; i < inter_size; i += blockDim.x)
|
||||
for (int64_t i = tid; i < inter_size; i += blockDim.x)
|
||||
{
|
||||
auto fc1_value = static_cast<float>(gemm_result[i + gated_off]);
|
||||
if (bias_ptr)
|
||||
@ -940,11 +945,11 @@ __global__ void doActivationKernel(T* output, HopperGroupedGemmInput::OutputType
|
||||
|
||||
template <class T>
|
||||
void doActivation(T* output, HopperGroupedGemmInput::OutputTypeAdaptor_t<T> const* gemm_result, float const* fp8_quant,
|
||||
T const* bias, int64_t const* total_rows_before_expert_, int num_experts, int inter_size, int num_tokens,
|
||||
T const* bias, int64_t const* total_rows_before_expert_, int num_experts, int64_t inter_size, int64_t num_tokens,
|
||||
ActivationType activation_type, cudaStream_t stream)
|
||||
{
|
||||
int const blocks = num_tokens;
|
||||
int const threads = std::min(inter_size, 1024);
|
||||
int64_t const blocks = num_tokens;
|
||||
int64_t const threads = std::min(inter_size, int64_t{1024});
|
||||
|
||||
// TODO Instead of T use a vectored type if performance would benefit
|
||||
auto fn_list = std::array{
|
||||
@ -961,9 +966,9 @@ void doActivation(T* output, HopperGroupedGemmInput::OutputTypeAdaptor_t<T> cons
|
||||
}
|
||||
|
||||
template <class T, class WeightType, class OutputType, class Enable>
|
||||
std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceBufferSizes(int const num_rows,
|
||||
int const hidden_size, int const inter_size, int const num_experts, int const num_experts_per_node, int const k,
|
||||
ActivationType activation_type) const
|
||||
std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceBufferSizes(
|
||||
int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts,
|
||||
int const num_experts_per_node, int const k, ActivationType activation_type) const
|
||||
{
|
||||
const size_t num_moe_inputs = k * num_rows;
|
||||
const size_t permuted_elems = num_moe_inputs * hidden_size;
|
||||
@ -979,7 +984,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWo
|
||||
// We need to have separate memory for these as we can no longer alias the output buffer for reuse
|
||||
glu_inter_elems = interbuf_elems;
|
||||
}
|
||||
int num_softmax_outs = 0;
|
||||
size_t num_softmax_outs = 0;
|
||||
|
||||
bool using_hopper = moe_gemm_runner_.supportsHopperSpecialisation();
|
||||
const size_t gemm_output_dtype = using_hopper ? sizeof(HopperGemmOutputType) : sizeof(T);
|
||||
@ -1011,9 +1016,9 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWo
|
||||
}
|
||||
|
||||
template <class T, class WeightType, class OutputType, class Enable>
|
||||
size_t CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceSize(int const num_rows,
|
||||
int const hidden_size, int const inter_size, int const num_experts, int const k, ActivationType activation_type,
|
||||
MOEParallelismConfig parallelism_config) const
|
||||
size_t CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceSize(int64_t const num_rows,
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const k,
|
||||
ActivationType activation_type, MOEParallelismConfig parallelism_config) const
|
||||
{
|
||||
int const ep_size = parallelism_config.ep_size;
|
||||
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size");
|
||||
@ -1023,9 +1028,9 @@ size_t CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::getWorkspaceSize(i
|
||||
}
|
||||
|
||||
template <class T, class WeightType, class OutputType, class Enable>
|
||||
void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::configureWsPtrs(char* ws_ptr, int const num_rows,
|
||||
int const hidden_size, int const inter_size, int const num_experts, int const num_experts_per_node, int const k,
|
||||
ActivationType activation_type)
|
||||
void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::configureWsPtrs(char* ws_ptr, int64_t const num_rows,
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const num_experts_per_node,
|
||||
int const k, ActivationType activation_type)
|
||||
{
|
||||
auto ws_sizes = getWorkspaceBufferSizes(
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type);
|
||||
@ -1070,26 +1075,27 @@ template <class T, class WeightType, class OutputType, class Enable>
|
||||
void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::runMoe(void const* input_activations_void,
|
||||
float const* gating_output, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void,
|
||||
ActivationType fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void,
|
||||
QuantParams quant_params, int const num_rows, int const hidden_size, int const inter_size, int const num_experts,
|
||||
int const k, char* workspace_ptr, void* final_output_void, bool const* finished, int const active_rows,
|
||||
void* expert_scales_void, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
|
||||
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const k, char* workspace_ptr, void* final_output_void, bool const* finished,
|
||||
int64_t const active_rows, void* expert_scales_void, int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expert_for_source_row, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
{
|
||||
static constexpr bool int_scales_required
|
||||
= std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
|
||||
static constexpr bool fp8_scales_required
|
||||
= std::is_same<WeightType, __nv_fp8_e4m3>::value || std::is_same<WeightType, __nv_fp8_e5m2>::value;
|
||||
|
||||
auto* input_activations = static_cast<T const*>(input_activations_void);
|
||||
auto* fc1_expert_weights = static_cast<WeightType const*>(fc1_expert_weights_void);
|
||||
auto* fc1_expert_biases = static_cast<T const*>(fc1_expert_biases_void);
|
||||
auto* fc2_expert_weights = static_cast<WeightType const*>(fc2_expert_weights_void);
|
||||
auto* fc1_int_scales = static_cast<T const*>(quant_params.fc1_weight_scales);
|
||||
auto* fc2_int_scales = static_cast<T const*>(quant_params.fc2_weight_scales);
|
||||
auto* fc1_fp8_dequant = static_cast<float const*>(quant_params.dequant_fc1);
|
||||
auto* fc2_fp8_quant = static_cast<float const*>(quant_params.quant_fc2);
|
||||
auto* fc2_fp8_dequant = static_cast<float const*>(quant_params.dequant_fc2);
|
||||
auto* fc2_expert_biases = static_cast<T const*>(fc2_expert_biases_void);
|
||||
auto const* input_activations = static_cast<T const*>(input_activations_void);
|
||||
auto const* fc1_expert_weights = static_cast<WeightType const*>(fc1_expert_weights_void);
|
||||
auto const* fc1_expert_biases = static_cast<T const*>(fc1_expert_biases_void);
|
||||
auto const* fc2_expert_weights = static_cast<WeightType const*>(fc2_expert_weights_void);
|
||||
auto const* fc1_int_scales = static_cast<T const*>(quant_params.fc1_weight_scales);
|
||||
auto const* fc2_int_scales = static_cast<T const*>(quant_params.fc2_weight_scales);
|
||||
auto const* fc1_fp8_dequant = quant_params.dequant_fc1;
|
||||
auto const* fc2_fp8_quant = quant_params.quant_fc2;
|
||||
auto const* fc2_fp8_dequant = quant_params.dequant_fc2;
|
||||
auto const* fc2_expert_biases = static_cast<T const*>(fc2_expert_biases_void);
|
||||
auto* final_output = static_cast<OutputType*>(final_output_void);
|
||||
auto* expert_scales = static_cast<float*>(expert_scales_void);
|
||||
|
||||
@ -1105,6 +1111,11 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::runMoe(void const* i
|
||||
TLLM_CHECK_WITH_INFO(hidden_size >= 128 / cutlass::sizeof_bits<WeightType>::value,
|
||||
"Hidden size is too small to meet alignment requirements for MOE GEMM");
|
||||
|
||||
// These values must fit into an int for building the source maps
|
||||
TLLM_CHECK_WITH_INFO(num_rows <= std::numeric_limits<int>::max(), "Number of rows is too large");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
num_rows * num_experts <= std::numeric_limits<int>::max(), "Number of rows * num_experts is too large");
|
||||
|
||||
if (int_scales_required)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -1166,7 +1177,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::runMoe(void const* i
|
||||
const size_t fc1_out_size = is_gated_activation ? inter_size * 2 : inter_size;
|
||||
|
||||
// Upper bound on number of expanded rows
|
||||
int const expanded_active_expert_rows = k * active_rows;
|
||||
int64_t const expanded_active_expert_rows = k * active_rows;
|
||||
computeTotalRowsBeforeExpert(
|
||||
permuted_experts_, expanded_active_expert_rows, num_experts_per_node, total_rows_before_expert_, stream);
|
||||
|
||||
@ -1272,7 +1283,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::computeTotalRowsBefo
|
||||
|
||||
template <class T, class WeightType, class OutputType, class Enable>
|
||||
HopperGroupedGemmInput CutlassMoeFCRunner<T, WeightType, OutputType, Enable>::computeStridesHopper(
|
||||
int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int gemm_n, int gemm_k,
|
||||
int64_t const* total_rows_before_expert, HopperGroupedGemmInput layout_info, int64_t gemm_n, int64_t gemm_k,
|
||||
int const num_experts, T const* in, WeightType const* weights, float const* fp8_dequant, T const* bias,
|
||||
HopperGemmOutputType* output, cudaStream_t stream)
|
||||
{
|
||||
@ -1322,7 +1333,7 @@ void makeLoadBalancedRoutingConfiguration(
|
||||
void* data_void, int num_experts, int num_tokens, int k, nvinfer1::DataType type, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(type == nvinfer1::DataType::kFLOAT, "Routing configuration must be float");
|
||||
check_cuda_error(cudaMemsetAsync(data_void, 0x0, num_experts * num_tokens * sizeof(float), stream));
|
||||
check_cuda_error(cudaMemsetAsync(data_void, 0x0, int64_t{num_experts} * num_tokens * sizeof(float), stream));
|
||||
|
||||
int stride = tensorrt_llm::common::ceilDiv(num_experts, k);
|
||||
|
||||
|
||||
@ -33,26 +33,6 @@ static inline size_t pad_to_multiple_of_16(size_t const& input)
|
||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||
}
|
||||
|
||||
/*
|
||||
Launches the topk gating softmax required for the MoE layers.
|
||||
|
||||
Params:
|
||||
input - a [num_rows x num_experts]
|
||||
finished - [num_rows] vector with 1 if the sentence at this row is done translating and 0 otherwise.
|
||||
output - a buffer of shape [num_rows x k] containing the top-k values of the softmax for each row.
|
||||
indices - a matrix of shape [num_rows x k] containing the top-k experts each row should get routed to.
|
||||
source_rows - a matrix of shape [num_rows x k] used internally for permuting. source_rows[row][k] = k * num_rows +
|
||||
row. It is constructed like this so we can track where each of the original rows end up in order to perform the
|
||||
"k-way" reduction later in the routing.
|
||||
|
||||
num_rows - The number of rows in the matrix
|
||||
num_experts - The number of expert layers present
|
||||
k - k value in topk
|
||||
*/
|
||||
template <typename T>
|
||||
void topk_gating_softmax_kernelLauncher(T const* input, bool const* finished, T* output, T* softmax_temp_out,
|
||||
int* indices, int* source_row, int const num_rows, int const num_experts, int const k, cudaStream_t stream);
|
||||
|
||||
class CubKeyValueSorter
|
||||
{
|
||||
public:
|
||||
@ -155,7 +135,7 @@ class CutlassMoeFCRunnerInterface
|
||||
{
|
||||
public:
|
||||
virtual ~CutlassMoeFCRunnerInterface() = default;
|
||||
virtual size_t getWorkspaceSize(int const num_rows, int const hidden_size, int const inter_size,
|
||||
virtual size_t getWorkspaceSize(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const k, ActivationType activation_type,
|
||||
MOEParallelismConfig parallelism_config) const
|
||||
= 0;
|
||||
@ -164,11 +144,12 @@ public:
|
||||
|
||||
virtual void runMoe(void const* input_activations, float const* gating_output, void const* fc1_expert_weights,
|
||||
void const* fc1_expert_biases, ActivationType fc1_activation_type, void const* fc2_expert_weights,
|
||||
void const* fc2_expert_biases, QuantParams quant_params, int const num_rows, int const hidden_size,
|
||||
int const inter_size, int const num_experts, int const k, char* workspace_ptr, void* final_output,
|
||||
bool const* finished, int const active_rows, void* expert_scales, int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expert_for_source_row, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream)
|
||||
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
|
||||
int64_t const inter_size, int const num_experts, int const k, char* workspace_ptr, void* final_output,
|
||||
bool const* finished, int64_t const active_rows, void* expert_scales,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
|
||||
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode,
|
||||
cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
bool is_profiler = false;
|
||||
@ -191,8 +172,9 @@ public:
|
||||
static_assert(
|
||||
std::is_same_v<T, WeightType> || !std::is_same_v<T, float>, "Does not support float with quantized weights");
|
||||
|
||||
size_t getWorkspaceSize(int const num_rows, int const hidden_size, int const fc1_output_size, int const num_experts,
|
||||
int const k, ActivationType activation_type, MOEParallelismConfig parallelism_config) const override;
|
||||
size_t getWorkspaceSize(int64_t const num_rows, int64_t const hidden_size, int64_t const fc1_output_size,
|
||||
int const num_experts, int const k, ActivationType activation_type,
|
||||
MOEParallelismConfig parallelism_config) const override;
|
||||
|
||||
void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm_config) override
|
||||
{
|
||||
@ -206,11 +188,12 @@ public:
|
||||
|
||||
void runMoe(void const* input_activations, float const* gating_output, void const* fc1_expert_weights,
|
||||
void const* fc1_expert_biases, ActivationType fc1_activation_type, void const* fc2_expert_weights,
|
||||
void const* fc2_expert_biases, QuantParams quant_params, int const num_rows, int const hidden_size,
|
||||
int const inter_size, int const num_experts, int const k, char* workspace_ptr, void* final_output,
|
||||
bool const* finished, int const active_rows, void* expert_scales, int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expert_for_source_row, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) override;
|
||||
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
|
||||
int64_t const inter_size, int const num_experts, int const k, char* workspace_ptr, void* final_output,
|
||||
bool const* finished, int64_t const active_rows, void* expert_scales,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
|
||||
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
private:
|
||||
using HopperGemmOutputType = typename HopperGroupedGemmInput::OutputTypeAdaptor_t<T>;
|
||||
@ -218,12 +201,13 @@ private:
|
||||
void computeTotalRowsBeforeExpert(int const* sorted_indices, int const total_indices, int const num_experts,
|
||||
int64_t* total_rows_before_expert, cudaStream_t stream);
|
||||
HopperGroupedGemmInput computeStridesHopper(int64_t const* total_rows_before_expert,
|
||||
HopperGroupedGemmInput layout_info, int gemm_n, int gemm_k, int const num_experts, T const* in,
|
||||
HopperGroupedGemmInput layout_info, int64_t gemm_n, int64_t gemm_k, int const num_experts, T const* in,
|
||||
WeightType const* weights, float const* fp8_dequant, T const* bias, HopperGemmOutputType* output,
|
||||
cudaStream_t stream);
|
||||
std::vector<size_t> getWorkspaceBufferSizes(int const num_rows, int const hidden_size, int const inter_size,
|
||||
int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type) const;
|
||||
void configureWsPtrs(char* ws_ptr, int const num_rows, int const hidden_size, int const inter_size,
|
||||
std::vector<size_t> getWorkspaceBufferSizes(int64_t const num_rows, int64_t const hidden_size,
|
||||
int64_t const inter_size, int const num_experts, int const num_experts_per_node, int const k,
|
||||
ActivationType activation_type) const;
|
||||
void configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type);
|
||||
|
||||
private:
|
||||
|
||||
@ -49,8 +49,9 @@ inline std::pair<float, float> getLimitsPenalty(DecodingPenaltyType penaltyType)
|
||||
case DecodingPenaltyType::Presence: return std::make_pair(fltMin, fltMax);
|
||||
case DecodingPenaltyType::Frequency: return std::make_pair(fltMin, fltMax);
|
||||
case DecodingPenaltyType::MinLength: return std::make_pair(-fltEpsilon, fltMax);
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Unknown penalty type %d", static_cast<int32_t>(penaltyType));
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(false, "Unknown penalty type %d", static_cast<int32_t>(penaltyType));
|
||||
return std::make_pair(fltMin, fltMax);
|
||||
}
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -84,9 +84,6 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
bool dt_softplus = params.delta_softplus;
|
||||
int num_channels = params.dim;
|
||||
|
||||
// static const int STAGES = 12;
|
||||
// static const int SEQ_UNROLL = 6;
|
||||
|
||||
__shared__ cuda::pipeline_shared_state<cuda::thread_scope::thread_scope_block, STAGES / SEQ_UNROLL> pipeline_state;
|
||||
auto block = cooperative_groups::this_thread_block();
|
||||
|
||||
@ -97,9 +94,6 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
__shared__ input_t sh_x[STAGES][CHANNELS_PER_BLOCK];
|
||||
__shared__ input_t sh_z[STAGES][CHANNELS_PER_BLOCK];
|
||||
|
||||
__shared__ weight_t sh_D[CHANNELS_PER_BLOCK];
|
||||
__shared__ weight_t sh_dt_bias[CHANNELS_PER_BLOCK];
|
||||
|
||||
int const channel = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int const sample = blockIdx.y; // batch id
|
||||
|
||||
@ -127,14 +121,6 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
|
||||
if (threadIdx.y == 1)
|
||||
{
|
||||
// Data loading warps
|
||||
|
||||
// Bias is independent of token
|
||||
sh_dt_bias[threadIdx.x] = dt_bias[channel];
|
||||
// D is independent of token
|
||||
if (D)
|
||||
sh_D[threadIdx.x] = D[channel];
|
||||
|
||||
cuda::pipeline pipeline = cuda::make_pipeline(block, &pipeline_state, cuda::pipeline_role::producer);
|
||||
|
||||
int stage = 0;
|
||||
@ -220,10 +206,11 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
float A_reg[DSTATE];
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
// state_reg[i] = toFloat(state[sample*num_channels*DSTATE + i*num_channels + channel]);
|
||||
state_reg[i] = 0.f;
|
||||
A_reg[i] = toFloat(A[i * num_channels + channel]);
|
||||
}
|
||||
float dt_bias_reg = dt_bias[channel];
|
||||
float D_reg = D ? D[channel] : 0.f;
|
||||
|
||||
cuda::pipeline pipeline = cuda::make_pipeline(block, &pipeline_state, cuda::pipeline_role::consumer);
|
||||
int stage = 0;
|
||||
@ -236,14 +223,14 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
for (int token_id = si * SEQ_UNROLL; token_id < num_tokens && token_id < (si + 1) * SEQ_UNROLL; token_id++)
|
||||
{
|
||||
|
||||
float dt_b = toFloat(sh_dt[stage][threadIdx.x]) + toFloat(sh_dt_bias[threadIdx.x]);
|
||||
float dt_b = toFloat(sh_dt[stage][threadIdx.x]) + dt_bias_reg;
|
||||
float dt_b_sp;
|
||||
if (dt_softplus)
|
||||
{
|
||||
dt_b_sp = dt_b <= 20.f ? log1pf(__expf(dt_b)) : dt_b; // softplus
|
||||
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
|
||||
}
|
||||
float my_x = toFloat(sh_x[stage][threadIdx.x]);
|
||||
float Dx = my_x * (D ? toFloat(sh_D[threadIdx.x]) : 0.f);
|
||||
float Dx = my_x * D_reg;
|
||||
float dtx = dt_b_sp * my_x;
|
||||
float my_z = z ? toFloat(sh_z[stage][threadIdx.x]) : 0.f;
|
||||
|
||||
@ -303,7 +290,7 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
{
|
||||
float enz = __expf(0.f - my_z);
|
||||
enz += 1.0;
|
||||
float sig_z = 1.0 / enz;
|
||||
float sig_z = __fdividef(1.f, enz);
|
||||
float silu_z = my_z * sig_z;
|
||||
out *= silu_z;
|
||||
}
|
||||
@ -332,16 +319,15 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream)
|
||||
int samples = params.batch;
|
||||
int channels = params.dim;
|
||||
|
||||
TLLM_CHECK(params.is_variable_B);
|
||||
TLLM_CHECK(params.is_variable_C);
|
||||
TLLM_CHECK(params.dstate == 16);
|
||||
|
||||
int const threads = 128;
|
||||
int const blocks = (channels + threads - 1) / threads;
|
||||
dim3 block(threads, 2);
|
||||
dim3 grid(blocks, samples);
|
||||
TLLM_CHECK((channels % block.x) == 0);
|
||||
|
||||
TLLM_CHECK(params.is_variable_B);
|
||||
TLLM_CHECK(params.is_variable_C);
|
||||
TLLM_CHECK(params.dstate == 16);
|
||||
|
||||
selective_scan_loop_kernel<input_t, weight_t><<<grid, block, 0, stream>>>(params);
|
||||
}
|
||||
|
||||
@ -412,15 +398,17 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
float dt_b_sp;
|
||||
if (dt_softplus)
|
||||
{
|
||||
dt_b_sp = dt_b <= 20.f ? logf(1.f + expf(dt_b)) : dt_b; // softplus
|
||||
// dt_b_sp = dt_b <= 20.f ? logf(1.f + expf(dt_b)) : dt_b; // softplus
|
||||
dt_b_sp = dt_b <= 20.f ? __logf(1.f + __expf(dt_b)) : dt_b; // softplus
|
||||
}
|
||||
|
||||
float out = 0.f;
|
||||
float out = D ? my_D * my_x : 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
float dA = expf(rA[i] * dt_b_sp);
|
||||
// float dA = expf(rA[i] * dt_b_sp);
|
||||
float dA = __expf(rA[i] * dt_b_sp);
|
||||
float dB = rB[i] * dt_b_sp;
|
||||
float sdA = rState[i] * dA;
|
||||
float dBx = dB * my_x;
|
||||
@ -429,11 +417,10 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
out += newState * rC[i];
|
||||
}
|
||||
|
||||
if (D)
|
||||
out += my_D * my_x;
|
||||
if (z)
|
||||
{
|
||||
float sig_z = 1.0 / (1.0 + exp(0.f - my_z));
|
||||
// float sig_z = 1.0 / (1.0 + exp(0.f - my_z));
|
||||
float sig_z = __fdividef(1.f, (1.f + __expf(0.f - my_z)));
|
||||
float silu_z = my_z * sig_z;
|
||||
out *= silu_z;
|
||||
}
|
||||
|
||||
@ -1347,10 +1347,11 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
|
||||
{
|
||||
case PositionEmbeddingType::kROPE_GPTJ:
|
||||
{
|
||||
mmha::apply_rotary_embedding(
|
||||
q, k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, dst_kv_seq_idx);
|
||||
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, 0,
|
||||
nullptr, dst_kv_seq_idx);
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kLONG_ROPE:
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
bool const do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
|
||||
@ -1379,7 +1380,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(q, k, transpose_idx / tidx_factor, rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale, dst_kv_seq_idx);
|
||||
rotary_embedding_scale, 0, nullptr, dst_kv_seq_idx);
|
||||
|
||||
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
@ -1469,9 +1470,10 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, T cons
|
||||
// To implement rotary embeddings, each thread processes two QKV elems:
|
||||
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
|
||||
dim3 grid(token_num, head_num);
|
||||
size_t smem_size
|
||||
= (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX ? 2 * rotary_embedding_dim * sizeof(T)
|
||||
: 0);
|
||||
size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|
||||
? 2 * rotary_embedding_dim * sizeof(T)
|
||||
: 0);
|
||||
// NOTE: add offset for rotary embedding
|
||||
if (qkv_bias != nullptr)
|
||||
{
|
||||
@ -1858,9 +1860,10 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa
|
||||
case PositionEmbeddingType::kROPE_GPTJ:
|
||||
{
|
||||
mmha::apply_rotary_embedding(
|
||||
k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, token_pos_idx);
|
||||
k, tidx, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, 0, nullptr, token_pos_idx);
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kLONG_ROPE:
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
bool const do_rotary = vec_size * tidx < rotary_embedding_dim;
|
||||
@ -1885,7 +1888,7 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa
|
||||
{
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
mmha::apply_rotary_embedding(k, transpose_idx / tidx_factor, rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale, token_pos_idx);
|
||||
rotary_embedding_scale, 0, nullptr, token_pos_idx);
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
}
|
||||
|
||||
@ -1919,8 +1922,10 @@ void invokeShiftKCache(KVCacheBuffer const& kvCacheBuffer, KVLinearBuffer const&
|
||||
int const vec_size = 16u / sizeof(T);
|
||||
dim3 block((sizePerHead / vec_size + 31) / 32 * 32);
|
||||
dim3 grid(token_num_in_k, kv_head_num, batch_beam);
|
||||
size_t smem_size
|
||||
= (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX ? 2 * rotary_embedding_dim * sizeof(T) : 0);
|
||||
size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|
||||
? 2 * rotary_embedding_dim * sizeof(T)
|
||||
: 0);
|
||||
|
||||
if (cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
|
||||
@ -77,7 +77,7 @@ struct QKVPreprocessingParams
|
||||
float const* rotary_embedding_inv_freq;
|
||||
float2 const* rotary_coef_cache_buffer;
|
||||
float const* kvScaleOrigQuant;
|
||||
int const* medusa_position_offsets;
|
||||
int const* spec_decoding_position_offsets;
|
||||
|
||||
// Scalars.
|
||||
int const batch_size;
|
||||
@ -101,6 +101,8 @@ struct QKVPreprocessingParams
|
||||
bool const enable_paged_kv_fmha;
|
||||
bool const quantized_fp8_output;
|
||||
int const multi_processor_count;
|
||||
int const rotary_vision_start;
|
||||
int const rotary_vision_length;
|
||||
|
||||
// Pre-compute on host.
|
||||
int half_rotary_dim;
|
||||
|
||||
@ -217,7 +217,8 @@ struct Rotary_base_t<float, RotaryPositionEmbeddingType::GPTJ>
|
||||
template <typename VecType, typename T, int VEC_SIZE, bool RECOMPUTE>
|
||||
inline __device__ void apply_rotary_embedding_gptneox(VecType& q, VecType& q_pair, VecType& k, VecType& k_pair,
|
||||
bool first_half, float2 (&rotary_coef_cache)[VEC_SIZE], float const* rotary_inv_freq_buffer,
|
||||
int const rotary_dim_idx, int const half_rotary_dim, int const rotary_position)
|
||||
int const rotary_dim_idx, int const half_rotary_dim, int const rotary_position, int const vision_start = -1,
|
||||
int const vision_length = -1)
|
||||
{
|
||||
// Each thread holds NUM_ELTS elements.
|
||||
// Currently we apply the rotary embedding in float data type for accuracy reasons.
|
||||
@ -234,8 +235,25 @@ inline __device__ void apply_rotary_embedding_gptneox(VecType& q, VecType& q_pai
|
||||
|
||||
if (RECOMPUTE)
|
||||
{
|
||||
float const rotary_inv_freq
|
||||
= float(rotary_position) * rotary_inv_freq_buffer[min(rotary_dim_idx + elt_id, half_rotary_dim - 1)];
|
||||
int real_rotary_position = rotary_position;
|
||||
if (vision_start != -1 && vision_length != -1)
|
||||
{
|
||||
int t_step_int = rotary_position;
|
||||
if (t_step_int <= vision_start)
|
||||
{
|
||||
real_rotary_position = t_step_int;
|
||||
}
|
||||
else if (t_step_int > vision_start && t_step_int <= (vision_length + vision_start))
|
||||
{
|
||||
real_rotary_position = vision_start + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
real_rotary_position = t_step_int - (vision_length - 1);
|
||||
}
|
||||
}
|
||||
float const rotary_inv_freq = float(real_rotary_position)
|
||||
* rotary_inv_freq_buffer[min(rotary_dim_idx + elt_id, half_rotary_dim - 1)];
|
||||
rotary_coef_cache[elt_id] = make_float2(cosf(rotary_inv_freq), sinf(rotary_inv_freq));
|
||||
}
|
||||
|
||||
@ -356,10 +374,10 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
int const token_idx_in_seq = (cache_seq_len - actual_seq_len) + local_token_idx;
|
||||
bool const valid_token = token_idx_in_seq < cache_seq_len;
|
||||
|
||||
// NOTE: only Medusa needs the position offsets.
|
||||
// NOTE: only spec decoding needs the position offsets.
|
||||
// In the generation phase, we assume all sequences should have the same input length.
|
||||
int const rotary_position = params.medusa_position_offsets != nullptr
|
||||
? (params.medusa_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
|
||||
int const rotary_position = params.spec_decoding_position_offsets != nullptr
|
||||
? (params.spec_decoding_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
|
||||
+ cache_seq_len - actual_seq_len)
|
||||
: token_idx_in_seq;
|
||||
|
||||
@ -445,7 +463,8 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
apply_rotary_embedding_gptneox<VecType, BaseType, ROTARY_COEF_VEC_SIZE, true>(q, q_pair, k, k_pair,
|
||||
first_half, rotary_coef_cache,
|
||||
params.rotary_embedding_inv_freq + batch_idx * params.half_rotary_dim, gptneox_rotary_dim_idx,
|
||||
params.half_rotary_dim, rotary_position);
|
||||
params.half_rotary_dim, rotary_position, params.rotary_vision_start,
|
||||
params.rotary_vision_length);
|
||||
cached_rotary_position = rotary_position;
|
||||
}
|
||||
else
|
||||
@ -453,7 +472,8 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
apply_rotary_embedding_gptneox<VecType, BaseType, ROTARY_COEF_VEC_SIZE, false>(q, q_pair, k, k_pair,
|
||||
first_half, rotary_coef_cache,
|
||||
params.rotary_embedding_inv_freq + batch_idx * params.half_rotary_dim, gptneox_rotary_dim_idx,
|
||||
params.half_rotary_dim, rotary_position);
|
||||
params.half_rotary_dim, rotary_position, params.rotary_vision_start,
|
||||
params.rotary_vision_length);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -670,11 +690,11 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
local_token_idx = std::min(local_token_idx, actual_seq_len - 1);
|
||||
int const global_token_idx = local_token_idx + global_token_offset;
|
||||
|
||||
// NOTE: only Medusa needs the position offsets.
|
||||
// NOTE: only spec decoding needs the position offsets.
|
||||
// In the generation phase, we assume all sequences should have the same input length.
|
||||
int const rotary_position = params.medusa_position_offsets != nullptr
|
||||
? (params.medusa_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len] + cache_seq_len
|
||||
- actual_seq_len)
|
||||
int const rotary_position = params.spec_decoding_position_offsets != nullptr
|
||||
? (params.spec_decoding_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
|
||||
+ cache_seq_len - actual_seq_len)
|
||||
: token_idx_in_kv_cache;
|
||||
|
||||
// head_num == kv_head_num:
|
||||
@ -837,7 +857,8 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
dim3 grid(params.max_input_seq_len, params.head_num); \
|
||||
grid.z = std::min(int(divUp(params.multi_processor_count * WARPS_PER_SM, grid.x * grid.y)), \
|
||||
int(divUp(params.batch_size, MIN_SEQUENCES_PER_WARP))); \
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX) \
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE) \
|
||||
{ \
|
||||
applyBiasRopeUpdateKVCache<T, TCache, Dh_MAX, ADD_BIAS, STORE_QKV, KVCacheBuffer, \
|
||||
RotaryPositionEmbeddingType::GPT_NEOX, DYNAMIC_ROTARY_SCALING><<<grid, block, 0, stream>>>(params); \
|
||||
@ -863,7 +884,8 @@ void kernelDispatchHeadSize(QKVPreprocessingParams<T, KVCacheBuffer> params, cud
|
||||
|
||||
constexpr int VEC_SIZE = Rotary_vec_t<T, Dh_MAX>::size;
|
||||
// Make sure we have multiple of paired vectors so that the access is aligned.
|
||||
TLLM_CHECK_WITH_INFO(params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
TLLM_CHECK_WITH_INFO((params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE)
|
||||
|| params.half_rotary_dim % VEC_SIZE == 0,
|
||||
"Rotary dim size is not supported.");
|
||||
|
||||
@ -946,7 +968,8 @@ void kernelV1Dispatch(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStrea
|
||||
#define APPLY_BIAS_ROPE_UPDATE_KV_CACHE_V2(ADD_BIAS, STORE_QKV) \
|
||||
dim3 block(BLOCK_SIZE); \
|
||||
dim3 grid(int(divUp(params.max_input_seq_len, tokens_per_cuda_block)), params.batch_size, params.head_num); \
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX) \
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE) \
|
||||
{ \
|
||||
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, KVCacheBuffer, \
|
||||
RotaryPositionEmbeddingType::GPT_NEOX><<<grid, block, 0, stream>>>(params); \
|
||||
@ -1010,7 +1033,8 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(QKVPreprocessingParams<T, KVCacheB
|
||||
bool const has_rotary_cos_sin_cache = params.rotary_coef_cache_buffer != nullptr;
|
||||
bool const has_sink_tokens = params.sink_token_len > 0;
|
||||
// V2 implementation requires multiple of paired 16 bytes for gpt-neox rotation.
|
||||
bool const support_rotary_for_v2 = params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
bool const support_rotary_for_v2 = (params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE)
|
||||
|| params.rotary_embedding_dim % 16 == 0;
|
||||
|
||||
if (long_seq_rotary_support || !has_rotary_cos_sin_cache || has_sink_tokens || !support_rotary_for_v2)
|
||||
|
||||
75
cpp/tensorrt_llm/layers/lookaheadDecodingUtils.cpp
Normal file
75
cpp/tensorrt_llm/layers/lookaheadDecodingUtils.cpp
Normal file
@ -0,0 +1,75 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
void printTokens2d(char const* name, TensorPtr const& tensor)
|
||||
{
|
||||
auto M = tensor->getShape().d[0];
|
||||
auto N = tensor->getShape().d[1];
|
||||
auto tr = BufferRange<TokenIdType>(*tensor);
|
||||
std::ostringstream buf;
|
||||
buf << name << ": " << tensor->getShape() << "(\n";
|
||||
for (SizeType mi = 0; mi < M; mi++)
|
||||
{
|
||||
for (SizeType ni = 0; ni < N; ni++)
|
||||
{
|
||||
auto token = tr[mi * N + ni];
|
||||
if (token >= 0 && token <= 255)
|
||||
{
|
||||
buf << "'" << static_cast<char>(token) << "'";
|
||||
}
|
||||
else
|
||||
{
|
||||
buf << token << "'";
|
||||
}
|
||||
buf << (ni == (N - 1) ? ';' : ',');
|
||||
}
|
||||
if (mi != M - 1)
|
||||
{
|
||||
buf << std::endl;
|
||||
}
|
||||
}
|
||||
buf << ")" << std::endl;
|
||||
TLLM_LOG_DEBUG(buf.str());
|
||||
}
|
||||
|
||||
void printTokens(char const* name, TensorPtr const& tensor)
|
||||
{
|
||||
std::ostringstream buf;
|
||||
buf << name << ": " << tensor->getShape() << "(";
|
||||
for (auto const& token : BufferRange<TokenIdType>(*tensor))
|
||||
{
|
||||
if (token >= 0 && token <= 255)
|
||||
{
|
||||
buf << "'" << static_cast<char>(token) << "',";
|
||||
}
|
||||
else
|
||||
{
|
||||
buf << token << ",";
|
||||
}
|
||||
}
|
||||
buf << ")" << std::endl << std::flush;
|
||||
TLLM_LOG_DEBUG(buf.str());
|
||||
}
|
||||
|
||||
void printTensor(char const* name, TensorPtr const& tensor)
|
||||
{
|
||||
std::ostringstream buf;
|
||||
buf << name << ": " << tensor->getShape() << "(";
|
||||
for (auto const& token : BufferRange<TokenIdType>(*tensor))
|
||||
{
|
||||
buf << token << ",";
|
||||
}
|
||||
buf << ")" << std::endl << std::flush;
|
||||
TLLM_LOG_DEBUG(buf.str());
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
17
cpp/tensorrt_llm/layers/lookaheadDecodingUtils.h
Normal file
17
cpp/tensorrt_llm/layers/lookaheadDecodingUtils.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
void printTokens(char const* name, runtime::ITensor::SharedPtr const& tensor);
|
||||
#define PRINT_TOKENS(x) tensorrt_llm::layers::printTokens(#x, x)
|
||||
|
||||
void printTokens2d(char const* name, runtime::ITensor::SharedPtr const& tensor);
|
||||
#define PRINT_TOKENS2D(x) tensorrt_llm::layers::printTokens2d(#x, x)
|
||||
|
||||
void printTensor(char const* name, runtime::ITensor::SharedPtr const& tensor);
|
||||
#define PRINT_TENSOR(x) tensorrt_llm::layers::printTensor(#x, x)
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
83
cpp/tensorrt_llm/layers/lookaheadPoolManager.cpp
Normal file
83
cpp/tensorrt_llm/layers/lookaheadPoolManager.cpp
Normal file
@ -0,0 +1,83 @@
|
||||
|
||||
|
||||
#include "tensorrt_llm/layers/lookaheadPoolManager.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
void LookaheadPoolManager::insertOne(Key key, TensorPtr ngram)
|
||||
{
|
||||
auto search = mTokenMap.find(key);
|
||||
if (search != mTokenMap.end())
|
||||
{
|
||||
search->second.remove_if(
|
||||
[&ngram](TensorPtr const& item)
|
||||
{
|
||||
auto ar = BufferRange<TokenIdType>(*ngram);
|
||||
auto br = BufferRange<TokenIdType>(*item);
|
||||
return std::equal(ar.begin(), ar.end(), br.begin());
|
||||
});
|
||||
if (mGuessSetSize >= 0 && search->second.size() >= mGuessSetSize)
|
||||
{
|
||||
search->second.pop_front();
|
||||
}
|
||||
search->second.push_back(ngram);
|
||||
}
|
||||
else
|
||||
{
|
||||
mTokenMap.insert({key, std::list<TensorPtr>({ngram})});
|
||||
}
|
||||
}
|
||||
|
||||
void LookaheadPoolManager::fillWithPrompt(TensorPtr prompt, SizeType level)
|
||||
{
|
||||
SizeType length = prompt->getShape().d[0];
|
||||
auto pr = BufferRange<Key>(*prompt);
|
||||
for (SizeType ti = 0; ti + level - 1 < length; ti++)
|
||||
{
|
||||
auto key = pr[ti];
|
||||
TensorPtr ngram
|
||||
= mBufferManager->copyFrom(*ITensor::slice(prompt, ti + 1, level - 1), runtime::MemoryType::kCPU);
|
||||
insertOne(key, ngram);
|
||||
}
|
||||
}
|
||||
|
||||
std::list<LookaheadPoolManager::TensorPtr> LookaheadPoolManager::guess(Key lastToken, SizeType guessSize) const
|
||||
{
|
||||
auto search = mTokenMap.find(lastToken);
|
||||
if (search != mTokenMap.end())
|
||||
{
|
||||
auto ngrams = search->second;
|
||||
if (ngrams.size() > guessSize)
|
||||
{
|
||||
auto it = std::prev(ngrams.end(), guessSize);
|
||||
return std::list<TensorPtr>(it, ngrams.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
return ngrams;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::list<TensorPtr>();
|
||||
}
|
||||
}
|
||||
|
||||
void LookaheadPoolManager::update(TensorPtr keyTokens, TensorPtr ngramTokens)
|
||||
{
|
||||
TLLM_CHECK(keyTokens->getShape().d[0] == ngramTokens->getShape().d[0]);
|
||||
auto kr = BufferRange<Key>(*keyTokens);
|
||||
auto window = ngramTokens->getShape().d[0];
|
||||
|
||||
for (SizeType wi = 0; wi < window; wi++)
|
||||
{
|
||||
TensorPtr ngram = mBufferManager->copyFrom(*ITensor::slice(ngramTokens, wi, 1), runtime::MemoryType::kCPU);
|
||||
ngram->squeeze(0);
|
||||
insertOne(kr[wi], ngram);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
57
cpp/tensorrt_llm/layers/lookaheadPoolManager.h
Normal file
57
cpp/tensorrt_llm/layers/lookaheadPoolManager.h
Normal file
@ -0,0 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! @brief A helper class for managing key-ngram pool.
|
||||
class LookaheadPoolManager
|
||||
{
|
||||
public:
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
using Key = runtime::TokenIdType;
|
||||
|
||||
LookaheadPoolManager(runtime::SizeType g, std::shared_ptr<runtime::BufferManager> bufferManager)
|
||||
: mGuessSetSize(g)
|
||||
, mBufferManager(bufferManager)
|
||||
{
|
||||
}
|
||||
|
||||
//! @brief fill token map from prompt
|
||||
//! @param prompt the user input prompt, [length] on cpu
|
||||
//! @param level the n-gram length
|
||||
void fillWithPrompt(TensorPtr prompt, runtime::SizeType level);
|
||||
|
||||
//! @brief get a list of guess tokens
|
||||
//! @param lastToken the newest golden token
|
||||
//! @param guessSize at most guessSize candidates returned
|
||||
//! @return the list guess tokens, with list size <= guessSize
|
||||
std::list<TensorPtr> guess(Key lastToken, runtime::SizeType guessSize) const;
|
||||
|
||||
//! @brief update token map with new generated tokens
|
||||
//! @param keyTokens the new shifted out tokens from each window, as the key, [window] on cpu
|
||||
//! @param ngramTokens the new shifted lookahead window, as the ngrams, [window, ngramLen] on cpu
|
||||
void update(TensorPtr keyTokens, TensorPtr ngramTokens);
|
||||
|
||||
std::unordered_map<Key, std::list<TensorPtr>> const& getMap() const
|
||||
{
|
||||
return mTokenMap;
|
||||
}
|
||||
|
||||
private:
|
||||
void insertOne(Key key, TensorPtr ngram);
|
||||
|
||||
private:
|
||||
std::shared_ptr<runtime::BufferManager> mBufferManager;
|
||||
//! @brief the token map with token as key and list of n-gram as value
|
||||
std::unordered_map<Key, std::list<TensorPtr>> mTokenMap;
|
||||
//! @brief guess set size, -1 for infinite size
|
||||
runtime::SizeType mGuessSetSize;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
@ -47,7 +47,8 @@ set(PLUGIN_LISTS
|
||||
mixtureOfExperts
|
||||
selectiveScanPlugin
|
||||
mambaConv1dPlugin
|
||||
lruPlugin)
|
||||
lruPlugin
|
||||
cumsumLastDimPlugin)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
include_directories(${PLUGIN_ITER})
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h"
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
|
||||
#include "tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h"
|
||||
#include "tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h"
|
||||
#include "tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.h"
|
||||
@ -193,6 +194,7 @@ extern "C"
|
||||
static tensorrt_llm::plugins::SelectiveScanPluginCreator selectiveScanPluginCreator;
|
||||
static tensorrt_llm::plugins::MambaConv1dPluginCreator mambaConv1DPluginCreator;
|
||||
static tensorrt_llm::plugins::lruPluginCreator lruPluginCreator;
|
||||
static tensorrt_llm::plugins::CumsumLastDimPluginCreator cumsumLastDimPluginCreator;
|
||||
|
||||
static std::array pluginCreators
|
||||
= { creatorPtr(identityPluginCreator),
|
||||
@ -219,6 +221,7 @@ extern "C"
|
||||
creatorPtr(selectiveScanPluginCreator),
|
||||
creatorPtr(mambaConv1DPluginCreator),
|
||||
creatorPtr(lruPluginCreator),
|
||||
creatorPtr(cumsumLastDimPluginCreator),
|
||||
};
|
||||
nbCreators = pluginCreators.size();
|
||||
return pluginCreators.data();
|
||||
|
||||
@ -16,30 +16,32 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "pluginUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
struct GemmDims
|
||||
{
|
||||
int32_t minM;
|
||||
int32_t maxM;
|
||||
int32_t n;
|
||||
int32_t k;
|
||||
using DimType = utils::DimType;
|
||||
|
||||
DimType minM;
|
||||
DimType maxM;
|
||||
DimType n;
|
||||
DimType k;
|
||||
|
||||
GemmDims()
|
||||
: minM(-1)
|
||||
@ -49,7 +51,7 @@ struct GemmDims
|
||||
{
|
||||
}
|
||||
|
||||
GemmDims(int32_t minM_, int32_t maxM_, int32_t n_, int32_t k_)
|
||||
GemmDims(DimType minM_, DimType maxM_, DimType n_, DimType k_)
|
||||
: minM(minM_)
|
||||
, maxM(maxM_)
|
||||
, n(n_)
|
||||
@ -57,7 +59,7 @@ struct GemmDims
|
||||
{
|
||||
}
|
||||
|
||||
bool isInitialized() const
|
||||
[[nodiscard]] bool isInitialized() const
|
||||
{
|
||||
return minM >= 0 && maxM >= 0 && n >= 0 && k >= 0;
|
||||
}
|
||||
|
||||
@ -22,21 +22,18 @@
|
||||
#include "tensorrt_llm/plugins/common/checkMacrosPlugin.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <NvInferRuntimeBase.h>
|
||||
#include <cstring>
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include <nccl.h>
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
66
cpp/tensorrt_llm/plugins/common/pluginUtils.h
Normal file
66
cpp/tensorrt_llm/plugins/common/pluginUtils.h
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
namespace tensorrt_llm::plugins::utils
|
||||
{
|
||||
using DimType = int32_t;
|
||||
|
||||
inline DimType computeMDimension(bool transA, nvinfer1::Dims const& dims)
|
||||
{
|
||||
DimType M{1};
|
||||
if (transA)
|
||||
{
|
||||
for (int i = dims.nbDims - 1; i > 0; --i)
|
||||
{
|
||||
M *= dims.d[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < dims.nbDims - 1; ++i)
|
||||
{
|
||||
M *= dims.d[i];
|
||||
}
|
||||
}
|
||||
return M;
|
||||
}
|
||||
|
||||
inline DimType computeNDimension(bool transB, nvinfer1::Dims const& dims)
|
||||
{
|
||||
DimType N{1};
|
||||
if (transB)
|
||||
{
|
||||
for (int32_t i = 0; i < dims.nbDims - 1; ++i)
|
||||
{
|
||||
N *= dims.d[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int32_t i = dims.nbDims - 1; i > 0; --i)
|
||||
{
|
||||
N *= dims.d[i];
|
||||
}
|
||||
}
|
||||
return N;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::plugins::utils
|
||||
22
cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/CMakeLists.txt
Normal file
22
cpp/tensorrt_llm/plugins/cumsumLastDimPlugin/CMakeLists.txt
Normal file
@ -0,0 +1,22 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
|
||||
file(GLOB SRCS *.cpp)
|
||||
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
|
||||
set(PLUGIN_SOURCES
|
||||
${PLUGIN_SOURCES}
|
||||
PARENT_SCOPE)
|
||||
@ -0,0 +1,276 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cumsumLastDimPlugin.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::common;
|
||||
using tensorrt_llm::plugins::CumsumLastDimPluginCreator;
|
||||
using tensorrt_llm::plugins::CumsumLastDimPlugin;
|
||||
|
||||
static char const* CUMSUM_LAST_DIM_PLUGIN_VERSION{"1"};
|
||||
static char const* CUMSUM_LAST_DIM_PLUGIN_NAME{"CumsumLastDim"};
|
||||
PluginFieldCollection CumsumLastDimPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> CumsumLastDimPluginCreator::mPluginAttributes;
|
||||
|
||||
CumsumLastDimPlugin::CumsumLastDimPlugin(int input_length, nvinfer1::DataType type)
|
||||
: mInputLength(input_length)
|
||||
, mType(type)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16),
|
||||
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF)
|
||||
|| (mType == DataType::kINT32),
|
||||
"Only support int, float, half, and bfloat16.");
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
CumsumLastDimPlugin::CumsumLastDimPlugin(void const* data, size_t length)
|
||||
{
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mInputLength);
|
||||
read(d, mType);
|
||||
TLLM_CHECK(d == a + length);
|
||||
TLLM_CHECK_WITH_INFO((getSMVersion() >= 80) || (mType != DataType::kBF16), "Unsupported data type");
|
||||
TLLM_CHECK_WITH_INFO((mType == DataType::kBF16) || (mType == DataType::kFLOAT) || (mType == DataType::kHALF)
|
||||
|| (mType == DataType::kINT32),
|
||||
"Only support int, float, half, and bfloat16.");
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* CumsumLastDimPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new CumsumLastDimPlugin(mInputLength, mType);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
// Outputs
|
||||
// output_tensor: [batch_size, input_length]
|
||||
nvinfer1::DimsExprs CumsumLastDimPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(outputIndex == 0, "Only one output.");
|
||||
return inputs[getInputTensorIdx()];
|
||||
}
|
||||
|
||||
bool CumsumLastDimPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
|
||||
void CumsumLastDimPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
|
||||
{
|
||||
}
|
||||
|
||||
size_t CumsumLastDimPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
if (mType == DataType::kINT32)
|
||||
{
|
||||
return invokeComputeCumsumLastDimWorkspaceSize<int>(mInputLength);
|
||||
}
|
||||
else if (mType == DataType::kHALF)
|
||||
{
|
||||
return invokeComputeCumsumLastDimWorkspaceSize<half>(mInputLength);
|
||||
}
|
||||
else if (mType == DataType::kFLOAT)
|
||||
{
|
||||
return invokeComputeCumsumLastDimWorkspaceSize<float>(mInputLength);
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (mType == DataType::kBF16)
|
||||
{
|
||||
return invokeComputeCumsumLastDimWorkspaceSize<__nv_bfloat16>(mInputLength);
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int CumsumLastDimPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// inputs
|
||||
// 0. input_tensor [batch_size, input_length]
|
||||
// outputs
|
||||
// 0. output_tensor [batch_size, input_length]
|
||||
auto const batch_size = inputDesc[getInputTensorIdx()].dims.d[0];
|
||||
size_t temp_storage_bytes = invokeComputeCumsumLastDimWorkspaceSize<T>(mInputLength);
|
||||
invokeCumsumLastDim<T>(
|
||||
batch_size, mInputLength, inputs[getInputTensorIdx()], outputs[0], workspace, temp_storage_bytes, stream);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CumsumLastDimPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
if (mType == DataType::kINT32)
|
||||
{
|
||||
return enqueueImpl<int>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
else if (mType == DataType::kHALF)
|
||||
{
|
||||
return enqueueImpl<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
else if (mType == DataType::kFLOAT)
|
||||
{
|
||||
return enqueueImpl<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (mType == DataType::kBF16)
|
||||
{
|
||||
return enqueueImpl<__nv_bfloat16>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType CumsumLastDimPlugin::getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(index == 0, "Only one output.");
|
||||
return inputTypes[getInputTensorIdx()];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
|
||||
char const* CumsumLastDimPlugin::getPluginType() const noexcept
|
||||
{
|
||||
return CUMSUM_LAST_DIM_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* CumsumLastDimPlugin::getPluginVersion() const noexcept
|
||||
{
|
||||
return CUMSUM_LAST_DIM_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int CumsumLastDimPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
int CumsumLastDimPlugin::initialize() noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CumsumLastDimPlugin::terminate() noexcept {}
|
||||
|
||||
size_t CumsumLastDimPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mInputLength) + sizeof(mType);
|
||||
}
|
||||
|
||||
void CumsumLastDimPlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mInputLength);
|
||||
write(d, mType);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
void CumsumLastDimPlugin::destroy() noexcept
|
||||
{
|
||||
delete this;
|
||||
}
|
||||
|
||||
///////////////
|
||||
|
||||
CumsumLastDimPluginCreator::CumsumLastDimPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("mInputLength", nullptr, PluginFieldType::kINT32, 49152));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
char const* CumsumLastDimPluginCreator::getPluginName() const noexcept
|
||||
{
|
||||
return CUMSUM_LAST_DIM_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* CumsumLastDimPluginCreator::getPluginVersion() const noexcept
|
||||
{
|
||||
return CUMSUM_LAST_DIM_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
PluginFieldCollection const* CumsumLastDimPluginCreator::getFieldNames() noexcept
|
||||
{
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2* CumsumLastDimPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int input_length;
|
||||
nvinfer1::DataType type;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
char const* attrName = fields[i].name;
|
||||
if (!strcmp(attrName, "input_length"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
input_length = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
|
||||
}
|
||||
}
|
||||
try
|
||||
{
|
||||
auto* obj = new CumsumLastDimPlugin(input_length, type);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IPluginV2* CumsumLastDimPluginCreator::deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept
|
||||
{
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call CumsumLastDimPlugin::destroy()
|
||||
try
|
||||
{
|
||||
auto* obj = new CumsumLastDimPlugin(serialData, serialLength);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -0,0 +1,98 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef TRT_CUMSUM_LAST_DIM_PLUGIN_H
|
||||
#define TRT_CUMSUM_LAST_DIM_PLUGIN_H
|
||||
|
||||
#include "tensorrt_llm/kernels/cumsumLastDim.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
class CumsumLastDimPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
CumsumLastDimPlugin(int mInputLength, nvinfer1::DataType type);
|
||||
CumsumLastDimPlugin(void const* data, size_t length);
|
||||
~CumsumLastDimPlugin() override = default;
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
|
||||
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
|
||||
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
||||
template <typename T>
|
||||
int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream);
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
char const* getPluginType() const noexcept override;
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
private:
|
||||
using IndexType = std::int32_t;
|
||||
|
||||
IndexType getInputTensorIdx() const
|
||||
{
|
||||
return 0;
|
||||
};
|
||||
|
||||
private:
|
||||
int mInputLength;
|
||||
nvinfer1::DataType mType;
|
||||
};
|
||||
|
||||
class CumsumLastDimPluginCreator : public BaseCreator
|
||||
{
|
||||
public:
|
||||
CumsumLastDimPluginCreator();
|
||||
char const* getPluginName() const noexcept override;
|
||||
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
|
||||
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
|
||||
#endif
|
||||
@ -14,13 +14,20 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "gemmPlugin.h"
|
||||
|
||||
#include "gemmPluginProfiler.h"
|
||||
#include "plugin.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
#include "pluginUtils.h"
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <cassert>
|
||||
|
||||
using namespace nvinfer1;
|
||||
using namespace tensorrt_llm::common;
|
||||
using tensorrt_llm::plugins::GemmDims;
|
||||
using tensorrt_llm::plugins::GemmPluginCreator;
|
||||
using tensorrt_llm::plugins::GemmPlugin;
|
||||
using tensorrt_llm::plugins::CublasLtGemmPluginProfiler;
|
||||
@ -47,10 +54,13 @@ void getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int&
|
||||
}
|
||||
|
||||
void runGemm(int const M, int const N, int const K, bool const transA, bool const transB, int const padLda,
|
||||
int const padLdb, const nvinfer1::DataType type, CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act,
|
||||
int const padLdb, nvinfer1::DataType const type, CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act,
|
||||
void const* weight, void* output, std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic, void* workspace,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
if (M == 0 || N == 0 || K == 0)
|
||||
return;
|
||||
|
||||
cublasWrapperPtr->setStream(stream);
|
||||
cublasWrapperPtr->setWorkspace(workspace);
|
||||
|
||||
@ -291,60 +301,19 @@ bool GemmPlugin::supportsFormatCombination(
|
||||
return desc.type == mType || desc.type == nvinfer1::DataType::kFLOAT;
|
||||
}
|
||||
|
||||
int32_t computeMDimension(bool transA, const int32_t nbDims, tensorrt_llm::runtime::ITensor::DimType const* dims)
|
||||
{
|
||||
int32_t M = 1;
|
||||
if (transA)
|
||||
{
|
||||
for (int i = nbDims - 1; i > 0; --i)
|
||||
{
|
||||
M *= dims[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < nbDims - 1; ++i)
|
||||
{
|
||||
M *= dims[i];
|
||||
}
|
||||
}
|
||||
return M;
|
||||
}
|
||||
|
||||
int32_t computeNDimension(bool transB, const int32_t nbDims, tensorrt_llm::runtime::ITensor::DimType const* dims)
|
||||
{
|
||||
int32_t N = 1;
|
||||
if (transB)
|
||||
{
|
||||
for (int i = 0; i < nbDims - 1; ++i)
|
||||
{
|
||||
N *= dims[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = nbDims - 1; i > 0; --i)
|
||||
{
|
||||
N *= dims[i];
|
||||
}
|
||||
}
|
||||
return N;
|
||||
}
|
||||
|
||||
void GemmPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
|
||||
{
|
||||
int const nbDimsA = in[0].max.nbDims;
|
||||
int const nbDimsB = in[1].max.nbDims;
|
||||
auto const nbDimsA = in[0].max.nbDims;
|
||||
|
||||
auto const minM = computeMDimension(mTransA, nbDimsA, in[0].min.d);
|
||||
auto const maxM = computeMDimension(mTransA, nbDimsA, in[0].max.d);
|
||||
auto const N = computeNDimension(mTransB, nbDimsB, in[1].max.d);
|
||||
auto const K = mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1];
|
||||
auto const minM = utils::computeMDimension(mTransA, in[0].min);
|
||||
auto const maxM = utils::computeMDimension(mTransA, in[0].max);
|
||||
auto const N = utils::computeNDimension(mTransB, in[1].max);
|
||||
auto const K = static_cast<utils::DimType>(mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]);
|
||||
|
||||
if (!mDims.isInitialized())
|
||||
{
|
||||
mDims = {minM, maxM, N, static_cast<runtime::SizeType>(K)};
|
||||
mDims = {minM, maxM, N, K};
|
||||
}
|
||||
mGemmId.n = N;
|
||||
mGemmId.k = K;
|
||||
@ -370,13 +339,13 @@ int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
|
||||
setGemmConfig();
|
||||
|
||||
int const nbDimsA = inputDesc[0].dims.nbDims;
|
||||
int const nbDimsB = inputDesc[1].dims.nbDims;
|
||||
int const padM = mTransA ? mPadLda : 0;
|
||||
int const padN = mTransB ? 0 : mPadLdb;
|
||||
int const padK = mTransA ? 0 : mPadLda;
|
||||
auto const M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d) - padM;
|
||||
auto const N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d) - padN;
|
||||
int const K = mTransA ? inputDesc[0].dims.d[0] - padK : inputDesc[0].dims.d[nbDimsA - 1] - padK;
|
||||
auto const M = utils::computeMDimension(mTransA, inputDesc[0].dims) - padM;
|
||||
auto const N = utils::computeNDimension(mTransB, inputDesc[1].dims) - padN;
|
||||
int const K = static_cast<utils::DimType>(
|
||||
mTransA ? inputDesc[0].dims.d[0] - padK : inputDesc[0].dims.d[nbDimsA - 1] - padK);
|
||||
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
runGemm(M, N, K, mTransA, mTransB, mPadLda, mPadLdb, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0],
|
||||
|
||||
@ -16,11 +16,11 @@
|
||||
*/
|
||||
#ifndef TRT_GEMM_PLUGIN_H
|
||||
#define TRT_GEMM_PLUGIN_H
|
||||
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@ -68,7 +68,11 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
float rotary_embedding_base;
|
||||
RotaryScalingType rotary_embedding_scale_type;
|
||||
float rotary_embedding_scale;
|
||||
float rotary_embedding_m_scale;
|
||||
float const* rotary_embedding_scaling_factors;
|
||||
int rotary_embedding_max_positions;
|
||||
int rotary_cogvlm_vision_start;
|
||||
int rotary_cogvlm_vision_length;
|
||||
PositionEmbeddingType position_embedding_type;
|
||||
bool position_shift_enabled;
|
||||
int max_attention_window;
|
||||
@ -179,7 +183,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
xqaParams.max_distance = mMaxDistance;
|
||||
xqaParams.multi_block_mode = mMultiBlockMode;
|
||||
// Medusa mode will have multiple query tokens.
|
||||
xqaParams.multi_query_tokens = mIsMedusaEnabled;
|
||||
xqaParams.multi_query_tokens = mIsSpecDecodingEnabled;
|
||||
|
||||
if (mKVCacheQuantMode.hasInt8KvCache())
|
||||
{
|
||||
@ -209,7 +213,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
xqaParams.workspaces = generationsParams.workspace;
|
||||
xqaParams.batch_size = generationsParams.num_requests;
|
||||
xqaParams.beam_width = generationsParams.beam_width;
|
||||
// Medusa mode has generation input_length > 1.
|
||||
// Speculative decoding mode has generation input_length > 1.
|
||||
xqaParams.generation_input_length = generationsParams.input_seq_length;
|
||||
xqaParams.max_attention_window_size = generationsParams.max_attention_window;
|
||||
xqaParams.cyclic_attention_window_size = generationsParams.cyclic_attention_window_size;
|
||||
@ -222,12 +226,12 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
xqaParams.alibi_slopes = generationsParams.alibi_slopes;
|
||||
if (!forConfigurePlugin)
|
||||
{
|
||||
// Medusa (need to take new generated ids into consideration).
|
||||
TLLM_CHECK_WITH_INFO(!mIsMedusaEnabled || generationsParams.medusa_packed_mask != nullptr,
|
||||
"Medusa mode needs a valid packed_mask input tensor.");
|
||||
// Speculative decoding (need to take new generated ids into consideration).
|
||||
TLLM_CHECK_WITH_INFO(!mIsSpecDecodingEnabled || generationsParams.spec_decoding_packed_mask != nullptr,
|
||||
"Speculative decoding mode needs a valid packed_mask input tensor.");
|
||||
}
|
||||
xqaParams.medusa_packed_mask = generationsParams.medusa_packed_mask;
|
||||
xqaParams.medusa_position_offsets = generationsParams.medusa_position_offsets;
|
||||
xqaParams.spec_decoding_packed_mask = generationsParams.spec_decoding_packed_mask;
|
||||
xqaParams.spec_decoding_position_offsets = generationsParams.spec_decoding_position_offsets;
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -290,7 +294,11 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.rotary_embedding_base = input_params.rotary_embedding_base;
|
||||
params.rotary_embedding_scale_type = input_params.rotary_embedding_scale_type;
|
||||
params.rotary_embedding_scale = input_params.rotary_embedding_scale;
|
||||
params.rotary_embedding_m_scale = input_params.rotary_embedding_m_scale;
|
||||
params.rotary_embedding_scaling_factors = input_params.rotary_embedding_scaling_factors;
|
||||
params.rotary_embedding_max_positions = input_params.rotary_embedding_max_positions;
|
||||
params.rotary_cogvlm_vision_start = input_params.rotary_cogvlm_vision_start;
|
||||
params.rotary_cogvlm_vision_length = input_params.rotary_cogvlm_vision_length;
|
||||
params.position_embedding_type = input_params.position_embedding_type;
|
||||
params.position_shift_enabled = input_params.position_shift_enabled;
|
||||
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
|
||||
@ -358,19 +366,23 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
|
||||
#endif
|
||||
#undef INSTANTIATE_MMHA_DISPATCH
|
||||
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size,
|
||||
int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
bool unfuse_qkv_gemm, // for AutoPP
|
||||
float rotary_embedding_scale, float rotary_embedding_m_scale, int rotary_embedding_max_positions, int tp_size,
|
||||
int tp_rank, // for ALiBi
|
||||
bool unfuse_qkv_gemm, // for AutoPP
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, bool enable_xqa,
|
||||
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
|
||||
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_spec_decoding_enabled)
|
||||
: mLayerIdx(layer_idx)
|
||||
, mNumHeads(num_heads)
|
||||
, mVisionStart(vision_start)
|
||||
, mVisionLength(vision_length)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
, mHeadSize(head_size)
|
||||
, mUnidirectional(unidirectional)
|
||||
@ -379,6 +391,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
, mRotaryEmbeddingBase(rotary_embedding_base)
|
||||
, mRotaryEmbeddingScaleType(rotary_embedding_scale_type)
|
||||
, mRotaryEmbeddingScale(rotary_embedding_scale)
|
||||
, mRotaryEmbeddingMscale(rotary_embedding_m_scale)
|
||||
, mRotaryEmbeddingMaxPositions(rotary_embedding_max_positions)
|
||||
, mPositionEmbeddingType(position_embedding_type)
|
||||
, mEnableContextFMHA(context_fmha_type != ContextFMHAType::DISABLED)
|
||||
@ -404,7 +417,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
, mPagedContextFMHA(use_paged_context_fmha)
|
||||
, mFP8ContextFMHA(use_fp8_context_fmha)
|
||||
, mUseKVCache(use_cache)
|
||||
, mIsMedusaEnabled(is_medusa_enabled)
|
||||
, mIsSpecDecodingEnabled(is_spec_decoding_enabled)
|
||||
, mDriver(CUDADriverWrapper::getInstance())
|
||||
{
|
||||
// Pre-check whether FMHA is supported in order to save memory allocation.
|
||||
if (mEnableContextFMHA)
|
||||
@ -473,12 +487,15 @@ int const GPTAttentionPluginCommon::getHeadSize(bool checkInit) const
|
||||
|
||||
// Parameterized constructor
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t length)
|
||||
: mDriver(CUDADriverWrapper::getInstance())
|
||||
{
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
unsigned int kvCacheQuantMode;
|
||||
|
||||
read(d, mLayerIdx);
|
||||
read(d, mNumHeads);
|
||||
read(d, mVisionStart);
|
||||
read(d, mVisionLength);
|
||||
read(d, mNumKVHeads);
|
||||
read(d, mHeadSize);
|
||||
read(d, mUnidirectional);
|
||||
@ -488,6 +505,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
||||
read(d, mRotaryEmbeddingBase);
|
||||
read(d, mRotaryEmbeddingScaleType);
|
||||
read(d, mRotaryEmbeddingScale);
|
||||
read(d, mRotaryEmbeddingMscale);
|
||||
read(d, mRotaryEmbeddingMaxPositions);
|
||||
read(d, mTpSize);
|
||||
read(d, mTpRank);
|
||||
@ -511,11 +529,16 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
||||
read(d, mPagedContextFMHA);
|
||||
read(d, mFP8ContextFMHA);
|
||||
read(d, mUseKVCache);
|
||||
read(d, mIsMedusaEnabled);
|
||||
read(d, mIsSpecDecodingEnabled);
|
||||
read(d, mNbMultiBlockSemaphores);
|
||||
|
||||
mKVCacheQuantMode = tc::QuantMode(kvCacheQuantMode);
|
||||
|
||||
uint32_t decoderXQARunnerResourceSerializedSize;
|
||||
read(d, decoderXQARunnerResourceSerializedSize);
|
||||
mDecoderXQARunnerResource = DecoderXQARunner::Resource(d, decoderXQARunnerResourceSerializedSize);
|
||||
d += decoderXQARunnerResourceSerializedSize;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(d == a + length,
|
||||
"Expected length (%d) != real length (%d). This is often "
|
||||
"caused by using different TensorRT-LLM version to build "
|
||||
@ -860,7 +883,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
params.sink_token_length, params.num_tokens, mNumHeads, mNumKVHeads, mNumHeads / mNumKVHeads, getHeadSize(),
|
||||
mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
||||
mRotaryEmbeddingMaxPositions, position_embedding_type, mPosShiftEnabled, cache_type,
|
||||
enablePagedKVContextFMHA, mFP8ContextFMHA, mMultiProcessorCount};
|
||||
enablePagedKVContextFMHA, mFP8ContextFMHA, mMultiProcessorCount, mVisionStart, mVisionLength};
|
||||
|
||||
invokeQKVPreprocessing(preprocessingParms, stream);
|
||||
|
||||
@ -1247,9 +1270,9 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
mDecoderXQARunner->template dispatch<KVCacheBuffer>(xqaParams, kv_cache_buffer, stream);
|
||||
return 0;
|
||||
}
|
||||
else if (mIsMedusaEnabled)
|
||||
else if (mIsSpecDecodingEnabled)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "No available XQA kernels are found for medusa mode.");
|
||||
TLLM_CHECK_WITH_INFO(false, "No available XQA kernels are found for speculative decoding mode.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1367,8 +1390,12 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.rotary_embedding_base = mRotaryEmbeddingBase;
|
||||
dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType;
|
||||
dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale;
|
||||
dispatch_params.rotary_embedding_m_scale = mRotaryEmbeddingMscale;
|
||||
dispatch_params.rotary_embedding_scaling_factors = params.rotary_embedding_scaling_factors;
|
||||
dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions;
|
||||
dispatch_params.position_shift_enabled = mPosShiftEnabled;
|
||||
dispatch_params.rotary_cogvlm_vision_start = mVisionStart;
|
||||
dispatch_params.rotary_cogvlm_vision_length = mVisionLength;
|
||||
dispatch_params.cross_attention = mCrossAttention;
|
||||
dispatch_params.memory_length_per_sample = params.encoder_input_lengths;
|
||||
|
||||
@ -1487,7 +1514,7 @@ int GPTAttentionPluginCommon::initialize() noexcept
|
||||
mFMHARunner->setup_flags(mFMHAForceFP32Acc, !mRemovePadding, true, mNumKVHeads);
|
||||
}
|
||||
|
||||
bool useXQAKernels = (mEnableXQA || mIsMedusaEnabled) && !mCrossAttention
|
||||
bool useXQAKernels = (mEnableXQA || mIsSpecDecodingEnabled) && !mCrossAttention
|
||||
&& (mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16);
|
||||
|
||||
if (useXQAKernels)
|
||||
@ -1502,22 +1529,22 @@ int GPTAttentionPluginCommon::initialize() noexcept
|
||||
xqa_runner_data_type = DATA_TYPE_BF16;
|
||||
}
|
||||
TLLM_LOG_DEBUG("Enabling XQA kernels for GPTAttention.");
|
||||
if (mIsMedusaEnabled)
|
||||
if (mIsSpecDecodingEnabled)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mNumHeads % mNumKVHeads == 0, "mNumHeads should be multiples of mNumKVHeads.");
|
||||
int numQHeadsPerKV = mNumHeads / mNumKVHeads;
|
||||
bool isPowerOfTwo = ((numQHeadsPerKV & (numQHeadsPerKV - 1)) == 0);
|
||||
TLLM_CHECK_WITH_INFO(isPowerOfTwo,
|
||||
"numQHeadsPerKV should be power of 2 for Medusa, mNumHeads=%d, mNumKVHeads=%d.", mNumHeads,
|
||||
mNumKVHeads);
|
||||
"numQHeadsPerKV should be power of 2 for Speculative decoding, mNumHeads=%d, mNumKVHeads=%d.",
|
||||
mNumHeads, mNumKVHeads);
|
||||
}
|
||||
|
||||
mDecoderXQARunner.reset(
|
||||
new DecoderXQARunner(xqa_runner_data_type, mNumHeads, mNumKVHeads, mHeadSize, mMultiBlockMode));
|
||||
mDecoderXQARunner.reset(new DecoderXQARunner(
|
||||
&mDecoderXQARunnerResource, xqa_runner_data_type, mNumHeads, mNumKVHeads, mHeadSize, mMultiBlockMode));
|
||||
}
|
||||
else if (mIsMedusaEnabled)
|
||||
else if (mIsSpecDecodingEnabled)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Medusa mode doesn't support the data type or cross attention.");
|
||||
TLLM_CHECK_WITH_INFO(false, "Speculative decoding mode doesn't support the data type or cross attention.");
|
||||
}
|
||||
|
||||
if (mNbMultiBlockSemaphores != 0)
|
||||
@ -1533,18 +1560,20 @@ void GPTAttentionPluginCommon::destroy() noexcept
|
||||
delete this;
|
||||
}
|
||||
|
||||
size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
|
||||
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mLayerIdx) + sizeof(mNumHeads) + sizeof(mNumKVHeads) + sizeof(mHeadSize) + sizeof(mUnidirectional)
|
||||
+ sizeof(mQScaling) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
|
||||
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
|
||||
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
|
||||
+ sizeof(unsigned int) // mKVCacheQuantMode
|
||||
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
|
||||
+ sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mPositionEmbeddingType)
|
||||
+ sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType)
|
||||
+ sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingMscale) + sizeof(mRotaryEmbeddingMaxPositions)
|
||||
+ sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc)
|
||||
+ sizeof(mMultiBlockMode) + sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
|
||||
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance)
|
||||
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA)
|
||||
+ sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + sizeof(mIsMedusaEnabled) + sizeof(mNbMultiBlockSemaphores);
|
||||
+ sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + sizeof(mIsSpecDecodingEnabled)
|
||||
+ sizeof(mNbMultiBlockSemaphores) + sizeof(uint32_t) // size of mDecoderXQARunnerResource buffer.
|
||||
+ mDecoderXQARunnerResource.getSerializationSize();
|
||||
}
|
||||
|
||||
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
@ -1552,6 +1581,8 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mLayerIdx);
|
||||
write(d, mNumHeads);
|
||||
write(d, mVisionStart);
|
||||
write(d, mVisionLength);
|
||||
write(d, mNumKVHeads);
|
||||
write(d, mHeadSize);
|
||||
write(d, mUnidirectional);
|
||||
@ -1561,6 +1592,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mRotaryEmbeddingBase);
|
||||
write(d, mRotaryEmbeddingScaleType);
|
||||
write(d, mRotaryEmbeddingScale);
|
||||
write(d, mRotaryEmbeddingMscale);
|
||||
write(d, mRotaryEmbeddingMaxPositions);
|
||||
write(d, mTpSize);
|
||||
write(d, mTpRank);
|
||||
@ -1584,8 +1616,15 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mPagedContextFMHA);
|
||||
write(d, mFP8ContextFMHA);
|
||||
write(d, mUseKVCache);
|
||||
write(d, mIsMedusaEnabled);
|
||||
write(d, mIsSpecDecodingEnabled);
|
||||
write(d, mNbMultiBlockSemaphores);
|
||||
|
||||
// An uint32_t that specifies the size of the serialized buffer, followed by the actual content.
|
||||
uint32_t decoderXQARunnerResourceSerializedSize = mDecoderXQARunnerResource.getSerializationSize();
|
||||
write(d, decoderXQARunnerResourceSerializedSize);
|
||||
mDecoderXQARunnerResource.serialize(d, decoderXQARunnerResourceSerializedSize);
|
||||
d += decoderXQARunnerResourceSerializedSize;
|
||||
|
||||
assert(d == a + getCommonSerializationSize());
|
||||
}
|
||||
|
||||
@ -1630,6 +1669,8 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("num_heads", nullptr, PluginFieldType::kINT32, -1));
|
||||
mPluginAttributes.emplace_back(PluginField("vision_start", nullptr, PluginFieldType::kINT32, -1));
|
||||
mPluginAttributes.emplace_back(PluginField("vision_length", nullptr, PluginFieldType::kINT32, -1));
|
||||
mPluginAttributes.emplace_back(PluginField("num_kv_heads", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("head_size", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("unidirectional", nullptr, PluginFieldType::kINT32, 1));
|
||||
@ -1639,6 +1680,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.emplace_back(PluginField("rotary_embedding_base", nullptr, PluginFieldType::kFLOAT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("rotary_embedding_scale_type", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("rotary_embedding_scale", nullptr, PluginFieldType::kFLOAT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("rotary_embedding_m_scale", nullptr, PluginFieldType::kFLOAT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("rotary_embedding_max_positions", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("tp_size", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("tp_rank", nullptr, PluginFieldType::kINT32, 0));
|
||||
@ -1661,7 +1703,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.emplace_back(PluginField("use_paged_context_fmha", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("use_fp8_context_fmha", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("use_cache", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("is_medusa_enabled", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("is_spec_decoding_enabled", nullptr, PluginFieldType::kINT8, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
@ -37,18 +37,20 @@ class GPTAttentionPluginCommon : public BasePlugin
|
||||
public:
|
||||
GPTAttentionPluginCommon() = delete;
|
||||
|
||||
GPTAttentionPluginCommon(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional,
|
||||
float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||
int head_size, int unidirectional, float q_scaling,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, int rotary_embedding_max_positions, int tp_size, int tp_rank, // for ALiBi
|
||||
bool unfuse_qkv_gemm, // for AutoPP
|
||||
float rotary_embedding_scale, float rotary_embedding_m_scale, int rotary_embedding_max_positions, int tp_size,
|
||||
int tp_rank, // for ALiBi
|
||||
bool unfuse_qkv_gemm, // for AutoPP
|
||||
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, bool multi_block_mode, bool enable_xqa,
|
||||
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false,
|
||||
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false,
|
||||
bool use_cache = true, bool is_medusa_enabled = false);
|
||||
bool use_cache = true, bool is_spec_decoding_enabled = false);
|
||||
|
||||
GPTAttentionPluginCommon(void const* data, size_t length);
|
||||
|
||||
@ -73,7 +75,7 @@ public:
|
||||
//! So plugin should put the resource release inside destroy.
|
||||
void destroy() noexcept override;
|
||||
|
||||
static size_t getCommonSerializationSize() noexcept;
|
||||
size_t getCommonSerializationSize() const noexcept;
|
||||
void serializeCommon(void* buffer) const noexcept;
|
||||
int const getHeadSize(bool checkInit = true) const;
|
||||
|
||||
@ -144,6 +146,7 @@ protected:
|
||||
float const* kv_scale_orig_quant;
|
||||
float const* kv_scale_quant_orig;
|
||||
float const* attention_output_orig_quant;
|
||||
float const* rotary_embedding_scaling_factors;
|
||||
T const* alibi_slopes;
|
||||
T* context_buf;
|
||||
void* key_value_cache;
|
||||
@ -168,10 +171,10 @@ protected:
|
||||
// optional when cross attention
|
||||
int32_t const* encoder_input_lengths = nullptr;
|
||||
int32_t const* host_context_lengths = nullptr;
|
||||
// optional when medusa is used.
|
||||
bool const* medusa_mask = nullptr;
|
||||
int32_t const* medusa_packed_mask = nullptr;
|
||||
int32_t const* medusa_position_offsets = nullptr;
|
||||
// optional when speculative decoding is used.
|
||||
bool const* spec_decoding_mask = nullptr;
|
||||
int32_t const* spec_decoding_packed_mask = nullptr;
|
||||
int32_t const* spec_decoding_position_offsets = nullptr;
|
||||
};
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
@ -204,7 +207,13 @@ protected:
|
||||
bool isRoPE() const
|
||||
{
|
||||
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPTJ
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX;
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE;
|
||||
}
|
||||
|
||||
bool isLongRoPE() const
|
||||
{
|
||||
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE;
|
||||
}
|
||||
|
||||
bool isCrossAttention() const
|
||||
@ -228,6 +237,8 @@ protected:
|
||||
|
||||
int mLayerIdx;
|
||||
int mNumHeads;
|
||||
int mVisionStart;
|
||||
int mVisionLength;
|
||||
int mNumKVHeads;
|
||||
int mHeadSize;
|
||||
int mUnidirectional;
|
||||
@ -236,6 +247,7 @@ protected:
|
||||
float mRotaryEmbeddingBase;
|
||||
tensorrt_llm::kernels::RotaryScalingType mRotaryEmbeddingScaleType;
|
||||
float mRotaryEmbeddingScale;
|
||||
float mRotaryEmbeddingMscale;
|
||||
int mRotaryEmbeddingMaxPositions;
|
||||
tensorrt_llm::kernels::PositionEmbeddingType mPositionEmbeddingType;
|
||||
bool mRemovePadding = false;
|
||||
@ -256,11 +268,11 @@ protected:
|
||||
bool mPagedContextFMHA = false;
|
||||
bool mFP8ContextFMHA = false;
|
||||
bool mDenseContextFMHA = false;
|
||||
bool mIsMedusaEnabled = false;
|
||||
bool mIsSpecDecodingEnabled = false;
|
||||
|
||||
// Medusa packed mask.
|
||||
uint4* mMedusaPackedMask;
|
||||
uint4* mMedusaPackedHostMask;
|
||||
// Speculative decoding packed mask.
|
||||
uint4* mSpecDecodingPackedMask;
|
||||
uint4* mSpecDecodingPackedHostMask;
|
||||
|
||||
// fmha runner (disable by default)
|
||||
// flag: disabled = 0, enabled = 1, enabled with fp32 accumulation = 2
|
||||
@ -270,7 +282,9 @@ protected:
|
||||
int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int mMaxSharedMemoryPerBlockOptin = tensorrt_llm::common::getMaxSharedMemoryPerBlockOptin();
|
||||
// The default copy constructor will leave it as nullptr. clone() shall initialize it.
|
||||
std::shared_ptr<CUDADriverWrapper> mDriver;
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::MHARunner> mFMHARunner;
|
||||
tensorrt_llm::kernels::DecoderXQARunner::Resource mDecoderXQARunnerResource;
|
||||
UniqPtrWNullCopy<tensorrt_llm::kernels::DecoderXQARunner> mDecoderXQARunner;
|
||||
|
||||
bool mMultiBlockMode;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user