Update TensorRT-LLM (#2016)

This commit is contained in:
Kaiyu Xie 2024-07-24 19:50:28 +08:00 committed by GitHub
parent 0d5ffae9a7
commit 5fa9436e17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
71 changed files with 400 additions and 375 deletions

View File

@ -21,7 +21,7 @@ TensorRT-LLM
🦙 400 tok/s - per node 🦙 400 tok/s - per node
🦙 37 tok/s - per user 🦙 37 tok/s - per user
🦙 1 node inference 🦙 1 node inference
➡️ [link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/?ncid=so-twit-317976%E2%9C%A8) ➡️ [link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms)
<div align="center"> <div align="center">
<img src="docs/source/media/picture-07-23-2024.png" width="45%"> <img src="docs/source/media/picture-07-23-2024.png" width="45%">
<div align="left"> <div align="left">

View File

@ -86,6 +86,7 @@ auto constexpr kLoraWeights = "lora_weights";
// "moe_4h_to_h": 14 # for mixtral adapter for expert mlp layer: down projection // "moe_4h_to_h": 14 # for mixtral adapter for expert mlp layer: down projection
// "moe_gate": 15 # for mixtral adapter for expert mlp layer: gate // "moe_gate": 15 # for mixtral adapter for expert mlp layer: gate
// "moe_router": 16 # for mixtral adapter for expert router layer // "moe_router": 16 # for mixtral adapter for expert router layer
// "mlp_router": 17 # for qwen2-moe adapter for shared expert gate layer
// //
// last dim holds [ module_id, layer_idx, adapter_size (D / R value) ] // last dim holds [ module_id, layer_idx, adapter_size (D / R value) ]
auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3] auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3]

View File

@ -48,6 +48,7 @@ public:
kMOE_4H_TO_H = 14, kMOE_4H_TO_H = 14,
kMOE_GATE = 15, kMOE_GATE = 15,
kMOE_ROUTER = 16, kMOE_ROUTER = 16,
kMLP_ROUTER = 17,
}; };
explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst, explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
@ -216,6 +217,8 @@ public:
return ModuleType::kMOE_GATE; return ModuleType::kMOE_GATE;
else if (name == "moe_router") else if (name == "moe_router")
return ModuleType::kMOE_ROUTER; return ModuleType::kMOE_ROUTER;
else if (name == "mlp_router")
return ModuleType::kMLP_ROUTER;
else else
return ModuleType::kINVALID; return ModuleType::kINVALID;
} }
@ -241,6 +244,7 @@ public:
case ModuleType::kMOE_4H_TO_H: return "moe_4h_to_h"; case ModuleType::kMOE_4H_TO_H: return "moe_4h_to_h";
case ModuleType::kMOE_GATE: return "moe_gate"; case ModuleType::kMOE_GATE: return "moe_gate";
case ModuleType::kMOE_ROUTER: return "moe_router"; case ModuleType::kMOE_ROUTER: return "moe_router";
case ModuleType::kMLP_ROUTER: return "mlp_router";
case ModuleType::kINVALID: return "INVALID"; case ModuleType::kINVALID: return "INVALID";
} }
return "INVALID"; return "INVALID";

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:0f132408402aeb54b82891673aa3050811d69ec264399ed5f8d4f7a5cc63e2d8 oid sha256:3e25541cdc2aaa48f6a6e4c386d22ca1832c8e120fc6e8c190db4ee066ebfb1f
size 4293074 size 4293186

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:9f42dae7c82f4c59dc973f2e9f72d41d6f2e0e68b04c12d14f095b647890af86 oid sha256:3108cd0580f6328bd46238ef708872d9d8030a9c8645b8b52bc750dfe094bc16
size 4395714 size 4395794

View File

@ -1,3 +1,3 @@
f2252f27a20618d3b7abe865c5192045 libtensorrt_llm_batch_manager_static.a 50a839e98b31729198870fc99ef2c5a9 libtensorrt_llm_batch_manager_static.a
ce8405cc0d369bf4fd79d30eef5ad9ed libtensorrt_llm_batch_manager_static.pre_cxx11.a a39a5bf618c8514725b59aac4513223f libtensorrt_llm_batch_manager_static.pre_cxx11.a
3706e7395b9b58994412617992727c8ff2d14c9f commit 3511a2653f2ba73f6f827aca6d2850b3d3e8e543 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:b875b0b7c85fd18492865f6db704be09c55e823d322b65e4af58359d0576ad0a oid sha256:9600435f1b9ab74c752d1831e1a6684a004927c84ab7c61fc076dbc128ca1521
size 4154538 size 4154674

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:c597863e8910b4ef0be61961598653813c0673949dd59b8938a1d6f231ad878e oid sha256:8145ecf59dea64448ca0969553d32bc99e119cc5fc703e7b47eccfb5886594a0
size 4133066 size 4133178

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:2e06ed93f0745bc9196414f36de6ff1d98069110027e1dc95530b2a9be82176e oid sha256:f89f551a880f4c6c1e68ed72b951ac482dec6033e55a336a0ecc401f4e9cf150
size 24008762 size 24009160

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:510394b3e137e08b7292e68e445ccb9ed6986748e639b032413ad56f265078cb oid sha256:33f259b374a02456f2b8d44571d92195b708c2011be4ecabe46267f49ca24c29
size 1426724 size 1426724

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:a4a76bfb7611d6a7ef3c8e4a9e191eb4235739b5cf5e2642ac102c03e87c7e44 oid sha256:f44786aee0842bdb260de49b734d2119a0521c650f0b733f5ce6f997e72bfb34
size 1452984 size 1452984

View File

@ -1,3 +1,3 @@
93f42e0f10a6efb28073513b8a9c4471 libtensorrt_llm_executor_static.a 0d5e559ebc885794ab9e63086ae7a18a libtensorrt_llm_executor_static.a
533416c32056580e0e21ac5f771f3371 libtensorrt_llm_executor_static.pre_cxx11.a f9a3d1bf32f33f88569d4d8635e5445a libtensorrt_llm_executor_static.pre_cxx11.a
3706e7395b9b58994412617992727c8ff2d14c9f commit 3511a2653f2ba73f6f827aca6d2850b3d3e8e543 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:378f06d390a108fc8fcdf34bd390e70ed98102c0b36647e292d49a9f680867a6 oid sha256:19bd908d16990cd11a295fcb71403e2ad285dc2c3b84d55228166d9240acd0d9
size 1476318 size 1476318

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:d448b5a52066c61d73bea45b08561ecdfbb5aa49d46dd255fb714e7e0aa0ab41 oid sha256:bed0b93d23eef43ce46c01e694f9e578c64fe9b30e1b05d65b7feed1a41e5148
size 1408208 size 1408208

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:4c4287066210a2511a5de3984be9d97318aa34fc0cc16c685deb342778d4f777 oid sha256:473c672353cb813af9ea65250bd79f61f5ea27c369c9f35bc3bace1e22c5e9bb
size 14325956 size 14325956

View File

@ -1,2 +1,2 @@
28ead889239ca8d558c1e1a93f0485b0 libtensorrt_llm_nvrtc_wrapper.so 28ead889239ca8d558c1e1a93f0485b0 libtensorrt_llm_nvrtc_wrapper.so
3706e7395b9b58994412617992727c8ff2d14c9f commit 3511a2653f2ba73f6f827aca6d2850b3d3e8e543 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:c116381592aea6404e15ace64a69425b35e59492c074920f867370c280c6ea93 oid sha256:20824706210bf184641c92fcb728ab0a3a74a36bc0b13e243c713a84c74a51ac
size 1089536 size 1089536

View File

@ -60,7 +60,7 @@ enum class RotaryScalingType : int8_t
kLINEAR = 1, kLINEAR = 1,
kDYNAMIC = 2, kDYNAMIC = 2,
kLONG = 3, kLONG = 3,
kWAVELEN = 4 kLLAMA3 = 4
}; };
struct BlockSparseParams struct BlockSparseParams

View File

@ -58,7 +58,8 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
case ModuleType::kMOE_GATE: case ModuleType::kMOE_GATE:
case ModuleType::kMOE_4H_TO_H: case ModuleType::kMOE_4H_TO_H:
case ModuleType::kMOE_ROUTER: case ModuleType::kMOE_ROUTER:
case ModuleType::kINVALID: throw std::runtime_error("Invalid loRA module " + moduleName); case ModuleType::kMLP_ROUTER:
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
} }
} }
return modules; return modules;

View File

@ -112,6 +112,7 @@ The following tensors are for a LoRA which has a `q` and `k` adapter.
| moe_4h_to_h | 14 | for mixtral adapter for expert mlp layer: down projection | | moe_4h_to_h | 14 | for mixtral adapter for expert mlp layer: down projection |
| moe_gate | 15 | for mixtral adapter for expert mlp layer: gate | | moe_gate | 15 | for mixtral adapter for expert mlp layer: gate |
| moe_router | 16 | for mixtral adapter for expert router layer | | moe_router | 16 | for mixtral adapter for expert router layer |
| mlp_router | 17 | for qwen2-moe adapter for shared expert gate layer |
#### LoraCache configuration #### LoraCache configuration

