doc: Update DeepSeekV3 doc (#3052)

* Update DeepGEMM and flashMLA related content

* Add single-node command for deepgemm

* Fix spelling

---------

Signed-off-by: xiaoweiw-nv <100599594+xiaoweiw-nv@users.noreply.github.com>
This commit is contained in:
Xiaowei Wang 2025-03-25 18:17:26 +08:00 committed by GitHub
parent e9df23f815
commit d9acce72bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,8 @@
# DeepSeekV3
This guide walks you through the complete process of running the Deepseekv3 model using NVIDIA's TRTLLM framework with the PyTorch backend. It covers everything from downloading the model weights, preparing the dataset and configuration files, to running the throughput benchmark.
This guide walks you through the complete process of running the DeepSeekv3 model using NVIDIA's TensorRT-LLM framework with the PyTorch backend. It covers everything from downloading the model weights, preparing the dataset and configuration files, to running the throughput benchmark.
> **Note:** This guide assumes you have access to the required hardware (with sufficient GPU memory) and that you replace placeholder values (e.g. `<YOUR_MODEL_DIR>`) with the appropriate paths. Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/build-from-source-linux.html) for how to build TRT-LLM from source and docker image.
> **Note:** This guide assumes you have access to the required hardware (with sufficient GPU memory) and that you replace placeholder values (e.g. `<YOUR_MODEL_DIR>`) with the appropriate paths. Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/build-from-source-linux.html) for how to build TensorRT-LLM from source and docker image.
---
@ -26,7 +26,7 @@ This guide walks you through the complete process of running the Deepseekv3 m
## Overview
Deepseekv3 is a highcapacity language model that can be executed using NVIDIA's TRTLLM framework with a PyTorch backend. This guide details a benchmark recipe where you will:
DeepSeekv3 is a highcapacity language model that can be executed using NVIDIA's TensorRT-LLM framework with a PyTorch backend. This guide details a benchmark recipe where you will:
- Download the model weights.
- Build a test dataset.
@ -38,13 +38,13 @@ Deepseekv3 is a highcapacity language model that can be executed using NVI
## Hardware
DeepSeek-v3 has 671B parameters which needs about 671GB GPU memory. 8\*H100 (640GB) is not enough to accommodate the weights. The following steps have been tested on 8\*H20 141GB, we will test on 8*H20 96GB in the future.
DeepSeek-v3 is trained natively with FP8 precision, we only provide FP8 solution in TRTLLM at this moment. Ampere architecture (SM80 & SM86) is not supported.
DeepSeek-v3 is trained natively with FP8 precision, we only provide FP8 solution in TensorRT-LLM at this moment. Ampere architecture (SM80 & SM86) is not supported.
---
## Downloading the Model Weights
Deepseekv3 model weights are available on [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V3). To download the weights, execute the following commands (replace `<YOUR_MODEL_DIR>` with the target directory where you want the weights stored):
DeepSeekv3 model weights are available on [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V3). To download the weights, execute the following commands (replace `<YOUR_MODEL_DIR>` with the target directory where you want the weights stored):
```bash
git lfs install
@ -165,7 +165,7 @@ Benchmark logs are saved to `/workspace/trt_bench.log`.
---
## Multi-node
TRT-LLM supports multi-node inference. You can use mpirun or Slurm to launch multi-node jobs. We will use two nodes for this example.
TensorRT-LLM supports multi-node inference. You can use mpirun or Slurm to launch multi-node jobs. We will use two nodes for this example.
### mpirun
mpirun requires each node to have passwordless ssh access to the other node. We need to setup the environment inside the docker container. Run the container with host network and mount the current directory as well as model directory to the container.
@ -276,6 +276,58 @@ trtllm-llmapi-launch trtllm-bench --model deepseek-ai/DeepSeek-V3 --model_path /
bash -c "trtllm-llmapi-launch trtllm-bench --model deepseek-ai/DeepSeek-V3 --model_path <YOUR_MODEL_DIR> throughput --backend pytorch --max_batch_size 161 --max_num_tokens 1160 --dataset /workspace/dataset.txt --tp 16 --ep 4 --kv_cache_free_gpu_mem_fraction 0.95 --extra_llm_api_options ./extra-llm-api-config.yml"
```
### Advanced Features
### FlashMLA
TensorRT-LLM has already integrated FlashMLA in the PyTorch backend. It is enabled automatically when running DeepSeek-V3/R1.
### DeepGEMM
TensorRT-LLM also supports DeepGEMM for DeepSeek-V3/R1. DeepGEMM provides significant e2e performance boost. DeepGEMM is enabled by an environment variable `TRTLLM_DG_ENABLED`:
```bash
#single-node
TRTLLM_DG_ENABLED=1 \
trtllm-bench \
--model deepseek-ai/DeepSeek-V3 \
--model_path /models/DeepSeek-V3 \
throughput \
--backend pytorch \
--max_batch_size ${MAX_BATCH_SIZE} \
--max_num_tokens ${MAX_NUM_TOKENS} \
--dataset dataset.txt \
--tp 8 \
--ep 8 \
--kv_cache_free_gpu_mem_fraction 0.9 \
--extra_llm_api_options /workspace/extra-llm-api-config.yml \
--concurrency ${CONCURRENCY} \
--num_requests ${NUM_REQUESTS} \
--streaming \
--report_json "${OUTPUT_FILENAME}.json"
# multi-node
mpirun -H <HOST1>:8,<HOST2>:8 \
-n 16 \
-x "TRTLLM_DG_ENABLED=1" \
-x "CUDA_HOME=/usr/local/cuda" \
trtllm-llmapi-launch trtllm-bench \
--model deepseek-ai/DeepSeek-V3 \
--model_path /models/DeepSeek-V3 \
throughput \
--backend pytorch \
--max_batch_size ${MAX_BATCH_SIZE} \
--max_num_tokens ${MAX_NUM_TOKENS} \
--dataset dataset.txt \
--tp 16 \
--ep 16 \
--kv_cache_free_gpu_mem_fraction 0.9 \
--extra_llm_api_options /workspace/extra-llm-api-config.yml \
--concurrency ${CONCURRENCY} \
--num_requests ${NUM_REQUESTS} \
--streaming \
--report_json "${OUTPUT_FILENAME}.json"
```
The cuda kernels of DeepGEMM are JIT compiled using NVCC. You need to install CUDA Toolkit 12.3 or above and specify the path to the CUDA Toolkit in the environment variable `CUDA_HOME`. We recommend you to use the latest version of CUDA Toolkit. In the case of compilation errors, you can set the environment variable `TRTLLM_DG_JIT_DEBUG` to 1 to print the debug information of the JIT compilation.
---
## Notes and Troubleshooting
@ -287,4 +339,4 @@ trtllm-llmapi-launch trtllm-bench --model deepseek-ai/DeepSeek-V3 --model_path /
---
By following these steps, you should be able to successfully run the Deepseekv3 benchmark using TRTLLM with the PyTorch backend.
By following these steps, you should be able to successfully run the DeepSeekv3 benchmark using TensorRT-LLM with the PyTorch backend.