mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#2156)
Co-authored-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
This commit is contained in:
parent
32ed92e449
commit
b8fc6633ba
@ -17,11 +17,14 @@ TensorRT-LLM
|
||||
<div align="left">
|
||||
|
||||
## Latest News
|
||||
* [2024/08/13] 🏎️SDXL with #TensorRT Model Optimizer ⏱️⚡ 🏁 cache diffusion 🏁 quantization aware training 🏁 QLoRA 🏁 #Python 3.12
|
||||
[➡️ link](https://developer.nvidia.com/blog/nvidia-tensorrt-model-optimizer-v0-15-boosts-inference-performance-and-expands-model-support/)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/picture-08-20-2024.png" width="40%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/08/13] 🐍 DIY Code Completion with #Mamba ⚡ #TensorRT #LLM for speed 🤖 NIM for ease ☁️ deploy anywhere
|
||||
[➡️ link](https://developer.nvidia.com/blog/revolutionizing-code-completion-with-codestral-mamba-the-next-gen-coding-llm/)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/picture-08-13-2024.png" width="50%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/08/06] 🗫 Multilingual Challenge Accepted 🗫
|
||||
🤖 #TensorRT #LLM boosts low-resource languages like Hebrew, Indonesian and Vietnamese ⚡[➡️ link](https://developer.nvidia.com/blog/accelerating-hebrew-llm-performance-with-nvidia-tensorrt-llm/?linkId=100000278659647)
|
||||
|
||||
@ -13,6 +13,9 @@ TensorRT-LLM provides a packaged benchmarking utility that is accessible via the
|
||||
- [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
|
||||
- [`meta-llama/Meta-Llama-3-70B`](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
|
||||
- [`EleutherAI/gpt-j-6b`](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
- [`mistralai/Mistral-7B-v0.1`](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [`mistralai/Mixtral-8x7B-v0.1`](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
|
||||
|
||||
|
||||
#### Support Quantization Modes
|
||||
|
||||
@ -43,20 +46,20 @@ For this quick start guide, we will focus on running a short max throughput benc
|
||||
of 128:128. In order to run the benchmark from start to finish simply run the following commands:
|
||||
|
||||
```shell
|
||||
python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1400 > /tmp/synthetic_128_128.txt
|
||||
python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 3000 > /tmp/synthetic_128_128.txt
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine-path /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine_dir /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
```
|
||||
|
||||
And that's it! Once the benchmark completes, a summary will be printed with summary metrics.
|
||||
|
||||
```
|
||||
```shell
|
||||
===========================================================
|
||||
= ENGINE DETAILS
|
||||
===========================================================
|
||||
Model: meta-llama/Llama-2-7b-hf
|
||||
Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
TensorRT-LLM Version: 0.12.0.dev2024073000
|
||||
TensorRT-LLM Version: 0.12.0
|
||||
Dtype: float16
|
||||
KV Cache Dtype: FP8
|
||||
Quantization: FP8
|
||||
@ -135,7 +138,7 @@ directory. For example, to generate a synthetic dataset of 1000 requests with a
|
||||
128/128 for [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b), simply run:
|
||||
|
||||
```shell
|
||||
benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > $PATH_TO_DATASET
|
||||
benchmarks/cpp/prepare_dataset.py --stdout --tokenizer meta-llama/Llama-2-7b-hf token-norm-dist --input-mean 128 --output-mean 128 --input-stdev 0 --output-stdev 0 --num-requests 1000 > /tmp/synthetic_128_128.txt
|
||||
```
|
||||
|
||||
You can pipe the above command to a file to reuse the same dataset, or simply pipe its output to the
|
||||
@ -148,16 +151,78 @@ build a pre-configured engine for one of the supported ISL:OSL combinations, you
|
||||
using the dataset you generated with `prepare_dataset.py` to build an FP8 quantized engine:
|
||||
|
||||
```shell
|
||||
trtllm-bench --model $HF_MODEL_NAME build --dataset $PATH_TO_DATASET --quantization FP8
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8
|
||||
```
|
||||
|
||||
or manually set a max sequence length thatL you plan to run with specifically:
|
||||
or manually set a max sequence length that you plan to run with specifically:
|
||||
|
||||
```shell
|
||||
trtllm-bench --model $HF_MODEL_NAME build --max_seq_len $MAX_SEQ_LEN --quantization FP8
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf build --max_seq_len 256 --quantization FP8
|
||||
```
|
||||
|
||||
The engine in this case will be written to the `/tmp/$HF_MODEL_NAME/tp_1_pp_1/` directory.
|
||||
Looking a little closer, the `build` sub-command
|
||||
will perform a lookup and build an engine using those reference settings. The
|
||||
look up table directly corresponds to the performance table found in our
|
||||
[Performance Overview](../docs/source/performance/perf-overview.md#throughput-measurements). The
|
||||
output of the `build` sub-command looks similar to the snippet below (for `meta-llama/Llama-2-7b-hf`):
|
||||
|
||||
```shell
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf build --dataset /tmp/synthetic_128_128.txt --quantization FP8
|
||||
[TensorRT-LLM] TensorRT-LLM version: 0.12.0
|
||||
[08/12/2024-19:13:06] [TRT-LLM] [I] Found dataset.
|
||||
[08/12/2024-19:13:07] [TRT-LLM] [I]
|
||||
===========================================================
|
||||
= DATASET DETAILS
|
||||
===========================================================
|
||||
Max Input Sequence Length: 128
|
||||
Max Output Sequence Length: 128
|
||||
Max Sequence Length: 256
|
||||
Number of Sequences: 3000
|
||||
===========================================================
|
||||
|
||||
|
||||
[08/12/2024-19:13:07] [TRT-LLM] [I] Set multiple_profiles to True.
|
||||
[08/12/2024-19:13:07] [TRT-LLM] [I] Set use_paged_context_fmha to True.
|
||||
[08/12/2024-19:13:07] [TRT-LLM] [I] Set use_fp8_context_fmha to True.
|
||||
[08/12/2024-19:13:07] [TRT-LLM] [I]
|
||||
===========================================================
|
||||
= ENGINE BUILD INFO
|
||||
===========================================================
|
||||
Model Name: meta-llama/Llama-2-7b-hf
|
||||
Workspace Directory: /tmp
|
||||
Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
|
||||
===========================================================
|
||||
= ENGINE CONFIGURATION DETAILS
|
||||
===========================================================
|
||||
Max Sequence Length: 256
|
||||
Max Batch Size: 4096
|
||||
Max Num Tokens: 8192
|
||||
Quantization: FP8
|
||||
===========================================================
|
||||
|
||||
Loading Model: [1/3] Downloading HF model
|
||||
Downloaded model to /data/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9
|
||||
Time: 0.115s
|
||||
Loading Model: [2/3] Loading HF model to memory
|
||||
current rank: 0, tp rank: 0, pp rank: 0
|
||||
Time: 60.786s
|
||||
Loading Model: [3/3] Building TRT-LLM engine
|
||||
Time: 163.331s
|
||||
Loading model done.
|
||||
Total latency: 224.232s
|
||||
[TensorRT-LLM][INFO] Engine version 0.12.0 found in the config file, assuming engine(s) built by new builder API.
|
||||
|
||||
<snip verbose logging>
|
||||
|
||||
[08/12/2024-19:17:09] [TRT-LLM] [I]
|
||||
|
||||
===========================================================
|
||||
ENGINE SAVED: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
===========================================================
|
||||
```
|
||||
|
||||
The engine in this case will be written to `/tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1` (the end of the log).
|
||||
|
||||
### Running a Max Throughput Benchmark
|
||||
|
||||
@ -175,111 +240,49 @@ list of logits. Otherwise, tokenize the `prompt` with as specified by `--model $
|
||||
3. Submit the dataset to the TensorRT-LLM `Executor` API at as fast of a rate as possible (offline mode).
|
||||
4. Wait for all requests to return, compute statistics, then report out results.
|
||||
|
||||
To run the benchmarker, run the following with the engine and dataset generated above:
|
||||
To run the benchmarker, run the following with the [engine](#building-a-benchmark-engine) and
|
||||
[dataset](#preparing-a-dataset) generated above:
|
||||
|
||||
```
|
||||
trtllm-bench --model $HF_MODEL_NAME throughput --dataset $PATH_TO_DATASET --engine_dir /tmp/$HF_MODEL_NAME/tp_1_pp_1/
|
||||
```
|
||||
```shell
|
||||
trtllm-bench --model meta-llama/Llama-2-7b-hf throughput --dataset /tmp/synthetic_128_128.txt --engine_dir /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
[TensorRT-LLM] TensorRT-LLM version: 0.12.0
|
||||
[08/12/2024-19:36:48] [TRT-LLM] [I] Preparing to run throughput benchmark...
|
||||
[08/12/2024-19:36:49] [TRT-LLM] [I] Setting up benchmarker and infrastructure.
|
||||
[08/12/2024-19:36:49] [TRT-LLM] [I] Ready to start benchmark.
|
||||
[08/12/2024-19:36:49] [TRT-LLM] [I] Initializing Executor.
|
||||
[TensorRT-LLM][INFO] Engine version 0.12.0 found in the config file, assuming engine(s) built by new builder API.
|
||||
|
||||
When the benchmark runs, you will see output similar to the following:
|
||||
<snip verbose logging>
|
||||
|
||||
```
|
||||
Preparing to run throughput benchmark...
|
||||
Setting up benchmarker and infrastructure.
|
||||
Initializing Throughput Benchmark. [rate=%d req/s]
|
||||
Ready to start benchmark.
|
||||
Initializing Executor.
|
||||
[TensorRT-LLM][INFO] Engine version 0.12.0.dev2024073000 found in the config file, assuming engine(s) built by new builder API.
|
||||
[TensorRT-LLM][INFO] Initializing MPI with thread mode 3
|
||||
[TensorRT-LLM][INFO] Initialized MPI
|
||||
[TensorRT-LLM][INFO] Engine version 0.12.0.dev2024073000 found in the config file, assuming engine(s) built by new builder API.
|
||||
[TensorRT-LLM][INFO] MPI size: 1, MPI local size: 1, rank: 0
|
||||
[TensorRT-LLM][INFO] Rank 0 is using GPU 0
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxNumSequences: 4096
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxBatchSize: 4096
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxBeamWidth: 1
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxSequenceLen: 4098
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxDraftLen: 0
|
||||
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: 4098
|
||||
[TensorRT-LLM][INFO] TRTGptModel enableTrtOverlap: 0
|
||||
[TensorRT-LLM][INFO] TRTGptModel normalizeLogProbs: 1
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxNumTokens: 8192
|
||||
[TensorRT-LLM][INFO] TRTGptModel maxInputLen: 4097 = maxSequenceLen - 1 since chunked context is enabled
|
||||
[TensorRT-LLM][INFO] Capacity Scheduler Policy: GUARANTEED_NO_EVICT
|
||||
[TensorRT-LLM][INFO] Context Chunking Scheduler Policy: FIRST_COME_FIRST_SERVED
|
||||
[TensorRT-LLM][INFO] Loaded engine size: 6214 MiB
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 928.77 MiB for execution context memory.
|
||||
[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
|
||||
[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
|
||||
[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
|
||||
[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
|
||||
[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
|
||||
[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
|
||||
[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
|
||||
[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
|
||||
[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 1. Please ensure there are no enqueued operations pending in this context prior to switching profiles
|
||||
[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
|
||||
[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
|
||||
[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
|
||||
[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
|
||||
[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 2. Please ensure there are no enqueued operations pending in this context prior to switching profiles
|
||||
[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
|
||||
[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
|
||||
[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
|
||||
[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
|
||||
[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 3. Please ensure there are no enqueued operations pending in this context prior to switching profiles
|
||||
[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
|
||||
[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
|
||||
[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
|
||||
[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
|
||||
[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 4. Please ensure there are no enqueued operations pending in this context prior to switching profiles
|
||||
[TensorRT-LLM][INFO] [MS] Running engine with multi stream info
|
||||
[TensorRT-LLM][INFO] [MS] Number of aux streams is 1
|
||||
[TensorRT-LLM][INFO] [MS] Number of total worker streams is 2
|
||||
[TensorRT-LLM][INFO] [MS] The main stream provided by execute/enqueue calls is the first worker stream
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 6166 (MiB)
|
||||
[TensorRT-LLM][INFO] Switching optimization profile from: 0 to 5. Please ensure there are no enqueued operations pending in this context prior to switching profiles
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 1.14 GB GPU memory for runtime buffers.
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 4.35 GB GPU memory for decoder.
|
||||
[TensorRT-LLM][INFO] Memory usage when calculating max tokens in paged kv cache: total: 79.10 GiB, available: 63.62 GiB
|
||||
[TensorRT-LLM][INFO] Number of blocks in KV cache primary pool: 4607
|
||||
[TensorRT-LLM][INFO] Number of blocks in KV cache secondary pool: 0, onboard blocks to primary memory before reuse: true
|
||||
[TensorRT-LLM][INFO] Max KV cache pages per sequence: 65
|
||||
[TensorRT-LLM][INFO] Number of tokens per block: 64.
|
||||
[TensorRT-LLM][INFO] [MemUsageChange] Allocated 62.99 GiB for max tokens in paged KV cache (294848).
|
||||
[TensorRT-LLM][INFO] Executor instance created by worker
|
||||
Starting response daemon...Executor started.
|
||||
|
||||
Request serving started.
|
||||
Starting statistics collection.
|
||||
Collecting live stats...
|
||||
Benchmark started.
|
||||
Request serving stopped.
|
||||
Collecting last stats...
|
||||
Ending statistics collection.
|
||||
Stop received.
|
||||
Stopping response parsing.
|
||||
Collecting last responses before shutdown.
|
||||
Completed request parsing.
|
||||
Parsing stopped.
|
||||
Request generator successfully joined.
|
||||
Statistics process successfully joined.
|
||||
[08/12/2024-19:36:58] [TRT-LLM] [I] Starting response daemon...
|
||||
[08/12/2024-19:36:58] [TRT-LLM] [I] Executor started.
|
||||
[08/12/2024-19:36:58] [TRT-LLM] [I] Request serving started.
|
||||
[08/12/2024-19:36:58] [TRT-LLM] [I] Starting statistics collection.
|
||||
[08/12/2024-19:36:58] [TRT-LLM] [I] Benchmark started.
|
||||
[08/12/2024-19:36:58] [TRT-LLM] [I] Collecting live stats...
|
||||
[08/12/2024-19:36:59] [TRT-LLM] [I] Request serving stopped.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Collecting last stats...
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Ending statistics collection.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Stop received.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Stopping response parsing.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Collecting last responses before shutdown.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Completed request parsing.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Parsing stopped.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Request generator successfully joined.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I] Statistics process successfully joined.
|
||||
[08/12/2024-19:37:19] [TRT-LLM] [I]
|
||||
===========================================================
|
||||
= ENGINE DETAILS
|
||||
===========================================================
|
||||
Model: meta-llama/Llama-2-7b-hf
|
||||
Engine Directory: /tmp/meta-llama/Llama-2-7b-hf/tp_1_pp_1
|
||||
TensorRT-LLM Version: 0.12.0.dev2024073000
|
||||
TensorRT-LLM Version: 0.12.0
|
||||
Dtype: float16
|
||||
KV Cache Dtype: FP8
|
||||
Quantization: FP8
|
||||
Max Input Length: 2048
|
||||
Max Sequence Length: 4098
|
||||
Max Input Length: 256
|
||||
Max Sequence Length: 256
|
||||
|
||||
===========================================================
|
||||
= WORLD + RUNTIME INFORMATION
|
||||
@ -289,28 +292,36 @@ PP Size: 1
|
||||
Max Runtime Batch Size: 4096
|
||||
Max Runtime Tokens: 8192
|
||||
Scheduling Policy: Guaranteed No Evict
|
||||
KV Memory Percentage: 99.0%
|
||||
Issue Rate (req/sec): 3.680275266452667e+18
|
||||
KV Memory Percentage: 90.0%
|
||||
Issue Rate (req/sec): 2.0827970096792666e+19
|
||||
===========================================================
|
||||
= STATISTICS
|
||||
===========================================================
|
||||
Number of requests: 3000
|
||||
Average Input Length (tokens): 128.0
|
||||
Average Output Length (tokens): 128.0
|
||||
Token Throughput (tokens/sec): 23405.927228471104
|
||||
Request Throughput (req/sec): 182.8588064724305
|
||||
Total Latency (seconds): 16.406100739
|
||||
Token Throughput (tokens/sec): 18886.813971319196
|
||||
Request Throughput (req/sec): 147.55323415093122
|
||||
Total Latency (seconds): 20.331645167
|
||||
===========================================================
|
||||
|
||||
Benchmark Shutdown called!
|
||||
Shutting down ExecutorServer.
|
||||
[TensorRT-LLM][INFO] Orchestrator sendReq thread exiting
|
||||
[TensorRT-LLM][INFO] Orchestrator recv thread exiting
|
||||
Executor shutdown.
|
||||
[TensorRT-LLM][INFO] Leader sendThread exiting
|
||||
[TensorRT-LLM][INFO] Leader recvReq thread exiting
|
||||
[TensorRT-LLM][INFO] Refreshed the MPI local session
|
||||
```
|
||||
|
||||
> [!WARNING] Some statistics are not reported.
|
||||
> There are some statistics that are not reported in the summary (typically as 0.0). These statistics
|
||||
> are not available currently.
|
||||
## Summary
|
||||
|
||||
In summary, the general process for reproducing a benchmark point is as follows:
|
||||
|
||||
- Prepare a dataset: `python benchmarks/cpp/prepare_dataset.py --stdout --tokenizer $HF_MODEL token-norm-dist --input-mean $ISL --output-mean $OSL --input-stdev 0 --output-stdev 0 --num-requests $NUM_REQUESTS > $DATASET_PATH`
|
||||
- Build engine: `trtllm-bench --model $HF_MODEL build --dataset $DATASET_PATH`
|
||||
- Benchmark engine: trtllm-bench --model $HF_MODEL throughput --dataset $DATASET_PATH --engine_dir $ENGINE_DIR`
|
||||
|
||||
where,
|
||||
- `$HF_MODEL` is the Huggingface name of a model.
|
||||
- `$NUM_REQUESTS` is the number of requests to generate.
|
||||
- `$DATASET_PATH` is the path where the dataset was written when preparing the dataset.
|
||||
- `$ENGINE_DIR` the engine directory as printed by `trtllm-bench build`.
|
||||
|
||||
@ -403,6 +403,7 @@ struct BenchInfo
|
||||
float firstTokenLatency{};
|
||||
std::optional<float> avgGenT2TLatency{};
|
||||
bool firstTokenSeen{false};
|
||||
SizeType32 decodingIter{0};
|
||||
};
|
||||
|
||||
class Recorder
|
||||
@ -527,6 +528,7 @@ public:
|
||||
outSeqLen -= inputSeqLen;
|
||||
}
|
||||
mRequestBenchInfos[requestId].outputLength = outSeqLen;
|
||||
mRequestBenchInfos[requestId].decodingIter = response.getResult().decodingIter;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -572,6 +574,7 @@ public:
|
||||
std::vector<float> genT2TLatencies;
|
||||
|
||||
int totalOutputTokens{0};
|
||||
int totalDecodingIter{0};
|
||||
mNumErrorSamples = 0;
|
||||
mNumSamples = 0;
|
||||
for (auto reqInfo : mRequestBenchInfos)
|
||||
@ -580,6 +583,7 @@ public:
|
||||
{
|
||||
reqLatencies.push_back(reqInfo.second.latency);
|
||||
totalOutputTokens += reqInfo.second.outputLength;
|
||||
totalDecodingIter += reqInfo.second.decodingIter;
|
||||
|
||||
if (mStreaming)
|
||||
{
|
||||
@ -601,6 +605,9 @@ public:
|
||||
mTotalLatency = std::chrono::duration<float, std::milli>(mEnd - mStart).count();
|
||||
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
|
||||
mTokenThroughput = totalOutputTokens / (mTotalLatency / 1000);
|
||||
mAcceptanceRate = totalDecodingIter
|
||||
? (static_cast<float>(totalOutputTokens) / static_cast<float>(totalDecodingIter))
|
||||
: 0.0f;
|
||||
|
||||
mAvgSeqLatency = std::accumulate(reqLatencies.begin(), reqLatencies.end(), 0.F) / reqLatencies.size();
|
||||
|
||||
@ -648,7 +655,8 @@ public:
|
||||
printf("\n[BENCHMARK] num_samples %d\n", mNumSamples);
|
||||
printf("[BENCHMARK] total_latency(ms) %.2f\n", mTotalLatency);
|
||||
printf("[BENCHMARK] seq_throughput(seq/sec) %.2f\n", mSeqThroughput);
|
||||
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n\n", mTokenThroughput);
|
||||
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n", mTokenThroughput);
|
||||
printf("[BENCHMARK] avg_acceptance_rate(tokens/decoding steps) %.2f\n\n", mAcceptanceRate);
|
||||
|
||||
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
|
||||
printf("[BENCHMARK] max_sequence_latency(ms) %.2f\n", mMaxSeqLatency);
|
||||
@ -763,6 +771,7 @@ private:
|
||||
float mAvgGenT2TLatency{};
|
||||
float mAvgFtLatency{};
|
||||
float mTokenThroughput{};
|
||||
float mAcceptanceRate{};
|
||||
float mP99SeqLatency{};
|
||||
float mP90SeqLatency{};
|
||||
float mP50SeqLatency{};
|
||||
|
||||
@ -9,6 +9,7 @@ import pandas as pd
|
||||
from transformers import AutoTokenizer, LlamaTokenizerFast
|
||||
|
||||
nltk.download("punkt", quiet=False)
|
||||
nltk.download('punkt_tab')
|
||||
import argparse
|
||||
|
||||
|
||||
@ -25,10 +26,9 @@ ACCURACY_TARGETS = {
|
||||
"tokens_per_sample": 294.45 * 0.9
|
||||
},
|
||||
Model.GPT_J: {
|
||||
"rouge1": 42.9435135,
|
||||
"rouge2": 20.1033765,
|
||||
"rougeL": 29.9581119,
|
||||
# "tokens_per_sample": ??
|
||||
"rouge1": 42.9865 * 0.99,
|
||||
"rouge2": 20.1235 * 0.99,
|
||||
"rougeL": 29.9881 * 0.99,
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,7 +138,6 @@ def main():
|
||||
target_texts = get_reference_df(args.dataset)
|
||||
model = Model.Llama_v2_70B
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(args.base_model)
|
||||
relaxing_factor = 1.0
|
||||
elif args.dataset.lower().endswith(".json"):
|
||||
target_texts = get_reference_json(args.dataset)
|
||||
model = Model.GPT_J
|
||||
@ -147,7 +146,6 @@ def main():
|
||||
padding_side="left",
|
||||
use_fast=False)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
relaxing_factor = 0.93
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Dataset expected to be pkl (open-orca) or json (cnn-dailymail)")
|
||||
@ -169,7 +167,7 @@ def main():
|
||||
print("Targets: ", targets)
|
||||
|
||||
for k, _ in targets.items():
|
||||
assert targets[k] * relaxing_factor <= achieved_scores[k]
|
||||
assert targets[k] <= achieved_scores[k]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -83,7 +83,9 @@ public:
|
||||
bool applyLogitsPostProcessorBatched = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> encoderInputTokens = std::nullopt, bool returnEncoderOutput = false,
|
||||
std::optional<RequestIdType> clientId = std::nullopt,
|
||||
executor::PriorityType priority = executor::Request::kDefaultPriority)
|
||||
executor::PriorityType priority = executor::Request::kDefaultPriority,
|
||||
std::optional<TensorPtr> encoderInputFeatures = std::nullopt,
|
||||
std::optional<SizeType32> encoderOutputLength = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -123,8 +125,10 @@ public:
|
||||
, mDecodingIter(0)
|
||||
, mPriority(priority)
|
||||
, mFinishReasons(samplingConfig.beamWidth)
|
||||
, mEncoderInputFeatures(std::move(encoderInputFeatures))
|
||||
, mEncoderOutputLength(encoderOutputLength)
|
||||
{
|
||||
if (mEncoderTokens.has_value())
|
||||
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
|
||||
{
|
||||
mState = REQUEST_STATE_ENCODER_INIT;
|
||||
}
|
||||
@ -170,6 +174,7 @@ public:
|
||||
, mPriority(req.getPriority())
|
||||
, mFinishReasons(mSamplingConfig.beamWidth)
|
||||
, mContextPhaseParams(req.getContextPhaseParams())
|
||||
, mEncoderOutputLength(req.getEncoderOutputLength())
|
||||
{
|
||||
if (mIsStreaming && mSamplingConfig.beamWidth > 1 && !mReturnAllGeneratedTokens)
|
||||
{
|
||||
@ -189,10 +194,14 @@ public:
|
||||
"since logits are not. Disabling returnGenerationLogits.");
|
||||
mReturnGenerationLogits = false;
|
||||
}
|
||||
if (req.getEncoderInputTokenIds())
|
||||
|
||||
if (req.getEncoderInputTokenIds().has_value() || req.getEncoderInputFeatures().has_value())
|
||||
{
|
||||
mState = REQUEST_STATE_ENCODER_INIT;
|
||||
mEncoderTokens = std::make_shared<VecTokens>(req.getEncoderInputTokenIds().value());
|
||||
if (req.getEncoderInputTokenIds().has_value())
|
||||
{
|
||||
mEncoderTokens = std::make_shared<VecTokens>(req.getEncoderInputTokenIds().value());
|
||||
}
|
||||
}
|
||||
if (req.getEmbeddingBias())
|
||||
{
|
||||
@ -254,14 +263,24 @@ public:
|
||||
// NOTE: Draft acceptance threshold is stored in mSamplingConfig
|
||||
}
|
||||
|
||||
auto const& encoderInputFeatures = req.getEncoderInputFeatures();
|
||||
if (encoderInputFeatures.has_value())
|
||||
{
|
||||
mEncoderInputFeatures = executor::detail::toITensor(encoderInputFeatures.value());
|
||||
}
|
||||
else
|
||||
{
|
||||
mEncoderInputFeatures = std::nullopt;
|
||||
}
|
||||
|
||||
initialize(req.getInputTokenIds(), req.getOutputConfig().returnLogProbs);
|
||||
}
|
||||
|
||||
void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen,
|
||||
std::optional<SizeType32> maxEncoderInputLen = std::nullopt)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!(maxEncoderInputLen.has_value() && getEncoderLen() > maxEncoderInputLen.value()),
|
||||
"Encoder length (%d) exceeds maximum encoder input length (%d).", getEncoderLen(),
|
||||
TLLM_CHECK_WITH_INFO(!(maxEncoderInputLen.has_value() && getEncoderInputLen() > maxEncoderInputLen.value()),
|
||||
"Encoder length (%d) exceeds maximum encoder input length (%d).", getEncoderInputLen(),
|
||||
maxEncoderInputLen.value());
|
||||
|
||||
if (mPromptLen > maxInputLen)
|
||||
@ -383,12 +402,36 @@ public:
|
||||
return mEncoderTokens;
|
||||
}
|
||||
|
||||
/// @brief Get the number of input tokens to encoder
|
||||
/// @return The number of encoder input tokens.
|
||||
[[nodiscard]] SizeType32 getEncoderLen() const
|
||||
/// @brief Get length of encoder input (could be tokens or features length)
|
||||
/// @return An integer.
|
||||
[[nodiscard]] SizeType32 getEncoderInputLen() const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(getEncoderTokens().has_value(), "Encoder tokens are not given");
|
||||
return getEncoderTokens().value()->size();
|
||||
if (mEncoderInputFeatures.has_value())
|
||||
{
|
||||
return getEncoderInputFeatures()->getShape().d[0];
|
||||
}
|
||||
else if (getEncoderTokens().has_value())
|
||||
{
|
||||
return getEncoderTokens().value()->size();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("GenericLlmRequest::getEncoderInputLen - Do not have encoder length!");
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Get length of encoder output. Fall back to encoder input length if not present
|
||||
/// @return An integer.
|
||||
[[nodiscard]] SizeType32 getEncoderOutputLen() const
|
||||
{
|
||||
if (mEncoderOutputLength.has_value())
|
||||
{
|
||||
return mEncoderOutputLength.value();
|
||||
}
|
||||
else
|
||||
{
|
||||
return getEncoderInputLen();
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Get the draft tokens
|
||||
@ -513,7 +556,8 @@ public:
|
||||
}
|
||||
|
||||
// for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
|
||||
mState = mEncoderTokens.has_value() ? REQUEST_STATE_ENCODER_INIT : REQUEST_STATE_CONTEXT_INIT;
|
||||
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? REQUEST_STATE_ENCODER_INIT
|
||||
: REQUEST_STATE_CONTEXT_INIT;
|
||||
mContextCurrentPosition = 0;
|
||||
mContextChunkSize = std::nullopt;
|
||||
mSeqSlot.reset();
|
||||
@ -716,6 +760,11 @@ public:
|
||||
return mEncoderOutputHost;
|
||||
}
|
||||
|
||||
[[nodiscard]] TensorPtr const getEncoderInputFeatures() const
|
||||
{
|
||||
return mEncoderInputFeatures.value_or(nullptr);
|
||||
}
|
||||
|
||||
void setEncoderOutputHost(TensorPtr encoderOutputHost)
|
||||
{
|
||||
mEncoderOutputHost = std::move(encoderOutputHost);
|
||||
@ -724,7 +773,7 @@ public:
|
||||
void allocEncoderOutputHost(SizeType32 encoderHiddenSize, nvinfer1::DataType dataType)
|
||||
{
|
||||
mEncoderOutputHost = runtime::BufferManager::pinned(
|
||||
runtime::ITensor::makeShape({getEncoderLen(), encoderHiddenSize}), dataType);
|
||||
runtime::ITensor::makeShape({getEncoderOutputLen(), encoderHiddenSize}), dataType);
|
||||
}
|
||||
|
||||
[[nodiscard]] TensorPtr const& getEncoderOutput() const noexcept
|
||||
@ -1091,6 +1140,7 @@ public:
|
||||
}
|
||||
|
||||
result.finishReasons = mFinishReasons;
|
||||
result.decodingIter = mDecodingIter;
|
||||
|
||||
// Update position of last sent response
|
||||
setMaxSentTokenLen(maxNbTokens);
|
||||
@ -1196,6 +1246,11 @@ protected:
|
||||
std::vector<executor::FinishReason> mFinishReasons;
|
||||
std::optional<executor::ContextPhaseParams> mContextPhaseParams;
|
||||
|
||||
std::optional<TensorPtr> mEncoderInputFeatures; // Input features of encoder for multimodal models
|
||||
std::optional<SizeType32>
|
||||
mEncoderOutputLength; // For some models like Whisper, encoder output shape cannot be inferred from encoder
|
||||
// input shape due to downsampling. Thus this is needed for setting buffer sizes correctly
|
||||
|
||||
private:
|
||||
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
|
||||
{
|
||||
|
||||
@ -76,6 +76,15 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
// Copy constructor
|
||||
TrtGptModelOptionalParams(TrtGptModelOptionalParams const& other)
|
||||
: TrtGptModelOptionalParams(other.kvCacheConfig, other.enableTrtOverlap, other.deviceIds,
|
||||
other.normalizeLogProbs, other.enableChunkedContext, other.peftCacheManagerConfig, other.decodingConfig,
|
||||
other.gpuWeightsPercent, other.maxBeamWidth, other.maxBatchSize, other.maxNumTokens, other.schedulerConfig,
|
||||
other.extendedRuntimePerfKnobConfig)
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(TrtGptModelOptionalParams const& other) const
|
||||
{
|
||||
return kvCacheConfig == other.kvCacheConfig //
|
||||
|
||||
@ -80,7 +80,7 @@ public:
|
||||
|
||||
void log(std::exception const& ex, Level level = Level::ERROR);
|
||||
|
||||
Level getLevel()
|
||||
Level getLevel() const
|
||||
{
|
||||
return level_;
|
||||
}
|
||||
@ -91,6 +91,11 @@ public:
|
||||
log(INFO, "Set logger level to %s", getLevelName(level));
|
||||
}
|
||||
|
||||
bool isEnabled(Level const level) const
|
||||
{
|
||||
return level_ <= level;
|
||||
}
|
||||
|
||||
private:
|
||||
static auto constexpr kPREFIX = "[TensorRT-LLM]";
|
||||
|
||||
@ -131,7 +136,7 @@ private:
|
||||
template <typename... Args>
|
||||
void Logger::log(Logger::Level level, char const* format, Args const&... args)
|
||||
{
|
||||
if (level_ <= level)
|
||||
if (isEnabled(level))
|
||||
{
|
||||
auto const fmt = getPrefix(level) + format;
|
||||
auto& out = level_ < WARNING ? std::cout : std::cerr;
|
||||
@ -150,7 +155,7 @@ void Logger::log(Logger::Level level, char const* format, Args const&... args)
|
||||
template <typename... Args>
|
||||
void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args)
|
||||
{
|
||||
if (level_ <= level)
|
||||
if (isEnabled(level))
|
||||
{
|
||||
auto const fmt = getPrefix(level, rank) + format;
|
||||
auto& out = level_ < WARNING ? std::cout : std::cerr;
|
||||
@ -166,7 +171,16 @@ void Logger::log(Logger::Level const level, int const rank, char const* format,
|
||||
}
|
||||
}
|
||||
|
||||
#define TLLM_LOG(level, ...) tensorrt_llm::common::Logger::getLogger()->log(level, __VA_ARGS__)
|
||||
#define TLLM_LOG(level, ...) \
|
||||
do \
|
||||
{ \
|
||||
auto* const logger = tensorrt_llm::common::Logger::getLogger(); \
|
||||
if (logger->isEnabled(level)) \
|
||||
{ \
|
||||
logger->log(level, __VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__)
|
||||
#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__)
|
||||
#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__)
|
||||
|
||||
@ -97,7 +97,7 @@ inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDeli
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string vec2str(std::vector<T> vec, char const* delim = kDefaultDelimiter)
|
||||
inline std::string vec2str(std::vector<T> const& vec, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
return arr2str(vec.data(), vec.size(), delim);
|
||||
}
|
||||
|
||||
@ -326,6 +326,9 @@ public:
|
||||
/// @param returnAllGeneratedTokens Indicates whether to return the full beams or just the newly generated tokens
|
||||
/// after every streaming step.
|
||||
/// @param priority Sets the execution priority of this request.
|
||||
/// @param encoderInputFeatures Encoder input features for multimodal models.
|
||||
/// @param encoderOutputLength Encoder output length if encoder input and output have different lengths (due to
|
||||
/// convolution down-sampling, etc.)
|
||||
Request(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming = false,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||
@ -339,7 +342,9 @@ public:
|
||||
std::optional<std::string> logitsPostProcessorName = std::nullopt,
|
||||
std::optional<VecTokens> encoderInputTokenIds = std::nullopt, std::optional<IdType> clientId = std::nullopt,
|
||||
bool returnAllGeneratedTokens = false, PriorityType priority = kDefaultPriority,
|
||||
std::optional<ContextPhaseParams> contextPhaseParams = std::nullopt);
|
||||
std::optional<ContextPhaseParams> contextPhaseParams = std::nullopt,
|
||||
std::optional<Tensor> encoderInputFeatures = std::nullopt,
|
||||
std::optional<SizeType32> encoderOutputLength = std::nullopt);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
@ -370,6 +375,8 @@ public:
|
||||
[[nodiscard]] PriorityType getPriority() const;
|
||||
[[nodiscard]] bool getReturnAllGeneratedTokens() const;
|
||||
[[nodiscard]] std::optional<ContextPhaseParams> const& getContextPhaseParams() const;
|
||||
[[nodiscard]] std::optional<Tensor> getEncoderInputFeatures() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getEncoderOutputLength() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
void setSamplingConfig(SamplingConfig const& config);
|
||||
@ -389,6 +396,8 @@ public:
|
||||
void setPriority(PriorityType priority);
|
||||
void setReturnAllGeneratedTokens(bool returnAllGeneratedTokens);
|
||||
void setContextPhaseParams(ContextPhaseParams contextPhaseParams);
|
||||
void setEncoderInputFeatures(Tensor encoderInputFeatures);
|
||||
void setEncoderOutputLength(SizeType32 encoderOutputLength);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -428,6 +437,9 @@ struct Result
|
||||
|
||||
/// @brief The params of the context phase.
|
||||
std::optional<ContextPhaseParams> contextPhaseParams;
|
||||
|
||||
/// @brief The decoding iterations it takes.
|
||||
SizeType32 decodingIter{0};
|
||||
};
|
||||
|
||||
/// @brief Class that holds either an error or a result
|
||||
@ -728,42 +740,6 @@ private:
|
||||
std::optional<size_t> mHostCacheSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the speculative decoding.
|
||||
// struct LookaheadDecodingConfig
|
||||
//{
|
||||
// LookaheadDecodingConfig(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize);
|
||||
//
|
||||
// explicit LookaheadDecodingConfig()
|
||||
// : LookaheadDecodingConfig(1, 1, 0)
|
||||
// {
|
||||
// }
|
||||
//
|
||||
// bool operator==(LookaheadDecodingConfig const& other) const;
|
||||
// [[nodiscard]] std::tuple<SizeType32 const, SizeType32 const, SizeType32 const> get() const;
|
||||
// [[nodiscard]] SizeType32 getWindowSize() const;
|
||||
// [[nodiscard]] SizeType32 getNgramSize() const;
|
||||
// [[nodiscard]] SizeType32 getVerificationSetSize() const;
|
||||
//
|
||||
// /// @brief return <maxDecodingTokens, maxPathLen, maxDraftTokens, maxDraftPathLen>
|
||||
// std::tuple<SizeType32, SizeType32, SizeType32, SizeType32> calculateSpeculativeResource() const;
|
||||
//
|
||||
// /// @brief return true when `this` can be executed on resources defined by `that`
|
||||
// bool isLE(LookaheadDecodingConfig const& that) const;
|
||||
//
|
||||
// /// @brief return true when the parameter combination is valid.
|
||||
// static bool isLegal(SizeType32 windowSize, SizeType32 ngramSize, SizeType32 verificationSetSize) noexcept;
|
||||
//
|
||||
// private:
|
||||
// friend class Serialization;
|
||||
//
|
||||
// // Number of NGrams in lookahead branch per step.
|
||||
// SizeType32 mWindowSize;
|
||||
// // Number of tokens per NGram.
|
||||
// SizeType32 mNgramSize;
|
||||
// // Number of NGrams in verification branch per step.
|
||||
// SizeType32 mVerificationSetSize;
|
||||
// };
|
||||
|
||||
/// @brief Configuration class for the decoding.
|
||||
class DecodingConfig
|
||||
{
|
||||
|
||||
@ -635,6 +635,16 @@ public:
|
||||
mManageWeightsType = manageWeightType;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string const& getModelName() const noexcept
|
||||
{
|
||||
return mModelName;
|
||||
}
|
||||
|
||||
void setModelName(std::string const& modelName)
|
||||
{
|
||||
mModelName = modelName;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType32 mVocabSize;
|
||||
SizeType32 mNbAttentionLayers;
|
||||
@ -692,6 +702,7 @@ private:
|
||||
nvinfer1::DataType mLogitsDtype;
|
||||
bool mUseShapeInference;
|
||||
ManageWeightsType mManageWeightsType;
|
||||
std::string mModelName;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -60,9 +60,10 @@ public:
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
SizeType32 generatedTokensPerEngineStep;
|
||||
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
|
||||
TensorPtr medusaPaths; // [maxDraftTokens + 1, maxAcceptedDraftTokensPerStep + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxDraftTokens + 1], on gpu
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadRuntimeConfig;
|
||||
nvinfer1::DataType dtype; // request data type, only used by explicit draft tokens.
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime::decoder_batch
|
||||
|
||||
@ -454,7 +454,7 @@ public:
|
||||
auto const gated_inter = mInterSize * mGatedMultiplier;
|
||||
|
||||
size_t workspace_size = mMoERunner.getWorkspaceSize(
|
||||
mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, {}, mUseLora);
|
||||
mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, mNormMode, {}, mUseLora);
|
||||
|
||||
mWorkspace = allocBuffer<char>(workspace_size);
|
||||
size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize;
|
||||
@ -640,8 +640,8 @@ public:
|
||||
auto stream = streamPtr->get();
|
||||
mMoERunner.runMoe(mInputTensor, mInputProbabilities, mExpertWeight1, mExpertBias1, mActType, mExpertWeight2,
|
||||
mExpertBias2, mQuantParams, mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace,
|
||||
mFinalOutput, nullptr, mTotalTokens, mScaleProbs, mSourceToExpandedMap, mSelectedExpert, parallelism_config,
|
||||
mNormMode, mUseLora, mLoraParams, stream);
|
||||
mFinalOutput, nullptr, mTotalTokens, mScaleProbs, mSourceToExpandedMap, mSelectedExpert, 0.01,
|
||||
parallelism_config, mNormMode, mUseLora, mLoraParams, stream);
|
||||
}
|
||||
|
||||
void runBenchmark(benchmark::State& state);
|
||||
|
||||
@ -612,7 +612,7 @@ void help()
|
||||
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
|
||||
"- \"act_fn\" - The enum value of the activation function. See\n"
|
||||
"\"cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h\"\n"
|
||||
"- \"norm_mode\" - The normalization mode. 0 = NONE, 1 = RENORM. See\n"
|
||||
"- \"norm_mode\" - The normalization mode. 0 = NONE, 1 = RENORM, 2 = SPARSE_MIXER. See\n"
|
||||
"\"cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h\"\n"
|
||||
"- \"tactic_id, tactic_id1, tactic_id2\"\n"
|
||||
"The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in "
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:460b75a97c0de65941839ccd5e0458cf5929574b9345b3cb723a695ae5a056e0
|
||||
size 4404838
|
||||
oid sha256:b7500141fc49186966fe8b7b729c7d154249d401fcd86076c65fc297e718b673
|
||||
size 4436976
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:645bbbad2c38b573df7c6e56588a6728d356a58444ac7c2f881d773faaca7593
|
||||
size 4516944
|
||||
oid sha256:b751f275e2c2968e04f7055bba8c26c9ca8baff727127b5a4848d765b44b9088
|
||||
size 4547626
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
a348613d480961aa14d4e77939be8a34 libtensorrt_llm_batch_manager_static.a
|
||||
317ec85caec48184c9c8b9cbd3eb44b1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
49402939d007b39393cabaa8fe96c110d16f5b35 commit
|
||||
a043d2afdd14a0c111202a26b3059141 libtensorrt_llm_batch_manager_static.a
|
||||
45a299004e39be51aa1830db49e26428 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
d461cf4c2e17afd6a88d2a63e85823dec7a3aab1 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a785c4459bdb4a7dad9df0c832211f26f699a331ce0b2b9516e7a666f83b895a
|
||||
size 4272894
|
||||
oid sha256:3238d25b58d025c38cff2a74a81c4dc746bfd0bcb479f5c53e3d38c8224b4039
|
||||
size 4299590
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9a6b98589222f8bf8e82f122110cb1824b1728646bad45a41e8b9ada632539dc
|
||||
size 4248190
|
||||
oid sha256:225652665f121266a95c8ff531407ddbb6e0a328622657274cd001b516ee1bbb
|
||||
size 4268298
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7daa6c306a2fb738bbe8b3d30324691c83d59aa933c79b0e48342976edb4e356
|
||||
size 25540884
|
||||
oid sha256:8c9bf8fde3bb12fcfa4830954bf1aefc0bd215b8a5a645b8dc0a713b8688c95c
|
||||
size 25712734
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f75b47945f8bb945a7086a0bcde038490ebfd2fbb406dfa0f3391f262cfac365
|
||||
size 1529360
|
||||
oid sha256:2180d92388c90b9ce4b357fd71dd97ba29c408a433f0251049f7ce1fc0bcc1ff
|
||||
size 1541172
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0790f83b79f8ff2a2313d238bdd409d8f082d92edf8e22e6dc75f6f5dfa8327d
|
||||
size 1553716
|
||||
oid sha256:bcafe4aac60dcbc4577d629e8736478577fe4d31894f875a17210a138a53e1b4
|
||||
size 1570164
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
bf15d213c14dcbe75d2116945bd24c82 libtensorrt_llm_executor_static.a
|
||||
492e0b37b7f004c5b7a7c46d079f354d libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
49402939d007b39393cabaa8fe96c110d16f5b35 commit
|
||||
e2cff2c22c9f06cf5fc380e46c3db75a libtensorrt_llm_executor_static.a
|
||||
b8c421b53aac0224619cf34de5d29057 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
d461cf4c2e17afd6a88d2a63e85823dec7a3aab1 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7f5fed27f812506b319a1275a6f00b71e3b8e3c0a8a2f71370b7c4673820306f
|
||||
size 1588916
|
||||
oid sha256:b3aac4dd334f191f39faad22dc2cf08efceb7976a564541bf10256f5753e41a9
|
||||
size 1602826
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c9e712e014960458ae1fbda43fcb882eb98f04f00c9e95afce2d881b29d2c5cf
|
||||
size 1517700
|
||||
oid sha256:6038eda2983dd0462b1207bc7fc190e4783b83e3df799df673962bc73e698522
|
||||
size 1530498
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e433610d288aa1533fd36c467fd67929fefec68043e486f45dd3a774a55667cd
|
||||
size 16515186
|
||||
oid sha256:9a21866ea77defc901454e49be3b35e9ccfc03cdaaa658a6d03bbea55b00e466
|
||||
size 16586574
|
||||
|
||||
@ -200,7 +200,15 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams)
|
||||
mKernelParams.cu_q_seqlens = reinterpret_cast<int const*>(runnerParams.cuQSeqLenPtr);
|
||||
mKernelParams.tile_id_counter_ptr = reinterpret_cast<uint32_t*>(runnerParams.tileCounterPtr);
|
||||
// TRT doesn't support host scales. Use device scales instead.
|
||||
mKernelParams.scale_bmm2_d = reinterpret_cast<uint32_t const*>(runnerParams.scaleBmm2Ptr);
|
||||
// The scaleBmm1Ptr offset.
|
||||
// 2 scales prepared for scaleBmm1 in the device memory: float scale, float (scale with log2e).
|
||||
int64_t scaleBmm1PtrOffset = (mLaunchParams.useBase2ExpTrick ? 1 : 0);
|
||||
// Only fp8 kernels need to load scales from the device memory.
|
||||
if (mFixedParams.dataType == DATA_TYPE_E4M3)
|
||||
{
|
||||
mKernelParams.scale_bmm1_d = reinterpret_cast<uint32_t const*>(runnerParams.scaleBmm1Ptr + scaleBmm1PtrOffset);
|
||||
mKernelParams.scale_bmm2_d = reinterpret_cast<uint32_t const*>(runnerParams.scaleBmm2Ptr);
|
||||
}
|
||||
|
||||
// Separate q and kv buffers may have different q and kv sequence lengths.
|
||||
if (mFixedParams.attentionInputLayout != AttentionInputLayout::PACKED_QKV)
|
||||
|
||||
@ -200,9 +200,9 @@ struct MHARunnerParams
|
||||
void const* cuMaskRowsPtr;
|
||||
// The dynamic scheduler tile counter.
|
||||
void* tileCounterPtr;
|
||||
// The bmm1 scale device ptr.
|
||||
uint32_t const* scaleBmm1Ptr;
|
||||
// The bmm2 scale device ptr.
|
||||
// The bmm1 scale device ptr (only used by fp8 kernels).
|
||||
float const* scaleBmm1Ptr;
|
||||
// The bmm2 scale device ptr (only used by fp8 kernels).
|
||||
float const* scaleBmm2Ptr;
|
||||
// The cuda stream.
|
||||
cudaStream_t stream;
|
||||
|
||||
@ -208,7 +208,6 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
decoder_params.seqKVLengths = xqaParams.sequence_lengths;
|
||||
decoder_params.batchSize = int(batch_beam_size);
|
||||
decoder_params.maxQSeqLength = xqaParams.generation_input_length;
|
||||
decoder_params.removePadding = xqaParams.multi_query_tokens;
|
||||
TLLM_CHECK_WITH_INFO(!xqaParams.multi_query_tokens || xqaParams.spec_decoding_generation_lengths != nullptr,
|
||||
"Spec_decoding_generation_lengths must be provided.");
|
||||
// Rotary embedding inv_freq buffer.
|
||||
@ -233,12 +232,12 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
(float2 const*) nullptr, xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets,
|
||||
int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep,
|
||||
xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads,
|
||||
xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size,
|
||||
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
|
||||
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
|
||||
xqaParams.position_shift_enabled, cache_type, true, false, multiprocessor_count, xqaParams.rotary_vision_start,
|
||||
xqaParams.rotary_vision_length};
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
|
||||
/*remove_padding*/ true, xqaParams.num_q_heads, xqaParams.num_kv_heads,
|
||||
xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
|
||||
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type, xqaParams.position_shift_enabled,
|
||||
cache_type, true, false, multiprocessor_count, xqaParams.rotary_vision_start, xqaParams.rotary_vision_length};
|
||||
|
||||
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:749eb5de849e31263fcb1c4a3982d6a852fd6c7002fbfa2cecf26718cc2ec8ae
|
||||
size 81591456
|
||||
oid sha256:38b241619f08ee636eee1d00a91d2fa2fc8a70f4afe1e12d01b180e6adeef7aa
|
||||
size 81578928
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
47b5d2e14616709b1dfb86b16213308e libtensorrt_llm_nvrtc_wrapper.so
|
||||
49402939d007b39393cabaa8fe96c110d16f5b35 commit
|
||||
4a08f099886e0595057a20115658be51 libtensorrt_llm_nvrtc_wrapper.so
|
||||
d461cf4c2e17afd6a88d2a63e85823dec7a3aab1 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fe774a929c614dcd0e99d4e8b721645f4d4c0729e7897c3fa87776973591ff3b
|
||||
size 84835264
|
||||
oid sha256:28e39c448442c950d41012ad964057d84c8afc51aa116bbee17ccacd76b43e9f
|
||||
size 84839528
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6b16e47ce5d366249f54bf1c5edb46841efa84be58de97d78539ac0ba4fc710b
|
||||
size 1127936
|
||||
oid sha256:1f650bc6f1c5c12d929948c12f3e67119db06e9289e6d58c76ff55b3c59db306
|
||||
size 1129472
|
||||
|
||||
@ -184,7 +184,6 @@ public:
|
||||
decoder_params.seqKVLengths = xqaParams.sequence_lengths;
|
||||
decoder_params.batchSize = int(batch_beam_size);
|
||||
decoder_params.maxQSeqLength = xqaParams.generation_input_length;
|
||||
decoder_params.removePadding = xqaParams.multi_query_tokens;
|
||||
TLLM_CHECK_WITH_INFO(!xqaParams.multi_query_tokens || xqaParams.spec_decoding_generation_lengths != nullptr,
|
||||
"Spec_decoding_generation_lengths must be provided.");
|
||||
// Rotary embedding inv_freq buffer.
|
||||
@ -209,12 +208,13 @@ public:
|
||||
(float2 const*) nullptr, xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets,
|
||||
int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep,
|
||||
xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length), xqaParams.num_q_heads,
|
||||
xqaParams.num_kv_heads, xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size,
|
||||
xqaParams.rotary_embedding_dim, xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type,
|
||||
xqaParams.rotary_embedding_scale, xqaParams.rotary_embedding_max_positions,
|
||||
xqaParams.position_embedding_type, xqaParams.position_shift_enabled, cache_type, true, false,
|
||||
multiprocessor_count, xqaParams.rotary_vision_start, xqaParams.rotary_vision_length};
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
|
||||
/*remove_padding*/ true, xqaParams.num_q_heads, xqaParams.num_kv_heads,
|
||||
xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
|
||||
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
|
||||
xqaParams.position_shift_enabled, cache_type, true, false, multiprocessor_count,
|
||||
xqaParams.rotary_vision_start, xqaParams.rotary_vision_length};
|
||||
|
||||
invokeQKVPreprocessing<T, KVCacheBuffer>(preprocessingParms, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -52,7 +52,7 @@ __global__ void curandBatchInitialize(
|
||||
SizeType32 const bid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (bid < size)
|
||||
{
|
||||
auto const batchSlot = batchSlots[bid];
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[bid] : bid;
|
||||
curand_init(randomSeeds[bid], 0, 0, &states[batchSlot]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
|
||||
// Store the result.
|
||||
if (batchIdx <= batchSizeBound && storeSeqOffsets)
|
||||
{
|
||||
params.seqQOffsets[batchIdx] = params.removePadding ? seqQOffset : batchIdx * params.maxQSeqLength;
|
||||
params.seqQOffsets[batchIdx] = seqQOffset;
|
||||
if (calculate_packed_mask_row_offsets)
|
||||
{
|
||||
params.packedMaskRowOffsets[batchIdx] = packedMaskRowOffset;
|
||||
@ -228,10 +228,28 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
|
||||
}
|
||||
}
|
||||
|
||||
// Reset fmha tile counter to 0 before launching fmha kernels.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0 && params.fmhaTileCounter != nullptr)
|
||||
// Perpare values for fmha.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
params.fmhaTileCounter[0] = 0u;
|
||||
// Reset fmha tile counter to 0 before launching fmha kernels.
|
||||
if (params.fmhaTileCounter)
|
||||
{
|
||||
params.fmhaTileCounter[0] = 0u;
|
||||
}
|
||||
// Take the quantization scales into consideration.
|
||||
if (params.fmhaBmm1Scale)
|
||||
{
|
||||
// The scale after fmha bmm1.
|
||||
params.fmhaBmm1Scale[0] = params.dequantScaleQkv[0] * params.dequantScaleQkv[0] * params.fmhaHostBmm1Scale;
|
||||
// The scale prepared for log2 optimization.
|
||||
constexpr float kLog2e = 1.4426950408889634074f;
|
||||
params.fmhaBmm1Scale[1] = params.fmhaBmm1Scale[0] * kLog2e;
|
||||
}
|
||||
if (params.fmhaBmm2Scale)
|
||||
{
|
||||
// The scale after fmha bmm2.
|
||||
params.fmhaBmm2Scale[0] = params.quantScaleO[0] * params.dequantScaleQkv[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -122,14 +122,24 @@ struct BuildDecoderInfoParams
|
||||
// The fmha tile counter ptr (set to 0 before fmha).
|
||||
uint32_t* fmhaTileCounter;
|
||||
|
||||
// Scales for fmha only.
|
||||
// The scale to dequant Qkv input.
|
||||
float const* dequantScaleQkv;
|
||||
// The scale to quant O output.
|
||||
float const* quantScaleO;
|
||||
// The fmha bmm1 host scale (1.0f / sqrt(headSize) by default).
|
||||
float fmhaHostBmm1Scale;
|
||||
// The scale after fmha bmm1.
|
||||
float* fmhaBmm1Scale;
|
||||
// The scale after fmha bmm2.
|
||||
float* fmhaBmm2Scale;
|
||||
|
||||
// The number of sequences in the batch.
|
||||
int batchSize;
|
||||
// The maximum query length of a sequence for Decoder (max_input_length), N for ctx phase, 1 for gen phase.
|
||||
int maxQSeqLength;
|
||||
// The maximum query length of a sequence for Encoder, for cross attention (cross_qkv_length).
|
||||
int maxEncoderQSeqLength;
|
||||
// Whether remove the input padding or not.
|
||||
bool removePadding;
|
||||
// The kv cache capacity.
|
||||
// We will apply the limited_length_causal mask when there are not enough kv cache.
|
||||
int attentionWindowSize;
|
||||
@ -184,7 +194,6 @@ struct BuildDecoderInfoParams
|
||||
ss << "batchSize: " << batchSize << std::endl;
|
||||
ss << "maxQSeqLength: " << maxQSeqLength << std::endl;
|
||||
ss << "maxEncoderQSeqLength: " << maxEncoderQSeqLength << std::endl;
|
||||
ss << "removePadding: " << std::boolalpha << removePadding << std::endl;
|
||||
ss << "attentionWindowSize: " << attentionWindowSize << std::endl;
|
||||
ss << "sinkTokenLength: " << sinkTokenLength << std::endl;
|
||||
ss << "numTokens: " << numTokens << std::endl;
|
||||
|
||||
@ -132,8 +132,8 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax, bool const* finished, float* output,
|
||||
int* indices, int* source_rows, int const num_experts, int const k, int const start_expert, int const end_expert,
|
||||
MOEExpertScaleNormalizationMode renorm_mode)
|
||||
int* indices, int* source_rows, int const num_experts, int const k, int const startk, int const endk,
|
||||
int const start_expert, int const end_expert, MOEExpertScaleNormalizationMode norm_mode)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
@ -149,7 +149,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
float renorm_value = 0.0f;
|
||||
bool const row_is_active = finished ? !finished[block_row] : true;
|
||||
int64_t const thread_read_offset = blockIdx.x * num_experts;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
for (int k_idx = startk; k_idx < endk; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
@ -161,7 +161,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||
for (int prior_k = startk; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
int const prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
@ -184,11 +184,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
|
||||
int64_t const idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : (num_experts + expert);
|
||||
assert(indices[idx] >= 0);
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
if (renorm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE)
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE)
|
||||
{
|
||||
renorm_value += result_kvp.value;
|
||||
}
|
||||
@ -196,8 +196,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (renorm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE && threadIdx.x == 0 && renorm_value != 0.f)
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE && threadIdx.x == 0 && renorm_value != 0.f)
|
||||
{
|
||||
assert(startk == 0 && endk == k);
|
||||
renorm_value = 1 / renorm_value;
|
||||
for (int k_idx = 0; k_idx < k; k_idx++)
|
||||
{
|
||||
@ -223,8 +224,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(float const* input, bool const* finished,
|
||||
float* output, int64_t const num_rows, int* indices, int* source_rows, int const k, int const start_expert,
|
||||
int const end_expert, MOEExpertScaleNormalizationMode renorm_mode)
|
||||
float* output, int64_t const num_rows, int* indices, int* source_rows, int const k, int const startk,
|
||||
int const endk, int const start_expert, int const end_expert, MOEExpertScaleNormalizationMode norm_mode)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
@ -349,7 +350,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(fl
|
||||
|
||||
float renorm_value = 0.0f;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
for (int k_idx = startk; k_idx < endk; ++k_idx)
|
||||
{
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
@ -400,18 +401,18 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(fl
|
||||
// single) thread per row of the input/output matrices.
|
||||
int64_t const idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : (NUM_EXPERTS + expert);
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
|
||||
// Accumulate renorm scalar
|
||||
if (renorm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE)
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE)
|
||||
{
|
||||
renorm_value += max_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k)
|
||||
if (k_idx + 1 < endk)
|
||||
{
|
||||
int const ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
int const thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
@ -426,8 +427,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(fl
|
||||
}
|
||||
}
|
||||
|
||||
if (renorm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE && thread_group_idx == 0 && renorm_value != 0.f)
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::RENORMALIZE && thread_group_idx == 0 && renorm_value != 0.f)
|
||||
{
|
||||
assert(startk == 0 && endk == k);
|
||||
renorm_value = 1 / renorm_value;
|
||||
for (int k_idx = 0; k_idx < k; k_idx++)
|
||||
{
|
||||
@ -454,8 +456,8 @@ struct TopkConstants
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(float const* input, bool const* finished, float* output, int* indices,
|
||||
int* source_row, int64_t const num_rows, int const k, int const start_expert, int const end_expert,
|
||||
MOEExpertScaleNormalizationMode renorm_mode, cudaStream_t stream)
|
||||
int* source_row, int64_t const num_rows, int const k, int const startk, int const endk, int const start_expert,
|
||||
int const end_expert, MOEExpertScaleNormalizationMode norm_mode, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
@ -468,13 +470,12 @@ void topkGatingSoftmaxLauncherHelper(float const* input, bool const* finished, f
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renorm_mode);
|
||||
input, finished, output, num_rows, indices, source_row, k, startk, endk, start_expert, end_expert, norm_mode);
|
||||
}
|
||||
|
||||
void topkGatingSoftmaxKernelLauncher(float const* input, bool const* finished, float* output,
|
||||
float* softmax_temp_output, int* indices, int* source_row, int64_t const num_rows, int const num_experts,
|
||||
int const k, int const start_expert, int const end_expert, MOEExpertScaleNormalizationMode renorm_mode,
|
||||
cudaStream_t stream)
|
||||
void topkGatingSoftmaxKernelLauncher(float const* input, float* output, float* softmax_temp_output, int* indices,
|
||||
int* source_row, int64_t const num_rows, int const num_experts, int const k, int const startk, int const endk,
|
||||
int const start_expert, int const end_expert, MOEExpertScaleNormalizationMode norm_mode, cudaStream_t stream)
|
||||
{
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
|
||||
@ -482,69 +483,143 @@ void topkGatingSoftmaxKernelLauncher(float const* input, bool const* finished, f
|
||||
{
|
||||
case 1:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<1, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<1, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<2, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<2, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<4, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<4, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 8:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<8, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<8, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 16:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<16, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<16, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 32:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<32, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<32, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 64:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<64, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<64, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 128:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<128, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<128, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
case 256:
|
||||
{
|
||||
topkGatingSoftmaxLauncherHelper<256, WARPS_PER_TB>(
|
||||
input, finished, output, indices, source_row, num_rows, k, start_expert, end_expert, renorm_mode, stream);
|
||||
topkGatingSoftmaxLauncherHelper<256, WARPS_PER_TB>(input, nullptr, output, indices, source_row, num_rows, k,
|
||||
startk, endk, start_expert, end_expert, norm_mode, stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
static constexpr int TPB = 256;
|
||||
TLLM_CHECK(softmax_temp_output != nullptr);
|
||||
moeSoftmax<TPB><<<num_rows, TPB, 0, stream>>>(input, finished, softmax_temp_output, num_experts);
|
||||
moeTopK<TPB><<<num_rows, TPB, 0, stream>>>(softmax_temp_output, finished, output, indices, source_row,
|
||||
num_experts, k, start_expert, end_expert, renorm_mode);
|
||||
moeSoftmax<TPB><<<num_rows, TPB, 0, stream>>>(input, nullptr, softmax_temp_output, num_experts);
|
||||
moeTopK<TPB><<<num_rows, TPB, 0, stream>>>(softmax_temp_output, nullptr, output, indices, source_row,
|
||||
num_experts, k, startk, endk, start_expert, end_expert, norm_mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void sparseMixerMask(float const* input, float* output, int const* indices, int k_idx, int k, int num_tokens,
|
||||
int num_experts, int start_expert, float epsilon)
|
||||
{
|
||||
int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (token_idx >= num_tokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Mask out the largest value selected in the previous iteration
|
||||
int last_selected = (k_idx > 0) ? indices[k * token_idx + (k_idx - 1)] : INT_MIN;
|
||||
// Adjust the selected index to correct for the expert parallel transformation
|
||||
last_selected = last_selected >= num_experts ? last_selected - num_experts : last_selected + start_expert;
|
||||
|
||||
// Find the max value in the current row
|
||||
float max_val = -INFINITY;
|
||||
for (int i = 0; i < num_experts; ++i)
|
||||
{
|
||||
if (i != last_selected)
|
||||
{
|
||||
float const val = input[token_idx * num_experts + i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Mask out any values that fail the condition '(max - value) / std::max(abs(value), max) > 2 * epsilon'
|
||||
for (int i = 0; i < num_experts; ++i)
|
||||
{
|
||||
float val = input[token_idx * num_experts + i];
|
||||
float mask = (max_val - val) / max(abs(val), max_val);
|
||||
bool mask_value = (mask > 2 * epsilon) || i == last_selected;
|
||||
output[token_idx * num_experts + i] = mask_value ? -INFINITY : val;
|
||||
}
|
||||
}
|
||||
|
||||
void sparseMixerTopkSoftmax(float const* input, float* output, float* mixer_temp_output, float* softmax_temp_output,
|
||||
int* indices, int* source_row, int64_t const num_rows, int const num_experts, int const k, int const start_expert,
|
||||
int const end_expert, float epsilon, cudaStream_t stream)
|
||||
{
|
||||
// TODO we need to update the sparseMixerMask() function to mask all previous experts instead of just the most
|
||||
// recent one.
|
||||
TLLM_CHECK_WITH_INFO(k <= 2, "Current sparse mixer only supports k <= 2");
|
||||
|
||||
// Each thread handles one token
|
||||
constexpr int threads_per_block = 256;
|
||||
int num_blocks = ceilDiv(num_rows, threads_per_block);
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// Run softmax and topk in serial for each selection, recalculating the mask for each step
|
||||
sparseMixerMask<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
input, mixer_temp_output, indices, k_idx, k, num_rows, num_experts, start_expert, epsilon);
|
||||
|
||||
topkGatingSoftmaxKernelLauncher(mixer_temp_output, output, softmax_temp_output, indices, source_row, num_rows,
|
||||
num_experts, k, k_idx, k_idx + 1, start_expert, end_expert, MOEExpertScaleNormalizationMode::NONE, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void selectExpertsForTokens(float const* input, float* output, float* mixer_temp_output, float* softmax_temp_output,
|
||||
int* indices, int* source_row, int64_t const num_rows, int const num_experts, int const k, int const start_expert,
|
||||
int const end_expert, float mixer_epsilon, MOEExpertScaleNormalizationMode norm_mode, cudaStream_t stream)
|
||||
{
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::SPARSE_MIXER)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mixer_temp_output, "Sparse mixer output is null when running sparse mixer");
|
||||
sparseMixerTopkSoftmax(input, output, mixer_temp_output, softmax_temp_output, indices, source_row, num_rows,
|
||||
num_experts, k, start_expert, end_expert, mixer_epsilon, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
topkGatingSoftmaxKernelLauncher(input, output, softmax_temp_output, indices, source_row, num_rows, num_experts,
|
||||
k, 0, k, start_expert, end_expert, norm_mode, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================== CUB Sorting things ====================================
|
||||
CubKeyValueSorter::CubKeyValueSorter()
|
||||
: num_experts_(0)
|
||||
@ -552,21 +627,28 @@ CubKeyValueSorter::CubKeyValueSorter()
|
||||
{
|
||||
}
|
||||
|
||||
int CubKeyValueSorter::expertsToBits(int num_experts)
|
||||
{
|
||||
// Max value we represent is V = num_experts + (num_experts - 1) = 2 * num_experts - 1
|
||||
// The maximum number of bits is therefore floor(log2(V)) + 1
|
||||
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
|
||||
}
|
||||
|
||||
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
|
||||
: num_experts_(num_experts)
|
||||
, num_bits_((int) log2(num_experts) + 1)
|
||||
, num_bits_(expertsToBits(num_experts))
|
||||
{
|
||||
}
|
||||
|
||||
void CubKeyValueSorter::updateNumExperts(int const num_experts)
|
||||
{
|
||||
num_experts_ = num_experts;
|
||||
num_bits_ = (int) log2(num_experts) + 1;
|
||||
num_bits_ = expertsToBits(num_experts);
|
||||
}
|
||||
|
||||
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts)
|
||||
{
|
||||
int num_bits = static_cast<int>(log2(num_experts)) + 1;
|
||||
int num_bits = expertsToBits(num_experts);
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
cub::DeviceRadixSort::SortPairs(
|
||||
@ -1282,7 +1364,8 @@ void loraReorder(T* output, T const* lora_result, int64_t const* num_valid_token
|
||||
template <class T, class WeightType, class OutputType, class ScaleBiasType, class Enable>
|
||||
std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::getWorkspaceBufferSizes(
|
||||
int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts,
|
||||
int const num_experts_per_node, int const k, ActivationType activation_type, bool use_lora) const
|
||||
int const num_experts_per_node, int const k, ActivationType activation_type,
|
||||
MOEExpertScaleNormalizationMode norm_mode, bool use_lora) const
|
||||
{
|
||||
size_t const num_moe_inputs = k * num_rows;
|
||||
size_t const permuted_elems = num_moe_inputs * hidden_size;
|
||||
@ -1299,12 +1382,18 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType,
|
||||
// We need to have separate memory for these as we can no longer alias the output buffer for reuse
|
||||
glu_inter_elems = interbuf_elems;
|
||||
}
|
||||
size_t num_softmax_outs = 0;
|
||||
|
||||
bool using_hopper = moe_gemm_runner_.supportsHopperSpecialisation();
|
||||
|
||||
size_t const gemm_output_dtype = sizeof(UnfusedGemmOutputType);
|
||||
|
||||
size_t sparse_mixer_outs = 0;
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::SPARSE_MIXER)
|
||||
{
|
||||
sparse_mixer_outs = num_rows * num_experts;
|
||||
}
|
||||
|
||||
size_t num_softmax_outs = 0;
|
||||
bool const is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256)
|
||||
{
|
||||
@ -1316,6 +1405,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType,
|
||||
size_t const permuted_experts_size = num_moe_inputs * sizeof(int);
|
||||
size_t const permuted_data_size = permuted_elems * sizeof(T);
|
||||
size_t const expert_first_token_offset_size = (num_experts_per_node + 1) * sizeof(int64_t);
|
||||
size_t const sparse_mixer_out_size = sparse_mixer_outs * sizeof(float);
|
||||
size_t const softmax_out_size = num_softmax_outs * sizeof(float);
|
||||
size_t const permuted_scales_size = mayHaveFinalizeFused() ? num_moe_inputs * sizeof(float) : 0;
|
||||
size_t const glu_inter_size = glu_inter_elems * gemm_output_dtype; // May be an intermediate type for quantization
|
||||
@ -1359,6 +1449,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType,
|
||||
permuted_rows_size, //
|
||||
permuted_experts_size, //
|
||||
expert_first_token_offset_size, //
|
||||
sparse_mixer_out_size, //
|
||||
softmax_out_size, //
|
||||
permuted_scales_size, //
|
||||
sorter_size, //
|
||||
@ -1377,12 +1468,13 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType,
|
||||
template <class T, class WeightType, class OutputType, class ScaleBiasType, class Enable>
|
||||
size_t CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::getWorkspaceSize(int64_t const num_rows,
|
||||
int64_t const hidden_size, int64_t const inter_size, int const num_experts, int const k,
|
||||
ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora) const
|
||||
ActivationType activation_type, MOEExpertScaleNormalizationMode norm_mode, MOEParallelismConfig parallelism_config,
|
||||
bool use_lora) const
|
||||
{
|
||||
int const ep_size = parallelism_config.ep_size;
|
||||
TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of ep size");
|
||||
auto workspace = getWorkspaceBufferSizes(
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type, use_lora);
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type, norm_mode, use_lora);
|
||||
auto ws_size = tensorrt_llm::common::calculateTotalWorkspaceSize(workspace.data(), workspace.size());
|
||||
TLLM_LOG_DEBUG("Mixture Of Experts Plugin requires workspace of %2f MiB", ws_size / 1024.f / 1024.f);
|
||||
return ws_size;
|
||||
@ -1391,11 +1483,12 @@ size_t CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::get
|
||||
template <class T, class WeightType, class OutputType, class ScaleBiasType, class Enable>
|
||||
void CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::configureWsPtrs(char* ws_ptr,
|
||||
int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts,
|
||||
int const num_experts_per_node, int const k, ActivationType activation_type, bool use_lora)
|
||||
int const num_experts_per_node, int const k, ActivationType activation_type,
|
||||
MOEExpertScaleNormalizationMode norm_mode, bool use_lora)
|
||||
|
||||
{
|
||||
auto ws_sizes = getWorkspaceBufferSizes(
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type, use_lora);
|
||||
num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, activation_type, norm_mode, use_lora);
|
||||
|
||||
std::vector<int8_t*> ws_sliced{(int8_t*) ws_ptr};
|
||||
for (auto size : ws_sizes)
|
||||
@ -1410,20 +1503,26 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::confi
|
||||
|
||||
expert_first_token_offset_ = (int64_t*) ws_sliced[3];
|
||||
|
||||
sparse_mixer_out_ = nullptr;
|
||||
if (norm_mode == MOEExpertScaleNormalizationMode::SPARSE_MIXER)
|
||||
{
|
||||
sparse_mixer_out_ = (float*) ws_sliced[4];
|
||||
}
|
||||
|
||||
softmax_out_ = nullptr;
|
||||
bool const is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256)
|
||||
{
|
||||
softmax_out_ = (float*) ws_sliced[4];
|
||||
softmax_out_ = (float*) ws_sliced[5];
|
||||
}
|
||||
|
||||
bool const gemm2_using_hopper = moe_gemm_runner_.isHopperSpecialised(*gemm2_config_);
|
||||
permuted_scales_ = (gemm2_using_hopper && mayHaveFinalizeFused()) ? (float*) ws_sliced[5] : nullptr;
|
||||
permuted_scales_ = (gemm2_using_hopper && mayHaveFinalizeFused()) ? (float*) ws_sliced[6] : nullptr;
|
||||
|
||||
sorter_ws_ = (char*) ws_sliced[6];
|
||||
sorter_ws_ = (char*) ws_sliced[7];
|
||||
|
||||
// Always same index, but overlapped with either fc1_result_ or fc2_result_
|
||||
permuted_data_ = (T*) ws_sliced[7];
|
||||
permuted_data_ = (T*) ws_sliced[8];
|
||||
|
||||
bool const is_gated_activation = isGatedActivation(activation_type);
|
||||
bool const gemm1_using_fused_moe
|
||||
@ -1434,20 +1533,20 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::confi
|
||||
bool const non_hopper_has_glu = !gemm1_using_fused_moe && is_gated_activation;
|
||||
bool const has_glu_inter_result = hopper_has_glu || non_hopper_has_glu || use_fp8;
|
||||
// Always same index, ignored if not needed
|
||||
glu_inter_result_ = has_glu_inter_result ? (T*) ws_sliced[8] : nullptr;
|
||||
glu_inter_result_ = has_glu_inter_result ? (T*) ws_sliced[9] : nullptr;
|
||||
|
||||
// fc1 and fc2 alias one of the above pointers, but it depends on if actfn is fused/unfused which is overlapped
|
||||
// NOTE: It is important to get the order of these correct as the wrong order will cause the buffer to be used as an
|
||||
// input and output for the same gemm, which will cause corruption
|
||||
fc1_result_ = has_glu_inter_result ? (T*) ws_sliced[7] : (T*) ws_sliced[8];
|
||||
fc2_result_ = has_glu_inter_result ? (T*) ws_sliced[8] : (T*) ws_sliced[7];
|
||||
fc1_result_ = has_glu_inter_result ? (T*) ws_sliced[8] : (T*) ws_sliced[9];
|
||||
fc2_result_ = has_glu_inter_result ? (T*) ws_sliced[9] : (T*) ws_sliced[8];
|
||||
|
||||
alpha_scale_ptr_array_ = reinterpret_cast<float const**>(ws_sliced[9]);
|
||||
alpha_scale_ptr_array_ = reinterpret_cast<float const**>(ws_sliced[10]);
|
||||
|
||||
hopper_grouped_gemm_input_ = {};
|
||||
if (moe_gemm_runner_.supportsHopperSpecialisation())
|
||||
{
|
||||
hopper_grouped_gemm_input_.configureWorkspace(ws_sliced[10], num_experts_per_node, ws_sliced[11], ws_sizes[11]);
|
||||
hopper_grouped_gemm_input_.configureWorkspace(ws_sliced[11], num_experts_per_node, ws_sliced[12], ws_sizes[12]);
|
||||
}
|
||||
|
||||
lora_fc1_result_ = {};
|
||||
@ -1456,9 +1555,9 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::confi
|
||||
|
||||
if (use_lora)
|
||||
{
|
||||
lora_fc1_result_ = (T*) ws_sliced[12];
|
||||
lora_add_bias_ = (T*) ws_sliced[13];
|
||||
lora_fc2_result_ = (T*) ws_sliced[14];
|
||||
lora_fc1_result_ = (T*) ws_sliced[13];
|
||||
lora_add_bias_ = (T*) ws_sliced[14];
|
||||
lora_fc2_result_ = (T*) ws_sliced[15];
|
||||
}
|
||||
}
|
||||
|
||||
@ -1834,7 +1933,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::runMo
|
||||
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const k, char* workspace_ptr, void* final_output_void, bool const* finished,
|
||||
int64_t const active_rows, void* token_topk_final_scales_void, int* expanded_source_row_to_expanded_dest_row,
|
||||
int* expert_for_source_row, MOEParallelismConfig parallelism_config,
|
||||
int* expert_for_source_row, float sparse_mixer_epsilon, MOEParallelismConfig parallelism_config,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, bool use_lora, LoraParams& lora_params, cudaStream_t stream)
|
||||
{
|
||||
static constexpr bool int_scales_required
|
||||
@ -1931,14 +2030,14 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType, Enable>::runMo
|
||||
int const num_experts_per_node = num_experts / parallelism_config.ep_size;
|
||||
|
||||
configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k,
|
||||
fc1_activation_type, use_lora);
|
||||
fc1_activation_type, normalization_mode, use_lora);
|
||||
|
||||
int const start_expert = num_experts_per_node * parallelism_config.ep_rank;
|
||||
int const end_expert = start_expert + num_experts_per_node;
|
||||
|
||||
topkGatingSoftmaxKernelLauncher(gating_output, finished, token_topk_unpermuted_scales, softmax_out_,
|
||||
expert_for_source_row, source_rows_, num_rows, num_experts, k, start_expert, end_expert, normalization_mode,
|
||||
stream);
|
||||
selectExpertsForTokens(gating_output, token_topk_unpermuted_scales, sparse_mixer_out_, softmax_out_,
|
||||
expert_for_source_row, source_rows_, num_rows, num_experts, k, start_expert, end_expert, sparse_mixer_epsilon,
|
||||
normalization_mode, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
|
||||
@ -49,15 +49,18 @@ public:
|
||||
int* values_out, size_t const num_key_value_pairs, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
static int expertsToBits(int experts);
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
||||
|
||||
enum class MOEExpertScaleNormalizationMode : int
|
||||
{
|
||||
NONE = 0, //!< Run the softmax on all scales and select the topk
|
||||
RENORMALIZE, //!< Renormalize the selected scales so they sum to one. This is equivalent to only running softmax on
|
||||
//!< the topk selected experts
|
||||
NONE = 0, //!< Run the softmax on all scales and select the topk
|
||||
RENORMALIZE, //!< Renormalize the selected scales so they sum to one. This is equivalent to only running softmax on
|
||||
//!< the topk selected experts
|
||||
SPARSE_MIXER, //!< Uses the sparse mixer algorithm for selecting the routing probabilities @link
|
||||
//!< https://arxiv.org/abs/2310.00811
|
||||
};
|
||||
|
||||
/**
|
||||
@ -193,8 +196,8 @@ class CutlassMoeFCRunnerInterface
|
||||
public:
|
||||
virtual ~CutlassMoeFCRunnerInterface() = default;
|
||||
virtual size_t getWorkspaceSize(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const k, ActivationType activation_type, MOEParallelismConfig parallelism_config,
|
||||
bool use_lora) const
|
||||
int const num_experts, int const k, ActivationType activation_type, MOEExpertScaleNormalizationMode norm_mode,
|
||||
MOEParallelismConfig parallelism_config, bool use_lora) const
|
||||
= 0;
|
||||
virtual void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config,
|
||||
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config)
|
||||
@ -206,7 +209,7 @@ public:
|
||||
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
|
||||
int64_t const inter_size, int const num_experts, int const k, char* workspace_ptr, void* final_output,
|
||||
bool const* finished, int64_t const active_rows, void* token_topk_unpermuted_scales,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, float sparse_mixer_epsilon,
|
||||
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, bool use_lora,
|
||||
LoraParams& lora_params, cudaStream_t stream)
|
||||
= 0;
|
||||
@ -274,8 +277,8 @@ public:
|
||||
std::is_same_v<T, WeightType> || !std::is_same_v<T, float>, "Does not support float with quantized weights");
|
||||
|
||||
size_t getWorkspaceSize(int64_t const num_rows, int64_t const hidden_size, int64_t const fc1_output_size,
|
||||
int const num_experts, int const k, ActivationType activation_type, MOEParallelismConfig parallelism_config,
|
||||
bool use_lora) const override;
|
||||
int const num_experts, int const k, ActivationType activation_type, MOEExpertScaleNormalizationMode norm_mode,
|
||||
MOEParallelismConfig parallelism_config, bool use_lora) const override;
|
||||
|
||||
void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config,
|
||||
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config) override
|
||||
@ -300,7 +303,7 @@ public:
|
||||
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
|
||||
int64_t const inter_size, int const num_experts, int const k, char* workspace_ptr, void* final_output,
|
||||
bool const* finished, int64_t const active_rows, void* token_topk_unpermuted_scales,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, float sparse_mixer_epsilon,
|
||||
MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, bool use_lora,
|
||||
LoraParams& lora_params, cudaStream_t stream) override;
|
||||
|
||||
@ -378,10 +381,10 @@ private:
|
||||
cudaStream_t stream);
|
||||
std::vector<size_t> getWorkspaceBufferSizes(int64_t const num_rows, int64_t const hidden_size,
|
||||
int64_t const inter_size, int const num_experts, int const num_experts_per_node, int const k,
|
||||
ActivationType activation_type, bool use_lora) const;
|
||||
ActivationType activation_type, MOEExpertScaleNormalizationMode norm_mode, bool use_lora) const;
|
||||
void configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
|
||||
int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type,
|
||||
bool use_lora);
|
||||
MOEExpertScaleNormalizationMode norm_mode, bool use_lora);
|
||||
|
||||
private:
|
||||
bool mayHaveDifferentGEMMOutputType() const
|
||||
@ -418,6 +421,7 @@ private:
|
||||
int* permuted_experts_{};
|
||||
char* sorter_ws_{};
|
||||
T* permuted_data_{};
|
||||
float* sparse_mixer_out_{};
|
||||
float* softmax_out_{};
|
||||
float* permuted_scales_{};
|
||||
|
||||
|
||||
@ -188,6 +188,15 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits,
|
||||
}
|
||||
}
|
||||
}
|
||||
// do clamp to prevent overflow
|
||||
if (logit > static_cast<float>(-MASK_VAL))
|
||||
{
|
||||
logit = static_cast<float>(-MASK_VAL);
|
||||
}
|
||||
else if (logit < static_cast<float>(MASK_VAL))
|
||||
{
|
||||
logit = static_cast<float>(MASK_VAL);
|
||||
}
|
||||
outLogitsPtr[index] = logit;
|
||||
}
|
||||
else
|
||||
|
||||
@ -23,6 +23,9 @@
|
||||
#else
|
||||
#include "3rdparty/cub/cub.cuh"
|
||||
#endif
|
||||
#if ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
@ -221,6 +224,10 @@ void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<T> const& par
|
||||
|
||||
template void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeFillContextBuffers(FillContextExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokeFillContextBuffers(
|
||||
FillContextExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -265,6 +272,10 @@ void invokeFillRandData(FillRandDataExplicitDraftTokensParams<T> const& params,
|
||||
|
||||
template void invokeFillRandData(FillRandDataExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeFillRandData(FillRandDataExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokeFillRandData(
|
||||
FillRandDataExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -404,6 +415,10 @@ template void invokeExtractExplicitDraftTokens(
|
||||
ExtractExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeExtractExplicitDraftTokens(
|
||||
ExtractExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokeExtractExplicitDraftTokens(
|
||||
ExtractExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -475,6 +490,9 @@ void invokeCopyProbs(ExtractExplicitDraftTokensParams<T> const& params, cudaStre
|
||||
|
||||
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokeCopyProbs(ExtractExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -503,6 +521,10 @@ void invokePackGenerationLengths(PackExplicitDraftTokensParams<T> const& params,
|
||||
|
||||
template void invokePackGenerationLengths(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokePackGenerationLengths(PackExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokePackGenerationLengths(
|
||||
PackExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -591,6 +613,10 @@ void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<T> const& param
|
||||
|
||||
template void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokePackExplicitDraftTokens(PackExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokePackExplicitDraftTokens(
|
||||
PackExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
void invokeCopyProbs(PackExplicitDraftTokensParams<T> const& params, cudaStream_t stream)
|
||||
@ -606,4 +632,8 @@ void invokeCopyProbs(PackExplicitDraftTokensParams<T> const& params, cudaStream_
|
||||
|
||||
template void invokeCopyProbs(PackExplicitDraftTokensParams<float> const& params, cudaStream_t stream);
|
||||
template void invokeCopyProbs(PackExplicitDraftTokensParams<half> const& params, cudaStream_t stream);
|
||||
#if ENABLE_BF16
|
||||
template void invokeCopyProbs(PackExplicitDraftTokensParams<__nv_bfloat16> const& params, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -104,6 +104,7 @@ struct QKVPreprocessingParams
|
||||
int cyclic_kv_cache_len{0};
|
||||
int sink_token_len{0};
|
||||
int token_num{0};
|
||||
bool remove_padding{true};
|
||||
int head_num{0};
|
||||
int kv_head_num{0};
|
||||
int qheads_per_kv_head{0};
|
||||
|
||||
@ -373,7 +373,8 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
{
|
||||
// The index of the token in the batch.
|
||||
int const global_token_idx = local_token_idx
|
||||
+ (variable_sequence_length ? params.cu_seq_lens[batch_idx] : batch_idx * params.max_input_seq_len);
|
||||
+ ((variable_sequence_length && params.remove_padding) ? params.cu_seq_lens[batch_idx]
|
||||
: batch_idx * params.max_input_seq_len);
|
||||
int const cache_seq_len = params.cache_seq_lens[batch_idx];
|
||||
int const actual_seq_len = variable_sequence_length ? params.seq_lens[batch_idx] : params.max_input_seq_len;
|
||||
// Chunked attention: takes past_kv_sequence_length into consideration.
|
||||
@ -688,8 +689,9 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
local_token_idx < seq_len_loop_end; local_token_idx += TOKENS_PER_BLOCK * gridDim.x)
|
||||
{
|
||||
// The index of the token in the batch.
|
||||
int const global_token_offset
|
||||
= variable_sequence_length ? params.cu_seq_lens[batch_idx] : batch_idx * params.max_input_seq_len;
|
||||
int const global_token_offset = (variable_sequence_length && params.remove_padding)
|
||||
? params.cu_seq_lens[batch_idx]
|
||||
: batch_idx * params.max_input_seq_len;
|
||||
int const cache_seq_len = params.cache_seq_lens[batch_idx];
|
||||
int const actual_seq_len = variable_sequence_length ? params.seq_lens[batch_idx] : params.max_input_seq_len;
|
||||
// Chunked attention: takes past_kv_sequence_length into consideration.
|
||||
|
||||
@ -186,6 +186,7 @@ public:
|
||||
// Hack to init some data for the context phase in the setup.
|
||||
TensorPtr randomDataSample; // [maxBatchSize], on gpu
|
||||
TensorPtr temperatures; // [maxBatchSize], on gpu
|
||||
nvinfer1::DataType dtype; // [1], on cpu
|
||||
};
|
||||
|
||||
class DynamicDecodeSetupParams : public BaseSetupParams
|
||||
|
||||
@ -116,31 +116,28 @@ void ExplicitDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
|
||||
// Setup penalties.
|
||||
FillBuffers const fillBuffers{batchSize, mDecoderDomain.getBatchSize(), mBufferManager};
|
||||
|
||||
// Set decoder dtype to WAR the lack of bf16 support in decoder.
|
||||
if (!mDecoderDtype)
|
||||
{
|
||||
mDecoderDtype = setupParams->dtype;
|
||||
}
|
||||
|
||||
fillBuffers(setupParams->temperature, DefaultDecodingParams::getTemperature(), mTemperature, mTemperatureDevice,
|
||||
batchSlots, getLimitsPenalty(DecodingPenaltyType::Temperature), "temperature penalty");
|
||||
|
||||
fillContextBuffers(batchSize, batchSlots, *setupParams);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::fillContextBuffers(
|
||||
SizeType32 batchSize, BufferConstPtr batchSlots, ExplicitDraftTokensSetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
FillContextExplicitDraftTokensParams<T> params;
|
||||
params.randDataSample = bufferCast<T>(*setupParams.randomDataSample);
|
||||
params.outputTemperatures = bufferCast<T>(*setupParams.temperatures);
|
||||
params.inputTemperatures = bufferCastOrNull<float>(mTemperatureDevice);
|
||||
params.curandState = reinterpret_cast<curandState_t*>(bufferCastOrNull<int8_t>(mCurandStatesDevice));
|
||||
params.batchSlots = bufferCast<SizeType32>(*batchSlots);
|
||||
params.batchSize = batchSize;
|
||||
|
||||
params.checkParams();
|
||||
|
||||
invokeFillContextBuffers(params, getStream());
|
||||
// Dispatch context buffer fill
|
||||
if (mDecoderDtype == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
fillContextBuffers<float>(batchSize, batchSlots, *setupParams);
|
||||
}
|
||||
else if (mDecoderDtype == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
fillContextBuffers<half>(batchSize, batchSlots, *setupParams);
|
||||
}
|
||||
else if (mDecoderDtype == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
fillContextBuffers<__nv_bfloat16>(batchSize, batchSlots, *setupParams);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -160,7 +157,18 @@ void ExplicitDraftTokensLayer<T>::forwardAsync(
|
||||
convertPackedMask(*outputs, *inputs);
|
||||
|
||||
// Slice output ids, pos ids, next draft tokens.
|
||||
splitInputDataToBatchSlots(*outputs, *inputs);
|
||||
if (mDecoderDtype == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
splitInputDataToBatchSlots<float>(*outputs, *inputs);
|
||||
}
|
||||
else if (mDecoderDtype == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
splitInputDataToBatchSlots<half>(*outputs, *inputs);
|
||||
}
|
||||
else if (mDecoderDtype == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
splitInputDataToBatchSlots<__nv_bfloat16>(*outputs, *inputs);
|
||||
}
|
||||
|
||||
// Pack accepted paths for KV cache rewind.
|
||||
packAcceptedPaths(*outputs, *inputs);
|
||||
@ -174,6 +182,96 @@ size_t ExplicitDraftTokensLayer<T>::getWorkspaceSize() const noexcept
|
||||
return mWorkspaceDevice->getSizeInBytes();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename Dtype>
|
||||
void ExplicitDraftTokensLayer<T>::fillContextBuffers(
|
||||
SizeType32 batchSize, BufferConstPtr batchSlots, ExplicitDraftTokensSetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
FillContextExplicitDraftTokensParams<Dtype> params;
|
||||
params.randDataSample = bufferCast<Dtype>(*setupParams.randomDataSample);
|
||||
params.outputTemperatures = bufferCast<Dtype>(*setupParams.temperatures);
|
||||
params.inputTemperatures = bufferCastOrNull<float>(mTemperatureDevice);
|
||||
params.curandState = reinterpret_cast<curandState_t*>(bufferCastOrNull<int8_t>(mCurandStatesDevice));
|
||||
params.batchSlots = bufferCast<SizeType32>(*batchSlots);
|
||||
params.batchSize = batchSize;
|
||||
|
||||
params.checkParams();
|
||||
|
||||
invokeFillContextBuffers(params, getStream());
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename Dtype>
|
||||
void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.localBatchSize;
|
||||
auto const maxSeqLen = outputs.outputIds->getDimension<-1>();
|
||||
|
||||
ExtractExplicitDraftTokensParams<Dtype> params;
|
||||
|
||||
params.outputIds = bufferCast<TokenIdType>(*outputs.outputIds);
|
||||
params.outputPositionIdsBase = bufferCast<SizeType32>(*outputs.positionIdsBase);
|
||||
params.outputPositionIds = bufferCast<SizeType32>(*outputs.nextDraftPosIds);
|
||||
params.outputNextDraftTokens = bufferCast<TokenIdType>(*outputs.nextDraftTokens);
|
||||
params.unpackedNextDraftTokens = bufferCast<TokenIdType>(*outputs.unpackedNextDraftTokens);
|
||||
params.unpackedNextDraftIndices = bufferCast<SizeType32>(*outputs.unpackedNextDraftIndices);
|
||||
params.acceptedLengths = bufferCast<SizeType32>(*outputs.numNewTokens.value());
|
||||
params.nextDraftLengths = bufferCast<SizeType32>(*outputs.nextDraftLengths);
|
||||
params.prevDraftLengths = bufferCast<SizeType32>(*outputs.prevDraftLengths);
|
||||
params.sequenceLengths = bufferCast<SizeType32>(*outputs.sequenceLength.value());
|
||||
params.randDataSample = bufferCast<Dtype>(*outputs.randomDataSample);
|
||||
params.randDataVerification = bufferCast<Dtype>(*outputs.randomDataValidation);
|
||||
params.outputDraftProbs = bufferCast<Dtype>(*outputs.nextDraftProbs);
|
||||
params.outputTemperatures = bufferCast<Dtype>(*outputs.temperatures);
|
||||
params.outputGenerationLengths = bufferCast<SizeType32>(*outputs.generationLengths);
|
||||
params.outputBestPathIndices = bufferCast<SizeType32>(*mBestPathIndicesSlots);
|
||||
params.outputLastDraftIndices = bufferCast<SizeType32>(*mLastDraftIndicesSlots);
|
||||
|
||||
params.batchSlots = bufferCast<SizeType32>(*inputs.seqSlots);
|
||||
params.nextDraftTokens = bufferCast<TokenIdType>(*inputs.nextDraftTokens);
|
||||
params.lastDraftTokens = bufferCast<TokenIdType>(*inputs.lastDraftTokens);
|
||||
params.inputUnpackedNextDraftIndices = bufferCast<SizeType32>(*inputs.nextDraftIndices);
|
||||
params.bestPathLengths = bufferCast<SizeType32>(*inputs.bestPathLengths);
|
||||
params.bestPathIndices = bufferCast<SizeType32>(*inputs.bestPathIndices);
|
||||
params.inputPositionIdsBase = bufferCast<SizeType32>(*inputs.positionIdsBase);
|
||||
params.packedPositionIds = bufferCast<SizeType32>(*inputs.packedPosIds);
|
||||
params.nextFlatTokens = bufferCast<TokenIdType>(*inputs.nextFlatTokens);
|
||||
params.nextDraftProbs = bufferCast<Dtype>(*inputs.nextDraftProbs);
|
||||
params.lastGenerationLengths = bufferCastOrNull<SizeType32>(inputs.lastGenerationLengths);
|
||||
params.generationLengthInclusiveSum = bufferCast<SizeType32>(*mGenerationLengthInclusiveSum);
|
||||
params.lastDraftIndices = bufferCast<SizeType32>(*inputs.lastDraftIndices);
|
||||
params.inputTemperatures = bufferCast<float>(*mTemperatureDevice);
|
||||
params.curandState = reinterpret_cast<curandState_t*>(bufferCastOrNull<int8_t>(mCurandStatesDevice));
|
||||
params.batchSize = batchSize;
|
||||
params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths();
|
||||
params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen();
|
||||
params.maxSeqLen = maxSeqLen;
|
||||
params.vocabSize = mDecoderDomain.getVocabSizePadded();
|
||||
params.numContextRequests = batchSize - inputs.lastDraftTokens->getDimension<0>();
|
||||
params.numGenerationRequests = inputs.lastDraftTokens->getDimension<0>();
|
||||
|
||||
params.checkParams();
|
||||
|
||||
// Copy max generation length
|
||||
mBufferManager->copy(*inputs.maxGenLengthDevice, *outputs.maxGenLengthHost);
|
||||
|
||||
invokeExtractExplicitDraftTokens(params, getStream());
|
||||
|
||||
invokeCopyProbs(params, getStream());
|
||||
|
||||
// Copy generation lengths
|
||||
mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::convertPackedMask(
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
@ -201,73 +299,6 @@ void ExplicitDraftTokensLayer<T>::convertPackedMask(
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::splitInputDataToBatchSlots(
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.localBatchSize;
|
||||
auto const maxSeqLen = outputs.outputIds->getDimension<-1>();
|
||||
|
||||
ExtractExplicitDraftTokensParams<T> params;
|
||||
|
||||
params.outputIds = bufferCast<TokenIdType>(*outputs.outputIds);
|
||||
params.outputPositionIdsBase = bufferCast<SizeType32>(*outputs.positionIdsBase);
|
||||
params.outputPositionIds = bufferCast<SizeType32>(*outputs.nextDraftPosIds);
|
||||
params.outputNextDraftTokens = bufferCast<TokenIdType>(*outputs.nextDraftTokens);
|
||||
params.unpackedNextDraftTokens = bufferCast<TokenIdType>(*outputs.unpackedNextDraftTokens);
|
||||
params.unpackedNextDraftIndices = bufferCast<SizeType32>(*outputs.unpackedNextDraftIndices);
|
||||
params.acceptedLengths = bufferCast<SizeType32>(*outputs.numNewTokens.value());
|
||||
params.nextDraftLengths = bufferCast<SizeType32>(*outputs.nextDraftLengths);
|
||||
params.prevDraftLengths = bufferCast<SizeType32>(*outputs.prevDraftLengths);
|
||||
params.sequenceLengths = bufferCast<SizeType32>(*outputs.sequenceLength.value());
|
||||
params.randDataSample = bufferCast<T>(*outputs.randomDataSample);
|
||||
params.randDataVerification = bufferCast<T>(*outputs.randomDataValidation);
|
||||
params.outputDraftProbs = bufferCast<T>(*outputs.nextDraftProbs);
|
||||
params.outputTemperatures = bufferCast<T>(*outputs.temperatures);
|
||||
params.outputGenerationLengths = bufferCast<SizeType32>(*outputs.generationLengths);
|
||||
params.outputBestPathIndices = bufferCast<SizeType32>(*mBestPathIndicesSlots);
|
||||
params.outputLastDraftIndices = bufferCast<SizeType32>(*mLastDraftIndicesSlots);
|
||||
|
||||
params.batchSlots = bufferCast<SizeType32>(*inputs.seqSlots);
|
||||
params.nextDraftTokens = bufferCast<TokenIdType>(*inputs.nextDraftTokens);
|
||||
params.lastDraftTokens = bufferCast<TokenIdType>(*inputs.lastDraftTokens);
|
||||
params.inputUnpackedNextDraftIndices = bufferCast<SizeType32>(*inputs.nextDraftIndices);
|
||||
params.bestPathLengths = bufferCast<SizeType32>(*inputs.bestPathLengths);
|
||||
params.bestPathIndices = bufferCast<SizeType32>(*inputs.bestPathIndices);
|
||||
params.inputPositionIdsBase = bufferCast<SizeType32>(*inputs.positionIdsBase);
|
||||
params.packedPositionIds = bufferCast<SizeType32>(*inputs.packedPosIds);
|
||||
params.nextFlatTokens = bufferCast<TokenIdType>(*inputs.nextFlatTokens);
|
||||
params.nextDraftProbs = bufferCast<T>(*inputs.nextDraftProbs);
|
||||
params.lastGenerationLengths = bufferCastOrNull<SizeType32>(inputs.lastGenerationLengths);
|
||||
params.generationLengthInclusiveSum = bufferCast<SizeType32>(*mGenerationLengthInclusiveSum);
|
||||
params.lastDraftIndices = bufferCast<SizeType32>(*inputs.lastDraftIndices);
|
||||
params.inputTemperatures = bufferCast<float>(*mTemperatureDevice);
|
||||
params.curandState = reinterpret_cast<curandState_t*>(bufferCastOrNull<int8_t>(mCurandStatesDevice));
|
||||
params.batchSize = batchSize;
|
||||
params.numPaths = mDecoderDomain.getSpeculativeDecodingModule()->getMaxNumPaths();
|
||||
params.maxPathLength = mDecoderDomain.getSpeculativeDecodingModule()->getMaxPathLen();
|
||||
params.maxSeqLen = maxSeqLen;
|
||||
params.vocabSize = mDecoderDomain.getVocabSizePadded();
|
||||
params.numContextRequests = batchSize - inputs.lastDraftTokens->getDimension<0>();
|
||||
params.numGenerationRequests = inputs.lastDraftTokens->getDimension<0>();
|
||||
|
||||
params.checkParams();
|
||||
|
||||
// Copy max generation length
|
||||
mBufferManager->copy(*inputs.maxGenLengthDevice, *outputs.maxGenLengthHost);
|
||||
|
||||
invokeExtractExplicitDraftTokens(params, getStream());
|
||||
|
||||
invokeCopyProbs(params, getStream());
|
||||
|
||||
// Copy generation lengths
|
||||
mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayer<T>::packAcceptedPaths(
|
||||
ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs)
|
||||
|
||||
@ -48,15 +48,17 @@ public:
|
||||
private:
|
||||
void allocateBuffer();
|
||||
|
||||
void fillContextBuffers(
|
||||
SizeType32 batchSize, BufferConstPtr batchSlots, ExplicitDraftTokensSetupParams const& params);
|
||||
|
||||
void convertPackedMask(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
void splitInputDataToBatchSlots(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
void packAcceptedPaths(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
template <typename Dtype>
|
||||
void fillContextBuffers(
|
||||
SizeType32 batchSize, BufferConstPtr batchSlots, ExplicitDraftTokensSetupParams const& setupParams);
|
||||
|
||||
template <typename Dtype>
|
||||
void splitInputDataToBatchSlots(ExplicitDraftTokensOutputs const& outputs, ExplicitDraftTokensInputs const& inputs);
|
||||
|
||||
private:
|
||||
using Base::mDecoderDomain;
|
||||
|
||||
@ -76,6 +78,8 @@ private:
|
||||
TensorPtr mLastDraftIndicesSlots;
|
||||
|
||||
TensorPtr mTemperature;
|
||||
|
||||
std::optional<nvinfer1::DataType> mDecoderDtype{std::nullopt};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -271,7 +271,6 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
params.batchSize = batch_size;
|
||||
params.maxQSeqLength = input_seq_len;
|
||||
params.numTokens = num_tokens;
|
||||
params.removePadding = mRemovePadding;
|
||||
params.attentionMaskType = AttentionMaskType::PADDING;
|
||||
params.fmhaTileCounter = fmha_tile_counter_ptr;
|
||||
invokeBuildDecoderInfo(params, stream);
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
|
||||
namespace tensorrt_llm::plugins::utils
|
||||
{
|
||||
using DimType64 = int64_t;
|
||||
@ -63,9 +65,14 @@ inline DimType64 computeNDimension(bool transB, nvinfer1::Dims const& dims)
|
||||
return N;
|
||||
}
|
||||
|
||||
inline std::int32_t logErrorReturn0(char const* variable)
|
||||
{
|
||||
TLLM_LOG_ERROR("Value of %s is out of range for int32_t", variable);
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define TLLM_INT32_CAST(value) \
|
||||
((value > 0x7FFFFFFFLL || value < -0x80000000LL) \
|
||||
? (TLLM_LOG_ERROR("Value of " #value " is out of range for int32_t"), 0) \
|
||||
: static_cast<int32_t>(value))
|
||||
((value > 0x7FFFFFFFLL || value < -0x80000000LL) ? tensorrt_llm::plugins::utils::logErrorReturn0(#value) \
|
||||
: static_cast<int32_t>(value))
|
||||
|
||||
} // namespace tensorrt_llm::plugins::utils
|
||||
|
||||
@ -620,8 +620,10 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t
|
||||
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
|
||||
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
|
||||
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
|
||||
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
|
||||
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
|
||||
|
||||
int const NUM_BUFFERS = 16;
|
||||
int const NUM_BUFFERS = 18;
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
|
||||
workspaces[1] = attention_mask_size;
|
||||
@ -639,6 +641,8 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t
|
||||
workspaces[13] = padding_offset_size;
|
||||
workspaces[14] = encoder_padding_offset_size;
|
||||
workspaces[15] = fmha_scheduler_counter;
|
||||
workspaces[16] = fmha_bmm1_scale_size;
|
||||
workspaces[17] = fmha_bmm2_scale_size;
|
||||
context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
|
||||
return context_workspace_size;
|
||||
@ -802,6 +806,8 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
size_t const encoder_padding_offset_size
|
||||
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.cross_qkv_length;
|
||||
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
|
||||
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
|
||||
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
|
||||
|
||||
bool const is_qk_buf_float_ = true;
|
||||
|
||||
@ -831,6 +837,10 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
: reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, encoder_padding_offset_size));
|
||||
uint32_t* fmha_tile_counter_ptr
|
||||
= reinterpret_cast<uint32_t*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_scheduler_counter));
|
||||
float* fmha_bmm1_scale_ptr
|
||||
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_bmm1_scale_size));
|
||||
float* fmha_bmm2_scale_ptr
|
||||
= reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, fmha_bmm2_scale_size));
|
||||
|
||||
// build attention_mask, cu_seqlens, and padding_offset tensors
|
||||
// Note: self attn and cross attn should use different params
|
||||
@ -852,13 +862,17 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
decoder_params.maxQSeqLength = params.input_seq_length;
|
||||
decoder_params.maxEncoderQSeqLength
|
||||
= isCrossAttention() ? params.cross_qkv_length : 0; // cross attention uses encoder seq length
|
||||
decoder_params.removePadding = mRemovePadding;
|
||||
decoder_params.attentionWindowSize = params.cyclic_attention_window_size;
|
||||
decoder_params.sinkTokenLength = params.sink_token_length;
|
||||
decoder_params.numTokens = params.num_tokens;
|
||||
decoder_params.attentionMaskType = mMaskType;
|
||||
decoder_params.blockSparseParams = mBlockSparseParams;
|
||||
decoder_params.fmhaTileCounter = fmha_tile_counter_ptr;
|
||||
decoder_params.quantScaleO = params.attention_output_orig_quant;
|
||||
decoder_params.dequantScaleQkv = params.kv_scale_quant_orig;
|
||||
decoder_params.fmhaHostBmm1Scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling);
|
||||
decoder_params.fmhaBmm1Scale = fmha_bmm1_scale_ptr;
|
||||
decoder_params.fmhaBmm2Scale = fmha_bmm2_scale_ptr;
|
||||
// Rotary embedding inv_freq buffer.
|
||||
decoder_params.rotaryEmbeddingScale = mRotaryEmbeddingScale;
|
||||
decoder_params.rotaryEmbeddingBase = mRotaryEmbeddingBase;
|
||||
@ -900,6 +914,12 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
stream);
|
||||
}
|
||||
|
||||
// FIXME: a temporary solution to make sure the padding part is 0.
|
||||
if (!mRemovePadding)
|
||||
{
|
||||
cudaMemsetAsync(params.context_buf, 0, params.num_tokens * local_hidden_units_qo * sizeof(T), stream);
|
||||
}
|
||||
|
||||
KvCacheDataType const cache_type = mKVCacheQuantMode.hasInt8KvCache()
|
||||
? KvCacheDataType::INT8
|
||||
: (mKVCacheQuantMode.hasFp8KvCache() ? KvCacheDataType::FP8 : KvCacheDataType::BASE);
|
||||
@ -953,6 +973,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
preprocessingParams.cyclic_kv_cache_len = params.cyclic_attention_window_size;
|
||||
preprocessingParams.sink_token_len = params.sink_token_length;
|
||||
preprocessingParams.token_num = params.num_tokens;
|
||||
preprocessingParams.remove_padding = mRemovePadding;
|
||||
preprocessingParams.head_num = mNumHeads;
|
||||
preprocessingParams.kv_head_num = mNumKVHeads;
|
||||
preprocessingParams.qheads_per_kv_head = mNumHeads / mNumKVHeads;
|
||||
@ -1032,7 +1053,8 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
fmhaParams.cuKvSeqLenPtr = cu_kv_seqlens;
|
||||
fmhaParams.cuMaskRowsPtr = cu_mask_rows;
|
||||
fmhaParams.tileCounterPtr = fmha_tile_counter_ptr;
|
||||
fmhaParams.scaleBmm2Ptr = params.attention_output_orig_quant;
|
||||
fmhaParams.scaleBmm1Ptr = fmha_bmm1_scale_ptr;
|
||||
fmhaParams.scaleBmm2Ptr = fmha_bmm2_scale_ptr;
|
||||
fmhaParams.stream = stream;
|
||||
fmhaParams.forceFp32Acc = mFMHAForceFP32Acc;
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int nu
|
||||
int expert_hidden_size, int expert_inter_size, tensorrt_llm::ActivationType activation_type,
|
||||
nvinfer1::DataType type, nvinfer1::DataType weight_type, nvinfer1::DataType output_type, QuantMode quant_mode,
|
||||
bool use_finished, bool use_bias, int tp_size, int tp_rank, int ep_size, int ep_rank,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, bool force_determinism,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, float sparse_mixer_epsilon, bool force_determinism,
|
||||
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
|
||||
LoraPluginProfilerPtr lora_profiler, int max_low_rank)
|
||||
: mRemoveInputPadding(remove_input_padding)
|
||||
@ -58,6 +58,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int nu
|
||||
, mUseBias(use_bias)
|
||||
, mParallelismConfig(MOEParallelismConfig{tp_size, tp_rank, ep_size, ep_rank})
|
||||
, mNormalizationMode(normalization_mode)
|
||||
, mSparseMixerEpsilon(sparse_mixer_epsilon)
|
||||
, mUseDeterministicKernels(force_determinism)
|
||||
, mGemmProfiler(std::move(gemm_profiler_ptr))
|
||||
, mUseLora(use_lora)
|
||||
@ -87,6 +88,7 @@ tensorrt_llm::plugins::MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(MixtureOfE
|
||||
, mDims(other.mDims)
|
||||
, mGemmId1(other.mGemmId1)
|
||||
, mGemmId2(other.mGemmId2)
|
||||
, mSparseMixerEpsilon(other.mSparseMixerEpsilon)
|
||||
, mUseDeterministicKernels(other.mUseDeterministicKernels)
|
||||
, mGemmProfiler(other.mGemmProfiler)
|
||||
, mUseLora(other.mUseLora)
|
||||
@ -108,7 +110,7 @@ size_t MixtureOfExpertsPlugin::getSerializationSize() const noexcept
|
||||
size_t size = sizeof(mRemoveInputPadding) + sizeof(mNumExperts) + sizeof(mK) + sizeof(mExpertHiddenSize)
|
||||
+ sizeof(mExpertInterSize) + sizeof(mActivationType) + sizeof(mType) + sizeof(mWeightType) + sizeof(mOutputType)
|
||||
+ sizeof(QuantMode::BaseType) + sizeof(mUseFinished) + sizeof(mUseBias) + sizeof(mParallelismConfig)
|
||||
+ sizeof(mNormalizationMode) + sizeof(mUseDeterministicKernels) + sizeof(mDims)
|
||||
+ sizeof(mNormalizationMode) + sizeof(mSparseMixerEpsilon) + sizeof(mDims) + sizeof(mUseDeterministicKernels)
|
||||
+ mGemmProfiler->getSerializationSize(mGemmId1) + mGemmProfiler->getSerializationSize(mGemmId2)
|
||||
+ sizeof(mUseLora) + sizeof(mLoraType) + sizeof(mMaxLowRank);
|
||||
|
||||
@ -144,6 +146,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(void const* data, size_t length,
|
||||
read(d, mUseBias);
|
||||
read(d, mParallelismConfig);
|
||||
read(d, mNormalizationMode);
|
||||
read(d, mSparseMixerEpsilon);
|
||||
read(d, mDims);
|
||||
read(d, mUseDeterministicKernels);
|
||||
read(d, mUseLora);
|
||||
@ -187,6 +190,7 @@ void MixtureOfExpertsPlugin::serialize(void* buffer) const noexcept
|
||||
write(d, mUseBias);
|
||||
write(d, mParallelismConfig);
|
||||
write(d, mNormalizationMode);
|
||||
write(d, mSparseMixerEpsilon);
|
||||
write(d, mDims);
|
||||
write(d, mUseDeterministicKernels);
|
||||
write(d, mUseLora);
|
||||
@ -449,7 +453,7 @@ auto MixtureOfExpertsPlugin::setupWorkspace(void* base_ptr, int64_t num_tokens,
|
||||
size_t dtype_size = tensorrt_llm::common::getDTypeSize(mType);
|
||||
|
||||
size_t moe_workspace_size = mMOERunner->getWorkspaceSize(num_tokens, mExpertHiddenSize, mExpertInterSize,
|
||||
mNumExperts, mK, mActivationType, mParallelismConfig, mUseLora);
|
||||
mNumExperts, mK, mActivationType, mNormalizationMode, mParallelismConfig, mUseLora);
|
||||
|
||||
// Output of post-softmax routing probabilities
|
||||
size_t scale_probabilities_size = num_tokens * mNumExperts * sizeof(float);
|
||||
@ -591,7 +595,7 @@ LoraParams MixtureOfExpertsPlugin::getLoraParams(
|
||||
int idx = 0;
|
||||
for (int req_id = 0; req_id < num_reqs; req_id++)
|
||||
{
|
||||
const RequestType reqType = static_cast<RequestType const>(req_types[req_id]);
|
||||
RequestType const reqType = static_cast<RequestType const>(req_types[req_id]);
|
||||
if (reqType == RequestType::kGENERATION)
|
||||
{
|
||||
mLoraExpandFC1WeightPtrs.push_back(fc1_lora_weight_ptrs[req_id * 2]);
|
||||
@ -719,8 +723,8 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
outputs[getOutputTensorIndex()],
|
||||
hasFinishedTensor() ? static_cast<bool const*>(inputs[getFinishedTensorIndex()]) : nullptr, num_not_finished,
|
||||
workspace.scale_probs, static_cast<int*>(workspace.src_to_dest_map),
|
||||
static_cast<int*>(workspace.selected_experts), mParallelismConfig, mNormalizationMode, mUseLora, lora_params,
|
||||
stream);
|
||||
static_cast<int*>(workspace.selected_experts), mSparseMixerEpsilon, mParallelismConfig, mNormalizationMode,
|
||||
mUseLora, lora_params, stream);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -828,6 +832,8 @@ MixtureOfExpertsPluginCreator::MixtureOfExpertsPluginCreator()
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("ep_rank", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("normalization_mode", nullptr, PluginFieldType::kINT32,
|
||||
static_cast<int>(MOEExpertScaleNormalizationMode::NONE)));
|
||||
mPluginAttributes.emplace_back(
|
||||
nvinfer1::PluginField("sparse_mixer_epsilon", nullptr, PluginFieldType::kFLOAT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_lora", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("lora_type_id", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("max_low_rank", nullptr, PluginFieldType::kINT32, 0));
|
||||
@ -861,6 +867,8 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
int mLoraType{INT_MAX};
|
||||
int mMaxLowRank{0};
|
||||
|
||||
float mSparseMixerEpsilon = -INFINITY;
|
||||
|
||||
// Read configurations from each fields
|
||||
struct MapPair
|
||||
{
|
||||
@ -908,6 +916,13 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
item.set = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Non integer inputs
|
||||
if (!strcmp(attrName, "sparse_mixer_epsilon"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
|
||||
mSparseMixerEpsilon = *static_cast<float const*>(fields[i].data);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& item : input_map)
|
||||
@ -927,6 +942,13 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
"MoE fuse lora, lora_type_id and max_low_rank are required but not set");
|
||||
}
|
||||
|
||||
if (static_cast<MOEExpertScaleNormalizationMode>(mNormalizationMode)
|
||||
== MOEExpertScaleNormalizationMode::SPARSE_MIXER)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mSparseMixerEpsilon > 0, "sparse_mixer_epsilon must be set when normalization mode is SPARSE_MIXER");
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto gemmProfiler = moePluginProfiler.createGemmPluginProfiler(/* inference */ false);
|
||||
@ -937,8 +959,9 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
static_cast<tensorrt_llm::ActivationType>(mActivationType), static_cast<nvinfer1::DataType>(mType),
|
||||
static_cast<nvinfer1::DataType>(mWeightType), static_cast<nvinfer1::DataType>(mOutputType),
|
||||
QuantMode(mQuantMode), mUseFinished != 0, mUseBias != 0, mTPSize, mTPRank, mEPSize, mEPRank,
|
||||
static_cast<MOEExpertScaleNormalizationMode>(mNormalizationMode), mRequiresDeterminism != 0, gemmProfiler,
|
||||
mUseLora != 0, static_cast<nvinfer1::DataType>(mLoraType), loraProfiler, mMaxLowRank);
|
||||
static_cast<MOEExpertScaleNormalizationMode>(mNormalizationMode), mSparseMixerEpsilon,
|
||||
mRequiresDeterminism != 0, gemmProfiler, mUseLora != 0, static_cast<nvinfer1::DataType>(mLoraType),
|
||||
loraProfiler, mMaxLowRank);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ public:
|
||||
int expert_inter_size, tensorrt_llm::ActivationType activation_type, nvinfer1::DataType type,
|
||||
nvinfer1::DataType weight_type, nvinfer1::DataType output_type, tensorrt_llm::common::QuantMode quant_mode,
|
||||
bool use_finished, bool use_bias, int tp_size, int tp_rank, int ep_size, int ep_rank,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, bool force_determinism,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, float sparse_mixer_epsilon, bool force_determinism,
|
||||
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
|
||||
LoraPluginProfilerPtr lora_profiler, int max_low_rank);
|
||||
MixtureOfExpertsPlugin(void const* data, size_t length, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr,
|
||||
@ -166,6 +166,7 @@ private:
|
||||
bool mUseBias{};
|
||||
MOEParallelismConfig mParallelismConfig{};
|
||||
MOEExpertScaleNormalizationMode mNormalizationMode{};
|
||||
float mSparseMixerEpsilon = false;
|
||||
|
||||
GemmDims mDims{};
|
||||
bool mUseDeterministicKernels = false;
|
||||
|
||||
@ -250,7 +250,8 @@ void InitBindings(pybind11::module_& m)
|
||||
std::optional<std::list<VecTokens>>, std::optional<std::list<VecTokens>>, std::optional<Tensor>,
|
||||
std::optional<tle::ExternalDraftTokensConfig>, std::optional<tle::PromptTuningConfig>,
|
||||
std::optional<tle::LoraConfig>, std::optional<tle::LookaheadDecodingConfig>,
|
||||
std::optional<std::string>, std::optional<VecTokens>, std::optional<IdType>, bool>(),
|
||||
std::optional<std::string>, std::optional<VecTokens>, std::optional<IdType>, bool, tle::PriorityType,
|
||||
std::optional<tle::ContextPhaseParams>, std::optional<Tensor>, std::optional<SizeType32>>(),
|
||||
py::arg("input_token_ids"), py::arg("max_new_tokens"), py::arg("streaming") = false,
|
||||
py::arg_v("sampling_config", tle::SamplingConfig(), "SamplingConfig()"),
|
||||
py::arg_v("output_config", tle::OutputConfig(), "OutputConfig()"), py::arg("end_id") = py::none(),
|
||||
@ -259,7 +260,9 @@ void InitBindings(pybind11::module_& m)
|
||||
py::arg("prompt_tuning_config") = py::none(), py::arg("lora_config") = py::none(),
|
||||
py::arg("lookahead_config") = py::none(), py::arg("logits_post_processor_name") = py::none(),
|
||||
py::arg("encoder_input_token_ids") = py::none(), py::arg("client_id") = py::none(),
|
||||
py::arg("return_all_generated_tokens") = false)
|
||||
py::arg("return_all_generated_tokens") = false, py::arg("priority") = tle::Request::kDefaultPriority,
|
||||
py::arg("contextPhaseParams") = py::none(), py::arg("encoder_input_features") = py::none(),
|
||||
py::arg("encoder_output_length") = py::none())
|
||||
.def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds)
|
||||
.def_property_readonly("max_new_tokens", &tle::Request::getMaxNewTokens)
|
||||
.def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming)
|
||||
@ -282,7 +285,9 @@ void InitBindings(pybind11::module_& m)
|
||||
"encoder_input_token_ids", &tle::Request::getEncoderInputTokenIds, &tle::Request::setEncoderInputTokenIds)
|
||||
.def_property("client_id", &tle::Request::getClientId, &tle::Request::setClientId)
|
||||
.def_property("return_all_generated_tokens", &tle::Request::getReturnAllGeneratedTokens,
|
||||
&tle::Request::setReturnAllGeneratedTokens);
|
||||
&tle::Request::setReturnAllGeneratedTokens)
|
||||
.def_property(
|
||||
"encoder_input_features", &tle::Request::getEncoderInputFeatures, &tle::Request::setEncoderInputFeatures);
|
||||
request.attr("BATCHED_POST_PROCESSOR_NAME") = tle::Request::kBatchedPostProcessorName;
|
||||
|
||||
py::enum_<tle::FinishReason>(m, "FinishReason")
|
||||
|
||||
@ -47,6 +47,7 @@ set(SRCS
|
||||
tllmRuntime.cpp
|
||||
tllmLogger.cpp
|
||||
transformerBuffers.cpp
|
||||
workerPool.cpp
|
||||
worldConfig.cpp)
|
||||
|
||||
include_directories(${API_INCLUDE_DIR}/tensorrt_llm/runtime)
|
||||
|
||||
@ -306,6 +306,10 @@ void ExplicitDraftTokensBuffers::setFromInputs(SizeType32 numCtxSequences, SizeT
|
||||
setFromInputs<half>(numCtxSequences, numGenSequences, vocabSizePadded, seqSlots, draftBuffers,
|
||||
contextPositionIds, *explicitDraftTokensModule, stream);
|
||||
break;
|
||||
case nvinfer1::DataType::kBF16:
|
||||
setFromInputs<__nv_bfloat16>(numCtxSequences, numGenSequences, vocabSizePadded, seqSlots, draftBuffers,
|
||||
contextPositionIds, *explicitDraftTokensModule, stream);
|
||||
break;
|
||||
default:
|
||||
TLLM_THROW("DataType %d not supported in ExplicitDraftTokensBuffers", static_cast<SizeType32>(dtype));
|
||||
break;
|
||||
|
||||
@ -133,6 +133,9 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
|
||||
explicitDraftTokensParams->temperature = mSamplingConfig.temperature;
|
||||
explicitDraftTokensParams->randomDataSample = output->explicitDraftTokensBuffers->randomDataSample;
|
||||
explicitDraftTokensParams->temperatures = output->explicitDraftTokensBuffers->temperatures;
|
||||
TLLM_CHECK(requestsOpt);
|
||||
// Ignore the dtype from all other requests assuming that it is the same for all.
|
||||
explicitDraftTokensParams->dtype = requestsOpt.value()[0].dtype;
|
||||
|
||||
setupParams->decodingParams = explicitDraftTokensParams;
|
||||
}
|
||||
|
||||
@ -332,6 +332,7 @@ GptJsonConfig parseJson(InputType&& input)
|
||||
}();
|
||||
|
||||
auto modelConfig = createModelConfig(json, engineVersionNone, tensorParallelism, dataType);
|
||||
modelConfig.setModelName(name);
|
||||
|
||||
parseBuilderConfig(modelConfig, builderConfig);
|
||||
|
||||
|
||||
@ -68,6 +68,9 @@ template void invokeFill(IBuffer&, std::uint8_t, CudaStream const&);
|
||||
template void invokeFill(IBuffer&, bool, CudaStream const&);
|
||||
template void invokeFill(IBuffer&, half, CudaStream const&);
|
||||
template void invokeFill(IBuffer&, float, CudaStream const&);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeFill(IBuffer&, __nv_bfloat16, CudaStream const&);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
71
cpp/tensorrt_llm/runtime/workerPool.cpp
Normal file
71
cpp/tensorrt_llm/runtime/workerPool.cpp
Normal file
@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "workerPool.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
WorkerPool::WorkerPool(std::size_t numWorkers, std::int32_t deviceId)
|
||||
{
|
||||
for (std::size_t i = 0; i < numWorkers; ++i)
|
||||
{
|
||||
mWorkers.emplace_back(
|
||||
[this, deviceId]
|
||||
{
|
||||
if (deviceId >= 0)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(deviceId));
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("WorkerPool did not set cuda device");
|
||||
}
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->mQueueMutex);
|
||||
this->condition.wait(lock, [this] { return this->stop || !this->mTasks.empty(); });
|
||||
if (this->stop && this->mTasks.empty())
|
||||
{
|
||||
return;
|
||||
}
|
||||
task = std::move(this->mTasks.front());
|
||||
this->mTasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
WorkerPool::~WorkerPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mQueueMutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for (std::thread& worker : mWorkers)
|
||||
{
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -16,17 +16,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include <cassert>
|
||||
#include <condition_variable>
|
||||
#include <exception>
|
||||
#include <functional>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
@ -34,120 +33,53 @@ namespace tensorrt_llm::runtime
|
||||
class WorkerPool
|
||||
{
|
||||
public:
|
||||
explicit WorkerPool(std::size_t numWorkers = 1, int device = -1)
|
||||
: mNumWorkers(numWorkers)
|
||||
, mShutdown(false)
|
||||
, mDevice(device)
|
||||
{
|
||||
initThreads();
|
||||
}
|
||||
explicit WorkerPool(std::size_t numWorkers = 1, std::int32_t deviceId = -1);
|
||||
|
||||
~WorkerPool()
|
||||
{
|
||||
shutdown();
|
||||
}
|
||||
WorkerPool(WorkerPool const&) = delete;
|
||||
WorkerPool(WorkerPool&&) = delete;
|
||||
WorkerPool& operator=(WorkerPool const&) = delete;
|
||||
WorkerPool& operator=(WorkerPool&&) = delete;
|
||||
~WorkerPool();
|
||||
|
||||
template <typename Function, typename Return = std::invoke_result_t<std::decay_t<Function>>>
|
||||
std::future<Return> enqueue(Function&& task)
|
||||
template <class F>
|
||||
auto enqueue(F&& task) -> std::future<typename std::invoke_result<F>::type>
|
||||
{
|
||||
if (mShutdown)
|
||||
using returnType = typename std::invoke_result<F>::type;
|
||||
auto const taskPromise = std::make_shared<std::promise<returnType>>();
|
||||
{
|
||||
throw std::runtime_error("WorkerPool is shutdown cannot enqueue new tasks");
|
||||
std::lock_guard<std::mutex> lock(mQueueMutex);
|
||||
mTasks.push(
|
||||
[task = std::forward<F>(task), taskPromise]()
|
||||
{
|
||||
try
|
||||
{
|
||||
if constexpr (std::is_void_v<returnType>)
|
||||
{
|
||||
task();
|
||||
taskPromise->set_value();
|
||||
}
|
||||
else
|
||||
{
|
||||
taskPromise->set_value(task());
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
taskPromise->set_exception(std::current_exception());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto const taskPromise = std::make_shared<std::promise<Return>>();
|
||||
std::lock_guard<std::mutex> lock(mTasksMutex);
|
||||
mTasks.push(
|
||||
[task = std::forward<Function>(task), taskPromise]()
|
||||
{
|
||||
try
|
||||
{
|
||||
if constexpr (std::is_void_v<Return>)
|
||||
{
|
||||
task();
|
||||
taskPromise->set_value();
|
||||
}
|
||||
else
|
||||
{
|
||||
taskPromise->set_value(task());
|
||||
}
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
taskPromise->set_exception(std::current_exception());
|
||||
}
|
||||
});
|
||||
mTasksCv.notify_one();
|
||||
condition.notify_one();
|
||||
return taskPromise->get_future();
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr size_t kMaxNumWorkers = 128;
|
||||
std::size_t mNumWorkers;
|
||||
std::vector<std::thread> mWorkers;
|
||||
std::queue<std::function<void()>> mTasks;
|
||||
|
||||
std::queue<std::function<void()>> mTasks{};
|
||||
mutable std::mutex mTasksMutex;
|
||||
std::condition_variable mTasksCv;
|
||||
|
||||
std::atomic<bool> mShutdown = false;
|
||||
|
||||
std::thread mThreads[kMaxNumWorkers];
|
||||
|
||||
int mDevice{-1};
|
||||
|
||||
void shutdown()
|
||||
{
|
||||
if (mShutdown)
|
||||
{
|
||||
return;
|
||||
}
|
||||
mShutdown = true;
|
||||
mTasksCv.notify_all();
|
||||
for (std::size_t i = 0; i < mNumWorkers; ++i)
|
||||
{
|
||||
mThreads[i].join();
|
||||
}
|
||||
}
|
||||
|
||||
void initThreads()
|
||||
{
|
||||
if (mNumWorkers > kMaxNumWorkers)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"numWorker > maxNumWorkers " + std::to_string(mNumWorkers) + " > " + std::to_string(kMaxNumWorkers));
|
||||
}
|
||||
for (std::size_t i = 0; i < mNumWorkers; ++i)
|
||||
{
|
||||
mThreads[i] = std::thread(&WorkerPool::doWork, this);
|
||||
}
|
||||
}
|
||||
|
||||
void doWork()
|
||||
{
|
||||
if (mDevice >= 0)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaSetDevice(mDevice));
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_WARNING("WorkerPool did not set cuda device");
|
||||
}
|
||||
while (!mShutdown)
|
||||
{
|
||||
std::function<void()> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mTasksMutex);
|
||||
mTasksCv.wait(lock, [this]() { return !mTasks.empty() || mShutdown; });
|
||||
if (mTasks.empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
task = mTasks.front();
|
||||
mTasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
std::mutex mQueueMutex;
|
||||
std::condition_variable condition;
|
||||
bool stop{};
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -21,6 +21,10 @@
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
#if ENABLE_BF16
|
||||
#include <cuda_bf16.h>
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
|
||||
@ -135,6 +139,23 @@ void prepareRandomTensors(th::Tensor& curandState, // [maxBatchSize, 48], uint8_
|
||||
tksd::invokeFillRandData(params, stream);
|
||||
}
|
||||
break;
|
||||
#ifdef ENABLE_BF16
|
||||
case at::ScalarType::BFloat16:
|
||||
{
|
||||
tksd::FillRandDataExplicitDraftTokensParams<__nv_bfloat16> params;
|
||||
params.batchSize = static_cast<tr::SizeType32>(batchSize);
|
||||
params.numPaths = static_cast<tr::SizeType32>(numPaths);
|
||||
params.draftLength = static_cast<tr::SizeType32>(draftLength);
|
||||
params.randDataSample = get_ptr<__nv_bfloat16>(randDataSample);
|
||||
params.randDataVerification = get_ptr<__nv_bfloat16>(randDataValidation);
|
||||
params.curandState = get_ptr<curandState_t>(curandState);
|
||||
params.batchSlots = nullptr;
|
||||
params.skipVerification = initialize;
|
||||
|
||||
tksd::invokeFillRandData(params, stream);
|
||||
}
|
||||
break;
|
||||
#endif // ENABLE_BF16
|
||||
default: throw std::runtime_error("Unsupported tensor type.");
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -246,6 +246,8 @@ protected:
|
||||
tensorrt_llm::ActivationType mActType = tensorrt_llm::ActivationType::Relu;
|
||||
MOEExpertScaleNormalizationMode mNormMode = MOEExpertScaleNormalizationMode::NONE;
|
||||
|
||||
float mSparseMixerEpsilon = 0.2f;
|
||||
|
||||
// Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths
|
||||
bool mUseDeterminsiticHopperReduce = true;
|
||||
|
||||
@ -287,7 +289,7 @@ protected:
|
||||
size_t const weight_size = hidden_size * (hidden_size * 4) * num_experts * sizeof(WeightStorage) * num_gemms;
|
||||
// Workspace size
|
||||
size_t const workspace_size = this->mMoERunner.getWorkspaceSize(
|
||||
num_tokens, hidden_size, hidden_size * 4, num_experts, k, this->mActType, {}, mUseLora);
|
||||
num_tokens, hidden_size, hidden_size * 4, num_experts, k, this->mActType, mNormMode, {}, mUseLora);
|
||||
// The input/output buffers
|
||||
size_t const in_out_size = 2 * num_tokens * hidden_size * sizeof(DataType);
|
||||
|
||||
@ -332,7 +334,7 @@ protected:
|
||||
}
|
||||
|
||||
size_t workspace_size = mMoERunner.getWorkspaceSize(
|
||||
mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, parallelism_config, mUseLora);
|
||||
mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, mNormMode, parallelism_config, mUseLora);
|
||||
|
||||
auto const stream = mStream->get();
|
||||
|
||||
@ -779,7 +781,7 @@ protected:
|
||||
mMoERunner.runMoe(mInputTensor, mInputProbabilities, weight1_ptr, bias1_ptr, mActType, weight2_ptr, bias2_ptr,
|
||||
quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, mNumExperts, mK,
|
||||
mWorkspace, mFinalOutput, mFinished, mActiveRows, mScaleProbs, mSourceToExpandedMap, mSelectedExpert,
|
||||
parallelism_config, mNormMode, mUseLora, lora_params, stream);
|
||||
mSparseMixerEpsilon, parallelism_config, mNormMode, mUseLora, lora_params, stream);
|
||||
|
||||
check_cuda_error(cudaStreamSynchronize(stream));
|
||||
}
|
||||
@ -805,7 +807,7 @@ protected:
|
||||
{
|
||||
if (entry >= num_experts_per_node * tp_rank && entry < num_experts_per_node * (tp_rank + 1))
|
||||
return entry;
|
||||
return (int) mNumExperts;
|
||||
return (int) mNumExperts + entry;
|
||||
});
|
||||
return result;
|
||||
}
|
||||
@ -953,9 +955,9 @@ protected:
|
||||
return res;
|
||||
});
|
||||
|
||||
for (int64_t token = 0; token < mTotalTokens; token++)
|
||||
for (int64_t token = 0; token < softmax.size(); token += mNumExperts)
|
||||
{
|
||||
auto start = softmax.begin() + token * mNumExperts;
|
||||
auto start = softmax.begin() + token;
|
||||
auto end = start + mNumExperts;
|
||||
auto sum = std::accumulate(start, end, 0.f);
|
||||
std::transform(start, end, start, [=](auto in) { return in / sum; });
|
||||
@ -966,7 +968,7 @@ protected:
|
||||
|
||||
void renormScales(float* probs, int const* experts)
|
||||
{
|
||||
if (mNormMode == MOEExpertScaleNormalizationMode::NONE)
|
||||
if (mNormMode != MOEExpertScaleNormalizationMode::RENORMALIZE)
|
||||
return;
|
||||
float sum = 0;
|
||||
for (int k_idx = 0; k_idx < mK; k_idx++)
|
||||
@ -980,12 +982,49 @@ protected:
|
||||
}
|
||||
}
|
||||
|
||||
float sparseMixer(std::vector<float> logits, int token_idx, int k_idx, int expected_expert)
|
||||
{
|
||||
EXPECT_LE(mK, 2);
|
||||
EXPECT_LT(k_idx, mK);
|
||||
EXPECT_LT(token_idx * mNumExperts, logits.size());
|
||||
EXPECT_LE((token_idx + 1) * mNumExperts, logits.size());
|
||||
|
||||
auto start_it = logits.begin() + token_idx * mNumExperts;
|
||||
auto end_it = logits.begin() + (token_idx + 1) * mNumExperts;
|
||||
|
||||
// Mask old maxes and get the kth largest
|
||||
auto max_it = end_it;
|
||||
for (int i = 0; i <= k_idx; i++)
|
||||
{
|
||||
max_it = std::max_element(start_it, end_it);
|
||||
if (i != k_idx)
|
||||
{
|
||||
EXPECT_NE(max_it, end_it);
|
||||
*max_it = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ((max_it - start_it), expected_expert)
|
||||
<< "Expected token " << token_idx << " k_idx " << k_idx << " to select expert " << expected_expert;
|
||||
|
||||
std::vector<float> masked;
|
||||
std::transform(start_it, end_it, std::back_inserter(masked),
|
||||
[this, max_it](auto val)
|
||||
{
|
||||
float mask_value = (*max_it - val) / max(abs(val), *max_it);
|
||||
return (mask_value > 2 * mSparseMixerEpsilon) ? -INFINITY : val;
|
||||
});
|
||||
auto output_probs = softmax(masked);
|
||||
return output_probs[expected_expert];
|
||||
}
|
||||
|
||||
void compareSoftmax(std::vector<int> const& expected_experts, std::vector<float> const& expected_probs,
|
||||
std::vector<float> scale_probs = {})
|
||||
{
|
||||
if (scale_probs.empty())
|
||||
scale_probs = getDataFromDevice(mScaleProbs, mTotalTokens * mK);
|
||||
auto softmax_probs = softmax(expected_probs);
|
||||
|
||||
for (int64_t token_id = 0; token_id < mTotalTokens; token_id++)
|
||||
{
|
||||
renormScales(&softmax_probs[token_id * mNumExperts], &expected_experts[token_id * mK]);
|
||||
@ -995,8 +1034,13 @@ protected:
|
||||
int selected_expert = expected_experts[token_id * mK + k_idx];
|
||||
if (selected_expert < mNumExperts) // Ignore 'finished' values
|
||||
{
|
||||
ASSERT_NEAR(softmax_probs[token_id * mNumExperts + selected_expert],
|
||||
scale_probs[token_id * mK + k_idx], getTolerance())
|
||||
float expected_value = softmax_probs[token_id * mNumExperts + selected_expert];
|
||||
if (mNormMode == tensorrt_llm::kernels::MOEExpertScaleNormalizationMode::SPARSE_MIXER)
|
||||
{
|
||||
expected_value = sparseMixer(expected_probs, token_id, k_idx, selected_expert);
|
||||
}
|
||||
|
||||
ASSERT_NEAR(expected_value, scale_probs[token_id * mK + k_idx], getTolerance())
|
||||
<< "Scales mismatched for token: " << token_id << " k: " << k_idx
|
||||
<< " selected_expert: " << selected_expert;
|
||||
}
|
||||
@ -1022,9 +1066,16 @@ protected:
|
||||
for (int k_idx = 0; k_idx < mK; k_idx++)
|
||||
{
|
||||
int selected_expert = expected_experts[token_id * mK + k_idx];
|
||||
|
||||
float scale_value = softmax_probs[token_id * mNumExperts + selected_expert];
|
||||
if (mNormMode == tensorrt_llm::kernels::MOEExpertScaleNormalizationMode::SPARSE_MIXER)
|
||||
{
|
||||
scale_value = sparseMixer(expected_probs, token_id, k_idx, selected_expert);
|
||||
}
|
||||
|
||||
sum += float(calcMLPValWithFinalBias(
|
||||
static_cast<float>(input_data[token_id * mHiddenSize + hidden_id]), selected_expert))
|
||||
* softmax_probs[token_id * mNumExperts + selected_expert];
|
||||
* scale_value;
|
||||
}
|
||||
|
||||
ASSERT_NEAR(OutputType{sum}, final_results[token_id * mHiddenSize + hidden_id], getTolerance(sum))
|
||||
@ -1170,6 +1221,13 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteRenormalization)
|
||||
this->BasicPermuteTest(3);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, PermuteSparseMixer)
|
||||
{
|
||||
this->mNormMode = tensorrt_llm::kernels::MOEExpertScaleNormalizationMode::SPARSE_MIXER;
|
||||
this->BasicPermuteTest();
|
||||
this->BasicPermuteTest(2);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, PermuteGeglu)
|
||||
{
|
||||
this->mActType = tensorrt_llm::ActivationType::Geglu;
|
||||
@ -1228,7 +1286,7 @@ std::vector<int> MixtureOfExpertsTest<TypeParam_>::calcPermuteMapExpertParallel(
|
||||
std::vector<int> map(expected_experts.size());
|
||||
auto getInterleavedIndex = [this](int i) { return (i % mK) * mTotalTokens + i / mK; };
|
||||
int map_idx = 0;
|
||||
for (int expert = 0; expert <= mNumExperts; expert++)
|
||||
for (int expert = 0; expert < mNumExperts * 2; expert++)
|
||||
{
|
||||
for (int i = 0; i < map.size(); i++)
|
||||
{
|
||||
@ -1293,7 +1351,7 @@ void MixtureOfExpertsTest<TypeParam_>::ExpertParallelTest(int k)
|
||||
// Note the index is [0,num_experts_per_node), so we offset the experts by the start for this node
|
||||
int const start_expert = i * (mNumExperts / parallelism);
|
||||
std::transform(selected_expert.begin(), selected_expert.end(), selected_expert.begin(),
|
||||
[&](int val) { return val == mNumExperts ? mNumExperts : val + start_expert; });
|
||||
[&](int val) { return val >= mNumExperts ? val : val + start_expert; });
|
||||
auto masked_expected_experts = maskSelectedExpertsForTP(expected_experts, parallelism, i);
|
||||
ASSERT_EQ(selected_expert, masked_expected_experts);
|
||||
|
||||
@ -1336,6 +1394,13 @@ TYPED_TEST(MixtureOfExpertsTest, ExpertParallelRenorm)
|
||||
this->ExpertParallelTest(2);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, ExpertParallelSparseMixer)
|
||||
{
|
||||
this->mNormMode = MOEExpertScaleNormalizationMode::SPARSE_MIXER;
|
||||
this->ExpertParallelTest();
|
||||
this->ExpertParallelTest(2);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, ExpertParallelGeglu)
|
||||
{
|
||||
this->mActType = tensorrt_llm::ActivationType::Geglu;
|
||||
@ -1451,6 +1516,13 @@ TYPED_TEST(MixtureOfExpertsTest, TensorParallelRenorm)
|
||||
this->TensorParallelTest(3);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, TensorParallelSparseMixer)
|
||||
{
|
||||
this->mNormMode = MOEExpertScaleNormalizationMode::SPARSE_MIXER;
|
||||
this->TensorParallelTest();
|
||||
this->TensorParallelTest(2);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, TensorParallelGeglu)
|
||||
{
|
||||
this->mActType = tensorrt_llm::ActivationType::Geglu;
|
||||
@ -1523,7 +1595,7 @@ void MixtureOfExpertsTest<TypeParam_>::MixedParallelTest(int k)
|
||||
// Note the index is [0,num_experts_per_node), so we offset the experts by the start for this node
|
||||
int const start_expert = j * (mNumExperts / ep_parallelism);
|
||||
std::transform(selected_expert.begin(), selected_expert.end(), selected_expert.begin(),
|
||||
[&](int val) { return val == mNumExperts ? mNumExperts : val + start_expert; });
|
||||
[&](int val) { return val >= mNumExperts ? val : val + start_expert; });
|
||||
auto masked_expected_experts = maskSelectedExpertsForTP(expected_experts, ep_parallelism, j);
|
||||
ASSERT_EQ(selected_expert, masked_expected_experts);
|
||||
|
||||
@ -1567,6 +1639,13 @@ TYPED_TEST(MixtureOfExpertsTest, MixedParallelRenorm)
|
||||
this->MixedParallelTest(2);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, MixedParallelSparseMixer)
|
||||
{
|
||||
this->mNormMode = MOEExpertScaleNormalizationMode::SPARSE_MIXER;
|
||||
this->MixedParallelTest();
|
||||
this->MixedParallelTest(2);
|
||||
}
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, MixedParallelGeglu)
|
||||
{
|
||||
this->mActType = tensorrt_llm::ActivationType::Geglu;
|
||||
@ -1825,3 +1904,32 @@ TEST_F(MixtureOfExpertsProfilerTest, TestGeneratedProfilerDistribution)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using MixtureOfExpertsUnitTests = MixtureOfExpertsTest<WeightParams<half, half>>;
|
||||
|
||||
TEST_F(MixtureOfExpertsUnitTests, SparseMixerReferenceTest)
|
||||
{
|
||||
// Test the sparse mixer reference implementation is doing the correct thing
|
||||
// This makes sure we are testing the correct behaviour
|
||||
this->mNumExperts = 4;
|
||||
this->mK = 2;
|
||||
auto res = this->sparseMixer({1.0f, 1.0f, -INFINITY, -INFINITY}, 0, 0, 0);
|
||||
ASSERT_FLOAT_EQ(res, 0.5f);
|
||||
res = this->sparseMixer({1.0f, 1.0f, -INFINITY, -INFINITY}, 0, 1, 1);
|
||||
ASSERT_FLOAT_EQ(res, 1.0f);
|
||||
|
||||
res = this->sparseMixer({2.0f, 0.0f, -INFINITY, -INFINITY}, 0, 0, 0);
|
||||
ASSERT_FLOAT_EQ(res, 1.0f);
|
||||
res = this->sparseMixer({2.0f, 0.0f, -INFINITY, -INFINITY}, 0, 1, 1);
|
||||
ASSERT_FLOAT_EQ(res, 1.0f);
|
||||
|
||||
res = this->sparseMixer({0.0f, 2.0f, -INFINITY, -INFINITY}, 0, 0, 1);
|
||||
ASSERT_FLOAT_EQ(res, 1.0f);
|
||||
res = this->sparseMixer({0.0f, 2.0f, -INFINITY, -INFINITY}, 0, 1, 0);
|
||||
ASSERT_FLOAT_EQ(res, 1.0f);
|
||||
|
||||
res = this->sparseMixer({1.0f, 1.0f, 1.0f, -INFINITY}, 0, 0, 0);
|
||||
ASSERT_FLOAT_EQ(res, 1.f / 3.f);
|
||||
res = this->sparseMixer({1.0f, 1.0f, 1.0f, -INFINITY}, 0, 1, 1);
|
||||
ASSERT_FLOAT_EQ(res, 0.5f);
|
||||
}
|
||||
|
||||
@ -489,7 +489,6 @@ protected:
|
||||
decoderParams.batchSize = batch_size;
|
||||
decoderParams.maxQSeqLength = input_seq_length;
|
||||
decoderParams.maxEncoderQSeqLength = mCrossAttention ? cross_qkv_length : 0;
|
||||
decoderParams.removePadding = mRemovePadding;
|
||||
decoderParams.attentionWindowSize = cyclic_attention_window_size;
|
||||
decoderParams.sinkTokenLength = sink_token_length;
|
||||
decoderParams.numTokens = num_tokens;
|
||||
@ -541,6 +540,7 @@ protected:
|
||||
preprocessingParams.cyclic_kv_cache_len = cyclic_attention_window_size;
|
||||
preprocessingParams.sink_token_len = sink_token_length;
|
||||
preprocessingParams.token_num = num_tokens;
|
||||
preprocessingParams.remove_padding = mRemovePadding;
|
||||
preprocessingParams.head_num = mNumHeads;
|
||||
preprocessingParams.kv_head_num = mNumKVHeads;
|
||||
preprocessingParams.qheads_per_kv_head = mNumHeads / mNumKVHeads;
|
||||
|
||||
@ -641,15 +641,16 @@ void ExplicitDraftTokensLayerTest<T>::SetUp()
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayerTest<T>::allocateBuffers()
|
||||
{
|
||||
auto const dataType = TRTDataType<T>::value;
|
||||
using DataType = typename T::DataType;
|
||||
auto const dataType = TRTDataType<DataType>::value;
|
||||
|
||||
auto speculativeDecodingModule = std::make_shared<SpeculativeDecodingModule>(mSamplingParams.getMaxDraftPathLen(),
|
||||
mSamplingParams.getMaxDecodingDraftTokens(), mSamplingParams.getMaxNumPaths());
|
||||
auto const decodingDomain = tensorrt_llm::layers::DecoderDomain(mSamplingParams.getMaxBatchSize(), 1,
|
||||
mSamplingParams.getVocabSize(), mSamplingParams.getVocabSize(), speculativeDecodingModule);
|
||||
|
||||
mExplicitDraftTokensLayer
|
||||
= std::make_shared<tensorrt_llm::layers::ExplicitDraftTokensLayer<T>>(decodingDomain, mBufferManager);
|
||||
mExplicitDraftTokensLayer = std::make_shared<tensorrt_llm::layers::ExplicitDraftTokensLayer<typename T::LayerType>>(
|
||||
decodingDomain, mBufferManager);
|
||||
|
||||
// outputs
|
||||
mOutputIds = BufferManager::pinnedPool(
|
||||
@ -828,6 +829,7 @@ void ExplicitDraftTokensLayerTest<T>::allocateBuffers()
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayerTest<T>::setup()
|
||||
{
|
||||
using DataType = typename T::DataType;
|
||||
// outputs
|
||||
trk::invokeFill(*mOutputIds, TokenIdType{-1}, *mStream);
|
||||
trk::invokeFill(*mSeqLengths, SizeType32{0}, *mStream);
|
||||
@ -835,19 +837,20 @@ void ExplicitDraftTokensLayerTest<T>::setup()
|
||||
trk::invokeFill(*mAcceptedLengthCumSum, SizeType32{-1}, *mStream);
|
||||
trk::invokeFill(*mOutputNextDraftTokens, TokenIdType{-1}, *mStream);
|
||||
trk::invokeFill(*mOutputPositionIdsBase, SizeType32{0}, *mStream);
|
||||
trk::invokeFill(*mRandomDataSample, T{0}, *mStream);
|
||||
trk::invokeFill(*mRandomDataValidation, T{0}, *mStream);
|
||||
trk::invokeFill(*mRandomDataSample, DataType{0}, *mStream);
|
||||
trk::invokeFill(*mRandomDataValidation, DataType{0}, *mStream);
|
||||
trk::invokeFill(*mPackedMasks, SizeType32{0}, *mStream);
|
||||
trk::invokeFill(*mNextPosIds, SizeType32{0}, *mStream);
|
||||
trk::invokeFill(*mOutputUnpackedNextDraftTokens, TokenIdType{-1}, *mStream);
|
||||
trk::invokeFill(*mOutputUnpackedNextDraftIndices, SizeType32{0}, *mStream);
|
||||
trk::invokeFill(*mEndIds, TokenIdType{-1}, *mStream);
|
||||
|
||||
auto inDraftProbs = BufferRange<T>(*mNextDraftProbs);
|
||||
auto inDraftProbs = BufferRange<DataType>(*mNextDraftProbs);
|
||||
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_real_distribution<float> distr(0.0, 1.0);
|
||||
std::generate(inDraftProbs.begin(), inDraftProbs.end(), [&gen, &distr]() { return static_cast<T>(distr(gen)); });
|
||||
std::generate(
|
||||
inDraftProbs.begin(), inDraftProbs.end(), [&gen, &distr]() { return static_cast<DataType>(distr(gen)); });
|
||||
|
||||
auto batchSlotsPtr = bufferCast<SizeType32>(*mBatchSlots);
|
||||
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
|
||||
@ -870,6 +873,7 @@ void ExplicitDraftTokensLayerTest<T>::setup()
|
||||
setupParams->temperature = mTemperatures;
|
||||
setupParams->randomDataSample = mRandomDataSample;
|
||||
setupParams->temperatures = mOutputTemperatures;
|
||||
setupParams->dtype = TRTDataType<DataType>::value;
|
||||
|
||||
mExplicitDraftTokensLayer->setup(mSamplingParams.getBatchSize(), 1, mBatchSlots, setupParams);
|
||||
|
||||
@ -1045,20 +1049,22 @@ std::vector<int32_t> boolArrayToBitmask(BufferRange<bool>::iterator boolIterator
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayerTest<T>::checkLayerResult()
|
||||
{
|
||||
using DataType = typename T::DataType;
|
||||
auto const batchSlots = BufferRange<SizeType32>(*mBatchSlots);
|
||||
|
||||
// Check generated random data
|
||||
{
|
||||
auto const randomDataSample = BufferRange<T>(*mRandomDataSample);
|
||||
auto const randomDataValidation = BufferRange<T>(*mRandomDataValidation);
|
||||
auto const randomDataSample = BufferRange<DataType>(*mRandomDataSample);
|
||||
auto const randomDataValidation = BufferRange<DataType>(*mRandomDataValidation);
|
||||
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
// Check that all fields are filled with non zero data
|
||||
EXPECT_NE(randomDataSample[batchSlot], T{0}) << " bi: " << bi;
|
||||
EXPECT_NE(randomDataSample[batchSlot], DataType{0}) << " bi: " << bi;
|
||||
auto const stride = mSamplingParams.getMaxNumPaths() * mSamplingParams.getMaxDraftPathLen();
|
||||
EXPECT_FALSE(std::any_of(randomDataValidation.begin() + batchSlot * stride,
|
||||
randomDataValidation.begin() + (batchSlot + 1) * stride, [](T val) { return val == T{0}; }))
|
||||
randomDataValidation.begin() + (batchSlot + 1) * stride,
|
||||
[](DataType val) { return val == DataType{0}; }))
|
||||
<< " bi: " << bi;
|
||||
}
|
||||
}
|
||||
@ -1211,8 +1217,8 @@ void ExplicitDraftTokensLayerTest<T>::checkLayerResult()
|
||||
|
||||
// Check draft probs
|
||||
{
|
||||
auto const outDraftProbs = BufferRange<T>(*mOutputDraftProbs);
|
||||
auto const inDraftProbs = BufferRange<T>(*mNextDraftProbs);
|
||||
auto const outDraftProbs = BufferRange<DataType>(*mOutputDraftProbs);
|
||||
auto const inDraftProbs = BufferRange<DataType>(*mNextDraftProbs);
|
||||
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
@ -1239,7 +1245,8 @@ void ExplicitDraftTokensLayerTest<T>::checkLayerResult()
|
||||
for (SizeType32 bi = 0; bi < mSamplingParams.getBatchSize(); ++bi)
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
EXPECT_EQ(BufferRange<T>(*mOutputTemperatures)[batchSlot], static_cast<T>(1.f / mTemperatures[bi]))
|
||||
EXPECT_EQ(
|
||||
BufferRange<DataType>(*mOutputTemperatures)[batchSlot], static_cast<DataType>(1.f / mTemperatures[bi]))
|
||||
<< " bi: " << bi;
|
||||
}
|
||||
}
|
||||
@ -1248,7 +1255,8 @@ void ExplicitDraftTokensLayerTest<T>::checkLayerResult()
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayerTest<T>::packData()
|
||||
{
|
||||
tksd::PackExplicitDraftTokensParams<T> params;
|
||||
using DataType = typename T::DataType;
|
||||
tksd::PackExplicitDraftTokensParams<DataType> params;
|
||||
params.batchSlots = bufferCast<SizeType32>(*mBatchSlots);
|
||||
params.cumSumGenerationLengths = bufferCast<SizeType32>(*mCumSumGenerationLengths);
|
||||
params.maxGenerationLength = bufferCast<SizeType32>(*mMaxGenerationLength);
|
||||
@ -1259,11 +1267,11 @@ void ExplicitDraftTokensLayerTest<T>::packData()
|
||||
params.outputGenerationLengths = bufferCast<SizeType32>(*mPackedGenerationLengths);
|
||||
params.inputGenerationLengths = bufferCast<SizeType32>(*mSpecDecodingGenerationLengths);
|
||||
|
||||
params.outputRandomDataSample = bufferCast<T>(*mPackedRandomDataSample);
|
||||
params.inputRandomDataSample = bufferCast<T>(*mRandomDataSample);
|
||||
params.outputRandomDataSample = bufferCast<DataType>(*mPackedRandomDataSample);
|
||||
params.inputRandomDataSample = bufferCast<DataType>(*mRandomDataSample);
|
||||
|
||||
params.outputRandomDataValidation = bufferCast<T>(*mPackedRandomDataVerification);
|
||||
params.inputRandomDataValidation = bufferCast<T>(*mRandomDataValidation);
|
||||
params.outputRandomDataValidation = bufferCast<DataType>(*mPackedRandomDataVerification);
|
||||
params.inputRandomDataValidation = bufferCast<DataType>(*mRandomDataValidation);
|
||||
|
||||
params.outputNextDraftTokens = bufferCast<TokenIdType>(*mPackedNextDraftTokens);
|
||||
params.inputNextDraftTokens = bufferCast<TokenIdType>(*mOutputUnpackedNextDraftTokens);
|
||||
@ -1278,11 +1286,11 @@ void ExplicitDraftTokensLayerTest<T>::packData()
|
||||
params.outputPositionOffsets = bufferCast<SizeType32>(*mPackedPositionOffsets);
|
||||
params.outputPositionIds = bufferCast<SizeType32>(*mPackedPackedPosIds);
|
||||
|
||||
params.outputDraftProbs = bufferCast<T>(*mPackedDraftProbs);
|
||||
params.inputDraftProbs = bufferCast<T>(*mOutputDraftProbs);
|
||||
params.outputDraftProbs = bufferCast<DataType>(*mPackedDraftProbs);
|
||||
params.inputDraftProbs = bufferCast<DataType>(*mOutputDraftProbs);
|
||||
|
||||
params.outputTemperatures = bufferCast<T>(*mPackedTemperatures);
|
||||
params.inputTemperatures = bufferCast<T>(*mOutputTemperatures);
|
||||
params.outputTemperatures = bufferCast<DataType>(*mPackedTemperatures);
|
||||
params.inputTemperatures = bufferCast<DataType>(*mOutputTemperatures);
|
||||
|
||||
params.batchSize = mSamplingParams.getBatchSize();
|
||||
params.numPaths = mSamplingParams.getMaxNumPaths();
|
||||
@ -1313,6 +1321,7 @@ void ExplicitDraftTokensLayerTest<T>::packData()
|
||||
template <typename T>
|
||||
void ExplicitDraftTokensLayerTest<T>::checkPackResult()
|
||||
{
|
||||
using DataType = typename T::DataType;
|
||||
auto const batchSlots = BufferRange<SizeType32>(*mBatchSlots);
|
||||
auto const maxGenLength = mNetwork.getMaxNextGenerationLength();
|
||||
auto const numPackedMasks = static_cast<SizeType32>(divUp(mSamplingParams.getMaxDecodingTokens(), 32));
|
||||
@ -1325,22 +1334,24 @@ void ExplicitDraftTokensLayerTest<T>::checkPackResult()
|
||||
EXPECT_EQ(BufferRange<SizeType32>(*mPackedGenerationLengths)[bi],
|
||||
BufferRange<SizeType32>(*mSpecDecodingGenerationLengths)[batchSlot])
|
||||
<< "bi: " << bi;
|
||||
EXPECT_EQ(BufferRange<T>(*mPackedRandomDataSample)[bi], BufferRange<T>(*mRandomDataSample)[batchSlot])
|
||||
EXPECT_EQ(
|
||||
BufferRange<DataType>(*mPackedRandomDataSample)[bi], BufferRange<DataType>(*mRandomDataSample)[batchSlot])
|
||||
<< "bi: " << bi;
|
||||
EXPECT_EQ(BufferRange<T>(*mPackedTemperatures)[bi], BufferRange<T>(*mOutputTemperatures)[batchSlot])
|
||||
EXPECT_EQ(
|
||||
BufferRange<DataType>(*mPackedTemperatures)[bi], BufferRange<DataType>(*mOutputTemperatures)[batchSlot])
|
||||
<< "bi: " << bi;
|
||||
|
||||
for (SizeType32 pi = 0; pi < mSamplingParams.getMaxNumPaths(); ++pi)
|
||||
{
|
||||
for (SizeType32 ti = 0; ti < mSamplingParams.getMaxDraftPathLen(); ++ti)
|
||||
{
|
||||
EXPECT_EQ(bufferCast<T>(*ITensor::at(mPackedRandomDataVerification, {bi, pi, ti}))[0],
|
||||
bufferCast<T>(*ITensor::at(mRandomDataValidation, {batchSlot, pi, ti}))[0])
|
||||
EXPECT_EQ(bufferCast<DataType>(*ITensor::at(mPackedRandomDataVerification, {bi, pi, ti}))[0],
|
||||
bufferCast<DataType>(*ITensor::at(mRandomDataValidation, {batchSlot, pi, ti}))[0])
|
||||
<< "bi: " << bi << " pi: " << pi << " ti: " << ti;
|
||||
for (SizeType32 vi = 0; vi < mSamplingParams.getVocabSize(); ++vi)
|
||||
{
|
||||
EXPECT_EQ(bufferCast<T>(*ITensor::at(mPackedDraftProbs, {bi, pi, ti, vi}))[0],
|
||||
bufferCast<T>(*ITensor::at(mOutputDraftProbs, {batchSlot, pi, ti, vi}))[0])
|
||||
EXPECT_EQ(bufferCast<DataType>(*ITensor::at(mPackedDraftProbs, {bi, pi, ti, vi}))[0],
|
||||
bufferCast<DataType>(*ITensor::at(mOutputDraftProbs, {batchSlot, pi, ti, vi}))[0])
|
||||
<< "bi: " << bi << " pi: " << pi << " ti: " << ti << " vi: " << vi;
|
||||
}
|
||||
}
|
||||
@ -1408,10 +1419,13 @@ void ExplicitDraftTokensLayerTest<T>::runTest(std::vector<std::string> const& pr
|
||||
checkPackResult();
|
||||
}
|
||||
|
||||
template class ExplicitDraftTokensLayerTest<float>;
|
||||
template class ExplicitDraftTokensLayerTest<half>;
|
||||
template class ExplicitDraftTokensLayerTest<TypePair<float, float>>;
|
||||
template class ExplicitDraftTokensLayerTest<TypePair<half, half>>;
|
||||
#ifdef ENABLE_BF16
|
||||
template class ExplicitDraftTokensLayerTest<TypePair<half, __nv_bfloat16>>;
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
TYPED_TEST_SUITE(ExplicitDraftTokensLayerTest, FloatAndHalfTypes);
|
||||
TYPED_TEST_SUITE(ExplicitDraftTokensLayerTest, TestTypes);
|
||||
|
||||
TYPED_TEST(ExplicitDraftTokensLayerTest, SimpleTestBS1)
|
||||
{
|
||||
@ -1556,15 +1570,25 @@ protected:
|
||||
|
||||
void TearDown() override {}
|
||||
|
||||
void runTest(
|
||||
SizeType32 batchSize, SizeType32 numPaths, SizeType32 draftLength, bool skipVerification, uint64_t randomSeed)
|
||||
void runTest(SizeType32 batchSize, SizeType32 numPaths, SizeType32 draftLength, bool skipVerification,
|
||||
uint64_t randomSeed, bool batchInit)
|
||||
{
|
||||
SizeType32* batchSlotsPtr{nullptr};
|
||||
|
||||
auto curandState = mBufferManager->gpu(ITensor::makeShape({batchSize, 48}), nvinfer1::DataType::kUINT8);
|
||||
auto* curandStatePtr = reinterpret_cast<curandState_t*>(bufferCast<uint8_t>(*curandState));
|
||||
|
||||
tk::invokeCurandInitialize(curandStatePtr, batchSlotsPtr, batchSize, randomSeed, mStream->get());
|
||||
if (batchInit)
|
||||
{
|
||||
auto randomSeeds = mBufferManager->gpu(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT64);
|
||||
trk::invokeFill(*randomSeeds, static_cast<int64_t>(randomSeed), *mStream);
|
||||
auto* randomSeedsPtr = bufferCast<uint64_t>(*randomSeeds);
|
||||
tk::invokeCurandBatchInitialize(curandStatePtr, batchSlotsPtr, batchSize, randomSeedsPtr, mStream->get());
|
||||
}
|
||||
else
|
||||
{
|
||||
tk::invokeCurandInitialize(curandStatePtr, batchSlotsPtr, batchSize, randomSeed, mStream->get());
|
||||
}
|
||||
mStream->synchronize();
|
||||
|
||||
tksd::FillRandDataExplicitDraftTokensParams<T> params;
|
||||
@ -1605,7 +1629,12 @@ private:
|
||||
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
using FloatHalfBfloatTypes = testing::Types<float, half, __nv_bfloat16>;
|
||||
TYPED_TEST_SUITE(FillRandDataTest, FloatHalfBfloatTypes);
|
||||
#else
|
||||
TYPED_TEST_SUITE(FillRandDataTest, FloatAndHalfTypes);
|
||||
#endif
|
||||
|
||||
TYPED_TEST(FillRandDataTest, SimpleTest)
|
||||
{
|
||||
@ -1616,7 +1645,19 @@ TYPED_TEST(FillRandDataTest, SimpleTest)
|
||||
|
||||
uint64_t randomSeed{0};
|
||||
|
||||
this->runTest(batchSize, numPaths, draftLength, skipVerification, randomSeed);
|
||||
this->runTest(batchSize, numPaths, draftLength, skipVerification, randomSeed, false);
|
||||
}
|
||||
|
||||
TYPED_TEST(FillRandDataTest, BatchInit)
|
||||
{
|
||||
SizeType32 constexpr batchSize{3};
|
||||
SizeType32 constexpr numPaths{2};
|
||||
SizeType32 constexpr draftLength{5};
|
||||
bool constexpr skipVerification{false};
|
||||
|
||||
uint64_t randomSeed{42};
|
||||
|
||||
this->runTest(batchSize, numPaths, draftLength, skipVerification, randomSeed, true);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::tests::layers
|
||||
|
||||
@ -310,7 +310,7 @@ private:
|
||||
|
||||
std::shared_ptr<tensorrt_llm::runtime::CudaStream> mStream;
|
||||
std::shared_ptr<tensorrt_llm::runtime::BufferManager> mBufferManager;
|
||||
std::shared_ptr<tensorrt_llm::layers::ExplicitDraftTokensLayer<T>> mExplicitDraftTokensLayer;
|
||||
std::shared_ptr<tensorrt_llm::layers::ExplicitDraftTokensLayer<typename T::LayerType>> mExplicitDraftTokensLayer;
|
||||
|
||||
ExplicitDraftTokensDummyNetwork mNetwork;
|
||||
|
||||
@ -334,6 +334,17 @@ public:
|
||||
DraftLettersVec const& nextDraftLetters, DraftLettersVec const& lastDraftLetters, SamplingParams& params);
|
||||
};
|
||||
|
||||
using FloatAndHalfTypes = testing::Types<float, half>;
|
||||
template <typename T, typename U>
|
||||
struct TypePair
|
||||
{
|
||||
using LayerType = T;
|
||||
using DataType = U;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
using TestTypes = testing::Types<TypePair<float, float>, TypePair<half, half>, TypePair<half, __nv_bfloat16>>;
|
||||
#else
|
||||
using TestTypes = testing::Types<TypePair<float, float>, TypePair<half, half>>;
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
} // namespace tensorrt_llm::tests::layers
|
||||
|
||||
@ -23,7 +23,7 @@ class Arguments:
|
||||
tp: int = 1
|
||||
pp: int = 1
|
||||
|
||||
beams: int = 1
|
||||
beams: str = '1'
|
||||
gpus_per_node: int = 4
|
||||
debug: bool = False
|
||||
|
||||
@ -32,6 +32,14 @@ class Arguments:
|
||||
|
||||
max_new_tokens: int = 64
|
||||
|
||||
@property
|
||||
def beams_tuple(self):
|
||||
return eval(f'tuple([{self.beams}])')
|
||||
|
||||
@property
|
||||
def max_beam(self):
|
||||
return max(self.beams_tuple)
|
||||
|
||||
@property
|
||||
def ckpt(self):
|
||||
return self.hf_repo_name.split('/')[-1]
|
||||
@ -130,7 +138,7 @@ class Build(RunCMDMixin):
|
||||
f'--paged_kv_cache disable',
|
||||
f'--moe_plugin disable',
|
||||
f'--enable_xqa disable',
|
||||
f'--max_beam_width {args.beams}',
|
||||
f'--max_beam_width {args.max_beam}',
|
||||
f'--max_batch_size 8',
|
||||
f'--max_input_len 512',
|
||||
f'--gemm_plugin {args.dtype}',
|
||||
@ -146,7 +154,7 @@ class Build(RunCMDMixin):
|
||||
f'--paged_kv_cache enable',
|
||||
f'--moe_plugin disable',
|
||||
f'--enable_xqa disable',
|
||||
f'--max_beam_width {args.beams}',
|
||||
f'--max_beam_width {args.max_beam}',
|
||||
f'--max_batch_size 8',
|
||||
f'--max_seq_len 201',
|
||||
f'--max_encoder_input_len 512',
|
||||
|
||||
@ -6,18 +6,22 @@ class Run(RunCMDMixin):
|
||||
def command(self):
|
||||
args = self.args
|
||||
world_size = args.tp * args.pp
|
||||
mpi_run = f'mpirun --allow-run-as-root -np {world_size} ' if world_size > 1 else ''
|
||||
ret = (
|
||||
f'python3 examples/enc_dec/run.py --engine_dir {args.engines_dir}',
|
||||
f'--engine_name {args.ckpt}',
|
||||
f'--model_name "{args.hf_models_dir}"',
|
||||
f'--max_new_tokens={args.max_new_tokens}',
|
||||
f'--num_beams={args.beams}',
|
||||
f'--compare_hf_fp32',
|
||||
f'--output_npy={args.data_dir}',
|
||||
"--debug_mode" if args.debug else "",
|
||||
)
|
||||
ret = mpi_run + ' '.join(ret)
|
||||
mpi_run = f'mpirun --allow-run-as-root -np {world_size}' if world_size > 1 else ''
|
||||
ret = []
|
||||
for beam in args.beams_tuple:
|
||||
ret.append((
|
||||
mpi_run,
|
||||
f'python3 examples/enc_dec/run.py --engine_dir {args.engines_dir}',
|
||||
f'--engine_name {args.ckpt}',
|
||||
f'--model_name "{args.hf_models_dir}"',
|
||||
f'--max_new_tokens={args.max_new_tokens}',
|
||||
f'--num_beams={beam}',
|
||||
f'--compare_hf_fp32',
|
||||
f'--output_npy={args.data_dir}',
|
||||
"--debug_mode" if args.debug else "",
|
||||
))
|
||||
ret = [' '.join(x) for x in ret]
|
||||
ret = ' && '.join(ret)
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
@ -395,17 +395,20 @@ def prepare_model_tests(model_name: str,
|
||||
|
||||
model_env = {**_os.environ, "PYTHONPATH": f"examples/{model_name}"}
|
||||
enc_dec_model_name_arg = []
|
||||
beams_arg = []
|
||||
if model_name in ('bart', 't5'):
|
||||
enc_dec_model_name_arg = [
|
||||
'--hf_repo_name',
|
||||
'facebook/bart-large-cnn' if model_name == 'bart' else 't5-small'
|
||||
]
|
||||
if model_name == 't5' and (not only_multi_gpu_arg):
|
||||
beams_arg = ['--beams', '1,2']
|
||||
model_name = 'enc_dec'
|
||||
|
||||
build_engines = [
|
||||
python_exe,
|
||||
str(scripts_dir / f"build_{model_name}_engines.py")
|
||||
] + model_cache_arg + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg
|
||||
] + model_cache_arg + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg + beams_arg
|
||||
|
||||
if model_name in ['gpt']:
|
||||
build_engines += ['--clean']
|
||||
@ -418,6 +421,7 @@ def prepare_model_tests(model_name: str,
|
||||
] + only_fp8_arg + only_multi_gpu_arg + enc_dec_model_name_arg
|
||||
if "enc_dec" in model_name:
|
||||
generate_expected_output += model_cache_arg
|
||||
generate_expected_output += beams_arg
|
||||
|
||||
if model_name in ['gpt']:
|
||||
generate_expected_output += ['--clean']
|
||||
@ -512,6 +516,7 @@ def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
included_tests.append("BartBasicTest")
|
||||
if run_t5:
|
||||
included_tests.append("T5BasicTest")
|
||||
included_tests.append("T5Beam2Test")
|
||||
if run_redrafter:
|
||||
included_tests.append("ExplicitDraftTokens")
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "tensorrt_llm/runtime/workerPool.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <random>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
@ -26,13 +27,13 @@ TEST(WorkerPool, basic)
|
||||
WorkerPool pool(2);
|
||||
|
||||
auto fn = []() { return 12345; };
|
||||
auto resultFuture = pool.enqueue<std::function<int()>, int>(std::move(fn));
|
||||
auto resultFuture = pool.enqueue(fn);
|
||||
|
||||
auto fn2 = []() { return 12.345f; };
|
||||
auto f2 = pool.enqueue<std::function<float()>, float>(std::move(fn2));
|
||||
auto f2 = pool.enqueue(fn2);
|
||||
|
||||
auto fn3 = []() { return 40.78f; };
|
||||
auto f3 = pool.enqueue<std::function<float()>, float>(std::move(fn3));
|
||||
auto f3 = pool.enqueue(fn3);
|
||||
|
||||
auto r1 = resultFuture.get();
|
||||
auto r2 = f2.get();
|
||||
@ -68,4 +69,46 @@ TEST(WorkerPool, voidReturn)
|
||||
EXPECT_EQ(returnVal2, 10002);
|
||||
EXPECT_EQ(returnVal3, 10003);
|
||||
}
|
||||
|
||||
class WorkerPoolTest : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
{
|
||||
protected:
|
||||
void SetUp() override
|
||||
{
|
||||
mNumTasks = std::get<0>(GetParam());
|
||||
mNumWorkers = std::get<1>(GetParam());
|
||||
pool = std::make_unique<WorkerPool>(mNumWorkers);
|
||||
}
|
||||
|
||||
int mNumTasks;
|
||||
int mNumWorkers;
|
||||
std::unique_ptr<WorkerPool> pool;
|
||||
};
|
||||
|
||||
TEST_P(WorkerPoolTest, ScheduleTasks)
|
||||
{
|
||||
std::vector<std::future<void>> futures;
|
||||
std::random_device randomDevice;
|
||||
std::mt19937 generator(randomDevice());
|
||||
std::uniform_int_distribution<> distribution(1, 5);
|
||||
|
||||
for (int i = 0; i < mNumTasks; ++i)
|
||||
{
|
||||
futures.push_back(
|
||||
pool->enqueue([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(distribution(generator))); }));
|
||||
}
|
||||
|
||||
for (auto& f : futures)
|
||||
{
|
||||
f.get();
|
||||
}
|
||||
|
||||
// This is a smoke test to try and catch threading and synchronization issues by stress testing. No assertion.
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(WorkerPoolTests, WorkerPoolTest,
|
||||
::testing::Combine(::testing::Values(1, 2, 4, 8, 16, 32, 64, 128), // Range for number of tasks
|
||||
::testing::Values(1, 2, 4, 8, 16, 32, 64, 128) // Range for number of workers
|
||||
));
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@ -18,15 +20,21 @@ def generate_examples():
|
||||
|
||||
# Source paths
|
||||
script_dir = root_dir / "examples/high-level-api"
|
||||
script_paths = sorted(script_dir.glob("*.py"))
|
||||
script_paths = sorted(
|
||||
script_dir.glob("*.py"),
|
||||
# The autoPP example should be at the end since it is a preview example
|
||||
key=lambda x: math.inf if 'llm_auto_parallel' in x.stem else 0)
|
||||
|
||||
# Destination paths
|
||||
doc_dir = root_dir / "docs/source/high-level-api-examples"
|
||||
doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths]
|
||||
|
||||
black_list = {'__init__.py', 'quickstart_example.py'}
|
||||
|
||||
# Generate the example docs for each example script
|
||||
for script_path, doc_path in zip(script_paths, doc_paths):
|
||||
if script_path.name == '__init__.py':
|
||||
if script_path.name in black_list:
|
||||
logging.warning(f"Skipping HLAPI file: {script_path.name}")
|
||||
continue
|
||||
script_url = f"https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/high-level-api/{script_path.name}"
|
||||
|
||||
|
||||
@ -32,6 +32,11 @@ We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-L
|
||||
|
||||
2. Install [CUDA 12.5.1 Toolkit](https://developer.nvidia.com/cuda-12-5-1-download-archive?target_os=Windows&target_arch=x86_64). Use the Express Installation option. Installation may require a restart.
|
||||
|
||||
3. If using conda environment, run the following command before installing TensorRT-LLM.
|
||||
```bash
|
||||
conda install -c conda-forge pyarrow
|
||||
```
|
||||
|
||||
|
||||
**Steps**
|
||||
|
||||
@ -58,3 +63,11 @@ We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-L
|
||||
|
||||
2. Build the model.
|
||||
3. Deploy the model.
|
||||
|
||||
**Known Issue**
|
||||
|
||||
1. `OSError: exception: access violation reading 0x0000000000000000` during `import tensorrt_llm` or `trtllm-build`.
|
||||
|
||||
This may be caused by an outdated Microsoft Visual C++ Redistributable Version. Please install
|
||||
[the latest MSVC](https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170#latest-microsoft-visual-c-redistributable-version)
|
||||
and retry. Check the system path to make sure the latest version installed in `System32` is searched first. Check dependencies to make sure no other packages are using an outdated version (e.g. package `pyarrow` might contain an outdated MSCV DLL).
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 381 KiB |
BIN
docs/source/media/picture-08-20-2024.png
Normal file
BIN
docs/source/media/picture-08-20-2024.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 96 KiB |
@ -129,30 +129,13 @@ curl -X POST localhost:8000/v2/models/ensemble/generate -d \
|
||||
We are working on a Python high-level API(HLAPI) for LLM workflow, which is still in incubation and may change later.
|
||||
Here we show you a preview of how it works and how to use it.
|
||||
|
||||
Note that the APIs are not stable and only support the few models. We appreciate your patience and understanding as we improve this API.
|
||||
Note that the APIs are not stable and only support the [few models](https://nvidia.github.io/TensorRT-LLM/high-level-api-examples/introduction.html#hlapi-supported-model). We appreciate your patience and understanding as we improve this API.
|
||||
|
||||
Here is a simple example to show how to use the HLAPI with TinyLlama.
|
||||
```python
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../examples/high-level-api/quickstart_example.py
|
||||
:language: python
|
||||
:linenos:
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.15.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
protobuf
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -276,7 +276,9 @@ if __name__ == "__main__":
|
||||
tokenizer.pad_token_id).sum(dim=1)
|
||||
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
|
||||
dim=1) - decoder_input_lengths
|
||||
print("--------------------------------------")
|
||||
print(
|
||||
f"------ HF beam = {args.num_beams} --------------------------------"
|
||||
)
|
||||
print("HF output_ids: ", output_ids)
|
||||
print("HF output text: ", hf_output_text)
|
||||
print("HF output generated lengths: ", output_gen_lengths)
|
||||
@ -315,7 +317,9 @@ if __name__ == "__main__":
|
||||
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
|
||||
dim=1) - decoder_input_lengths
|
||||
|
||||
print("--------------------------------------")
|
||||
print(
|
||||
f"------ TRT-LLM beam = {args.num_beams} --------------------------------"
|
||||
)
|
||||
if 'encoder_output' in tllm_output:
|
||||
encoder_output = tllm_output['encoder_output']
|
||||
print_tensor('TRT-LLM encoder_output:', encoder_output)
|
||||
@ -351,7 +355,7 @@ if __name__ == "__main__":
|
||||
save_npy(encoder_output,
|
||||
'encoder_output') # [num_tokens, hidden_size]
|
||||
save_npy(
|
||||
output_ids, 'output_ids'
|
||||
output_ids, f'output_ids_beam{args.num_beams}'
|
||||
) # [batch_size, max_output_tokens], max_output_tokens = decoder_input_tokens + max_new_tokens
|
||||
|
||||
# simple accuracy check
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
transformers>=4.31.0
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
|
||||
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
|
||||
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
flax~=0.8.0
|
||||
# jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
|
||||
jax~=0.4.19; platform_system == "Windows"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
rouge_score~=0.1.2
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
35
examples/high-level-api/llm_auto_parallel.py
Normal file
35
examples/high-level-api/llm_auto_parallel.py
Normal file
@ -0,0 +1,35 @@
|
||||
### Automatic Parallelism with LLM
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
|
||||
# Enable auto parallelism
|
||||
auto_parallel=True,
|
||||
world_size=2)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
for output in llm.generate(prompts, sampling_params):
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
|
||||
# Got output like
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The president of the United States is', Generated text: 'likely to nominate a new Supreme Court justice to fill the seat vacated by the death of Antonin Scalia. The Senate should vote to confirm the'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,250 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import click
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.hlapi import KvCacheConfig
|
||||
from tensorrt_llm.hlapi.llm import SamplingParams
|
||||
from tensorrt_llm.hlapi.llm_utils import KvCacheConfig, QuantAlgo, QuantConfig
|
||||
from tensorrt_llm.hlapi.utils import get_device_count
|
||||
|
||||
# NOTE, Currently, the following examples are only available for LLaMA models.
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@click.command('run_llm_generate')
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--engine_dir',
|
||||
type=str,
|
||||
help='The directory of the engine.',
|
||||
default=None)
|
||||
@click.option('--tp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs for Tensor Parallel.')
|
||||
@click.option('--pp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs for Pipeline Parallel.')
|
||||
@click.option('--prompt_is_digit',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Whether the prompt is a list of integers.')
|
||||
def run_llm_generate(
|
||||
prompt: str,
|
||||
model_dir: str,
|
||||
engine_dir: Optional[str] = None,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1,
|
||||
prompt_is_digit: bool = False,
|
||||
end_id: int = 2,
|
||||
):
|
||||
''' Running LLM with arbitrary model formats including:
|
||||
- HF model
|
||||
- TRT-LLM checkpoint
|
||||
- TRT-LLM engine
|
||||
|
||||
It will dump the engine to `engine_dir` if specified.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be either a string or a list of integers when tokenizer is disabled.
|
||||
model_dir: The directory of the model.
|
||||
engine_dir: The directory of the engine, if specified different than model_dir then it will save the engine to `engine_dir`.
|
||||
tp_size: The number of GPUs for Tensor Parallel.
|
||||
pp_size: The number of GPUs for Pipeline Parallel.
|
||||
'''
|
||||
|
||||
# Avoid the tp_size and pp_size setting override the ones loaded from built engine
|
||||
world_size = tp_size * pp_size
|
||||
if get_device_count() < world_size:
|
||||
print(
|
||||
"Skip the example for TP!!! Since the number of GPUs is less than required"
|
||||
)
|
||||
return
|
||||
if world_size > 1:
|
||||
print(f'Running LLM with Tensor Parallel on {tp_size} GPUs.')
|
||||
|
||||
llm = LLM(model_dir,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size)
|
||||
|
||||
if engine_dir and os.path.abspath(model_dir) != os.path.abspath(engine_dir):
|
||||
print(f"Saving engine to {engine_dir}...")
|
||||
llm.save(engine_dir)
|
||||
|
||||
prompts = parse_prompts(prompt, prompt_is_digit)
|
||||
|
||||
sampling_params = SamplingParams(end_id=end_id,
|
||||
pad_id=end_id) if prompt_is_digit else None
|
||||
|
||||
for output in llm.generate(prompts, sampling_params=sampling_params):
|
||||
print("OUTPUT:", output)
|
||||
|
||||
|
||||
@click.command('run_llm_generate_async_example')
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--streaming',
|
||||
is_flag=True,
|
||||
help='Whether to enable streaming generation.')
|
||||
@click.option('--tp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs for Tensor Parallel.')
|
||||
@click.option('--pp_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs for Pipeline Parallel.')
|
||||
def run_llm_generate_async_example(prompt: str,
|
||||
model_dir: str,
|
||||
streaming: bool = False,
|
||||
tp_size: int = 1,
|
||||
pp_size: int = 1):
|
||||
''' Running LLM generation asynchronously. '''
|
||||
|
||||
if get_device_count() < tp_size:
|
||||
print(
|
||||
"Skip the example for TP!!! Since the number of GPUs is less than required"
|
||||
)
|
||||
return
|
||||
if tp_size > 1:
|
||||
print(f'Running LLM with Tensor Parallel on {tp_size} GPUs.')
|
||||
|
||||
# Avoid the tp_size and pp_size setting override the ones loaded from built engine
|
||||
llm = LLM(model_dir,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4))
|
||||
prompts = parse_prompts(prompt, False)
|
||||
|
||||
async def task(prompt: str):
|
||||
outputs = []
|
||||
async for output in llm.generate_async(prompt, streaming=streaming):
|
||||
outputs.append(output.outputs[0].text)
|
||||
print(' '.join(outputs))
|
||||
|
||||
async def main():
|
||||
tasks = [task(prompt) for prompt in prompts]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@click.command('run_llm_with_quantization')
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--quant_type',
|
||||
type=str,
|
||||
default='int4_awq',
|
||||
help='The quantization type.')
|
||||
def run_llm_with_quantization(prompt: str, model_dir: str, quant_type: str):
|
||||
''' Running LLM with quantization.
|
||||
quant_type could be 'int4_awq' or 'fp8'.
|
||||
'''
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if not (major >= 8):
|
||||
print("Quantization currently only supported on post Ampere")
|
||||
return
|
||||
|
||||
if 'fp8' in quant_type:
|
||||
if not (major > 8):
|
||||
print("Hopper GPUs are required for fp8 quantization")
|
||||
return
|
||||
|
||||
quant_config = QuantConfig()
|
||||
if quant_type == 'int4_awq':
|
||||
quant_config.quant_algo = QuantAlgo.W4A16_AWQ
|
||||
else:
|
||||
quant_config.quant_algo = QuantAlgo.FP8
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
|
||||
llm = LLM(model_dir, quant_config=quant_config)
|
||||
prompts = parse_prompts(prompt, False)
|
||||
|
||||
for output in llm.generate(prompts):
|
||||
print(output)
|
||||
|
||||
|
||||
@click.command('run_llm_with_async_future')
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
def run_llm_with_async_future(prompt: str, model_dir: str):
|
||||
llm = LLM(model_dir,
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4))
|
||||
|
||||
prompts = parse_prompts(prompt)
|
||||
# The result of generate() is similar to a Future, it won't block the main thread, call .result() to explicitly wait for the result
|
||||
futures = [llm.generate_async(prompt) for prompt in prompts]
|
||||
for future in futures:
|
||||
# .result() is a blocking call, call it when you want to wait for the result
|
||||
output = future.result()
|
||||
print(output.outputs[0].text)
|
||||
|
||||
# Similar to .result(), there is an async version of .result(), which is .aresult(), and it works with the generate_async().
|
||||
async def task(prompt: str):
|
||||
generation = llm.generate_async(prompt, streaming=False)
|
||||
output = await generation.aresult()
|
||||
print(output.outputs[0].text)
|
||||
|
||||
async def main():
|
||||
tasks = [task(prompt) for prompt in prompts]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@click.command('run_llm_with_auto_parallel')
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--world_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of GPUs for Auto Parallel.')
|
||||
def run_llm_with_auto_parallel(prompt: str,
|
||||
model_dir: str,
|
||||
world_size: int = 1):
|
||||
''' Running LLM with auto parallel enabled. '''
|
||||
if get_device_count() < world_size:
|
||||
print(
|
||||
"Skip the example for auto parallel!!! Since the number of GPUs is less than required"
|
||||
)
|
||||
return
|
||||
if world_size > 1:
|
||||
print(f'Running LLM with Auto Parallel on {world_size} GPUs.')
|
||||
|
||||
llm = LLM(
|
||||
model_dir,
|
||||
auto_parallel=True,
|
||||
world_size=world_size,
|
||||
)
|
||||
prompts = parse_prompts(prompt)
|
||||
|
||||
for output in llm.generate(prompts):
|
||||
print(output)
|
||||
|
||||
|
||||
def parse_prompts(prompt: str, is_digit: bool = False) -> Union[str, List[int]]:
|
||||
''' Process a single prompt. '''
|
||||
if is_digit:
|
||||
return [[int(i) for i in prompt.split()]]
|
||||
else:
|
||||
return [prompt]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli.add_command(run_llm_generate)
|
||||
cli.add_command(run_llm_generate_async_example)
|
||||
cli.add_command(run_llm_with_quantization)
|
||||
cli.add_command(run_llm_with_async_future)
|
||||
cli.add_command(run_llm_with_auto_parallel)
|
||||
cli()
|
||||
32
examples/high-level-api/llm_generate.py
Normal file
32
examples/high-level-api/llm_generate.py
Normal file
@ -0,0 +1,32 @@
|
||||
### Generate text
|
||||
import tempfile
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
# Model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
|
||||
# You can save the engine to disk and load it back later, the LLM class can accept either a HF model or a TRT-LLM engine.
|
||||
llm.save(tempfile.mkdtemp())
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
for output in llm.generate(prompts, sampling_params):
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
# Got output like
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The president of the United States is', Generated text: 'likely to nominate a new Supreme Court justice to fill the seat vacated by the death of Antonin Scalia. The Senate should vote to confirm the'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
40
examples/high-level-api/llm_generate_async.py
Normal file
40
examples/high-level-api/llm_generate_async.py
Normal file
@ -0,0 +1,40 @@
|
||||
### Generate Text Asynchronously
|
||||
import asyncio
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
# model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
# Async based on Python coroutines
|
||||
async def task(prompt: str):
|
||||
output = await llm.generate_async(prompt, sampling_params)
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
tasks = [task(prompt) for prompt in prompts]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
# Got output like follows:
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The president of the United States is', Generated text: 'likely to nominate a new Supreme Court justice to fill the seat vacated by the death of Antonin Scalia. The Senate should vote to confirm the'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
59
examples/high-level-api/llm_generate_async_streaming.py
Normal file
59
examples/high-level-api/llm_generate_async_streaming.py
Normal file
@ -0,0 +1,59 @@
|
||||
### Generate Text in Streaming
|
||||
import asyncio
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
# model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
# Async based on Python coroutines
|
||||
async def task(id: int, prompt: str):
|
||||
|
||||
# streaming=True is used to enable streaming generation.
|
||||
async for output in llm.generate_async(prompt,
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
print(f"Generation for prompt-{id}: {output.outputs[0].text!r}")
|
||||
|
||||
|
||||
async def main():
|
||||
tasks = [task(id, prompt) for id, prompt in enumerate(prompts)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
# Got output like follows:
|
||||
# Generation for prompt-0: '\n'
|
||||
# Generation for prompt-3: 'an'
|
||||
# Generation for prompt-2: 'Paris'
|
||||
# Generation for prompt-1: 'likely'
|
||||
# Generation for prompt-0: '\n\n'
|
||||
# Generation for prompt-3: 'an exc'
|
||||
# Generation for prompt-2: 'Paris.'
|
||||
# Generation for prompt-1: 'likely to'
|
||||
# Generation for prompt-0: '\n\nJ'
|
||||
# Generation for prompt-3: 'an exciting'
|
||||
# Generation for prompt-2: 'Paris.'
|
||||
# Generation for prompt-1: 'likely to nomin'
|
||||
# Generation for prompt-0: '\n\nJane'
|
||||
# Generation for prompt-3: 'an exciting time'
|
||||
# Generation for prompt-1: 'likely to nominate'
|
||||
# Generation for prompt-0: '\n\nJane Smith'
|
||||
# Generation for prompt-3: 'an exciting time for'
|
||||
# Generation for prompt-1: 'likely to nominate a'
|
||||
# Generation for prompt-0: '\n\nJane Smith.'
|
||||
# Generation for prompt-3: 'an exciting time for us'
|
||||
# Generation for prompt-1: 'likely to nominate a new'
|
||||
40
examples/high-level-api/llm_generate_distributed.py
Normal file
40
examples/high-level-api/llm_generate_distributed.py
Normal file
@ -0,0 +1,40 @@
|
||||
### Distributed LLM Generation
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
# model could accept HF model name or a path to local HF model.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
# Distributed settings
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
for output in llm.generate(prompts, sampling_params):
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
# Got output like
|
||||
# Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
|
||||
# Prompt: 'The president of the United States is', Generated text: 'likely to nominate a new Supreme Court justice to fill the seat vacated by the death of Antonin Scalia. The Senate should vote to confirm the'
|
||||
# Prompt: 'The capital of France is', Generated text: 'Paris.'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an exciting time for us. We are constantly researching, developing, and improving our platform to create the most advanced and efficient model available. We are'
|
||||
|
||||
|
||||
# Due to the requirement of the underlying mpi4py, for multi-gpu, the main function must be placed inside the
|
||||
# `if __name__ == '__main__':` block.
|
||||
# Refer to https://mpi4py.readthedocs.io/en/stable/mpi4py.futures.html#mpipoolexecutor
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
51
examples/high-level-api/llm_quantization.py
Normal file
51
examples/high-level-api/llm_quantization.py
Normal file
@ -0,0 +1,51 @@
|
||||
### Generation with Quantization
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.hlapi import QuantAlgo, QuantConfig
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
post_ada = major > 8 or (major == 8 and minor >= 9)
|
||||
|
||||
quant_configs = [
|
||||
QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ),
|
||||
]
|
||||
|
||||
if post_ada:
|
||||
quant_configs.append(
|
||||
QuantConfig(quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8))
|
||||
else:
|
||||
logging.error(
|
||||
"FP8 quantization only works on post-ada GPUs, skipped in the example.")
|
||||
|
||||
for quant_config in quant_configs:
|
||||
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
# define the quantization config to trigger built-in end-to-end quantization.
|
||||
quant_config=quant_config)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
for output in llm.generate(prompts, sampling_params):
|
||||
print(
|
||||
f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
|
||||
)
|
||||
|
||||
# Got output like
|
||||
# Prompt: 'Hello, my name is', Generated text: 'Jane Smith. I am a resident of the city. Can you tell me more about the public services provided in the area?'
|
||||
# Prompt: 'The president of the United States is', Generated text: 'considered the head of state, and the vice president of the United States is considered the head of state. President and Vice President of the United States (US)'
|
||||
# Prompt: 'The capital of France is', Generated text: 'located in Paris, France. The population of Paris, France, is estimated to be 2 million. France is home to many famous artists, including Picasso'
|
||||
# Prompt: 'The future of AI is', Generated text: 'an open and collaborative project. The project is an ongoing effort, and we invite participation from members of the community.\n\nOur community is'
|
||||
@ -1,2 +1,2 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
|
||||
@ -1,127 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from llm_examples import *
|
||||
|
||||
from tensorrt_llm.hlapi.utils import print_colored
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@click.command('run_single_gpu')
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--examples_root',
|
||||
type=str,
|
||||
help='The root directory of the examples.')
|
||||
@click.option('--llm_examples',
|
||||
type=str,
|
||||
help='The path to the llm_examples.py.',
|
||||
default='llm_examples.py')
|
||||
@click.option('--engine_dir',
|
||||
type=str,
|
||||
help='The directory of the engine.',
|
||||
default="/tmp/hlapi.engine.example")
|
||||
def run_single_gpu(
|
||||
prompt: str,
|
||||
model_dir: str,
|
||||
examples_root: str,
|
||||
llm_examples: str,
|
||||
engine_dir: str,
|
||||
):
|
||||
run_example(
|
||||
"Running LLM from HuggingFace model",
|
||||
f"{sys.executable} {llm_examples} run_llm_generate --prompt=\"{prompt}\" --model_dir={model_dir} --engine_dir={engine_dir}"
|
||||
)
|
||||
|
||||
run_example(
|
||||
"Running LLM from built engine with streaming enabled",
|
||||
f"{sys.executable} {llm_examples} run_llm_generate_async_example --prompt=\"{prompt}\" --model_dir={engine_dir} --streaming"
|
||||
)
|
||||
|
||||
run_example(
|
||||
"Running LLM with async future",
|
||||
f"{sys.executable} {llm_examples} run_llm_with_async_future --prompt=\"{prompt}\" --model_dir={engine_dir}"
|
||||
)
|
||||
|
||||
|
||||
@click.command("run_multi_gpu")
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--examples_root',
|
||||
type=str,
|
||||
help='The root directory of the examples.')
|
||||
@click.option('--llm_examples',
|
||||
type=str,
|
||||
help='The path to the llm_examples.py.',
|
||||
default='llm_examples.py')
|
||||
@click.option('--engine_dir',
|
||||
type=str,
|
||||
help='The directory of the engine.',
|
||||
default="/tmp/hlapi.engine.example")
|
||||
@click.option('--run_autopp',
|
||||
type=bool,
|
||||
help='Whether to run with auto parallel.',
|
||||
default=True)
|
||||
def run_multi_gpu(
|
||||
prompt: str,
|
||||
model_dir: str,
|
||||
examples_root: str,
|
||||
llm_examples: str,
|
||||
engine_dir: str,
|
||||
run_autopp: bool = True,
|
||||
):
|
||||
run_example(
|
||||
"Running LLM from HuggingFace model with TP enabled",
|
||||
f"{sys.executable} {llm_examples} run_llm_generate --prompt=\"{prompt}\" --model_dir={model_dir} --tp_size=2 --engine_dir={engine_dir}.tp2"
|
||||
)
|
||||
|
||||
run_example(
|
||||
"Running LLM from built engine with streaming enabled and TP=2",
|
||||
f"{sys.executable} {llm_examples} run_llm_generate_async_example --prompt=\"{prompt}\" --model_dir={engine_dir}.tp2 --streaming"
|
||||
) # Loading the engine with TP=2.
|
||||
|
||||
if run_autopp:
|
||||
run_example(
|
||||
"Running LLM with auto parallel",
|
||||
f"{sys.executable} {llm_examples} run_llm_with_auto_parallel --prompt=\"{prompt}\" --model_dir={model_dir} --world_size=2"
|
||||
)
|
||||
|
||||
|
||||
@click.command("run_quant")
|
||||
@click.option('--prompt', type=str, default="What is LLM?")
|
||||
@click.option('--model_dir', type=str, help='The directory of the model.')
|
||||
@click.option('--examples_root',
|
||||
type=str,
|
||||
help='The root directory of the examples.')
|
||||
@click.option('--llm_examples',
|
||||
type=str,
|
||||
help='The path to the llm_examples.py.',
|
||||
default='llm_examples.py')
|
||||
def run_quant(
|
||||
prompt: str,
|
||||
model_dir: str,
|
||||
examples_root: str,
|
||||
llm_examples: str,
|
||||
):
|
||||
run_example(
|
||||
"Running LLM with quantization",
|
||||
f"{sys.executable} {llm_examples} run_llm_with_quantization --quant_type=int4_awq --prompt=\"{prompt}\" --model_dir={model_dir}"
|
||||
)
|
||||
|
||||
|
||||
def run_example(hint: str, command: str):
|
||||
print_colored(hint + "\n", "bold_green")
|
||||
print(command)
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli.add_command(run_single_gpu)
|
||||
cli.add_command(run_multi_gpu)
|
||||
cli.add_command(run_quant)
|
||||
cli()
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets==2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
transformers>=4.39.0
|
||||
datasets~=2.14.5
|
||||
evaluate
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
|
||||
@ -126,6 +126,7 @@ of the different top-k values.
|
||||
|
||||
- 0 (NONE) corresponds to: `scales = topk(softmax(routing values))`
|
||||
- 1 (RENORM) corresponds to: `scales = softmax(topk(routing values))`
|
||||
- 2 (SPARSE_MIXER) corresponds to: `scales = sparsemixer(routing values)`
|
||||
|
||||
Mixtral uses `RENORM` mode, this is set as the default. To use a different mode use the `--moe_normalization_mode` flag.
|
||||
See [tensorrt_llm/layers/moe.py](../../tensorrt_llm/layers/moe.py#L56) for available values
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
transformers==4.38.2
|
||||
accelerate==0.25.0
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.13.0.dev2024082000
|
||||
tensorrt_llm==0.13.0.dev2024082700
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user