Binary file not shown.

Before

Width:  |  Height:  |  Size: 673 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 126 KiB

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.15.0 datasets~=2.15.0
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
protobuf protobuf

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
transformers>=4.31.0 transformers>=4.31.0
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1

View File

@ -3,7 +3,7 @@
# WAR the new posting of "nvidia-cudnn-cu12~=9.0". # WAR the new posting of "nvidia-cudnn-cu12~=9.0".
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
flax~=0.8.0 flax~=0.8.0
# jax[cuda12_pip]~=0.4.19; platform_system != "Windows" # jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
jax~=0.4.19; platform_system == "Windows" jax~=0.4.19; platform_system == "Windows"

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
rouge_score~=0.1.2 rouge_score~=0.1.2
evaluate~=0.4.1 evaluate~=0.4.1

View File

@ -1,6 +1,6 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets==2.14.6 datasets==2.14.6
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,2 +1,2 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets==2.14.5 datasets==2.14.5
rouge_score~=0.1.2 rouge_score~=0.1.2
sentencepiece~=0.1.99 sentencepiece~=0.1.99

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1209,7 +1209,7 @@ Note that the sink tokens is included in the sliding attention tokens, and there
## Run LLaMA-3.1 405B Model ## Run LLaMA-3.1 405B Model
Currently, TensorRT-LLM supports Meta checkpoint and Huggingface checkpoint for LLaMA-3.1. In this section, we demonstrate how to run the LLaMA-3.1 405B model via TensorRT-LLM. Here, we assume users have downloaded the checkpoints and placed them at `llama_3.1_405B_meta_model/` (Meta checkpoint) and `llama_3.1_405B_HF_model/` (HF checkpoint). Before converting the checkpoints to TensorRT-LLM unified checkpoints, **please check that `"use_scaled_rope": true` is set in the configuration file**. With this flag, TensorRT-LLM will enable the rope scaling of LLaMA-3.1. If not, please add it to the config file. Currently, TensorRT-LLM supports Meta checkpoint and Huggingface checkpoint for LLaMA-3.1. In this section, we demonstrate how to run the LLaMA-3.1 405B model via TensorRT-LLM. Here, we assume users have downloaded the checkpoints and placed them at `llama_3.1_405B_meta_model/` (Meta BF16 checkpoint), `llama_3.1_405B_HF_model/` (HF BF16 checkpoint) and `llama_3.1_405B_HF_FP8_model/` (HF FP8 checkpoint). Before converting the checkpoints to TensorRT-LLM unified checkpoints, **please check that `{"rope_scaling": {"rope_type": "llama3"}}` is set in the configuration file**. With this flag, TensorRT-LLM will enable the rope scaling of LLaMA-3.1. If not, please add it to the config file.
Users can run the LLaMA-3.1 model with higher precision (bf16/fp16) or fp8. Here, to prevent accuracy drop, we perform per-channel per-token fp8 quantization (leveraged from https://github.com/pytorch/FBGEMM) on MLP layers, keeping other layers at higher precision. Note that fp8 quantization is only supported on Huggingface checkpoint now. We will support it on Meta checkpoint soon. Users can run the LLaMA-3.1 model with higher precision (bf16/fp16) or fp8. Here, to prevent accuracy drop, we perform per-channel per-token fp8 quantization (leveraged from https://github.com/pytorch/FBGEMM) on MLP layers, keeping other layers at higher precision. Note that fp8 quantization is only supported on Huggingface checkpoint now. We will support it on Meta checkpoint soon.
@ -1217,7 +1217,10 @@ Users can run the LLaMA-3.1 model with higher precision (bf16/fp16) or fp8. Here
To use the fp8 quantization, please add the `--use_fp8_rowwise` flag during the checkpoint conversion. In this demonstration, we convert the Meta checkpoint to bfloat16 with TP8-PP2 and the HF checkpoint to FP8 with TP8. To use the fp8 quantization, please add the `--use_fp8_rowwise` flag during the checkpoint conversion. In this demonstration, we convert the Meta checkpoint to bfloat16 with TP8-PP2 and the HF checkpoint to FP8 with TP8.
Note that you may need to update your transformers installation via `pip install --upgrade transformers`.
```bash ```bash
# Run BF16 model by BF16
python examples/llama/convert_checkpoint.py --meta_ckpt_dir llama_3.1_405B_meta_model/ \ python examples/llama/convert_checkpoint.py --meta_ckpt_dir llama_3.1_405B_meta_model/ \
--output_dir llama_3.1_405B_meta_model/trt_ckpts/tp8-pp2/ \ --output_dir llama_3.1_405B_meta_model/trt_ckpts/tp8-pp2/ \
--dtype bfloat16 \ --dtype bfloat16 \
@ -1226,6 +1229,7 @@ python examples/llama/convert_checkpoint.py --meta_ckpt_dir llama_3.1_405B_meta_
--load_by_shard \ --load_by_shard \
--workers 8 --workers 8
# Run BF16 model by FP8
python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/ \ python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/ \
--output_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \ --output_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \
--dtype bfloat16 \ --dtype bfloat16 \
@ -1234,6 +1238,15 @@ python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/
--pp_size 1 \ --pp_size 1 \
--load_by_shard \ --load_by_shard \
--workers 8 --workers 8
# Run FP8 model by FP8
python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_FP8_model/ \
--output_dir llama_3.1_405B_HF_FP8_model/trt_ckpts/tp8-pp1/ \
--dtype bfloat16 \
--tp_size 8 \
--pp_size 1 \
--load_by_shard \
--workers 8
``` ```
### Build Engine ### Build Engine
@ -1254,6 +1267,14 @@ trtllm-build --checkpoint_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \
--max_seq_len 65000 \ --max_seq_len 65000 \
--use_paged_context_fmha enable \ --use_paged_context_fmha enable \
--workers 8 --workers 8
trtllm-build --checkpoint_dir llama_3.1_405B_HF_FP8_model/trt_ckpts/tp8-pp1/ \
--output_dir llama_3.1_405B_HF_FP8_model/trt_engines/tp8-pp1/ \
--max_num_tokens 4096 \
--max_input_len 64000 \
--max_seq_len 65000 \
--use_paged_context_fmha enable \
--workers 8
``` ```
### Run Inference ### Run Inference
@ -1298,7 +1319,7 @@ srun --mpi pmi2 -N 1 -n 8 --ntasks-per-node 8 --container-image <your container>
--container-name llama-3.1-405b \ --container-name llama-3.1-405b \
--container-workdir <your container work directory> \ --container-workdir <your container work directory> \
bash -c 'python ./examples/eval_long_context.py --task passkey \ bash -c 'python ./examples/eval_long_context.py --task passkey \
--engine_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \ --engine_dir llama_3.1_405B_HF_model/trt_engines/tp8-pp1/ \
--tokenizer_dir llama_3.1_405B_HF_model/ \ --tokenizer_dir llama_3.1_405B_HF_model/ \
--stop_idx 6 \ --stop_idx 6 \
--max_input_length 64000 \ --max_input_length 64000 \
@ -1307,6 +1328,22 @@ bash -c 'python ./examples/eval_long_context.py --task passkey \
--max_tokens_in_paged_kv_cache 65064 \ --max_tokens_in_paged_kv_cache 65064 \
--data_dir 64k_context \ --data_dir 64k_context \
--output_dir 64k_context_tp8' --output_dir 64k_context_tp8'
# Long context test for 64k
srun --mpi pmi2 -N 1 -n 8 --ntasks-per-node 8 --container-image <your container> \
--container-mounts <your container mount> \
--container-name llama-3.1-405b \
--container-workdir <your container work directory> \
bash -c 'python ./examples/eval_long_context.py --task passkey \
--engine_dir llama_3.1_405B_HF_FP8_model/trt_engines/tp8-pp1/ \
--tokenizer_dir llama_3.1_405B_HF_FP8_model/ \
--stop_idx 6 \
--max_input_length 64000 \
--enable_chunked_context \
--kv_cache_free_gpu_memory_fraction 0.999 \
--max_tokens_in_paged_kv_cache 65064 \
--data_dir 64k_context \
--output_dir 64k_context_tp8'
``` ```
The following script shows how to run evaluation on MMLU tasks: The following script shows how to run evaluation on MMLU tasks:
@ -1332,5 +1369,16 @@ bash -c 'python ./examples/mmlu.py --test_trt_llm \
--tokenizer_dir llama_3.1_405B_HF_model/ \ --tokenizer_dir llama_3.1_405B_HF_model/ \
--enable_chunked_context \ --enable_chunked_context \
--kv_cache_free_gpu_memory_fraction 0.999 \ --kv_cache_free_gpu_memory_fraction 0.999 \
--max_tokens_in_paged_kv_cache 256064' --max_tokens_in_paged_kv_cache 65064'
srun --mpi pmi2 -N 1 -n 8 --ntasks-per-node 8 --container-image <your container> \
--container-mounts <your container mount> \
--container-name llama-3.1-405b \
--container-workdir <your container work directory> \
bash -c 'python ./examples/mmlu.py --test_trt_llm \
--engine_dir llama_3.1_405B_HF_FP8_model/trt_engines/tp8-pp1/ \
--tokenizer_dir llama_3.1_405B_HF_FP8_model/ \
--enable_chunked_context \
--kv_cache_free_gpu_memory_fraction 0.999 \
--max_tokens_in_paged_kv_cache 65064'
``` ```

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets==2.14.6 datasets==2.14.6
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
transformers>=4.39.0 transformers>=4.39.0
datasets~=2.14.5 datasets~=2.14.5
evaluate evaluate

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
rouge_score~=0.1.2 rouge_score~=0.1.2
sentencepiece~=0.1.99 sentencepiece~=0.1.99

View File

@ -1,4 +1,4 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
transformers==4.38.2 transformers==4.38.2
accelerate==0.25.0 accelerate==0.25.0

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
transformers==4.40.2 transformers==4.40.2
# https://github.com/NVIDIA/NeMo/issues/9793 # https://github.com/NVIDIA/NeMo/issues/9793
huggingface_hub==0.23.5 huggingface_hub==0.23.5

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.14.5 datasets~=2.14.5
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets>=2.14.4 datasets>=2.14.4
nemo-toolkit[all]<=1.20.0,>=1.18.0 nemo-toolkit[all]<=1.20.0,>=1.18.0
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.16.0 datasets~=2.16.0
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.16.0 datasets~=2.16.0
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
git+https://github.com/google-deepmind/recurrentgemma.git git+https://github.com/google-deepmind/recurrentgemma.git
flax>=0.8.2 flax>=0.8.2
jax~=0.4.23 jax~=0.4.23

View File

@ -304,7 +304,7 @@ def main(args):
encoder_input_lengths = [x.size(0) encoder_input_lengths = [x.size(0)
for x in encoder_input_ids] if is_enc_dec else None for x in encoder_input_ids] if is_enc_dec else None
if not supports_inflight_batching( if not args.use_py_session and not supports_inflight_batching(
os.path.join(args.engine_dir, "decoder") if is_enc_dec else args. os.path.join(args.engine_dir, "decoder") if is_enc_dec else args.
engine_dir): engine_dir):
logger.warning( logger.warning(

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets~=2.16.1 datasets~=2.16.1
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
datasets==2.14.6 datasets==2.14.6
evaluate~=0.4.1 evaluate~=0.4.1
rouge_score~=0.1.2 rouge_score~=0.1.2

View File

@ -1,5 +1,5 @@
--extra-index-url https://pypi.nvidia.com --extra-index-url https://pypi.nvidia.com
tensorrt_llm==0.12.0.dev2024072300 tensorrt_llm==0.12.0.dev2024072301
tiktoken tiktoken
datasets datasets
kaldialign kaldialign

View File

@ -112,26 +112,12 @@ def trt_version():
return trt.__version__ return trt.__version__
# TRT supports strongly_typed in 9.1 def trt_gte(major: int, minor: int = 0):
def support_strongly_type(): """
return version.parse(trt_version()) >= version.parse("9.1.0") Check if TRT version is greater than or equal to major.minor
"""
# Check if TRT version >= 10
def trt_gte_10():
return version.parse(trt_version()).major > 9
# Check if TRT version >= 10.1
def trt_gte_10_1():
trt_ver = version.parse(trt_version()) trt_ver = version.parse(trt_version())
return trt_ver.major > 9 and trt_ver.minor > 0 return trt_ver.major >= major and trt_ver.minor >= minor
# Check if TRT version >= 10.2
def trt_gte_10_2():
ver = version.parse(trt_version())
return (ver.major * 10 + ver.minor) >= 102
def torch_version(): def torch_version():

View File

@ -13,7 +13,7 @@ import torch
from filelock import FileLock from filelock import FileLock
from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np, from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np,
trt_dtype_to_torch, trt_gte_10) trt_dtype_to_torch)
from tensorrt_llm.functional import (AllReduceConfig, AllReduceFusionParams, from tensorrt_llm.functional import (AllReduceConfig, AllReduceFusionParams,
AllReduceStrategy, create_allreduce_plugin) AllReduceStrategy, create_allreduce_plugin)
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
@ -39,7 +39,7 @@ from .tensor_parallel.sharding_strategy import ShardingStrategy
from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer, from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer,
to_trt_weights) to_trt_weights)
default_int_dtype = trt.int64 if trt_gte_10() else trt.int32 default_int_dtype = trt.int64
@dataclass @dataclass

