mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#2016)
This commit is contained in:
parent
0d5ffae9a7
commit
5fa9436e17
@ -21,7 +21,7 @@ TensorRT-LLM
|
||||
🦙 400 tok/s - per node
|
||||
🦙 37 tok/s - per user
|
||||
🦙 1 node inference
|
||||
➡️ [link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/?ncid=so-twit-317976%E2%9C%A8)
|
||||
➡️ [link](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/picture-07-23-2024.png" width="45%">
|
||||
<div align="left">
|
||||
|
||||
@ -86,6 +86,7 @@ auto constexpr kLoraWeights = "lora_weights";
|
||||
// "moe_4h_to_h": 14 # for mixtral adapter for expert mlp layer: down projection
|
||||
// "moe_gate": 15 # for mixtral adapter for expert mlp layer: gate
|
||||
// "moe_router": 16 # for mixtral adapter for expert router layer
|
||||
// "mlp_router": 17 # for qwen2-moe adapter for shared expert gate layer
|
||||
//
|
||||
// last dim holds [ module_id, layer_idx, adapter_size (D / R value) ]
|
||||
auto constexpr kLoraConfig = "lora_config"; // [num_lora_modules_layers, 3]
|
||||
|
||||
@ -48,6 +48,7 @@ public:
|
||||
kMOE_4H_TO_H = 14,
|
||||
kMOE_GATE = 15,
|
||||
kMOE_ROUTER = 16,
|
||||
kMLP_ROUTER = 17,
|
||||
};
|
||||
|
||||
explicit constexpr LoraModule(ModuleType const& t, SizeType32 inDim, SizeType32 outDim, bool inDimFirst,
|
||||
@ -216,6 +217,8 @@ public:
|
||||
return ModuleType::kMOE_GATE;
|
||||
else if (name == "moe_router")
|
||||
return ModuleType::kMOE_ROUTER;
|
||||
else if (name == "mlp_router")
|
||||
return ModuleType::kMLP_ROUTER;
|
||||
else
|
||||
return ModuleType::kINVALID;
|
||||
}
|
||||
@ -241,6 +244,7 @@ public:
|
||||
case ModuleType::kMOE_4H_TO_H: return "moe_4h_to_h";
|
||||
case ModuleType::kMOE_GATE: return "moe_gate";
|
||||
case ModuleType::kMOE_ROUTER: return "moe_router";
|
||||
case ModuleType::kMLP_ROUTER: return "mlp_router";
|
||||
case ModuleType::kINVALID: return "INVALID";
|
||||
}
|
||||
return "INVALID";
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0f132408402aeb54b82891673aa3050811d69ec264399ed5f8d4f7a5cc63e2d8
|
||||
size 4293074
|
||||
oid sha256:3e25541cdc2aaa48f6a6e4c386d22ca1832c8e120fc6e8c190db4ee066ebfb1f
|
||||
size 4293186
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9f42dae7c82f4c59dc973f2e9f72d41d6f2e0e68b04c12d14f095b647890af86
|
||||
size 4395714
|
||||
oid sha256:3108cd0580f6328bd46238ef708872d9d8030a9c8645b8b52bc750dfe094bc16
|
||||
size 4395794
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
f2252f27a20618d3b7abe865c5192045 libtensorrt_llm_batch_manager_static.a
|
||||
ce8405cc0d369bf4fd79d30eef5ad9ed libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
3706e7395b9b58994412617992727c8ff2d14c9f commit
|
||||
50a839e98b31729198870fc99ef2c5a9 libtensorrt_llm_batch_manager_static.a
|
||||
a39a5bf618c8514725b59aac4513223f libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
3511a2653f2ba73f6f827aca6d2850b3d3e8e543 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b875b0b7c85fd18492865f6db704be09c55e823d322b65e4af58359d0576ad0a
|
||||
size 4154538
|
||||
oid sha256:9600435f1b9ab74c752d1831e1a6684a004927c84ab7c61fc076dbc128ca1521
|
||||
size 4154674
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c597863e8910b4ef0be61961598653813c0673949dd59b8938a1d6f231ad878e
|
||||
size 4133066
|
||||
oid sha256:8145ecf59dea64448ca0969553d32bc99e119cc5fc703e7b47eccfb5886594a0
|
||||
size 4133178
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2e06ed93f0745bc9196414f36de6ff1d98069110027e1dc95530b2a9be82176e
|
||||
size 24008762
|
||||
oid sha256:f89f551a880f4c6c1e68ed72b951ac482dec6033e55a336a0ecc401f4e9cf150
|
||||
size 24009160
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:510394b3e137e08b7292e68e445ccb9ed6986748e639b032413ad56f265078cb
|
||||
oid sha256:33f259b374a02456f2b8d44571d92195b708c2011be4ecabe46267f49ca24c29
|
||||
size 1426724
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a4a76bfb7611d6a7ef3c8e4a9e191eb4235739b5cf5e2642ac102c03e87c7e44
|
||||
oid sha256:f44786aee0842bdb260de49b734d2119a0521c650f0b733f5ce6f997e72bfb34
|
||||
size 1452984
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
93f42e0f10a6efb28073513b8a9c4471 libtensorrt_llm_executor_static.a
|
||||
533416c32056580e0e21ac5f771f3371 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
3706e7395b9b58994412617992727c8ff2d14c9f commit
|
||||
0d5e559ebc885794ab9e63086ae7a18a libtensorrt_llm_executor_static.a
|
||||
f9a3d1bf32f33f88569d4d8635e5445a libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
3511a2653f2ba73f6f827aca6d2850b3d3e8e543 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:378f06d390a108fc8fcdf34bd390e70ed98102c0b36647e292d49a9f680867a6
|
||||
oid sha256:19bd908d16990cd11a295fcb71403e2ad285dc2c3b84d55228166d9240acd0d9
|
||||
size 1476318
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d448b5a52066c61d73bea45b08561ecdfbb5aa49d46dd255fb714e7e0aa0ab41
|
||||
oid sha256:bed0b93d23eef43ce46c01e694f9e578c64fe9b30e1b05d65b7feed1a41e5148
|
||||
size 1408208
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4c4287066210a2511a5de3984be9d97318aa34fc0cc16c685deb342778d4f777
|
||||
oid sha256:473c672353cb813af9ea65250bd79f61f5ea27c369c9f35bc3bace1e22c5e9bb
|
||||
size 14325956
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
28ead889239ca8d558c1e1a93f0485b0 libtensorrt_llm_nvrtc_wrapper.so
|
||||
3706e7395b9b58994412617992727c8ff2d14c9f commit
|
||||
3511a2653f2ba73f6f827aca6d2850b3d3e8e543 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c116381592aea6404e15ace64a69425b35e59492c074920f867370c280c6ea93
|
||||
oid sha256:20824706210bf184641c92fcb728ab0a3a74a36bc0b13e243c713a84c74a51ac
|
||||
size 1089536
|
||||
|
||||
@ -60,7 +60,7 @@ enum class RotaryScalingType : int8_t
|
||||
kLINEAR = 1,
|
||||
kDYNAMIC = 2,
|
||||
kLONG = 3,
|
||||
kWAVELEN = 4
|
||||
kLLAMA3 = 4
|
||||
};
|
||||
|
||||
struct BlockSparseParams
|
||||
|
||||
@ -58,7 +58,8 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
|
||||
case ModuleType::kMOE_GATE:
|
||||
case ModuleType::kMOE_4H_TO_H:
|
||||
case ModuleType::kMOE_ROUTER:
|
||||
case ModuleType::kINVALID: throw std::runtime_error("Invalid loRA module " + moduleName);
|
||||
case ModuleType::kMLP_ROUTER:
|
||||
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
|
||||
}
|
||||
}
|
||||
return modules;
|
||||
|
||||
@ -112,6 +112,7 @@ The following tensors are for a LoRA which has a `q` and `k` adapter.
|
||||
| moe_4h_to_h | 14 | for mixtral adapter for expert mlp layer: down projection |
|
||||
| moe_gate | 15 | for mixtral adapter for expert mlp layer: gate |
|
||||
| moe_router | 16 | for mixtral adapter for expert router layer |
|
||||
| mlp_router | 17 | for qwen2-moe adapter for shared expert gate layer |
|
||||
|
||||
#### LoraCache configuration
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 673 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 126 KiB |
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.15.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
protobuf
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
transformers>=4.31.0
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
|
||||
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
|
||||
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
flax~=0.8.0
|
||||
# jax[cuda12_pip]~=0.4.19; platform_system != "Windows"
|
||||
jax~=0.4.19; platform_system == "Windows"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
rouge_score~=0.1.2
|
||||
evaluate~=0.4.1
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets==2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1209,7 +1209,7 @@ Note that the sink tokens is included in the sliding attention tokens, and there
|
||||
|
||||
## Run LLaMA-3.1 405B Model
|
||||
|
||||
Currently, TensorRT-LLM supports Meta checkpoint and Huggingface checkpoint for LLaMA-3.1. In this section, we demonstrate how to run the LLaMA-3.1 405B model via TensorRT-LLM. Here, we assume users have downloaded the checkpoints and placed them at `llama_3.1_405B_meta_model/` (Meta checkpoint) and `llama_3.1_405B_HF_model/` (HF checkpoint). Before converting the checkpoints to TensorRT-LLM unified checkpoints, **please check that `"use_scaled_rope": true` is set in the configuration file**. With this flag, TensorRT-LLM will enable the rope scaling of LLaMA-3.1. If not, please add it to the config file.
|
||||
Currently, TensorRT-LLM supports Meta checkpoint and Huggingface checkpoint for LLaMA-3.1. In this section, we demonstrate how to run the LLaMA-3.1 405B model via TensorRT-LLM. Here, we assume users have downloaded the checkpoints and placed them at `llama_3.1_405B_meta_model/` (Meta BF16 checkpoint), `llama_3.1_405B_HF_model/` (HF BF16 checkpoint) and `llama_3.1_405B_HF_FP8_model/` (HF FP8 checkpoint). Before converting the checkpoints to TensorRT-LLM unified checkpoints, **please check that `{"rope_scaling": {"rope_type": "llama3"}}` is set in the configuration file**. With this flag, TensorRT-LLM will enable the rope scaling of LLaMA-3.1. If not, please add it to the config file.
|
||||
|
||||
Users can run the LLaMA-3.1 model with higher precision (bf16/fp16) or fp8. Here, to prevent accuracy drop, we perform per-channel per-token fp8 quantization (leveraged from https://github.com/pytorch/FBGEMM) on MLP layers, keeping other layers at higher precision. Note that fp8 quantization is only supported on Huggingface checkpoint now. We will support it on Meta checkpoint soon.
|
||||
|
||||
@ -1217,7 +1217,10 @@ Users can run the LLaMA-3.1 model with higher precision (bf16/fp16) or fp8. Here
|
||||
|
||||
To use the fp8 quantization, please add the `--use_fp8_rowwise` flag during the checkpoint conversion. In this demonstration, we convert the Meta checkpoint to bfloat16 with TP8-PP2 and the HF checkpoint to FP8 with TP8.
|
||||
|
||||
Note that you may need to update your transformers installation via `pip install --upgrade transformers`.
|
||||
|
||||
```bash
|
||||
# Run BF16 model by BF16
|
||||
python examples/llama/convert_checkpoint.py --meta_ckpt_dir llama_3.1_405B_meta_model/ \
|
||||
--output_dir llama_3.1_405B_meta_model/trt_ckpts/tp8-pp2/ \
|
||||
--dtype bfloat16 \
|
||||
@ -1226,6 +1229,7 @@ python examples/llama/convert_checkpoint.py --meta_ckpt_dir llama_3.1_405B_meta_
|
||||
--load_by_shard \
|
||||
--workers 8
|
||||
|
||||
# Run BF16 model by FP8
|
||||
python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/ \
|
||||
--output_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \
|
||||
--dtype bfloat16 \
|
||||
@ -1234,6 +1238,15 @@ python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_model/
|
||||
--pp_size 1 \
|
||||
--load_by_shard \
|
||||
--workers 8
|
||||
|
||||
# Run FP8 model by FP8
|
||||
python examples/llama/convert_checkpoint.py --model_dir llama_3.1_405B_HF_FP8_model/ \
|
||||
--output_dir llama_3.1_405B_HF_FP8_model/trt_ckpts/tp8-pp1/ \
|
||||
--dtype bfloat16 \
|
||||
--tp_size 8 \
|
||||
--pp_size 1 \
|
||||
--load_by_shard \
|
||||
--workers 8
|
||||
```
|
||||
|
||||
### Build Engine
|
||||
@ -1254,6 +1267,14 @@ trtllm-build --checkpoint_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \
|
||||
--max_seq_len 65000 \
|
||||
--use_paged_context_fmha enable \
|
||||
--workers 8
|
||||
|
||||
trtllm-build --checkpoint_dir llama_3.1_405B_HF_FP8_model/trt_ckpts/tp8-pp1/ \
|
||||
--output_dir llama_3.1_405B_HF_FP8_model/trt_engines/tp8-pp1/ \
|
||||
--max_num_tokens 4096 \
|
||||
--max_input_len 64000 \
|
||||
--max_seq_len 65000 \
|
||||
--use_paged_context_fmha enable \
|
||||
--workers 8
|
||||
```
|
||||
|
||||
### Run Inference
|
||||
@ -1298,7 +1319,7 @@ srun --mpi pmi2 -N 1 -n 8 --ntasks-per-node 8 --container-image <your container>
|
||||
--container-name llama-3.1-405b \
|
||||
--container-workdir <your container work directory> \
|
||||
bash -c 'python ./examples/eval_long_context.py --task passkey \
|
||||
--engine_dir llama_3.1_405B_HF_model/trt_ckpts/tp8-pp1/ \
|
||||
--engine_dir llama_3.1_405B_HF_model/trt_engines/tp8-pp1/ \
|
||||
--tokenizer_dir llama_3.1_405B_HF_model/ \
|
||||
--stop_idx 6 \
|
||||
--max_input_length 64000 \
|
||||
@ -1307,6 +1328,22 @@ bash -c 'python ./examples/eval_long_context.py --task passkey \
|
||||
--max_tokens_in_paged_kv_cache 65064 \
|
||||
--data_dir 64k_context \
|
||||
--output_dir 64k_context_tp8'
|
||||
|
||||
# Long context test for 64k
|
||||
srun --mpi pmi2 -N 1 -n 8 --ntasks-per-node 8 --container-image <your container> \
|
||||
--container-mounts <your container mount> \
|
||||
--container-name llama-3.1-405b \
|
||||
--container-workdir <your container work directory> \
|
||||
bash -c 'python ./examples/eval_long_context.py --task passkey \
|
||||
--engine_dir llama_3.1_405B_HF_FP8_model/trt_engines/tp8-pp1/ \
|
||||
--tokenizer_dir llama_3.1_405B_HF_FP8_model/ \
|
||||
--stop_idx 6 \
|
||||
--max_input_length 64000 \
|
||||
--enable_chunked_context \
|
||||
--kv_cache_free_gpu_memory_fraction 0.999 \
|
||||
--max_tokens_in_paged_kv_cache 65064 \
|
||||
--data_dir 64k_context \
|
||||
--output_dir 64k_context_tp8'
|
||||
```
|
||||
|
||||
The following script shows how to run evaluation on MMLU tasks:
|
||||
@ -1332,5 +1369,16 @@ bash -c 'python ./examples/mmlu.py --test_trt_llm \
|
||||
--tokenizer_dir llama_3.1_405B_HF_model/ \
|
||||
--enable_chunked_context \
|
||||
--kv_cache_free_gpu_memory_fraction 0.999 \
|
||||
--max_tokens_in_paged_kv_cache 256064'
|
||||
--max_tokens_in_paged_kv_cache 65064'
|
||||
|
||||
srun --mpi pmi2 -N 1 -n 8 --ntasks-per-node 8 --container-image <your container> \
|
||||
--container-mounts <your container mount> \
|
||||
--container-name llama-3.1-405b \
|
||||
--container-workdir <your container work directory> \
|
||||
bash -c 'python ./examples/mmlu.py --test_trt_llm \
|
||||
--engine_dir llama_3.1_405B_HF_FP8_model/trt_engines/tp8-pp1/ \
|
||||
--tokenizer_dir llama_3.1_405B_HF_FP8_model/ \
|
||||
--enable_chunked_context \
|
||||
--kv_cache_free_gpu_memory_fraction 0.999 \
|
||||
--max_tokens_in_paged_kv_cache 65064'
|
||||
```
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
transformers>=4.39.0
|
||||
datasets~=2.14.5
|
||||
evaluate
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
transformers==4.38.2
|
||||
accelerate==0.25.0
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
transformers==4.40.2
|
||||
# https://github.com/NVIDIA/NeMo/issues/9793
|
||||
huggingface_hub==0.23.5
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets>=2.14.4
|
||||
nemo-toolkit[all]<=1.20.0,>=1.18.0
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.16.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.16.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
git+https://github.com/google-deepmind/recurrentgemma.git
|
||||
flax>=0.8.2
|
||||
jax~=0.4.23
|
||||
|
||||
@ -304,7 +304,7 @@ def main(args):
|
||||
encoder_input_lengths = [x.size(0)
|
||||
for x in encoder_input_ids] if is_enc_dec else None
|
||||
|
||||
if not supports_inflight_batching(
|
||||
if not args.use_py_session and not supports_inflight_batching(
|
||||
os.path.join(args.engine_dir, "decoder") if is_enc_dec else args.
|
||||
engine_dir):
|
||||
logger.warning(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets~=2.16.1
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
datasets==2.14.6
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
--extra-index-url https://pypi.nvidia.com
|
||||
tensorrt_llm==0.12.0.dev2024072300
|
||||
tensorrt_llm==0.12.0.dev2024072301
|
||||
tiktoken
|
||||
datasets
|
||||
kaldialign
|
||||
|
||||
@ -112,26 +112,12 @@ def trt_version():
|
||||
return trt.__version__
|
||||
|
||||
|
||||
# TRT supports strongly_typed in 9.1
|
||||
def support_strongly_type():
|
||||
return version.parse(trt_version()) >= version.parse("9.1.0")
|
||||
|
||||
|
||||
# Check if TRT version >= 10
|
||||
def trt_gte_10():
|
||||
return version.parse(trt_version()).major > 9
|
||||
|
||||
|
||||
# Check if TRT version >= 10.1
|
||||
def trt_gte_10_1():
|
||||
def trt_gte(major: int, minor: int = 0):
|
||||
"""
|
||||
Check if TRT version is greater than or equal to major.minor
|
||||
"""
|
||||
trt_ver = version.parse(trt_version())
|
||||
return trt_ver.major > 9 and trt_ver.minor > 0
|
||||
|
||||
|
||||
# Check if TRT version >= 10.2
|
||||
def trt_gte_10_2():
|
||||
ver = version.parse(trt_version())
|
||||
return (ver.major * 10 + ver.minor) >= 102
|
||||
return trt_ver.major >= major and trt_ver.minor >= minor
|
||||
|
||||
|
||||
def torch_version():
|
||||
|
||||
@ -13,7 +13,7 @@ import torch
|
||||
from filelock import FileLock
|
||||
|
||||
from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np,
|
||||
trt_dtype_to_torch, trt_gte_10)
|
||||
trt_dtype_to_torch)
|
||||
from tensorrt_llm.functional import (AllReduceConfig, AllReduceFusionParams,
|
||||
AllReduceStrategy, create_allreduce_plugin)
|
||||
from tensorrt_llm.logger import logger
|
||||
@ -39,7 +39,7 @@ from .tensor_parallel.sharding_strategy import ShardingStrategy
|
||||
from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer,
|
||||
to_trt_weights)
|
||||
|
||||
default_int_dtype = trt.int64 if trt_gte_10() else trt.int32
|
||||
default_int_dtype = trt.int64
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -25,8 +25,7 @@ from typing import Dict, Optional, Union
|
||||
import tensorrt as trt
|
||||
|
||||
from ._common import _is_building, check_max_num_tokens, serialize_engine
|
||||
from ._utils import (str_dtype_to_trt, support_strongly_type, to_json_file,
|
||||
trt_gte_10, trt_gte_10_2)
|
||||
from ._utils import str_dtype_to_trt, to_json_file
|
||||
from .auto_parallel import auto_parallel
|
||||
from .auto_parallel.config import AutoParallelConfig
|
||||
from .graph_rewriting import optimize
|
||||
@ -112,7 +111,7 @@ class Builder():
|
||||
explicit_batch_flag = 1 << int(
|
||||
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
|
||||
if support_strongly_type() and self.strongly_typed:
|
||||
if self.strongly_typed:
|
||||
return Network()._init(
|
||||
self.trt_builder.create_network(
|
||||
explicit_batch_flag
|
||||
@ -145,35 +144,15 @@ class Builder():
|
||||
@param int8: whether to build with int8 enabled or not. Can't be used together with refit option
|
||||
@return: A BuilderConfig object, return None if failed
|
||||
'''
|
||||
if strongly_typed and not support_strongly_type():
|
||||
logger.warning(
|
||||
"TRT version does not support strongly_type. strongly_typed flag is ignored."
|
||||
)
|
||||
|
||||
# In TRT 10.0, enable strongly_typed by default.
|
||||
self.strongly_typed = self.strongly_typed or (strongly_typed and
|
||||
support_strongly_type())
|
||||
self.strongly_typed = self.strongly_typed or strongly_typed
|
||||
|
||||
quant_mode = kwargs.get("quant_mode", QuantMode(0))
|
||||
if not strongly_typed and precision not in self._ALLOWED_PRECISIONS:
|
||||
logger.error(
|
||||
f"precision should be one of {self._ALLOWED_PRECISIONS}")
|
||||
|
||||
if use_strip_plan and not trt_gte_10():
|
||||
logger.error(
|
||||
"cannot use --strip_plan with tensorrt version 9.x or below")
|
||||
|
||||
if (use_refit or use_strip_plan) and int8 and not trt_gte_10():
|
||||
# TRT folds weights into Myelin graph because network contains int8 tensor or Q/DQ nodes
|
||||
# These folded weights can not be refitted
|
||||
logger.error(
|
||||
"can't use refit/strip_plan and int8 mode at the same time before tensorrt 10.0"
|
||||
)
|
||||
|
||||
config = self.trt_builder.create_builder_config()
|
||||
if weight_streaming:
|
||||
assert trt_gte_10(), \
|
||||
"Weight streaming is only supported by TensorRT 10.0 or later."
|
||||
config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
|
||||
if not self.strongly_typed:
|
||||
fp8 = quant_mode.has_fp8_qdq() or quant_mode.has_fp8_kv_cache()
|
||||
@ -197,7 +176,7 @@ class Builder():
|
||||
config.set_flag(trt.BuilderFlag.REFIT)
|
||||
|
||||
# Use fine-grained refit when strip plan is enabled in TRT10.2+.
|
||||
if use_strip_plan and trt_gte_10_2():
|
||||
if use_strip_plan:
|
||||
config.set_flag(trt.BuilderFlag.REFIT_INDIVIDUAL)
|
||||
|
||||
if use_strip_plan:
|
||||
@ -396,7 +375,6 @@ class Builder():
|
||||
engine = None
|
||||
|
||||
# Rename weights
|
||||
is_refit_individual_supported = trt_gte_10_2()
|
||||
if network.named_parameters is not None:
|
||||
for name, param in network.named_parameters:
|
||||
if param._get_weights() is None:
|
||||
@ -409,9 +387,8 @@ class Builder():
|
||||
if not network.trt_network.set_weights_name(
|
||||
param._get_weights(), name):
|
||||
raise RuntimeError(f'Failed to set weight: {name}')
|
||||
if is_refit_individual_supported:
|
||||
# This mark_weights_refittable has no side effect when refit_individual is not enabled.
|
||||
network.trt_network.mark_weights_refittable(name)
|
||||
# This mark_weights_refittable has no side effect when refit_individual is not enabled.
|
||||
network.trt_network.mark_weights_refittable(name)
|
||||
|
||||
network._fill_weights()
|
||||
# Build engine
|
||||
|
||||
@ -447,7 +447,8 @@ def main():
|
||||
# Extract rotary scaling which will be used for checks and default value of max_seq_len
|
||||
rotary_scaling = getattr(model_config, "rotary_scaling", None)
|
||||
if rotary_scaling is not None:
|
||||
rotary_type = rotary_scaling['type']
|
||||
rotary_type = rotary_scaling.get('type',
|
||||
rotary_scaling.get('rope_type'))
|
||||
rotary_factor = rotary_scaling.get(
|
||||
'factor', 1.0) if rotary_type != 'su' else 1
|
||||
else:
|
||||
|
||||
@ -30,8 +30,7 @@ from ._common import default_net, default_trtnet, precision
|
||||
from ._utils import (bf16_array, bool_array, dim_resolve_negative,
|
||||
dim_to_trt_axes, dims_array, fp16_array, fp32_array,
|
||||
int32_array, int64_array, np_dtype_to_trt,
|
||||
str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_str,
|
||||
trt_gte_10)
|
||||
str_dtype_to_trt, trt_dtype_to_np, trt_dtype_to_str)
|
||||
from .network import PluginInfo, set_np_weight, set_plugin_info
|
||||
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
|
||||
from .quantization import QuantMode
|
||||
@ -639,7 +638,7 @@ class RotaryScalingType(IntEnum):
|
||||
linear = 1
|
||||
dynamic = 2
|
||||
longrope = 3
|
||||
wavelen = 4
|
||||
llama3 = 4
|
||||
|
||||
@staticmethod
|
||||
def from_string(s):
|
||||
@ -1030,10 +1029,6 @@ def matmul(input: Tensor,
|
||||
# This option is only supported for fp16, but not bf16 or any other precisions.
|
||||
use_fp32_acc = use_fp32_acc and input.dtype == trt.DataType.HALF and mat2.dtype == trt.DataType.HALF
|
||||
|
||||
# TODO: fp32 accum has issues with strongly_typed and it will be fixed in TensorRT 10.0
|
||||
if default_net().strongly_typed and not trt_gte_10():
|
||||
use_fp32_acc = False
|
||||
|
||||
if use_fp32_acc:
|
||||
input = cast(input, 'float32')
|
||||
mat2 = cast(mat2, 'float32')
|
||||
@ -4139,11 +4134,14 @@ def bert_attention(tensor: Tensor,
|
||||
class RopeEmbeddingUtils:
|
||||
|
||||
@staticmethod
|
||||
def apply_wavelen_scaling(inv_freqs: np.ndarray,
|
||||
scale_factor: float = 8.0,
|
||||
low_freq_factor: float = 1.0,
|
||||
high_freq_factor: float = 4.0,
|
||||
old_context_len: int = 8192):
|
||||
# ref: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L298
|
||||
def apply_llama3_scaling(inv_freqs: np.ndarray, rope_scaling_config: dict):
|
||||
|
||||
scale_factor = rope_scaling_config.get("factor", 8.0)
|
||||
low_freq_factor = rope_scaling_config.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_scaling_config.get("high_freq_factor", 4.0)
|
||||
old_context_len = rope_scaling_config.get(
|
||||
"original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
@ -4183,12 +4181,19 @@ class RopeEmbeddingUtils:
|
||||
theta: float = 10000.0,
|
||||
scale: float = 1.0,
|
||||
scale_type: RotaryScalingType = RotaryScalingType.none,
|
||||
# Other scaling configs that only used by certain scaling types.
|
||||
rope_scaling_config: dict = None,
|
||||
dtype=np.float32):
|
||||
if scale_type == RotaryScalingType.linear:
|
||||
scale = 1.0 / scale
|
||||
inv_freq = scale / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype)
|
||||
if scale_type == RotaryScalingType.wavelen:
|
||||
inv_freq = RopeEmbeddingUtils.apply_wavelen_scaling(inv_freq)
|
||||
if scale_type == RotaryScalingType.llama3:
|
||||
assert rope_scaling_config is not None, "rotary_scaling config must be provided."
|
||||
inv_freq = 1.0 / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype)
|
||||
inv_freq = RopeEmbeddingUtils.apply_llama3_scaling(
|
||||
inv_freq, rope_scaling_config)
|
||||
else:
|
||||
inv_freq = scale / (theta
|
||||
**(np.arange(0, dim, 2) / dim)).astype(dtype)
|
||||
sinusoid_inp = np.expand_dims(np.einsum("i , j -> i j",
|
||||
np.arange(num_pos, dtype=dtype),
|
||||
inv_freq,
|
||||
@ -4618,7 +4623,7 @@ def gpt_attention(
|
||||
* RotaryScalingType.linear
|
||||
* RotaryScalingType.dynamic
|
||||
* RotaryScalingType.longrope
|
||||
* RotaryScalingType.wavelen
|
||||
* RotaryScalingType.llama3
|
||||
|
||||
rotary_embedding_scale: float
|
||||
The scale value to use for linear/dynamic scaling in RoPE.
|
||||
|
||||
@ -20,7 +20,7 @@ import tensorrt as trt
|
||||
|
||||
from .._common import default_net, precision
|
||||
from .._utils import (fp32_array, int32_array, is_same_dtype, trt_dtype_to_np,
|
||||
trt_dtype_to_str, trt_gte_10)
|
||||
trt_dtype_to_str)
|
||||
from ..functional import (ACT2FN, AllReduceFusionParams, AttentionMaskType,
|
||||
Conditional, LayerNormType, PositionEmbeddingType,
|
||||
RopeEmbeddingUtils, RotaryScalingType, Tensor, arange,
|
||||
@ -362,8 +362,10 @@ class Attention(Module):
|
||||
self.rotary_embedding_percentage = rotary_embedding_percentage
|
||||
self.use_implicit_relative_attention = self.relative_attention and use_implicit_relative_attention
|
||||
if rotary_embedding_scaling is not None:
|
||||
rotary_scaling_type = rotary_embedding_scaling.get(
|
||||
"type", rotary_embedding_scaling.get("rope_type"))
|
||||
self.rotary_embedding_scale_type = RotaryScalingType.from_string(
|
||||
rotary_embedding_scaling["type"])
|
||||
rotary_scaling_type)
|
||||
self.rotary_embedding_scale = rotary_embedding_scaling.get(
|
||||
"factor", 1.0)
|
||||
|
||||
@ -433,7 +435,8 @@ class Attention(Module):
|
||||
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||||
self.max_position_embeddings, self.rotary_embedding_dim,
|
||||
self.rotary_embedding_base, self.rotary_embedding_scale,
|
||||
self.rotary_embedding_scale_type)
|
||||
self.rotary_embedding_scale_type,
|
||||
self.rotary_embedding_scaling)
|
||||
self.register_parameter(
|
||||
'rotary_inv_freq',
|
||||
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
|
||||
@ -1153,12 +1156,7 @@ class Attention(Module):
|
||||
if norm_before_bmm1:
|
||||
# Apply norm on query earlier to prevent matmul fp16 overflow.
|
||||
query /= (self.q_scaling * self.norm_factor)
|
||||
if trt_gte_10() or self.position_embedding_type.is_alibi():
|
||||
attention_scores = matmul(query, key)
|
||||
else:
|
||||
# For TRT 9.x, OOTB need this WAR to fuse mha.
|
||||
attention_scores = matmul(cast(query, 'float32'),
|
||||
cast(key, 'float32'))
|
||||
attention_scores = matmul(query, key)
|
||||
if not norm_before_bmm1:
|
||||
attention_scores = attention_scores / (self.q_scaling *
|
||||
self.norm_factor)
|
||||
@ -1182,24 +1180,16 @@ class Attention(Module):
|
||||
|
||||
attention_probs = softmax(attention_scores, dim=-1)
|
||||
|
||||
if trt_gte_10() or self.position_embedding_type.is_alibi():
|
||||
# For trt_version() == 9.x and pos_embed == alibi, TRT has gpu buffer management issues. Need this WAR to avoid peak gpu mem regression.
|
||||
# A dummy reshape WAR for mha fusion for 10.0
|
||||
attention_probs = attention_probs.view(
|
||||
concat([
|
||||
shape(attention_probs, 0),
|
||||
shape(attention_probs, 1),
|
||||
shape(attention_probs, 2),
|
||||
shape(value, 2)
|
||||
]))
|
||||
context = matmul(attention_probs, value,
|
||||
use_fp32_acc=False).permute([0, 2, 1, 3])
|
||||
else:
|
||||
# For TRT 9.x, need this WAR to fuse mha.
|
||||
context = matmul(attention_probs,
|
||||
cast(value, 'float32')).permute([0, 2, 1, 3])
|
||||
if context.dtype != value.dtype:
|
||||
context = cast(context, value.dtype)
|
||||
# A dummy reshape WAR for mha fusion
|
||||
attention_probs = attention_probs.view(
|
||||
concat([
|
||||
shape(attention_probs, 0),
|
||||
shape(attention_probs, 1),
|
||||
shape(attention_probs, 2),
|
||||
shape(value, 2)
|
||||
]))
|
||||
context = matmul(attention_probs, value,
|
||||
use_fp32_acc=False).permute([0, 2, 1, 3])
|
||||
context = context.view(
|
||||
concat([
|
||||
shape(context, 0),
|
||||
|
||||
@ -18,7 +18,7 @@ import tensorrt as trt
|
||||
|
||||
from .._common import default_net
|
||||
from ..functional import (ACT2FN, AllReduceFusionParams, cast, concat,
|
||||
gemm_swiglu)
|
||||
gemm_swiglu, is_gated_activation)
|
||||
from ..module import Module
|
||||
from ..quantization import QuantMode
|
||||
from ..quantization.functional import quantize
|
||||
@ -28,6 +28,34 @@ from .lora import LoraRuntimeParams
|
||||
from .normalization import LayerNorm
|
||||
|
||||
|
||||
def fc_gate_lora(hidden_states, lora, lora_layer_params):
|
||||
if lora_layer_params is not None:
|
||||
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_h_to_4h")
|
||||
mlp_gate_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_gate")
|
||||
|
||||
if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
|
||||
mlp_in_lora_params = LoraRuntimeParams(
|
||||
lora_ranks=[
|
||||
mlp_fc_lora_params.lora_ranks[0],
|
||||
mlp_gate_lora_params.lora_ranks[0]
|
||||
],
|
||||
lora_weights_pointers=[
|
||||
mlp_fc_lora_params.lora_weights_pointers[0],
|
||||
mlp_gate_lora_params.lora_weights_pointers[0]
|
||||
],
|
||||
host_request_types=mlp_fc_lora_params.host_request_types,
|
||||
host_context_lengths=mlp_fc_lora_params.host_context_lengths,
|
||||
max_context_length=mlp_fc_lora_params.max_context_length)
|
||||
|
||||
mlp_fc_lora, mlp_gate_lora = lora(hidden_states, mlp_in_lora_params)
|
||||
mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora],
|
||||
dim=mlp_fc_lora.rank() - 1)
|
||||
return mlp_in_result
|
||||
return None
|
||||
|
||||
|
||||
class MLP(Module):
|
||||
|
||||
def __init__(
|
||||
@ -76,19 +104,28 @@ class MLP(Module):
|
||||
self.tp_size = tp_size
|
||||
self.quant_mode = quant_mode
|
||||
self.eps = eps
|
||||
# see optimize_model's add_lora for LoRA initialization
|
||||
self.lora = None
|
||||
|
||||
def forward(self, hidden_states, lora_layer_params=None, gegelu_limit=None):
|
||||
mlp_fc_lora_params = None
|
||||
if lora_layer_params is not None:
|
||||
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_h_to_4h")
|
||||
if is_gated_activation(self.hidden_act):
|
||||
inter = self.fc(hidden_states)
|
||||
lora_result = fc_gate_lora(hidden_states, self.lora,
|
||||
lora_layer_params)
|
||||
if lora_result is not None:
|
||||
inter = inter + lora_result
|
||||
else:
|
||||
mlp_fc_lora_params = None
|
||||
if lora_layer_params is not None:
|
||||
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_h_to_4h")
|
||||
inter = self.fc(hidden_states, mlp_fc_lora_params)
|
||||
|
||||
mlp_proj_lora_params = None
|
||||
if lora_layer_params is not None:
|
||||
mlp_proj_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_4h_to_h")
|
||||
|
||||
inter = self.fc(hidden_states, mlp_fc_lora_params)
|
||||
if self.hidden_act == 'gegelu':
|
||||
inter = ACT2FN[self.hidden_act](inter, gegelu_limit)
|
||||
else:
|
||||
@ -286,32 +323,9 @@ class FusedGatedMLP(Module):
|
||||
|
||||
inter = self.fused_fc(hidden_states)
|
||||
|
||||
if lora_layer_params is not None:
|
||||
mlp_fc_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_h_to_4h")
|
||||
mlp_gate_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_gate")
|
||||
|
||||
if mlp_fc_lora_params is not None and mlp_gate_lora_params is not None:
|
||||
mlp_in_lora_params = LoraRuntimeParams(
|
||||
lora_ranks=[
|
||||
mlp_fc_lora_params.lora_ranks[0],
|
||||
mlp_gate_lora_params.lora_ranks[0]
|
||||
],
|
||||
lora_weights_pointers=[
|
||||
mlp_fc_lora_params.lora_weights_pointers[0],
|
||||
mlp_gate_lora_params.lora_weights_pointers[0]
|
||||
],
|
||||
host_request_types=mlp_fc_lora_params.host_request_types,
|
||||
host_context_lengths=mlp_fc_lora_params.
|
||||
host_context_lengths,
|
||||
max_context_length=mlp_fc_lora_params.max_context_length)
|
||||
|
||||
mlp_fc_lora, mlp_gate_lora = self.lora(hidden_states,
|
||||
mlp_in_lora_params)
|
||||
mlp_in_result = concat([mlp_gate_lora, mlp_fc_lora],
|
||||
dim=mlp_fc_lora.rank() - 1)
|
||||
inter = inter + mlp_in_result
|
||||
lora_result = fc_gate_lora(hidden_states, self.lora, lora_layer_params)
|
||||
if lora_result is not None:
|
||||
inter = inter + lora_result
|
||||
|
||||
if self.hidden_act == 'silu':
|
||||
inter = ACT2FN['swiglu'](inter)
|
||||
|
||||
@ -18,10 +18,8 @@ from typing import List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
from packaging import version
|
||||
|
||||
from tensorrt_llm._utils import (get_init_params, str_dtype_to_trt, trt_gte_10,
|
||||
trt_version)
|
||||
from tensorrt_llm._utils import get_init_params, str_dtype_to_trt
|
||||
from tensorrt_llm.layers.lora import LoraParams
|
||||
|
||||
from .._common import default_net, default_trtnet
|
||||
@ -553,106 +551,81 @@ class MoeOOTB(MOE):
|
||||
router_probs = softmax(routing, -1)
|
||||
topk_values, topk_indices = topk(router_probs, self.top_k, dim=-1)
|
||||
|
||||
if trt_gte_10() and version.parse(trt_version()).minor >= 2:
|
||||
# For TRT 10.2 and above, avoid over-computing by using NonZero ops to select tokens for each experts.
|
||||
hidden_size = shape(hidden_states, -1)
|
||||
# [B*sq, hidden]
|
||||
inputs_merged = hidden_states.view(concat([-1, hidden_size]))
|
||||
flat_topk_indices = topk_indices.view(
|
||||
concat([-1, shape(topk_indices, -1)]))
|
||||
flat_topk_values = topk_values.view(concat([-1,
|
||||
shape(topk_values, -1)]))
|
||||
|
||||
hidden_size = shape(hidden_states, -1)
|
||||
#[B*sq, hidden]
|
||||
inputs_merged = hidden_states.view(concat([-1, hidden_size]))
|
||||
flat_topk_indices = topk_indices.view(
|
||||
concat([-1, shape(topk_indices, -1)]))
|
||||
flat_topk_values = topk_values.view(
|
||||
concat([-1, shape(topk_values, -1)]))
|
||||
# Create output space
|
||||
zero_buffer = inputs_merged * 0.0
|
||||
output = zero_buffer
|
||||
|
||||
# Create output space
|
||||
zero_buffer = inputs_merged * 0.0
|
||||
output = zero_buffer
|
||||
expert_indices_stack = []
|
||||
indices_stack = []
|
||||
# When topk indices are equal to expert index, the expert will inference the tokens.
|
||||
# Bundle all indices and experts index, then do mask once.
|
||||
for i, expert in enumerate(self.experts):
|
||||
if self.mapping.has_moe_ep():
|
||||
index = i + self.experts_per_node * self.mapping.moe_ep_rank
|
||||
else:
|
||||
index = i
|
||||
expert_indices_stack.append(
|
||||
flat_topk_indices.view(concat([1, shape(flat_topk_indices)])))
|
||||
|
||||
expert_indices_stack = []
|
||||
indices_stack = []
|
||||
# When topk indices are equal to expert index, the expert will inference the tokens.
|
||||
# Bundle all indices and experts index, then do mask once.
|
||||
for i, expert in enumerate(self.experts):
|
||||
if self.mapping.has_moe_ep():
|
||||
index = i + self.experts_per_node * self.mapping.moe_ep_rank
|
||||
else:
|
||||
index = i
|
||||
expert_indices_stack.append(
|
||||
flat_topk_indices.view(concat([1,
|
||||
shape(flat_topk_indices)])))
|
||||
indices_stack.append(constant(int32_array(index)))
|
||||
|
||||
indices_stack.append(constant(int32_array(index)))
|
||||
all_expert_indices = concat(expert_indices_stack, dim=0)
|
||||
indices = expand(
|
||||
concat(indices_stack).view(concat([len(self.experts), 1, 1])),
|
||||
shape(all_expert_indices))
|
||||
|
||||
all_expert_indices = concat(expert_indices_stack, dim=0)
|
||||
indices = expand(
|
||||
concat(indices_stack).view(concat([len(self.experts), 1, 1])),
|
||||
shape(all_expert_indices))
|
||||
# Create all experts mask
|
||||
all_expert_mask = all_expert_indices == indices
|
||||
|
||||
# Create all experts mask
|
||||
all_expert_mask = all_expert_indices == indices
|
||||
experts_weights = cast(
|
||||
sum(flat_topk_values *
|
||||
cast(all_expert_mask, flat_topk_values.dtype),
|
||||
dim=-1,
|
||||
keepdim=True), self.dtype)
|
||||
|
||||
experts_weights = cast(
|
||||
sum(flat_topk_values *
|
||||
cast(all_expert_mask, flat_topk_values.dtype),
|
||||
dim=-1,
|
||||
keepdim=True), self.dtype)
|
||||
all_expert_mask = cast(
|
||||
sum(cast(all_expert_mask, flat_topk_values.dtype),
|
||||
dim=-1,
|
||||
keepdim=True), 'bool')
|
||||
all_expert_mask = repeat_interleave(all_expert_mask, shape(output, -1),
|
||||
2)
|
||||
|
||||
all_expert_mask = cast(
|
||||
sum(cast(all_expert_mask, flat_topk_values.dtype),
|
||||
dim=-1,
|
||||
keepdim=True), 'bool')
|
||||
all_expert_mask = repeat_interleave(all_expert_mask,
|
||||
shape(output, -1), 2)
|
||||
# split the mask and weights for each expert
|
||||
experts_mask = split(all_expert_mask, 1, dim=0)
|
||||
expert_weights = split(experts_weights, 1, dim=0)
|
||||
|
||||
# split the mask and weights for each expert
|
||||
experts_mask = split(all_expert_mask, 1, dim=0)
|
||||
expert_weights = split(experts_weights, 1, dim=0)
|
||||
for i, expert in enumerate(self.experts):
|
||||
# get mask token index
|
||||
non_zero_index = nonzero(experts_mask[i].view(
|
||||
concat([-1, hidden_size])))
|
||||
non_zero_index = non_zero_index.transpose(1, 0)
|
||||
input_for_expert = gather_nd(inputs_merged, non_zero_index, 0)
|
||||
input_for_expert = input_for_expert.view(concat([-1, hidden_size]),
|
||||
zero_is_placeholder=False)
|
||||
|
||||
for i, expert in enumerate(self.experts):
|
||||
# get mask token index
|
||||
non_zero_index = nonzero(experts_mask[i].view(
|
||||
concat([-1, hidden_size])))
|
||||
non_zero_index = non_zero_index.transpose(1, 0)
|
||||
input_for_expert = gather_nd(inputs_merged, non_zero_index, 0)
|
||||
input_for_expert = input_for_expert.view(
|
||||
concat([-1, hidden_size]), zero_is_placeholder=False)
|
||||
# Expert inference
|
||||
expert_output = expert(
|
||||
input_for_expert,
|
||||
lora_layer_params=self.moe_to_expert_lora_params(
|
||||
lora_layer_params, index))
|
||||
|
||||
# Expert inference
|
||||
expert_output = expert(
|
||||
input_for_expert,
|
||||
lora_layer_params=self.moe_to_expert_lora_params(
|
||||
lora_layer_params, index))
|
||||
# scatter expert output to real position
|
||||
expert_finialized_output = zero_buffer
|
||||
expert_finialized_output = scatter_nd(
|
||||
expert_finialized_output, non_zero_index,
|
||||
expert_output.view([-1])) * expert_weights[i]
|
||||
|
||||
# scatter expert output to real position
|
||||
expert_finialized_output = zero_buffer
|
||||
expert_finialized_output = scatter_nd(
|
||||
expert_finialized_output, non_zero_index,
|
||||
expert_output.view([-1])) * expert_weights[i]
|
||||
output += expert_finialized_output
|
||||
|
||||
output += expert_finialized_output
|
||||
|
||||
output = output.view(shape(hidden_states))
|
||||
else:
|
||||
output = hidden_states * 0.0 # Create output space
|
||||
# Use over-computation when TRT version is too low.
|
||||
# Experts inference
|
||||
for i, expert in enumerate(self.experts):
|
||||
if self.mapping.has_moe_ep():
|
||||
index = i + self.experts_per_node * self.mapping.moe_ep_rank
|
||||
else:
|
||||
index = i
|
||||
# inference expert
|
||||
out = expert(hidden_states,
|
||||
lora_layer_params=self.moe_to_expert_lora_params(
|
||||
lora_layer_params, index))
|
||||
|
||||
expert_mask = topk_indices == index
|
||||
expert_weights = cast(
|
||||
sum(topk_values * cast(expert_mask, topk_values.dtype),
|
||||
dim=-1,
|
||||
keepdim=True), self.dtype)
|
||||
|
||||
output += out * expert_weights
|
||||
output = output.view(shape(hidden_states))
|
||||
|
||||
need_ep_reduce = self.mapping.has_moe_ep(
|
||||
) and self.mapping.moe_ep_group is not None
|
||||
|
||||
@ -13,7 +13,6 @@ import yaml
|
||||
from ._utils import (DictConversion, pad_vocab_size, release_gc,
|
||||
str_dtype_to_torch, torch_to_numpy)
|
||||
from .layers.linear import ColumnLinear
|
||||
from .logger import logger
|
||||
from .mapping import Mapping
|
||||
from .models.convert_utils import (get_model_path, load_state_dict,
|
||||
split_matrix_tp)
|
||||
@ -34,30 +33,43 @@ def get_all_nemo_lora_weights(lora_weights):
|
||||
m = layer_pattern.match(key)
|
||||
layer_idx = int(m.group(1))
|
||||
layer_weights[layer_idx][inout] = weights
|
||||
else:
|
||||
raise KeyError(f"unsupported key {key} from Nemo LoRA weights")
|
||||
return layer_weights
|
||||
|
||||
|
||||
def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
|
||||
# The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight
|
||||
HF_LORA_PATTERN = re.compile(
|
||||
r'(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.lora_(A|B)\.weight'
|
||||
)
|
||||
|
||||
|
||||
def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
|
||||
all_weights = defaultdict(lambda: defaultdict(dict))
|
||||
pattern = re.compile(
|
||||
r'(.*)\.(\d+)\.(\w+)\.(\w+|experts\.(\d+)\.(\w+))\.lora_(A|B)\.weight')
|
||||
pattern = HF_LORA_PATTERN
|
||||
for key, weights in lora_weights.items():
|
||||
m = pattern.match(key)
|
||||
if not m:
|
||||
if "lm_head" not in key and "embed_tokens" not in key:
|
||||
logger.warning(f"no match {key} from HF LoRA weights")
|
||||
raise KeyError(f"unsupported key {key} from HF LoRA weights")
|
||||
continue
|
||||
if component is not None and component not in m.group(1):
|
||||
continue
|
||||
layer_idx = int(m.group(2))
|
||||
expert_idx = m.group(5)
|
||||
expert_idx = m.group(6)
|
||||
is_moe = expert_idx is not None
|
||||
module_name = m.group(6 if is_moe else 4)
|
||||
hf_module = m.group(3) + "." + module_name
|
||||
if is_moe:
|
||||
expert_name = m.group(5)
|
||||
module_name = m.group(7)
|
||||
hf_module = m.group(3) + "." + expert_name + "." + module_name
|
||||
else:
|
||||
module_name = m.group(4)
|
||||
hf_module = m.group(3) + "." + module_name
|
||||
if hf_module not in hf_modules:
|
||||
hf_module = module_name
|
||||
assert hf_module in hf_modules
|
||||
inout = "in" if m.group(7) == "A" else "out"
|
||||
inout = "in" if m.group(8) == "A" else "out"
|
||||
iter_fn(layer_idx, hf_module, expert_idx, inout, weights)
|
||||
if not is_moe:
|
||||
all_weights[layer_idx][hf_module][inout] = weights
|
||||
else:
|
||||
@ -66,31 +78,27 @@ def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
|
||||
return all_weights
|
||||
|
||||
|
||||
def get_hf_target_modules(lora_weights, hf_modules, lora_target_modules):
|
||||
hf_target_modules = set()
|
||||
pattern = re.compile(
|
||||
r'(.*)\.(\d+)\.(\w+)\.(\w+|experts\.(\d+)\.(\w+))\.lora_(A|B)\.weight')
|
||||
for key in lora_weights.keys():
|
||||
m = pattern.match(key)
|
||||
if not m:
|
||||
if "lm_head" not in key and "embed_tokens" not in key:
|
||||
logger.warning(f"no match {key} from HF LoRA weights")
|
||||
continue
|
||||
match_target_module = False
|
||||
for module in lora_target_modules:
|
||||
if module in key:
|
||||
match_target_module = True
|
||||
break
|
||||
if not match_target_module:
|
||||
continue
|
||||
expert_idx = m.group(5)
|
||||
is_moe = expert_idx is not None
|
||||
module_name = m.group(6 if is_moe else 4)
|
||||
hf_module = m.group(3) + "." + module_name
|
||||
if hf_module not in hf_modules:
|
||||
hf_module = module_name
|
||||
assert hf_module in hf_modules
|
||||
def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
|
||||
|
||||
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
|
||||
if expert_idx is None:
|
||||
all_weights[layer_idx][hf_module][inout] = weights
|
||||
else:
|
||||
all_weights[layer_idx][hf_module].setdefault(expert_idx, {})
|
||||
all_weights[layer_idx][hf_module][expert_idx][inout] = weights
|
||||
|
||||
all_weights = defaultdict(lambda: defaultdict(dict))
|
||||
iterate_hf_lora(iter_fn, lora_weights, hf_modules, component)
|
||||
return all_weights
|
||||
|
||||
|
||||
def get_hf_target_modules(lora_weights, hf_modules):
|
||||
|
||||
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
|
||||
hf_target_modules.add(hf_module)
|
||||
|
||||
hf_target_modules = set()
|
||||
iterate_hf_lora(iter_fn, lora_weights, hf_modules)
|
||||
return hf_target_modules
|
||||
|
||||
|
||||
@ -146,7 +154,6 @@ class HfLoraLoader:
|
||||
lora_dir = lora_dirs[0]
|
||||
with open(f"{lora_dir}/adapter_config.json") as f:
|
||||
adapter_config = json.load(f)
|
||||
self.lora_target_modules = adapter_config["target_modules"]
|
||||
|
||||
lora_weight = load_state_dict(get_model_path(lora_dir, "adapter_model"))
|
||||
self.lora_weight = lora_weight
|
||||
@ -162,17 +169,16 @@ class HfLoraLoader:
|
||||
def get_target_modules(self, trtllm_modules_to_hf_modules):
|
||||
hf_modules_to_trtllm_modules = invert_module_mapping(
|
||||
trtllm_modules_to_hf_modules)
|
||||
lora_target_modules = []
|
||||
lora_target_modules = set()
|
||||
if self.is_valid:
|
||||
hf_target_modules = get_hf_target_modules(
|
||||
self.lora_weight,
|
||||
hf_modules=set(hf_modules_to_trtllm_modules.keys()),
|
||||
lora_target_modules=self.lora_target_modules,
|
||||
)
|
||||
for m in hf_target_modules:
|
||||
trtllm_module = hf_modules_to_trtllm_modules[m]
|
||||
lora_target_modules.append(trtllm_module)
|
||||
return lora_target_modules
|
||||
lora_target_modules.add(trtllm_module)
|
||||
return list(lora_target_modules)
|
||||
|
||||
|
||||
class NemoLoraLoader:
|
||||
@ -343,6 +349,7 @@ class LoraManager(object):
|
||||
"moe_4h_to_h": 14,
|
||||
"moe_gate": 15,
|
||||
"moe_router": 16,
|
||||
"mlp_router": 17,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@ -615,7 +622,7 @@ class LoraManager(object):
|
||||
is_moe = False
|
||||
t_in = module_weights["in"]
|
||||
t_out = module_weights["out"]
|
||||
if lora_module in ["moe_router"]:
|
||||
if lora_module in ["moe_router", "mlp_router"]:
|
||||
pass
|
||||
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
|
||||
pass
|
||||
|
||||
@ -119,10 +119,6 @@ class LLaMAConfig(PretrainedConfig):
|
||||
attn_bias = getattr(hf_config, 'bias', False) or getattr(
|
||||
hf_config, 'attention_bias', False)
|
||||
rotary_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if getattr(hf_config, "use_scaled_rope", False):
|
||||
rotary_scaling = {"type": "wavelen"}
|
||||
else:
|
||||
rotary_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
rotary_base = getattr(hf_config, "rope_theta", 10000.0)
|
||||
residual_mlp = getattr(hf_config, "parallel_attn_mlp_res", False)
|
||||
disable_weight_only_quant_plugin = kwargs.pop(
|
||||
@ -219,7 +215,7 @@ class LLaMAConfig(PretrainedConfig):
|
||||
dtype = 'float16'
|
||||
|
||||
if meta_config.get('use_scaled_rope'):
|
||||
rotary_scaling = {"type": "wavelen"}
|
||||
rotary_scaling = {"type": "llama3"}
|
||||
else:
|
||||
rotary_scaling = meta_config.get("rope_scaling")
|
||||
|
||||
|
||||
@ -15,8 +15,8 @@ from .._common import default_net
|
||||
from .._utils import (get_init_params, numpy_to_torch, release_gc,
|
||||
str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch)
|
||||
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits
|
||||
from ..layers import (AttentionParams, Embedding, FusedGatedMLP, FusedRgLru,
|
||||
GatedMLP, KeyValueCacheParams, LoraParams,
|
||||
from ..layers import (MLP, AttentionParams, Embedding, FusedGatedMLP,
|
||||
FusedRgLru, GatedMLP, KeyValueCacheParams, LoraParams,
|
||||
PromptTuningEmbedding, RgLru)
|
||||
from ..layers.attention import Attention, BertAttention
|
||||
from ..layers.linear import ColumnLinear, Linear, RowLinear
|
||||
@ -982,7 +982,7 @@ def add_lora(model: PretrainedModel,
|
||||
out_hidden_sizes=[layer.out_features],
|
||||
max_low_rank=max_rank,
|
||||
)
|
||||
if isinstance(layer, FusedGatedMLP):
|
||||
if isinstance(layer, (MLP, FusedGatedMLP)):
|
||||
if max_rank is None:
|
||||
max_rank = min(layer.hidden_size,
|
||||
layer.ffn_hidden_size // layer.tp_size)
|
||||
|
||||
@ -13,14 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Optional, Union
|
||||
|
||||
from tensorrt_llm.lora_manager import LoraConfig, use_lora
|
||||
|
||||
from ..._utils import pad_vocab_size
|
||||
from ...functional import Tensor, recv, send, sigmoid
|
||||
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
|
||||
Embedding, GatedMLP, RmsNorm, RowLinear)
|
||||
from ...lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules, use_lora)
|
||||
from ...mapping import Mapping
|
||||
from ...module import Module
|
||||
from ...quantization import W8A8_SQ_PLUGIN_LIST, QuantAlgo
|
||||
@ -143,10 +144,16 @@ class QWenDecoderLayer(Module):
|
||||
|
||||
shared_output = None
|
||||
if self.config.qwen_type == 'qwen2_moe':
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
shared_output = self.shared_expert(
|
||||
hidden_states, lora_layer_params=lora_layer_params)
|
||||
if self.shared_expert_gate is not None:
|
||||
gate_lora_params = None
|
||||
if lora_layer_params is not None:
|
||||
gate_lora_params = lora_layer_params.get_runtime_params(
|
||||
0, "mlp_router")
|
||||
shared_output = sigmoid(
|
||||
self.shared_expert_gate(hidden_states)) * shared_output
|
||||
self.shared_expert_gate(hidden_states,
|
||||
gate_lora_params)) * shared_output
|
||||
|
||||
hidden_states = self.mlp(hidden_states,
|
||||
lora_layer_params=lora_layer_params)
|
||||
@ -247,6 +254,25 @@ class QWenForCausalLM(DecoderModelForCausalLM):
|
||||
"mlp_4h_to_h": "mlp.c_proj",
|
||||
"mlp_gate": "w1",
|
||||
}
|
||||
elif config.qwen_type == 'qwen2_moe':
|
||||
self.trtllm_modules_to_hf_modules = copy.copy(
|
||||
get_default_trtllm_modules_to_hf_modules())
|
||||
self.trtllm_modules_to_hf_modules.update({
|
||||
"mlp_h_to_4h":
|
||||
"mlp.shared_expert.gate_proj",
|
||||
"mlp_4h_to_h":
|
||||
"mlp.shared_expert.down_proj",
|
||||
"mlp_gate":
|
||||
"mlp.shared_expert.up_proj",
|
||||
"mlp_router":
|
||||
"mlp.shared_expert_gate",
|
||||
"moe_h_to_4h":
|
||||
"mlp.experts.gate_proj",
|
||||
"moe_4h_to_h":
|
||||
"mlp.experts.down_proj",
|
||||
"moe_gate":
|
||||
"mlp.experts.up_proj",
|
||||
})
|
||||
else:
|
||||
self.trtllm_modules_to_hf_modules = None
|
||||
super().__init__(config, transformer, lm_head)
|
||||
|
||||
@ -27,7 +27,7 @@ import tensorrt as trt
|
||||
from tensorrt_llm.module import Module
|
||||
|
||||
from ._common import set_network
|
||||
from ._utils import get_extra_attr, has_extra_attr, set_extra_attr, trt_gte_10_1
|
||||
from ._utils import get_extra_attr, has_extra_attr, set_extra_attr
|
||||
from .logger import logger
|
||||
from .plugin import PluginConfig
|
||||
|
||||
@ -214,11 +214,8 @@ class Network(object):
|
||||
logger.debug(
|
||||
f'Add input: {name}, shape: {shape}, dtype: {dtype}, dimension names:{list(dim_range.keys())}'
|
||||
)
|
||||
# NOTE: Multi-profile build sometimes fails with named dimensions in TRT < 10.1 : https://nvbugs/4645559
|
||||
# TODO: Remove this condition once things are stable with TRT 10.1
|
||||
if trt_gte_10_1():
|
||||
for i, dim_name in enumerate(dim_range.keys()):
|
||||
tensor.trt_tensor.set_dimension_name(i, str(dim_name))
|
||||
for i, dim_name in enumerate(dim_range.keys()):
|
||||
tensor.trt_tensor.set_dimension_name(i, str(dim_name))
|
||||
else:
|
||||
logger.debug(f'Add input: {name}, shape: {shape}, dtype: {dtype}')
|
||||
self._inputs[name] = tensor
|
||||
|
||||
@ -1452,10 +1452,10 @@ class SmoothQuantAttention(Module):
|
||||
self.rotary_embedding_dim = 0
|
||||
|
||||
if rotary_embedding_scaling is not None:
|
||||
assert rotary_embedding_scaling["type"] in ["linear", "dynamic"]
|
||||
self.rotary_embedding_scale_type = RotaryScalingType.linear if rotary_embedding_scaling[
|
||||
"type"] == "linear" else RotaryScalingType.dynamic
|
||||
self.rotary_embedding_scale = rotary_embedding_scaling["factor"]
|
||||
self.rotary_embedding_scale_type = RotaryScalingType.from_string(
|
||||
rotary_embedding_scaling["type"])
|
||||
self.rotary_embedding_scale = rotary_embedding_scaling.get(
|
||||
"factor", 1.0)
|
||||
assert self.rotary_embedding_scale > 1.0
|
||||
|
||||
if self.position_embedding_type.is_rope():
|
||||
@ -1464,7 +1464,7 @@ class SmoothQuantAttention(Module):
|
||||
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||||
self.max_position_embeddings, self.rotary_embedding_dim,
|
||||
self.rotary_embedding_base, self.rotary_embedding_scale,
|
||||
self.rotary_embedding_scale_type)
|
||||
self.rotary_embedding_scale_type, rotary_embedding_scaling)
|
||||
self.register_parameter(
|
||||
'rotary_inv_freq',
|
||||
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
|
||||
|
||||
@ -34,7 +34,7 @@ from tensorrt_llm.runtime.redrafter_utils import *
|
||||
|
||||
from .._ipc_utils import set_peer_access
|
||||
from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy,
|
||||
trt_dtype_to_torch, trt_gte_10)
|
||||
trt_dtype_to_torch)
|
||||
from ..logger import logger
|
||||
from ..lora_manager import LoraManager
|
||||
from ..mapping import Mapping
|
||||
@ -266,7 +266,7 @@ class _Runtime(object):
|
||||
self.engine_inspector = self.engine.create_engine_inspector()
|
||||
# cuda graph ping-pong instances
|
||||
self.cuda_graph_instances = [None for _ in range(2)]
|
||||
if not (trt_gte_10() and self.engine.streamable_weights_size):
|
||||
if not self.engine.streamable_weights_size:
|
||||
# engine does not have weight streaming enabled
|
||||
self.__prepare_execution_contexts()
|
||||
|
||||
@ -371,20 +371,16 @@ class _Runtime(object):
|
||||
self.context_1 = None
|
||||
self.ctx_context = None
|
||||
|
||||
if not trt_gte_10():
|
||||
assert gpu_weights_percent == 1, "Weight streaming is only supported by TensorRT 10.0 or later."
|
||||
return
|
||||
else:
|
||||
min = self.engine.minimum_weight_streaming_budget
|
||||
max = self.engine.streamable_weights_size
|
||||
budget = int(min + gpu_weights_percent * (max - min))
|
||||
min = self.engine.minimum_weight_streaming_budget
|
||||
max = self.engine.streamable_weights_size
|
||||
budget = int(min + gpu_weights_percent * (max - min))
|
||||
|
||||
budget_config = budget if gpu_weights_percent != 1 else 0
|
||||
self.engine.weight_streaming_budget = budget_config
|
||||
assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!"
|
||||
logger.info(
|
||||
f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes."
|
||||
)
|
||||
budget_config = budget if gpu_weights_percent != 1 else 0
|
||||
self.engine.weight_streaming_budget = budget_config
|
||||
assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!"
|
||||
logger.info(
|
||||
f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes."
|
||||
)
|
||||
|
||||
if self.engine.streamable_weights_size:
|
||||
try:
|
||||
@ -1087,9 +1083,6 @@ class GenerationSession(object):
|
||||
# Create two cuda graph once.If cuda graph has already existed, skip it.
|
||||
if self.runtime.cuda_graph_instances[instance_idx] is not None:
|
||||
return
|
||||
# WAR for TRT 9.x
|
||||
if not trt_gte_10() and step < 3:
|
||||
return
|
||||
# capture cuda graph
|
||||
CUASSERT(
|
||||
cudart.cudaStreamBeginCapture(
|
||||
|
||||
@ -24,7 +24,7 @@ import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from .. import profiler
|
||||
from .._utils import mpi_comm, mpi_world_size, numpy_to_torch, trt_gte_10
|
||||
from .._utils import mpi_comm, mpi_world_size, numpy_to_torch
|
||||
from ..bindings import MpiComm
|
||||
from ..bindings.executor import Executor
|
||||
from ..builder import Engine, get_engine_version
|
||||
@ -520,7 +520,7 @@ class ModelRunner(ModelRunnerMixin):
|
||||
runtime_mapping,
|
||||
debug_mode=debug_mode,
|
||||
stream=stream)
|
||||
if trt_gte_10() and session.runtime.engine.streamable_weights_size:
|
||||
if session.runtime.engine.streamable_weights_size:
|
||||
session.runtime._set_weight_streaming(gpu_weights_percent)
|
||||
|
||||
if session.use_lora_plugin:
|
||||
@ -623,7 +623,7 @@ class ModelRunner(ModelRunnerMixin):
|
||||
else:
|
||||
lora_manager = None
|
||||
|
||||
if trt_gte_10() and session.runtime.engine.streamable_weights_size:
|
||||
if session.runtime.engine.streamable_weights_size:
|
||||
session.runtime._set_weight_streaming(gpu_weights_percent)
|
||||
|
||||
profiler.stop('load tensorrt_llm engine')
|
||||
|
||||
@ -23,7 +23,7 @@ import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
|
||||
from .._utils import torch_dtype_to_trt, trt_dtype_to_torch, trt_gte_10
|
||||
from .._utils import torch_dtype_to_trt, trt_dtype_to_torch
|
||||
from ..logger import logger
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ class Session(object):
|
||||
self._engine = self.runtime.deserialize_cuda_engine(engine_buffer)
|
||||
|
||||
self._context = None
|
||||
if not (trt_gte_10() and self.engine.streamable_weights_size):
|
||||
if not self.engine.streamable_weights_size:
|
||||
self.__prepare_execution_contexts()
|
||||
return self
|
||||
|
||||
@ -210,20 +210,16 @@ class Session(object):
|
||||
|
||||
self._context = None
|
||||
|
||||
if not trt_gte_10():
|
||||
assert gpu_weights_percent == 1, "Weight streaming is only supported by TensorRT 10.0 or later."
|
||||
return
|
||||
else:
|
||||
min = self.engine.minimum_weight_streaming_budget
|
||||
max = self.engine.streamable_weights_size
|
||||
budget = int(min + gpu_weights_percent * (max - min))
|
||||
min = self.engine.minimum_weight_streaming_budget
|
||||
max = self.engine.streamable_weights_size
|
||||
budget = int(min + gpu_weights_percent * (max - min))
|
||||
|
||||
budget_config = budget if gpu_weights_percent != 1 else 0
|
||||
self.engine.weight_streaming_budget = budget_config
|
||||
assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!"
|
||||
logger.info(
|
||||
f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes."
|
||||
)
|
||||
budget_config = budget if gpu_weights_percent != 1 else 0
|
||||
self.engine.weight_streaming_budget = budget_config
|
||||
assert self.engine.weight_streaming_budget == budget_config, "Failed to set weight streaming budget!"
|
||||
logger.info(
|
||||
f"Set gpu weights percent to {gpu_weights_percent}, which is {budget} bytes. Valid range: {min} bytes ~ {max} bytes."
|
||||
)
|
||||
|
||||
if self.engine.streamable_weights_size:
|
||||
try:
|
||||
|
||||
@ -12,4 +12,4 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
__version__ = "0.12.0.dev2024072300"
|
||||
__version__ = "0.12.0.dev2024072301"
|
||||
|
||||
@ -219,7 +219,12 @@ class TestMamba(unittest.TestCase):
|
||||
# gen
|
||||
part_step1_id = step1_id[i].view(1, 1)
|
||||
part_hf_gen_outputs = hf_mamba.forward(
|
||||
part_step1_id, cache_params=part_cache_params)
|
||||
part_step1_id,
|
||||
cache_params=part_cache_params,
|
||||
cache_position=torch.arange(
|
||||
hf_config.conv_kernel - 1,
|
||||
hf_config.conv_kernel,
|
||||
device=part_step1_id.device))
|
||||
torch.cuda.synchronize()
|
||||
gen_ref[i][:] = part_hf_gen_outputs.logits[0, -1, :]
|
||||
else:
|
||||
@ -231,7 +236,11 @@ class TestMamba(unittest.TestCase):
|
||||
# gen
|
||||
hf_outputs = hf_mamba.forward(step1_id,
|
||||
cache_params=cache_params,
|
||||
use_cache=True)
|
||||
use_cache=True,
|
||||
cache_position=torch.arange(
|
||||
hf_config.conv_kernel - 1,
|
||||
hf_config.conv_kernel,
|
||||
device=step1_id.device))
|
||||
gen_ref = hf_outputs.logits[:, -1, :]
|
||||
|
||||
# get tensorrt llm mamba rumtime
|
||||
|
||||
Loading…
Reference in New Issue
Block a user