View File

@ -25,8 +25,7 @@ from typing import Dict, Optional, Union
import tensorrt as trt import tensorrt as trt
from ._common import _is_building, check_max_num_tokens, serialize_engine from ._common import _is_building, check_max_num_tokens, serialize_engine
from ._utils import (str_dtype_to_trt, support_strongly_type, to_json_file, from ._utils import str_dtype_to_trt, to_json_file
trt_gte_10, trt_gte_10_2)
from .auto_parallel import auto_parallel from .auto_parallel import auto_parallel
from .auto_parallel.config import AutoParallelConfig from .auto_parallel.config import AutoParallelConfig
from .graph_rewriting import optimize from .graph_rewriting import optimize
@ -112,7 +111,7 @@ class Builder():
explicit_batch_flag = 1 << int( explicit_batch_flag = 1 << int(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
if support_strongly_type() and self.strongly_typed: if self.strongly_typed:
return Network()._init( return Network()._init(
self.trt_builder.create_network( self.trt_builder.create_network(
explicit_batch_flag explicit_batch_flag
@ -145,35 +144,15 @@ class Builder():
@param int8: whether to build with int8 enabled or not. Can't be used together with refit option @param int8: whether to build with int8 enabled or not. Can't be used together with refit option
@return: A BuilderConfig object, return None if failed @return: A BuilderConfig object, return None if failed
''' '''
if strongly_typed and not support_strongly_type(): self.strongly_typed = self.strongly_typed or strongly_typed
logger.warning(
"TRT version does not support strongly_type. strongly_typed flag is ignored."
)
# In TRT 10.0, enable strongly_typed by default.
self.strongly_typed = self.strongly_typed or (strongly_typed and
support_strongly_type())
quant_mode = kwargs.get("quant_mode", QuantMode(0)) quant_mode = kwargs.get("quant_mode", QuantMode(0))
if not strongly_typed and precision not in self._ALLOWED_PRECISIONS: if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
logger.error( logger.error(
f"precision should be one of {self._ALLOWED_PRECISIONS}") f"precision should be one of {self._ALLOWED_PRECISIONS}")
if use_strip_plan and not trt_gte_10():
logger.error(
"cannot use --strip_plan with tensorrt version 9.x or below")
if (use_refit or use_strip_plan) and int8 and not trt_gte_10():
# TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes
# These folded weights can not be refitted
logger.error(
"can't use refit/strip_plan and int8 mode at the same time before tensorrt 10.0"
)
config = self.trt_builder.create_builder_config() config = self.trt_builder.create_builder_config()
if weight_streaming: if weight_streaming:
assert trt_gte_10(), \
"Weight streaming is only supported by TensorRT 10.0 or later."
config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING) config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
if not self.strongly_typed: if not self.strongly_typed:
fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache() fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache()
@ -197,7 +176,7 @@ class Builder():
config.set_flag(trt.BuilderFlag.REFIT) config.set_flag(trt.BuilderFlag.REFIT)
# Use fine-grained refit when strip plan is enabled in TRT10.2+. # Use fine-grained refit when strip plan is enabled in TRT10.2+.
if use_strip_plan and trt_gte_10_2(): if use_strip_plan:
config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL) config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL)
if use_strip_plan: if use_strip_plan:
@ -396,7 +375,6 @@ class Builder():
engine = None engine = None
# Rename weights # Rename weights
is_refit_individual_supported = trt_gte_10_2()
if network.named_parameters is not None: if network.named_parameters is not None:
for name, param in network.named_parameters: for name, param in network.named_parameters:
if param._get_weights() is None: if param._get_weights() is None:
@ -409,9 +387,8 @@ class Builder():
if not network.trt_network.set_weights_name( if not network.trt_network.set_weights_name(
param._get_weights(), name): param._get_weights(), name):
raise RuntimeError(f'Failed to set weight: {name}') raise RuntimeError(f'Failed to set weight: {name}')
if is_refit_individual_supported: # This mark_weights_refittable has no side effect when refit_individual is not enabled.
# This mark_weights_refittable has no side effect when refit_individual is not enabled. network.trt_network.mark_weights_refittable(name)
network.trt_network.mark_weights_refittable(name)
network._fill_weights() network._fill_weights()
# Build engine # Build engine

View File

@ -447,7 +447,8 @@ def main():
# Extract rotary scaling which will be used for checks and default value of max_seq_len # Extract rotary scaling which will be used for checks and default value of max_seq_len
rotary_scaling = getattr(model_config, "rotary_scaling", None) rotary_scaling = getattr(model_config, "rotary_scaling", None)
if rotary_scaling is not None: if rotary_scaling is not None:
rotary_type = rotary_scaling['type'] rotary_type = rotary_scaling.get('type',
rotary_scaling.get('rope_type'))
rotary_factor = rotary_scaling.get( rotary_factor = rotary_scaling.get(
'factor', 1.0) if rotary_type != 'su' else 1 'factor', 1.0) if rotary_type != 'su' else 1
else: else:

View File

@ -30,8 +30,7 @@ from ._common import default_net, default_trtnet, precision
from ._utils import (bf16_array, bool_array, dim_resolve_negative, from ._utils import (bf16_array, bool_array, dim_resolve_negative,
dim_to_trt_axes, dims_array, fp16_array, fp32_array, dim_to_trt_axes, dims_array, fp16_array, fp32_array,
int32_array, int64_array, np_dtype_to_trt, int32_array, int64_array, np_dtype_to_trt,
str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_str, str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_str)
trt_gte_10)
from .network import PluginInfo, set_np_weight, set_plugin_info from .network import PluginInfo, set_np_weight, set_plugin_info
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
from .quantization import QuantMode from .quantization import QuantMode
@ -639,7 +638,7 @@ class RotaryScalingType(IntEnum):
linear = 1 linear = 1
dynamic = 2 dynamic = 2
longrope = 3 longrope = 3
wavelen = 4 llama3 = 4
@staticmethod @staticmethod
def from_string(s): def from_string(s):
@ -1030,10 +1029,6 @@ def matmul(input: Tensor,
# This option is only supported for fp16, but not bf16 or any other precisions. # This option is only supported for fp16, but not bf16 or any other precisions.
use_fp32_acc = use_fp32_acc and input.dtype == trt.DataType.HALF and mat2.dtype == trt.DataType.HALF use_fp32_acc = use_fp32_acc and input.dtype == trt.DataType.HALF and mat2.dtype == trt.DataType.HALF
# TODO: fp32 accum has issues with strongly_typed and it will be fixed in TensorRT 10.0
if default_net().strongly_typed and not trt_gte_10():
use_fp32_acc = False
if use_fp32_acc: if use_fp32_acc:
input = cast(input, 'float32') input = cast(input, 'float32')
mat2 = cast(mat2, 'float32') mat2 = cast(mat2, 'float32')
@ -4139,11 +4134,14 @@ def bert_attention(tensor: Tensor,
class RopeEmbeddingUtils: class RopeEmbeddingUtils:
@staticmethod @staticmethod
def apply_wavelen_scaling(inv_freqs: np.ndarray, # ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L298
scale_factor: float = 8.0, def apply_llama3_scaling(inv_freqs: np.ndarray, rope_scaling_config: dict):
low_freq_factor: float = 1.0,
high_freq_factor: float = 4.0, scale_factor = rope_scaling_config.get("factor", 8.0)
old_context_len: int = 8192): low_freq_factor = rope_scaling_config.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling_config.get("high_freq_factor", 4.0)
old_context_len = rope_scaling_config.get(
"original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor high_freq_wavelen = old_context_len / high_freq_factor
@ -4183,12 +4181,19 @@ class RopeEmbeddingUtils:
theta: float = 10000.0, theta: float = 10000.0,
scale: float = 1.0, scale: float = 1.0,
scale_type: RotaryScalingType = RotaryScalingType.none, scale_type: RotaryScalingType = RotaryScalingType.none,
# Other scaling configs that only used by certain scaling types.
rope_scaling_config: dict = None,
dtype=np.float32): dtype=np.float32):
if scale_type == RotaryScalingType.linear: if scale_type == RotaryScalingType.linear:
scale = 1.0 / scale scale = 1.0 / scale
inv_freq = scale / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype) if scale_type == RotaryScalingType.llama3:
if scale_type == RotaryScalingType.wavelen: assert rope_scaling_config is not None, "rotary_scaling config must be provided."
inv_freq = RopeEmbeddingUtils.apply_wavelen_scaling(inv_freq) inv_freq = 1.0 / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype)
inv_freq = RopeEmbeddingUtils.apply_llama3_scaling(
inv_freq, rope_scaling_config)
else:
inv_freq = scale / (theta
**(np.arange(0, dim, 2) / dim)).astype(dtype)
sinusoid_inp = np.expand_dims(np.einsum("i , j -> i j", sinusoid_inp = np.expand_dims(np.einsum("i , j -> i j",
np.arange(num_pos, dtype=dtype), np.arange(num_pos, dtype=dtype),
inv_freq, inv_freq,
@ -4618,7 +4623,7 @@ def gpt_attention(
* RotaryScalingType.linear * RotaryScalingType.linear
* RotaryScalingType.dynamic * RotaryScalingType.dynamic
* RotaryScalingType.longrope * RotaryScalingType.longrope
* RotaryScalingType.wavelen * RotaryScalingType.llama3
rotary_embedding_scale: float rotary_embedding_scale: float
The scale value to use for linear/dynamic scaling in RoPE. The scale value to use for linear/dynamic scaling in RoPE.

View File

@ -20,7 +20,7 @@ import tensorrt as trt
from .._common import default_net, precision from .._common import default_net, precision
from .._utils import (fp32_array, int32_array, is_same_dtype, trt_dtype_to_np, from .._utils import (fp32_array, int32_array, is_same_dtype, trt_dtype_to_np,
trt_dtype_to_str, trt_gte_10) trt_dtype_to_str)
from ..functional import (ACT2FN, AllReduceFusionParams, AttentionMaskType, from ..functional import (ACT2FN, AllReduceFusionParams, AttentionMaskType,
Conditional, LayerNormType, PositionEmbeddingType, Conditional, LayerNormType, PositionEmbeddingType,
RopeEmbeddingUtils, RotaryScalingType, Tensor, arange, RopeEmbeddingUtils, RotaryScalingType, Tensor, arange,
@ -362,8 +362,10 @@ class Attention(Module):
self.rotary_embedding_percentage = rotary_embedding_percentage self.rotary_embedding_percentage = rotary_embedding_percentage
self.use_implicit_relative_attention = self.relative_attention and use_implicit_relative_attention self.use_implicit_relative_attention = self.relative_attention and use_implicit_relative_attention
if rotary_embedding_scaling is not None: if rotary_embedding_scaling is not None:
rotary_scaling_type = rotary_embedding_scaling.get(
"type", rotary_embedding_scaling.get("rope_type"))
self.rotary_embedding_scale_type = RotaryScalingType.from_string( self.rotary_embedding_scale_type = RotaryScalingType.from_string(
rotary_embedding_scaling["type"]) rotary_scaling_type)
self.rotary_embedding_scale = rotary_embedding_scaling.get( self.rotary_embedding_scale = rotary_embedding_scaling.get(
"factor", 1.0) "factor", 1.0)
@ -433,7 +435,8 @@ class Attention(Module):
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
self.max_position_embeddings, self.rotary_embedding_dim, self.max_position_embeddings, self.rotary_embedding_dim,
self.rotary_embedding_base, self.rotary_embedding_scale, self.rotary_embedding_base, self.rotary_embedding_scale,
self.rotary_embedding_scale_type) self.rotary_embedding_scale_type,
self.rotary_embedding_scaling)
self.register_parameter( self.register_parameter(
'rotary_inv_freq', 'rotary_inv_freq',
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
@ -1153,12 +1156,7 @@ class Attention(Module):
if norm_before_bmm1: if norm_before_bmm1:
# Apply norm on query earlier to prevent matmul fp16 overflow. # Apply norm on query earlier to prevent matmul fp16 overflow.
query /= (self.q_scaling * self.norm_factor) query /= (self.q_scaling * self.norm_factor)
if trt_gte_10() or self.position_embedding_type.is_alibi(): attention_scores = matmul(query, key)
attention_scores = matmul(query, key)
else:
# For TRT 9.x, OOTB need this WAR to fuse mha.
attention_scores = matmul(cast(query, 'float32'),
cast(key, 'float32'))
if not norm_before_bmm1: if not norm_before_bmm1:
attention_scores = attention_scores / (self.q_scaling * attention_scores = attention_scores / (self.q_scaling *
self.norm_factor) self.norm_factor)
@ -1182,24 +1180,16 @@ class Attention(Module):
attention_probs = softmax(attention_scores, dim=-1) attention_probs = softmax(attention_scores, dim=-1)
if trt_gte_10() or self.position_embedding_type.is_alibi(): # A dummy reshape WAR for mha fusion
# For trt_version() == 9.x and pos_embed == alibi, TRT has gpu buffer management issues. Need this WAR to avoid peak gpu mem regression. attention_probs = attention_probs.view(
# A dummy reshape WAR for mha fusion for 10.0 concat([
attention_probs = attention_probs.view( shape(attention_probs, 0),
concat([ shape(attention_probs, 1),
shape(attention_probs, 0), shape(attention_probs, 2),
shape(attention_probs, 1), shape(value, 2)
shape(attention_probs, 2), ]))
shape(value, 2) context = matmul(attention_probs, value,
])) use_fp32_acc=False).permute([0, 2, 1, 3])
context = matmul(attention_probs, value,
use_fp32_acc=False).permute([0, 2, 1, 3])
else:
# For TRT 9.x, need this WAR to fuse mha.
context = matmul(attention_probs,
cast(value, 'float32')).permute([0, 2, 1, 3])
if context.dtype != value.dtype:
context = cast(context, value.dtype)
context = context.view( context = context.view(
concat([ concat([
shape(context, 0), shape(context, 0),

View File

@ -18,7 +18,7 @@ import tensorrt as trt
from .._common import default_net from .._common import default_net
from ..functional import (ACT2FN, AllReduceFusionParams, cast, concat, from ..functional import (ACT2FN, AllReduceFusionParams, cast, concat,
gemm_swiglu) gemm_swiglu, is_gated_activation)
from ..module import Module from ..module import Module
from ..quantization import QuantMode from ..quantization import QuantMode
from ..quantization.functional import quantize from ..quantization.functional import quantize
@ -28,6 +28,34 @@ from .lora import LoraRuntimeParams
from .normalization import LayerNorm from .normalization import LayerNorm
def fc_gate_lora(hidden_states, lora, lora_layer_params):
if lora_layer_params is not None:
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_h_to_4h")
mlp_gate_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_gate")
if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
mlp_in_lora_params = LoraRuntimeParams(
lora_ranks=[
mlp_fc_lora_params.lora_ranks[0],
mlp_gate_lora_params.lora_ranks[0]
],
lora_weights_pointers=[
mlp_fc_lora_params.lora_weights_pointers[0],
mlp_gate_lora_params.lora_weights_pointers[0]
],
host_request_types=mlp_fc_lora_params.host_request_types,
host_context_lengths=mlp_fc_lora_params.host_context_lengths,
max_context_length=mlp_fc_lora_params.max_context_length)
mlp_fc_lora, mlp_gate_lora = lora(hidden_states, mlp_in_lora_params)
mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora],
dim=mlp_fc_lora.rank() - 1)
return mlp_in_result
return None
class MLP(Module): class MLP(Module):
def __init__( def __init__(
@ -76,19 +104,28 @@ class MLP(Module):
self.tp_size = tp_size self.tp_size = tp_size
self.quant_mode = quant_mode self.quant_mode = quant_mode
self.eps = eps self.eps = eps
# see optimize_model's add_lora for LoRA initialization
self.lora = None
def forward(self, hidden_states, lora_layer_params=None, gegelu_limit=None): def forward(self, hidden_states, lora_layer_params=None, gegelu_limit=None):
mlp_fc_lora_params = None if is_gated_activation(self.hidden_act):
if lora_layer_params is not None: inter = self.fc(hidden_states)
mlp_fc_lora_params = lora_layer_params.get_runtime_params( lora_result = fc_gate_lora(hidden_states, self.lora,
0, "mlp_h_to_4h") lora_layer_params)
if lora_result is not None:
inter = inter + lora_result
else:
mlp_fc_lora_params = None
if lora_layer_params is not None:
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_h_to_4h")
inter = self.fc(hidden_states, mlp_fc_lora_params)
mlp_proj_lora_params = None mlp_proj_lora_params = None
if lora_layer_params is not None: if lora_layer_params is not None:
mlp_proj_lora_params = lora_layer_params.get_runtime_params( mlp_proj_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_4h_to_h") 0, "mlp_4h_to_h")
inter = self.fc(hidden_states, mlp_fc_lora_params)
if self.hidden_act == 'gegelu': if self.hidden_act == 'gegelu':
inter = ACT2FN[self.hidden_act](inter, gegelu_limit) inter = ACT2FN[self.hidden_act](inter, gegelu_limit)
else: else:
@ -286,32 +323,9 @@ class FusedGatedMLP(Module):
inter = self.fused_fc(hidden_states) inter = self.fused_fc(hidden_states)
if lora_layer_params is not None: lora_result = fc_gate_lora(hidden_states, self.lora, lora_layer_params)
mlp_fc_lora_params = lora_layer_params.get_runtime_params( if lora_result is not None:
0, "mlp_h_to_4h") inter = inter + lora_result
mlp_gate_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_gate")
if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
mlp_in_lora_params = LoraRuntimeParams(
lora_ranks=[
mlp_fc_lora_params.lora_ranks[0],
mlp_gate_lora_params.lora_ranks[0]
],
lora_weights_pointers=[
mlp_fc_lora_params.lora_weights_pointers[0],
mlp_gate_lora_params.lora_weights_pointers[0]
],
host_request_types=mlp_fc_lora_params.host_request_types,
host_context_lengths=mlp_fc_lora_params.
host_context_lengths,
max_context_length=mlp_fc_lora_params.max_context_length)
mlp_fc_lora, mlp_gate_lora = self.lora(hidden_states,
mlp_in_lora_params)
mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora],
dim=mlp_fc_lora.rank() - 1)
inter = inter + mlp_in_result
if self.hidden_act == 'silu': if self.hidden_act == 'silu':
inter = ACT2FN['swiglu'](inter) inter = ACT2FN['swiglu'](inter)

View File

@ -18,10 +18,8 @@ from typing import List, Optional, Type, Union
import numpy as np import numpy as np
import tensorrt as trt import tensorrt as trt
from packaging import version
from tensorrt_llm._utils import (get_init_params, str_dtype_to_trt, trt_gte_10, from tensorrt_llm._utils import get_init_params, str_dtype_to_trt
trt_version)
from tensorrt_llm.layers.lora import LoraParams from tensorrt_llm.layers.lora import LoraParams
from .._common import default_net, default_trtnet from .._common import default_net, default_trtnet
@ -553,106 +551,81 @@ class MoeOOTB(MOE):
router_probs = softmax(routing, -1) router_probs = softmax(routing, -1)
topk_values, topk_indices = topk(router_probs, self.top_k, dim=-1) topk_values, topk_indices = topk(router_probs, self.top_k, dim=-1)
if trt_gte_10() and version.parse(trt_version()).minor >= 2: hidden_size = shape(hidden_states, -1)
# For TRT 10.2 and above, avoid over-computing by using NonZero ops to select tokens for each experts. # [B*sq, hidden]
inputs_merged = hidden_states.view(concat([-1, hidden_size]))
flat_topk_indices = topk_indices.view(
concat([-1, shape(topk_indices, -1)]))
flat_topk_values = topk_values.view(concat([-1,
shape(topk_values, -1)]))
hidden_size = shape(hidden_states, -1) # Create output space
#[B*sq, hidden] zero_buffer = inputs_merged * 0.0
inputs_merged = hidden_states.view(concat([-1, hidden_size])) output = zero_buffer
flat_topk_indices = topk_indices.view(
concat([-1, shape(topk_indices, -1)]))
flat_topk_values = topk_values.view(
concat([-1, shape(topk_values, -1)]))
# Create output space expert_indices_stack = []
zero_buffer = inputs_merged * 0.0 indices_stack = []
output = zero_buffer # When topk indices are equal to expert index, the expert will inference the tokens.
# Bundle all indices and experts index, then do mask once.
for i, expert in enumerate(self.experts):
if self.mapping.has_moe_ep():
index = i + self.experts_per_node * self.mapping.moe_ep_rank
else:
index = i
expert_indices_stack.append(
flat_topk_indices.view(concat([1, shape(flat_topk_indices)])))
expert_indices_stack = [] indices_stack.append(constant(int32_array(index)))
indices_stack = []
# When topk indices are equal to expert index, the expert will inference the tokens.
# Bundle all indices and experts index, then do mask once.
for i, expert in enumerate(self.experts):
if self.mapping.has_moe_ep():
index = i + self.experts_per_node * self.mapping.moe_ep_rank
else:
index = i
expert_indices_stack.append(
flat_topk_indices.view(concat([1,
shape(flat_topk_indices)])))
indices_stack.append(constant(int32_array(index))) all_expert_indices = concat(expert_indices_stack, dim=0)
indices = expand(
concat(indices_stack).view(concat([len(self.experts), 1, 1])),
shape(all_expert_indices))
all_expert_indices = concat(expert_indices_stack, dim=0) # Create all experts mask
indices = expand( all_expert_mask = all_expert_indices == indices
concat(indices_stack).view(concat([len(self.experts), 1, 1])),
shape(all_expert_indices))
# Create all experts mask experts_weights = cast(
all_expert_mask = all_expert_indices == indices sum(flat_topk_values *
cast(all_expert_mask, flat_topk_values.dtype),
dim=-1,
keepdim=True), self.dtype)
experts_weights = cast( all_expert_mask = cast(
sum(flat_topk_values * sum(cast(all_expert_mask, flat_topk_values.dtype),
cast(all_expert_mask, flat_topk_values.dtype), dim=-1,
dim=-1, keepdim=True), 'bool')
keepdim=True), self.dtype) all_expert_mask = repeat_interleave(all_expert_mask, shape(output, -1),
2)
all_expert_mask = cast( # split the mask and weights for each expert
sum(cast(all_expert_mask, flat_topk_values.dtype), experts_mask = split(all_expert_mask, 1, dim=0)
dim=-1, expert_weights = split(experts_weights, 1, dim=0)
keepdim=True), 'bool')
all_expert_mask = repeat_interleave(all_expert_mask,
shape(output, -1), 2)
# split the mask and weights for each expert for i, expert in enumerate(self.experts):
experts_mask = split(all_expert_mask, 1, dim=0) # get mask token index
expert_weights = split(experts_weights, 1, dim=0) non_zero_index = nonzero(experts_mask[i].view(
concat([-1, hidden_size])))
non_zero_index = non_zero_index.transpose(1, 0)
input_for_expert = gather_nd(inputs_merged, non_zero_index, 0)
input_for_expert = input_for_expert.view(concat([-1, hidden_size]),
zero_is_placeholder=False)
for i, expert in enumerate(self.experts): # Expert inference
# get mask token index expert_output = expert(
non_zero_index = nonzero(experts_mask[i].view( input_for_expert,
concat([-1, hidden_size]))) lora_layer_params=self.moe_to_expert_lora_params(
non_zero_index = non_zero_index.transpose(1, 0) lora_layer_params, index))
input_for_expert = gather_nd(inputs_merged, non_zero_index, 0)
input_for_expert = input_for_expert.view(
concat([-1, hidden_size]), zero_is_placeholder=False)
# Expert inference # scatter expert output to real position
expert_output = expert( expert_finialized_output = zero_buffer
input_for_expert, expert_finialized_output = scatter_nd(
lora_layer_params=self.moe_to_expert_lora_params( expert_finialized_output, non_zero_index,
lora_layer_params, index)) expert_output.view([-1])) * expert_weights[i]
# scatter expert output to real position output += expert_finialized_output
expert_finialized_output = zero_buffer
expert_finialized_output = scatter_nd(
expert_finialized_output, non_zero_index,
expert_output.view([-1])) * expert_weights[i]
output += expert_finialized_output output = output.view(shape(hidden_states))
output = output.view(shape(hidden_states))
else:
output = hidden_states * 0.0 # Create output space
# Use over-computation when TRT version is too low.
# Experts inference
for i, expert in enumerate(self.experts):
if self.mapping.has_moe_ep():
index = i + self.experts_per_node * self.mapping.moe_ep_rank
else:
index = i
# inference expert
out = expert(hidden_states,
lora_layer_params=self.moe_to_expert_lora_params(
lora_layer_params, index))
expert_mask = topk_indices == index
expert_weights = cast(
sum(topk_values * cast(expert_mask, topk_values.dtype),
dim=-1,
keepdim=True), self.dtype)
output += out * expert_weights
need_ep_reduce = self.mapping.has_moe_ep( need_ep_reduce = self.mapping.has_moe_ep(
) and self.mapping.moe_ep_group is not None ) and self.mapping.moe_ep_group is not None

View File

@ -13,7 +13,6 @@ import yaml
from ._utils import (DictConversion, pad_vocab_size, release_gc, from ._utils import (DictConversion, pad_vocab_size, release_gc,
str_dtype_to_torch, torch_to_numpy) str_dtype_to_torch, torch_to_numpy)
from .layers.linear import ColumnLinear from .layers.linear import ColumnLinear
from .logger import logger
from .mapping import Mapping from .mapping import Mapping
from .models.convert_utils import (get_model_path, load_state_dict, from .models.convert_utils import (get_model_path, load_state_dict,
split_matrix_tp) split_matrix_tp)
@ -34,30 +33,43 @@ def get_all_nemo_lora_weights(lora_weights):
m = layer_pattern.match(key) m = layer_pattern.match(key)
layer_idx = int(m.group(1)) layer_idx = int(m.group(1))
layer_weights[layer_idx][inout] = weights layer_weights[layer_idx][inout] = weights
else:
raise KeyError(f"unsupported key {key} from Nemo LoRA weights")
return layer_weights return layer_weights
def get_all_hf_lora_weights(lora_weights, hf_modules, component=None): # The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight
HF_LORA_PATTERN = re.compile(
r'(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.lora_(A|B)\.weight'
)
def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
all_weights = defaultdict(lambda: defaultdict(dict)) all_weights = defaultdict(lambda: defaultdict(dict))
pattern = re.compile( pattern = HF_LORA_PATTERN
r'(.*)\.(\d+)\.(\w+)\.(\w+|experts\.(\d+)\.(\w+))\.lora_(A|B)\.weight')
for key, weights in lora_weights.items(): for key, weights in lora_weights.items():
m = pattern.match(key) m = pattern.match(key)
if not m: if not m:
if "lm_head" not in key and "embed_tokens" not in key: if "lm_head" not in key and "embed_tokens" not in key:
logger.warning(f"no match {key} from HF LoRA weights") raise KeyError(f"unsupported key {key} from HF LoRA weights")
continue continue
if component is not None and component not in m.group(1): if component is not None and component not in m.group(1):
continue continue
layer_idx = int(m.group(2)) layer_idx = int(m.group(2))
expert_idx = m.group(5) expert_idx = m.group(6)
is_moe = expert_idx is not None is_moe = expert_idx is not None
module_name = m.group(6 if is_moe else 4) if is_moe:
hf_module = m.group(3) + "." + module_name expert_name = m.group(5)
module_name = m.group(7)
hf_module = m.group(3) + "." + expert_name + "." + module_name
else:
module_name = m.group(4)
hf_module = m.group(3) + "." + module_name
if hf_module not in hf_modules: if hf_module not in hf_modules:
hf_module = module_name hf_module = module_name
assert hf_module in hf_modules assert hf_module in hf_modules
inout = "in" if m.group(7) == "A" else "out" inout = "in" if m.group(8) == "A" else "out"
iter_fn(layer_idx, hf_module, expert_idx, inout, weights)
if not is_moe: if not is_moe:
all_weights[layer_idx][hf_module][inout] = weights all_weights[layer_idx][hf_module][inout] = weights
else: else:
@ -66,31 +78,27 @@ def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
return all_weights return all_weights
def get_hf_target_modules(lora_weights, hf_modules, lora_target_modules): def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
hf_target_modules = set()
pattern = re.compile( def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
r'(.*)\.(\d+)\.(\w+)\.(\w+|experts\.(\d+)\.(\w+))\.lora_(A|B)\.weight') if expert_idx is None:
for key in lora_weights.keys(): all_weights[layer_idx][hf_module][inout] = weights
m = pattern.match(key) else:
if not m: all_weights[layer_idx][hf_module].setdefault(expert_idx, {})
if "lm_head" not in key and "embed_tokens" not in key: all_weights[layer_idx][hf_module][expert_idx][inout] = weights
logger.warning(f"no match {key} from HF LoRA weights")
continue all_weights = defaultdict(lambda: defaultdict(dict))
match_target_module = False iterate_hf_lora(iter_fn, lora_weights, hf_modules, component)
for module in lora_target_modules: return all_weights
if module in key:
match_target_module = True
break def get_hf_target_modules(lora_weights, hf_modules):
if not match_target_module:
continue def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
expert_idx = m.group(5)
is_moe = expert_idx is not None
module_name = m.group(6 if is_moe else 4)
hf_module = m.group(3) + "." + module_name
if hf_module not in hf_modules:
hf_module = module_name
assert hf_module in hf_modules
hf_target_modules.add(hf_module) hf_target_modules.add(hf_module)
hf_target_modules = set()
iterate_hf_lora(iter_fn, lora_weights, hf_modules)
return hf_target_modules return hf_target_modules
@ -146,7 +154,6 @@ class HfLoraLoader:
lora_dir = lora_dirs[0] lora_dir = lora_dirs[0]
with open(f"{lora_dir}/adapter_config.json") as f: with open(f"{lora_dir}/adapter_config.json") as f:
adapter_config = json.load(f) adapter_config = json.load(f)
self.lora_target_modules = adapter_config["target_modules"]
lora_weight = load_state_dict(get_model_path(lora_dir, "adapter_model")) lora_weight = load_state_dict(get_model_path(lora_dir, "adapter_model"))
self.lora_weight = lora_weight self.lora_weight = lora_weight
@ -162,17 +169,16 @@ class HfLoraLoader:
def get_target_modules(self, trtllm_modules_to_hf_modules): def get_target_modules(self, trtllm_modules_to_hf_modules):
hf_modules_to_trtllm_modules = invert_module_mapping( hf_modules_to_trtllm_modules = invert_module_mapping(
trtllm_modules_to_hf_modules) trtllm_modules_to_hf_modules)
lora_target_modules = [] lora_target_modules = set()
if self.is_valid: if self.is_valid:
hf_target_modules = get_hf_target_modules( hf_target_modules = get_hf_target_modules(
self.lora_weight, self.lora_weight,
hf_modules=set(hf_modules_to_trtllm_modules.keys()), hf_modules=set(hf_modules_to_trtllm_modules.keys()),
lora_target_modules=self.lora_target_modules,
) )
for m in hf_target_modules: for m in hf_target_modules:
trtllm_module = hf_modules_to_trtllm_modules[m] trtllm_module = hf_modules_to_trtllm_modules[m]
lora_target_modules.append(trtllm_module) lora_target_modules.add(trtllm_module)
return lora_target_modules return list(lora_target_modules)
class NemoLoraLoader: class NemoLoraLoader:
@ -343,6 +349,7 @@ class LoraManager(object):
"moe_4h_to_h": 14, "moe_4h_to_h": 14,
"moe_gate": 15, "moe_gate": 15,
"moe_router": 16, "moe_router": 16,
"mlp_router": 17,
} }
def __init__(self): def __init__(self):
@ -615,7 +622,7 @@ class LoraManager(object):
is_moe = False is_moe = False
t_in = module_weights["in"] t_in = module_weights["in"]
t_out = module_weights["out"] t_out = module_weights["out"]
if lora_module in ["moe_router"]: if lora_module in ["moe_router", "mlp_router"]:
pass pass
elif "moe" in lora_module and runtime_mapping.has_moe_ep(): elif "moe" in lora_module and runtime_mapping.has_moe_ep():
pass pass

View File

@ -119,10 +119,6 @@ class LLaMAConfig(PretrainedConfig):
attn_bias = getattr(hf_config, 'bias', False) or getattr( attn_bias = getattr(hf_config, 'bias', False) or getattr(
hf_config, 'attention_bias', False) hf_config, 'attention_bias', False)
rotary_scaling = getattr(hf_config, "rope_scaling", None) rotary_scaling = getattr(hf_config, "rope_scaling", None)
if getattr(hf_config, "use_scaled_rope", False):
rotary_scaling = {"type": "wavelen"}
else:
rotary_scaling = getattr(hf_config, "rope_scaling", None)
rotary_base = getattr(hf_config, "rope_theta", 10000.0) rotary_base = getattr(hf_config, "rope_theta", 10000.0)
residual_mlp = getattr(hf_config, "parallel_attn_mlp_res", False) residual_mlp = getattr(hf_config, "parallel_attn_mlp_res", False)
disable_weight_only_quant_plugin = kwargs.pop( disable_weight_only_quant_plugin = kwargs.pop(
@ -219,7 +215,7 @@ class LLaMAConfig(PretrainedConfig):
dtype = 'float16' dtype = 'float16'
if meta_config.get('use_scaled_rope'): if meta_config.get('use_scaled_rope'):
rotary_scaling = {"type": "wavelen"} rotary_scaling = {"type": "llama3"}
else: else:
rotary_scaling = meta_config.get("rope_scaling") rotary_scaling = meta_config.get("rope_scaling")

View File

@ -15,8 +15,8 @@ from .._common import default_net
from .._utils import (get_init_params, numpy_to_torch, release_gc, from .._utils import (get_init_params, numpy_to_torch, release_gc,
str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch)
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
from ..layers import (AttentionParams, Embedding, FusedGatedMLP, FusedRgLru, from ..layers import (MLP, AttentionParams, Embedding, FusedGatedMLP,
GatedMLP, KeyValueCacheParams, LoraParams, FusedRgLru, GatedMLP, KeyValueCacheParams, LoraParams,
PromptTuningEmbedding, RgLru) PromptTuningEmbedding, RgLru)
from ..layers.attention import Attention, BertAttention from ..layers.attention import Attention, BertAttention
from ..layers.linear import ColumnLinear, Linear, RowLinear from ..layers.linear import ColumnLinear, Linear, RowLinear
@ -982,7 +982,7 @@ def add_lora(model: PretrainedModel,
out_hidden_sizes=[layer.out_features], out_hidden_sizes=[layer.out_features],
max_low_rank=max_rank, max_low_rank=max_rank,
) )
if isinstance(layer, FusedGatedMLP): if isinstance(layer, (MLP, FusedGatedMLP)):
if max_rank is None: if max_rank is None:
max_rank = min(layer.hidden_size, max_rank = min(layer.hidden_size,
layer.ffn_hidden_size // layer.tp_size) layer.ffn_hidden_size // layer.tp_size)

View File

@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from typing import Optional, Union from typing import Optional, Union
from tensorrt_llm.lora_manager import LoraConfig, use_lora
from ..._utils import pad_vocab_size from ..._utils import pad_vocab_size
from ...functional import Tensor, recv, send, sigmoid from ...functional import Tensor, recv, send, sigmoid
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear, from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, RmsNorm, RowLinear) Embedding, GatedMLP, RmsNorm, RowLinear)
from ...lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, use_lora)
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo
@ -143,10 +144,16 @@ class QWenDecoderLayer(Module):
shared_output = None shared_output = None
if self.config.qwen_type == 'qwen2_moe': if self.config.qwen_type == 'qwen2_moe':
shared_output = self.shared_expert(hidden_states) shared_output = self.shared_expert(
hidden_states, lora_layer_params=lora_layer_params)
if self.shared_expert_gate is not None: if self.shared_expert_gate is not None:
gate_lora_params = None
if lora_layer_params is not None:
gate_lora_params = lora_layer_params.get_runtime_params(
0, "mlp_router")
shared_output = sigmoid( shared_output = sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output self.shared_expert_gate(hidden_states,
gate_lora_params)) * shared_output
hidden_states = self.mlp(hidden_states, hidden_states = self.mlp(hidden_states,
lora_layer_params=lora_layer_params) lora_layer_params=lora_layer_params)
@ -247,6 +254,25 @@ class QWenForCausalLM(DecoderModelForCausalLM):
"mlp_4h_to_h": "mlp.c_proj", "mlp_4h_to_h": "mlp.c_proj",
"mlp_gate": "w1", "mlp_gate": "w1",
} }
elif config.qwen_type == 'qwen2_moe':
self.trtllm_modules_to_hf_modules = copy.copy(
get_default_trtllm_modules_to_hf_modules())
self.trtllm_modules_to_hf_modules.update({
"mlp_h_to_4h":
"mlp.shared_expert.gate_proj",
"mlp_4h_to_h":
"mlp.shared_expert.down_proj",
"mlp_gate":
"mlp.shared_expert.up_proj",
"mlp_router":
"mlp.shared_expert_gate",
"moe_h_to_4h":
"mlp.experts.gate_proj",
"moe_4h_to_h":
"mlp.experts.down_proj",
"moe_gate":
"mlp.experts.up_proj",
})
else: else:
self.trtllm_modules_to_hf_modules = None self.trtllm_modules_to_hf_modules = None
super().__init__(config, transformer, lm_head) super().__init__(config, transformer, lm_head)

View File

@ -27,7 +27,7 @@ import tensorrt as trt
from tensorrt_llm.module import Module from tensorrt_llm.module import Module
from ._common import set_network from ._common import set_network
from ._utils import get_extra_attr, has_extra_attr, set_extra_attr, trt_gte_10_1 from ._utils import get_extra_attr, has_extra_attr, set_extra_attr
from .logger import logger from .logger import logger
from .plugin import PluginConfig from .plugin import PluginConfig
@ -214,11 +214,8 @@ class Network(object):
logger.debug( logger.debug(
f'Add input: {name}, shape: {shape}, dtype: {dtype}, dimension names:{list(dim_range.keys())}' f'Add input: {name}, shape: {shape}, dtype: {dtype}, dimension names:{list(dim_range.keys())}'
) )
# NOTE: Multi-profile build sometimes fails with named dimensions in TRT < 10.1 : https://nvbugs/4645559 for i, dim_name in enumerate(dim_range.keys()):
# TODO: Remove this condition once things are stable with TRT 10.1 tensor.trt_tensor.set_dimension_name(i, str(dim_name))
if trt_gte_10_1():
for i, dim_name in enumerate(dim_range.keys()):
tensor.trt_tensor.set_dimension_name(i, str(dim_name))
else: else:
logger.debug(f'Add input: {name}, shape: {shape}, dtype: {dtype}') logger.debug(f'Add input: {name}, shape: {shape}, dtype: {dtype}')
self._inputs[name] = tensor self._inputs[name] = tensor

View File

@ -1452,10 +1452,10 @@ class SmoothQuantAttention(Module):
self.rotary_embedding_dim = 0 self.rotary_embedding_dim = 0
if rotary_embedding_scaling is not None: if rotary_embedding_scaling is not None:
assert rotary_embedding_scaling["type"] in ["linear", "dynamic"] self.rotary_embedding_scale_type = RotaryScalingType.from_string(
self.rotary_embedding_scale_type = RotaryScalingType.linear if rotary_embedding_scaling[ rotary_embedding_scaling["type"])
"type"] == "linear" else RotaryScalingType.dynamic self.rotary_embedding_scale = rotary_embedding_scaling.get(
self.rotary_embedding_scale = rotary_embedding_scaling["factor"] "factor", 1.0)
assert self.rotary_embedding_scale > 1.0 assert self.rotary_embedding_scale > 1.0
if self.position_embedding_type.is_rope(): if self.position_embedding_type.is_rope():
@ -1464,7 +1464,7 @@ class SmoothQuantAttention(Module):
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin( rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
self.max_position_embeddings, self.rotary_embedding_dim, self.max_position_embeddings, self.rotary_embedding_dim,
self.rotary_embedding_base, self.rotary_embedding_scale, self.rotary_embedding_base, self.rotary_embedding_scale,
self.rotary_embedding_scale_type) self.rotary_embedding_scale_type, rotary_embedding_scaling)
self.register_parameter( self.register_parameter(
'rotary_inv_freq', 'rotary_inv_freq',
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True)) Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))

View File

@ -34,7 +34,7 @@ from tensorrt_llm.runtime.redrafter_utils import *
from .._ipc_utils import set_peer_access from .._ipc_utils import set_peer_access
from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy, from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
trt_dtype_to_torch, trt_gte_10) trt_dtype_to_torch)
from ..logger import logger from ..logger import logger
from ..lora_manager import LoraManager from ..lora_manager import LoraManager
from ..mapping import Mapping from ..mapping import Mapping
@ -266,7 +266,7 @@ class _Runtime(object):
self.engine_inspector = self.engine.create_engine_inspector() self.engine_inspector = self.engine.create_engine_inspector()
# cuda graph ping-pong instances # cuda graph ping-pong instances
self.cuda_graph_instances = [None for _ in range(2)] self.cuda_graph_instances = [None for _ in range(2)]
if not (trt_gte_10() and self.engine.streamable_weights_size): if not self.engine.streamable_weights_size:
# engine does not have weight streaming enabled # engine does not have weight streaming enabled
self.__prepare_execution_contexts() self.__prepare_execution_contexts()
@ -371,20 +371,16 @@ class _Runtime(object):
self.context_1 = None self.context_1 = None
self.ctx_context = None self.ctx_context = None
if not trt_gte_10(): min = self.engine.minimum_weight_streaming_budget
assert gpu_weights_percent == 1, "Weight streaming is only supported by TensorRT 10.0 or later." max = self.engine.streamable_weights_size
return budget = int(min + gpu_weights_percent * (max - min))
else:
min = self.engine.minimum_weight_streaming_budget
max = self.engine.streamable_weights_size
budget = int(min + gpu_weights_percent * (max - min))
budget_config = budget if gpu_weights_percent != 1 else 0 budget_config = budget if gpu_weights_percent != 1 else 0
self.engine.weight_streaming_budget = budget_config self.engine.weight_streaming_budget = budget_config
assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!" assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!"
logger.info( logger.info(
f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes." f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes."
) )
if self.engine.streamable_weights_size: if self.engine.streamable_weights_size:
try: try:
@ -1087,9 +1083,6 @@ class GenerationSession(object):
# Create two cuda graph once.If cuda graph has already existed, skip it. # Create two cuda graph once.If cuda graph has already existed, skip it.
if self.runtime.cuda_graph_instances[instance_idx] is not None: if self.runtime.cuda_graph_instances[instance_idx] is not None:
return return
# WAR for TRT 9.x
if not trt_gte_10() and step < 3:
return
# capture cuda graph # capture cuda graph
CUASSERT( CUASSERT(
cudart.cudaStreamBeginCapture( cudart.cudaStreamBeginCapture(

View File

@ -24,7 +24,7 @@ import tensorrt as trt
import torch import torch
from .. import profiler from .. import profiler
from .._utils import mpi_comm, mpi_world_size, numpy_to_torch, trt_gte_10 from .._utils import mpi_comm, mpi_world_size, numpy_to_torch
from ..bindings import MpiComm from ..bindings import MpiComm
from ..bindings.executor import Executor from ..bindings.executor import Executor
from ..builder import Engine, get_engine_version from ..builder import Engine, get_engine_version
@ -520,7 +520,7 @@ class ModelRunner(ModelRunnerMixin):
runtime_mapping, runtime_mapping,
debug_mode=debug_mode, debug_mode=debug_mode,
stream=stream) stream=stream)
if trt_gte_10() and session.runtime.engine.streamable_weights_size: if session.runtime.engine.streamable_weights_size:
session.runtime._set_weight_streaming(gpu_weights_percent) session.runtime._set_weight_streaming(gpu_weights_percent)
if session.use_lora_plugin: if session.use_lora_plugin:
@ -623,7 +623,7 @@ class ModelRunner(ModelRunnerMixin):
else: else:
lora_manager = None lora_manager = None
if trt_gte_10() and session.runtime.engine.streamable_weights_size: if session.runtime.engine.streamable_weights_size:
session.runtime._set_weight_streaming(gpu_weights_percent) session.runtime._set_weight_streaming(gpu_weights_percent)
profiler.stop('load tensorrt_llm engine') profiler.stop('load tensorrt_llm engine')

View File

@ -23,7 +23,7 @@ import torch
import tensorrt as trt import tensorrt as trt
# isort: on # isort: on
from .._utils import torch_dtype_to_trt, trt_dtype_to_torch, trt_gte_10 from .._utils import torch_dtype_to_trt, trt_dtype_to_torch
from ..logger import logger from ..logger import logger
@ -66,7 +66,7 @@ class Session(object):
self._engine = self.runtime.deserialize_cuda_engine(engine_buffer) self._engine = self.runtime.deserialize_cuda_engine(engine_buffer)
self._context = None self._context = None
if not (trt_gte_10() and self.engine.streamable_weights_size): if not self.engine.streamable_weights_size:
self.__prepare_execution_contexts() self.__prepare_execution_contexts()
return self return self
@ -210,20 +210,16 @@ class Session(object):
self._context = None self._context = None
if not trt_gte_10(): min = self.engine.minimum_weight_streaming_budget
assert gpu_weights_percent == 1, "Weight streaming is only supported by TensorRT 10.0 or later." max = self.engine.streamable_weights_size
return budget = int(min + gpu_weights_percent * (max - min))
else:
min = self.engine.minimum_weight_streaming_budget
max = self.engine.streamable_weights_size
budget = int(min + gpu_weights_percent * (max - min))
budget_config = budget if gpu_weights_percent != 1 else 0 budget_config = budget if gpu_weights_percent != 1 else 0
self.engine.weight_streaming_budget = budget_config self.engine.weight_streaming_budget = budget_config
assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!" assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!"
logger.info( logger.info(
f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes." f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes."
) )
if self.engine.streamable_weights_size: if self.engine.streamable_weights_size:
try: try:

View File

@ -12,4 +12,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.12.0.dev2024072300" __version__ = "0.12.0.dev2024072301"

View File

@ -219,7 +219,12 @@ class TestMamba(unittest.TestCase):
# gen # gen
part_step1_id = step1_id[i].view(1, 1) part_step1_id = step1_id[i].view(1, 1)
part_hf_gen_outputs = hf_mamba.forward( part_hf_gen_outputs = hf_mamba.forward(
part_step1_id, cache_params=part_cache_params) part_step1_id,
cache_params=part_cache_params,
cache_position=torch.arange(
hf_config.conv_kernel - 1,
hf_config.conv_kernel,
device=part_step1_id.device))
torch.cuda.synchronize() torch.cuda.synchronize()
gen_ref[i][:] = part_hf_gen_outputs.logits[0, -1, :] gen_ref[i][:] = part_hf_gen_outputs.logits[0, -1, :]
else: else:
@ -231,7 +236,11 @@ class TestMamba(unittest.TestCase):
# gen # gen
hf_outputs = hf_mamba.forward(step1_id, hf_outputs = hf_mamba.forward(step1_id,
cache_params=cache_params, cache_params=cache_params,
use_cache=True) use_cache=True,
cache_position=torch.arange(
hf_config.conv_kernel - 1,
hf_config.conv_kernel,
device=step1_id.device))
gen_ref = hf_outputs.logits[:, -1, :] gen_ref = hf_outputs.logits[:, -1, :]
# get tensorrt llm mamba rumtime # get tensorrt llm mamba rumtime