chore: bump version to 0.19.0 (#3598) (#3841)

test: add test cases for 0.19 release (#3608)

* fix test name



* add quickstart test for nemotron-ultra



* add rcca multi-node test case for deepseek-v3



* add rcca info



---------




squash (#3642)



fix: nvbugs/5187237: fix deterministic mode crash (#3448)

* nvbugs/5187237 nvbugs/5112075: fix deterministic mode error

* remove waive


* Revert "remove waive"

This reverts commit 0bf5486d19906d692bfb7a6262333c296b0087ac.



* revert ar fusion



---------



update fp8 doc (#3647)




tests: change qa perf test to trtllm-bench (#3619)




 fix: FP8 quantized lm_head (NvBug 5214229) (#3567)



infra: Add PR approval protection for the release branch (#3634)



fix: nvbugs/5231298: pytorch allreduce issue (#3673)



Fix: nvbugs/5222698 variable not defined (#3630)

* Fix: nvbugs/5222698 variable not defined



* Tidy code



---------



test:sync waives.txt from main branch by disabling test_perf/gpt_350m-cppmanager case (#3685)



test:restore fp8 kv cache testing for L0 (#3671)



doc: Update DeepSeek perf docs (#3693)

* Update DeepSeek perf docs



* update



* Apply suggestions from code review




---------




tests: waive test_llm_multi_node (#3664)



fix: update test_user_buffers_mm_add_prologue atol (#3711)



Fix: cherry-pick hmac encryption from main branch (#3635)

* security fix cherry-pick changes from main



* fix hmac in remote mpi session (#3649)



---------





Un-waive DS-V3-Lite tests. (#3621)



fix: FP8 kv accuracy (#3675)

* fix FP8 kv accuracy



* update doc



---------



Fix script options for engines. (#3622)



unwaive multi-node test (#3721)



chore : Split more tests out of gpt tests (#3524) (#3674)



doc:add torch examples link into torch backend documentation (#3749)




test: Get Eagle tests working (#3593) (#3722)




Waive L0 test (#3756)



waive failed case in perf test, change default max_batch_size to 512 and write config.json to output log (#3656)





Update ds v3 parameters in stress test. (#3676)

waive gemma on L20 (#3766)



https://nvbugs/5141291: Fix convert.py script for Qwen model. (#3758)

Include Qwen2VLDecoderLayer in the smooth_qwen2_model function.



fix: PP4 fixes and cleanup (#3688)




remove benchmark test list (#3643)



skip disagg deepseek test if sm!=90 (#3720)



test: skip failed cases on B200 (#3710)

* add skip condition to tests



* fix error



---------



test: [nvbug: 5234494] skip_pre_ada for fp8 cases (#3718)

* skip_pre_ada for fp8 cases



* update



* update after rebase



---------



add know issue to deepseek doc. (#3800)



Fix ModelOpt Mixtral AWQ OOM (#3714) (#3761)




Waive L0 tests (#3826)



fix: Reduce memory usage in fused moe op associated with AutoTuning and fix moe fallback issue. (#3793)

* Reduce memory usage in fused moe op associated with AutoTuning.
* Replace pre-defined bucket size strategy with a generating function based on the tune_max_num_tokens.
* Add free_memory logic of workspace in min_latency_mode fused moe path.



* Fix fused_moe fallback issue. (#3652)

min_latency_mode is only set to False during warmup phase. Thus when it becomes true during inference, all tactics fall back to the default one and thus cause perf regression.



---------



[doc] Better document for Draft-Target-Model (DTM) speculative decoding (#3797)




Fix pre-commit



Fix again



Address some review comments for the MI

Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
Co-authored-by: Zhanrui Sun <184402041+ZhanruiSunCh@users.noreply.github.com>
This commit is contained in:
Dom Brown 2025-04-29 09:57:22 +01:00 committed by GitHub
parent 94e6167879
commit 8709fe8b53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 446 additions and 186 deletions

5
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1,5 @@
# This file defines code ownership rules for the repository.
# The rule below requires that any PR to release/**/* branches must be approved by at least one member
# of the NVIDIA/trt-llm-release-branch-approval team, regardless of who else approves the PR.
# Without approval from a member of this team, PRs cannot be merged to release branches.
* @NVIDIA/trt-llm-release-branch-approval

View File

@ -28,7 +28,8 @@ __global__ void lamport_initialize_kernel(float* ptr, int size)
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
{
lamport_initialize_kernel<<<bytes / 128, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
int grid_size = (bytes + 127) / 128;
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
}
Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,

View File

@ -1989,6 +1989,10 @@ void residualRmsNorm(
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream)
{
sync_check_cuda_error(stream);
if (size == 0)
{
return;
}
switch (dataType)
{
case nvinfer1::DataType::kFLOAT:

View File

@ -162,18 +162,13 @@ public:
int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool min_latency_mode,
torch::optional<c10::ArrayRef<int64_t>> profile_ids)
{
// Free the profile workspace to save memory
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(
cu_free_status == cudaSuccess, "Can't free profile workspace for MoE GEMM profile before runMoe.");
mProfileWorkspace = nullptr;
}
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
freeProfileWorkspace();
TORCH_CHECK(cluster_size == 1 && cluster_rank == 0, "smart_router is supported in min_latency mode");
CHECK_INPUT(input, mActivationDtype)
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
@ -251,6 +246,9 @@ public:
{
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
freeProfileWorkspace();
CHECK_INPUT(input, mActivationDtype)
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
@ -381,13 +379,7 @@ public:
hidden_size, inter_size, GROUP_SIZE, tensorrt_llm::ActivationType::Swiglu, USE_BIAS, USE_LORA,
min_latency_mode, parallelism_config);
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(cu_free_status == cudaSuccess,
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
mProfileWorkspace = nullptr;
}
freeProfileWorkspace();
size_t profile_workspace_size = mProfiler->getWorkspaceSize(num_rows);
auto const cu_malloc_status = cudaMalloc(&mProfileWorkspace, profile_workspace_size);
TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile.");
@ -422,6 +414,17 @@ private:
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;
void freeProfileWorkspace()
{
if (mProfileWorkspace != nullptr)
{
auto const cu_free_status = cudaFree(mProfileWorkspace);
TORCH_CHECK(cu_free_status == cudaSuccess,
"Can't free profile workspace for MoE GEMM profile during memory reallocation.");
mProfileWorkspace = nullptr;
}
}
void setRunnerProfiles(torch::optional<c10::ArrayRef<int64_t>> profile_ids)
{
if (mUseFp8BlockScaling)

View File

@ -4,6 +4,33 @@ NVIDIA has announced world-record DeepSeek-R1 inference performance at NVIDIA GT
In this blog, we share the configurations and procedures about how to reproduce the number on both B200 and H200 with PyTorch workflow.
## Table of Contents
- [How to get best performance on DeepSeek-R1 in TensorRT-LLM](#how-to-get-best-performance-on-deepseek-r1-in-tensorrt-llm)
- [Table of Contents](#table-of-contents)
- [Prerequisites: Install TensorRT-LLM and download models](#prerequisites-install-tensorrt-llm-and-download-models)
- [1. Download TensorRT-LLM](#1-download-tensorrt-llm)
- [2. Download the DeepSeek R1 models](#2-download-the-deepseek-r1-models)
- [3. Build and run TensorRT-LLM container](#3-build-and-run-tensorrt-llm-container)
- [4. Compile and Install TensorRT-LLM](#4-compile-and-install-tensorrt-llm)
- [5. Optional: Tune GPU clocks](#5-optional-tune-gpu-clocks)
- [6. Dataset preparation](#6-dataset-preparation)
- [Reproducing steps](#reproducing-steps)
- [B200 min-latency](#b200-min-latency)
- [Expected Results](#expected-results)
- [B200 max-throughput](#b200-max-throughput)
- [Benchmark](#benchmark)
- [Expected Result Format](#expected-result-format)
- [H200 min-latency](#h200-min-latency)
- [Expected Result Format](#expected-result-format-1)
- [H200 max-throughput](#h200-max-throughput)
- [Expected Result Format](#expected-result-format-2)
- [Exploring more ISL/OSL combinations](#exploring-more-islosl-combinations)
- [WIP: Enable more features by default](#wip-enable-more-features-by-default)
- [WIP: Chunked context support on DeepSeek models](#wip-chunked-context-support-on-deepseek-models)
- [Out of memory issues](#out-of-memory-issues)
## Prerequisites: Install TensorRT-LLM and download models
This section can be skipped if you already have TensorRT-LLM installed and have already downloaded the DeepSeek R1 model checkpoint.
@ -324,3 +351,25 @@ Total Token Throughput (tokens/sec): 15707.0888
Total Latency (ms): 993548.8470
Average request latency (ms): 197768.0434
```
## Exploring more ISL/OSL combinations
To benchmark TensorRT-LLM on DeepSeek models with more ISL/OSL combinations, you can use `prepare_dataset.py` to generate the dataset and use similar commands mentioned in the previous section. TensorRT-LLM is working on enhancements that can make the benchmark process smoother.
### WIP: Enable more features by default
Currently, there are some features that need to be enabled through a user-defined file `extra-llm-api-config.yml`, such as CUDA graph, overlap scheduler and attention dp. We're working on to enable those features by default, so that users can get good out-of-the-box performance on DeepSeek models.
Note that, `max_batch_size` and `max_num_tokens` can easily affect the performance. The default values for them are already carefully designed and should deliver good performance on overall cases, however, you may still need to tune it for peak performance.
Generally, you should make sure that `max_batch_size` is not too low to bottleneck the throughput, and `max_num_tokens` needs to be large enough so that it covers the max input sequence length of the samples in dataset, as mentioned in below section "WIP: Chunked context support on DeepSeek models".
For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).
### WIP: Chunked context support on DeepSeek models
TensorRT-LLM team is actively working on chunked context support for DeepSeek models. Because of that missing feature, there is currently a limitation that `max_num_tokens` has to be at least larger than the max input sequence length of the samples in dataset.
For more details on `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).
### Out of memory issues
It's possible seeing OOM issues on some cases. Considering reducing `kv_cache_free_gpu_mem_fraction` to a smaller value as a workaround. We're working on the investigation and addressing the problem.

View File

@ -41,6 +41,7 @@ scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --ex
- [Architecture Overview](./torch/arch_overview.md)
- [Adding a New Model](./torch/adding_new_model.md)
- [Examples](../../examples/pytorch/README.md)
## Key Components
@ -50,4 +51,4 @@ scripts/huggingface_example.sh --model <huggingface_model_card> --quant fp8 --ex
## Known Issues
- The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the [PyTorch NGC Container (https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for optimal support on SBSA platforms.
- The PyTorch workflow on SBSA is incompatible with bare metal environments like Ubuntu 24.04. Please use the [PyTorch NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) for optimal support on SBSA platforms.

View File

@ -19,10 +19,8 @@ We provide two styles of running DTM now: using TensorRT-LLM-BLS in Triton Infer
+ We use open-source `llama-7B/13B` as draft and target models in this example, assuming the paths to the models' repository are `DRAFT_MODEL_PATH` and `TARGET_MODEL_PATH`.
+ `--use_paged_context_fmha=enable` must be specified since we need KV-Cache reuse in this approach.
+ `--speculative_decoding_mode=draft_tokens_external` and `--max_draft_len` must be specified for target model.
+ `--use_paged_context_fmha=enable` are optional, but recommended for the performance.
+ `--gather_generation_logits` is necessary if using generation logits for selecting tokens in target model.
+ `--tp_size` can be modified set if using TP mode for draft / target model.
+ `--max_batch_size` more than 1 is acceptable in general usage, but we use 1 in this example.
```bash
cd examples/models/core/llama
@ -97,9 +95,9 @@ python3 examples/run.py \
### Triton Inference Server workflow
+ This example is based on TensorRT-LLM-0.18.0 and TRTLLM-backend-0.18.0 with docker image `nvcr.io/nvidia/tritonserver:25.03-trtllm-python-py3`.
+ Draft model approach is supported since TensorRT-LLM-0.7.0 (using two separate Tritonserver to maintain draft and target model respectively), but has significant optimization in TensorRT-LLM-0.10.0 (using one Tritonserver with [Business Logic Scripting](https://github.com/triton-inference-server/python_backend?tab=readme-ov-file#business-logic-scripting), BLS).
+ DTM model approach is supported since TensorRT-LLM-0.7.0 (using two separate Tritonserver to maintain draft and target model respectively), but has significant optimization in TensorRT-LLM-0.10.0 (using one Tritonserver with [Business Logic Scripting](https://github.com/triton-inference-server/python_backend?tab=readme-ov-file#business-logic-scripting), BLS).
1. Get related repository inside the container
#### Get related repository inside the container
```bash
git clone https://github.com/triton-inference-server/tensorrtllm_backend.git
@ -109,50 +107,42 @@ git lfs pull
git submodule update --init --recursive
pip install -r requirements.txt
pip install SentencePiece tritonclient
```
2. Set necessary variables
+ If draft and target models can be placed in one GPU (llama-7B-FP8 + llama-30B-FP8, totally 40GiB in one H100-80GiB GPU as example), `DRAFT_GPU_DEVICE_IDS` and `TARGET_GPU_DEVICE_IDS` can be the same, (`0` as example). It appears better performance than placing on two separate GPUs.
+ Elsewise, the draft and target models can be placed in different GPUs, `DRAFT_GPU_DEVICE_IDS="0"` and `TARGET_GPU_DEVICE_IDS="1"` as example.
+ Furthermore, if TP mode is used, the device ids can be a list, `DRAFT_GPU_DEVICE_IDS="0"` and `TARGET_GPU_DEVICE_IDS="1,2,3,4"` as example.
```bash
export DRAFT_MODEL_NAME="tensorrt_llm_draft"
export TARGET_MODEL_NAME="tensorrt_llm"
export DRAFT_DEVICE_IDS="0"
export TARGET_DEVICE_IDS="1"
export TRITON_MODEL_REPO=llama_dtm
```
3. Edit model configuration
#### Simple deploy
+ Edit model configuration.
```bash
export DRAFT_DEVICE_IDS="0"
export TARGET_DEVICE_IDS="1"
rm -rf ${TRITON_MODEL_REPO}
cp -r all_models/inflight_batcher_llm/ ${TRITON_MODEL_REPO}
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/ensemble/config.pbtxt triton_max_batch_size:${MAX_BATCH_SIZE},logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/preprocessing/config.pbtxt triton_max_batch_size:${MAX_BATCH_SIZE},tokenizer_dir:${TARGET_MODEL_PATH},preprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/postprocessing/config.pbtxt triton_max_batch_size:${MAX_BATCH_SIZE},tokenizer_dir:${TARGET_MODEL_PATH},postprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False,tensorrt_llm_model_name:${TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:${DRAFT_MODEL_NAME},logits_datatype:TYPE_FP32
cp -r ${TRITON_MODEL_REPO}/tensorrt_llm ${TRITON_MODEL_REPO}/tensorrt_llm_draft
sed -i 's/name: "tensorrt_llm"/name: "tensorrt_llm_draft"/g' ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${TARGET_ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,gpu_device_ids:${TARGET_DEVICE_IDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${MAX_BATCH_SIZE},decoupled_mode:False,max_beam_width:1,engine_dir:${DRAFT_ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,gpu_device_ids:${DRAFT_DEVICE_IDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/ensemble/config.pbtxt triton_max_batch_size:4,logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/preprocessing/config.pbtxt triton_max_batch_size:4,tokenizer_dir:${HF_MODEL},preprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/postprocessing/config.pbtxt triton_max_batch_size:4,tokenizer_dir:${HF_MODEL},postprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,logits_datatype:TYPE_FP32,bls_instance_count:1,accumulate_tokens:False,tensorrt_llm_model_name:${TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:${DRAFT_MODEL_NAME}
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,logits_datatype:TYPE_FP32,triton_backend:tensorrtllm,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16,engine_dir:${TARGET_ENGINE_PATH},gpu_device_ids:${TARGET_DEVICE_IDS}
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,logits_datatype:TYPE_FP32,triton_backend:tensorrtllm,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16,engine_dir:${DRAFT_ENGINE_PATH},gpu_device_ids:${DRAFT_DEVICE_IDS}
```
4. Start the triton inference server
+ `--multi-model` is necessary if TP mode is used.
+ Verbose log will be written in to file `triton_log.txt`.
+ Start the triton inference server.
+ Verbose log will be written in to file `triton_log.txt` if specifying `--log`.
```bash
python3 scripts/launch_triton_server.py \
--model_repo=${TRITON_MODEL_REPO} \
--multi-model \
--world_size=1 \
--log &
--log
```
+ You can see the output below in the file if Triton server launches successfully:
@ -163,7 +153,7 @@ Started GRPCInferenceService at 0.0.0.0:8001
Started Metrics Service at 0.0.0.0:8002
```
5. Send a request for inference
+ Send a request for inference.
```bash
python3 inflight_batcher_llm/client/e2e_grpc_speculative_decoding_client.py \
@ -186,8 +176,8 @@ Ubuntu is a free and open source operating system that runs from the desktop, to
Ubuntu is a community developed operating system that is perfect for laptops, desktops, servers, and cloud. It is used by millions of people around the world who want to explore new ideas and discover new opportunities.
```
6. Test DTM with a script
+ Prepare a JSON file `input_data.json` containing input data as below (more requests are acceptable).
+ Test DTM with a script.
+ Prepare a JSON file `input_data.json` containing input data as below (more requests are acceptable).
```json
[
@ -199,9 +189,10 @@ Ubuntu is a community developed operating system that is perfect for laptops, de
]
```
+ Use command below to launch requests for inference.
+ Use command below to launch test.
```bash
### Use BLS speculative decoding
python3 tools/inflight_batcher_llm/speculative_decoding_test.py \
--max-input-len 2500 \
--dataset input_data.json \
@ -214,8 +205,21 @@ python3 tools/inflight_batcher_llm/speculative_decoding_test.py \
--execute-bls-speculative-decoding \
--disable-output-comparison \
--num-draft-tokens=4 \
--use-draft-logits \
--verbose
--use-draft-logits
### Use client-side speculative decoding
python3 tools/inflight_batcher_llm/speculative_decoding_test.py \
--max-input-len 2500 \
--dataset input_data.json \
--url-control=localhost:8001 \
--url-target=localhost:8001 \
--url-draft=localhost:8001 \
--draft-tensorrt-llm-model-name="${DRAFT_MODEL_NAME}" \
--target-tensorrt-llm-model-name="${TARGET_MODEL_NAME}" \
--bls-speculative-tensorrt-llm-model-name="tensorrt_llm_bls" \
--disable-output-comparison \
--num-draft-tokens=4 \
--use-draft-logits
```
+ You can receive the following results if everything goes smoothly.
@ -224,84 +228,105 @@ python3 tools/inflight_batcher_llm/speculative_decoding_test.py \
Ubuntu is a free and open source operating system. It is a Linux based operating system. ...
```
7. Stop triton inference server after all work is done
+ Stop triton inference server after all work is done.
```bash
pkill tritonserver
```
8. Advanced usage: Fast logits D2D transfer.
+ In addition, it appears better performance can be achieved with both draft and target engines deployed on a single GPU (llama-7B-FP8 + llama-30B-FP8, for a total of 40GiB on one H100-80GiB GPU for example).
+ Fast logits boosts the performance (TPS) by hiding the latency of logits transfer from draft engine to target engine supported since TensorRT-LLM-0.15.0.
#### Usage of Tensor-Parallelization mode.
+ Modify `participant_ids` entry in `tensorrt_llm/config.pbtxt` and `tensorrt_llm_draft/config.pbtxt` to suitable MPI ranks. Usually in this setting, rank 0 is reserved for the orchestrator rank; rank 1 is for draft engine; the rest of the ranks are for target engine. In this example, `particpant_ids` can be set as snippet below. Same logic also applies to TP>1 target engine.
+ In this example, we use draft engine with TP=1 and target engine with TP=2 (both symmetrical or asymmetrical TP size are acceptable), and want to place the draft engine on GPU0, target engine on GPU1 and GPU2.
+ Edit model configuration.
```txt
### In tensorrt_llm_draft/config.pbtxt
parameters: {
key: "gpu_device_ids"
value: {
string_value: "0"
}
}
parameters: {
key: "participant_ids"
value: {
string_value: "1"
}
}
### In tensorrt_llm/config.pbtxt
parameters: {
key: "gpu_device_ids"
value: {
string_value: "1"
}
}
parameters: {
key: "participant_ids"
value: {
string_value: "2"
}
}
```bash
export DRAFT_DEVICE_IDS="0"
export TARGET_DEVICE_IDS="1,2"
rm -rf ${TRITON_MODEL_REPO}
cp -r all_models/inflight_batcher_llm/ ${TRITON_MODEL_REPO}
cp -r ${TRITON_MODEL_REPO}/tensorrt_llm ${TRITON_MODEL_REPO}/tensorrt_llm_draft
sed -i 's/name: "tensorrt_llm"/name: "tensorrt_llm_draft"/g' ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/ensemble/config.pbtxt triton_max_batch_size:4,logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/preprocessing/config.pbtxt triton_max_batch_size:4,tokenizer_dir:${HF_MODEL},preprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/postprocessing/config.pbtxt triton_max_batch_size:4,tokenizer_dir:${HF_MODEL},postprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,logits_datatype:TYPE_FP32,bls_instance_count:1,accumulate_tokens:False,tensorrt_llm_model_name:${TARGET_MODEL_NAME},tensorrt_llm_draft_model_name:${DRAFT_MODEL_NAME}
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,logits_datatype:TYPE_FP32,triton_backend:tensorrtllm,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16,engine_dir:${TARGET_ENGINE_PATH}
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,logits_datatype:TYPE_FP32,triton_backend:tensorrtllm,max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16,engine_dir:${DRAFT_ENGINE_PATH}
sed -i 's/\${gpu_device_ids}/'"${DRAFT_DEVICE_IDS}"'/g' ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt
sed -i 's/\${gpu_device_ids}/'"${TARGET_DEVICE_IDS}"'/g' ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt
```
+ Enable `speculative_decoding_fast_logits` in both `tensorrt_llm/config.pbtxt` and `tensorrt_llm_draft/config.pbtxt`.
+ As you see, the only difference is `gpu_device_ids`, which needs fix manually since comma is not supported in script `python3 tools/fill_template.py`.
```txt
parameters: {
key: "speculative_decoding_fast_logits"
value: {
string_value: "1"
}
}
```
+ Launched Triton Server
+ Use orchestrator mode with `--disable-spawn-process`. See [model config](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/model_config.md) for more information.
+ `--world_size` has to be set as 1 (orchestrator rank 0) + 1 (draft engine ranks) + 1 (target engine ranks).
+ Start the triton inference server
+ Use `--multi-model` to enable orchestrator mode in TP>1 scenario. See [model config](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/model_config.md) for more information.
```bash
python3 scripts/launch_triton_server.py \
--model_repo=${TRITON_MODEL_REPO} \
--multi-model \
--disable-spawn-processes \
--world_size=3 \
--log &
--tensorrt_llm_model_name "tensorrt_llm,tensorrt_llm_draft" \
--multi-model
```
+ Send request with `use_draft_logits` to tritonserver BLS API:
+ All other operations are the same as `Simple deploy` part.
#### Usage of Fast logits D2D transfer
+ Fast logits boosts the performance (TPS) by hiding the latency of logits transfer from draft engine to target engine supported since TensorRT-LLM-0.15.0.
+ In this example, we use draft engine with TP=1 and target engine with TP=2 (both symmetrical or asymmetrical TP size are acceptable), and want to place the draft engine on GPU0, target engine on GPU1 and GPU2.
+ For `participant_ids`, rank 0 is reserved for the orchestrator; rank (`1` ~ `tp_size_draft`) are for draft engine; rank (`tp_size_draft+1` ~ `tp_size_draft+tp_size_target`) are for target engine.
+ Edit model configuration.
```bash
export DRAFT_DEVICE_IDS="0"
export TARGET_DEVICE_IDS="1,2"
export DRAFT_PARTICIPANT_IDS="1"
export TARGET_PARTICIPANT_IDS="2,3"
cd /work/tekit-backend
rm -rf ${TRITON_MODEL_REPO}
cp -r all_models/inflight_batcher_llm/ ${TRITON_MODEL_REPO}
cp -r ${TRITON_MODEL_REPO}/tensorrt_llm ${TRITON_MODEL_REPO}/tensorrt_llm_draft
sed -i 's/name: "tensorrt_llm"/name: "tensorrt_llm_draft"/g' ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/ensemble/config.pbtxt triton_max_batch_size:4,logits_datatype:TYPE_FP32
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/preprocessing/config.pbtxt triton_max_batch_size:4,tokenizer_dir:${HF_MODEL},preprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/postprocessing/config.pbtxt triton_max_batch_size:4,tokenizer_dir:${HF_MODEL},postprocessing_instance_count:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_bls/config.pbtxt triton_max_batch_size:4,decoupled_mode:False,bls_instance_count:1,accumulate_tokens:False,tensorrt_llm_model_name:${TARGET_MODEL_NAME},logits_datatype:TYPE_FP32,tensorrt_llm_draft_model_name:${DRAFT_MODEL_NAME}
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt triton_max_batch_size:4,triton_backend:tensorrtllm,decoupled_mode:False,max_beam_width:1,engine_dir:${TARGET_ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,gpu_device_ids:${TARGET_DEVICE_IDS},participant_ids:2,3,speculative_decoding_fast_logits:1
python3 tools/fill_template.py -i ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt triton_max_batch_size:4,triton_backend:tensorrtllm,decoupled_mode:False,max_beam_width:1,engine_dir:${DRAFT_ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:True,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:0,encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32,gpu_device_ids:${DRAFT_DEVICE_IDS},participant_ids:1,speculative_decoding_fast_logits:1
sed -i 's/\${gpu_device_ids}/'"${DRAFT_DEVICE_IDS}"'/g' ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt
sed -i 's/\${participant_ids}/'"${DRAFT_PARTICIPANT_IDS}"'/g' ${TRITON_MODEL_REPO}/tensorrt_llm_draft/config.pbtxt
sed -i 's/\${gpu_device_ids}/'"${TARGET_DEVICE_IDS}"'/g' ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt
sed -i 's/\${participant_ids}/'"${TARGET_PARTICIPANT_IDS}"'/g' ${TRITON_MODEL_REPO}/tensorrt_llm/config.pbtxt
```
curl -X POST "http://localhost:8000/v2/models/tensorrt_llm_bls/generate" \
-H "Content-Type: application/json" \
-d '{
"text_input": "Continue writing the following story: James Best, best known for his",
"max_tokens": 128,
"num_draft_tokens": 10,
"use_draft_logits": true,
"stream": false
}'
+ As you see, the differences are `participant_ids` and `speculative_decoding_fast_logits`.
+ Start the triton inference server.
+ Use `--disable-spawn-process` to enable pre-spawn variant in orchestrator mode.
+ `--world_size` must be equal to `1 + tp_size_draft + tp_size_target`, which is 4 in this example.
```bash
python3 scripts/launch_triton_server.py \
--model_repo ${TRITON_MODEL_REPO} \
--tensorrt_llm_model_name tensorrt_llm,tensorrt_llm_draft \
--multi-model \
--world_size 4 \
--disable-spawn-processes
```
+ All other operations are the same as the `Simple deploy` part.
### Additional information
+ With the fast logits enabled and following optimization tips in [model configuration](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/model_config.md#some-tips-for-model-configuration), speculative decoding with draft logits achieves 2.x throughput in BS1, 1.x throughput in BS16 comparing to auto-regressive decoding using Llama 3.2 1B draft and Llama 3.1 70B target.
+ Streaming mode or batched-request mode are not supported in DTM yet.

View File

@ -543,3 +543,7 @@ pytorch_backend_config:
- **GPU Memory:** Adjust `--max_batch_size` and `--max_num_tokens` if you encounter out-of-memory errors.
- **Logs:** Check `/workspace/trt_bench.log` for detailed performance information and troubleshooting messages.
- **Configuration Files:** Verify that the configuration files are correctly formatted to avoid runtime issues.
## Known Issues
- MTP + attention DP + CUDA graph + overlap scheduler might have accuracy issues. We'll fix it later.

View File

@ -128,8 +128,8 @@ def fused_moe(
# TODO: only profile for min_latency_mode = False due to the error in the moe_kernels
tuning_config = TuningConfig(dynamic_tensors=(
# input, dim 0, all valid buckets, map a seq_len to power of 2 bucket index
(0, 0, ((16384, 8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4,
2, 1), next_positive_power_of_2)),
(0, 0, ((8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
next_positive_power_of_2)),
# min_latency_tensor, dim 0, (0 for False, 1 for True), map to it self
(2, 0, ((0, ), lambda x: x)),
))

View File

@ -8,7 +8,7 @@ import torch
import torch.distributed as dist
from tensorrt_llm._utils import (mpi_allgather, mpi_barrier, mpi_broadcast,
mpi_comm, mpi_isend, mpi_recv,
mpi_comm, mpi_isend, mpi_recv, mpi_send,
torch_dtype_to_np)
from tensorrt_llm.mapping import Mapping
@ -114,6 +114,10 @@ class MPIDist(Distributed):
# non-blocking send numpy buffer
return mpi_isend(buf, dest, tag)
def send(self, buf: np.ndarray, dest, tag=0):
# blocking send numpy buffer
mpi_send(buf, dest, tag)
def recv(self, buf: np.ndarray, src, tag=0):
# in-place recv numpy buffer
return mpi_recv(buf, src, tag)
@ -238,6 +242,7 @@ class PPComm:
# PP communication using torch.distributed with nccl backend
def __init__(self, global_mapping: Mapping):
self.mapping = global_mapping
self.send_event = torch.cuda.Event()
if not dist.is_initialized():
master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = os.getenv("MASTER_PORT", "6000")

View File

@ -19,7 +19,8 @@ def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor:
if mapping not in allreduce_workspaces:
ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace(
mapping,
CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size),
CustomAllReduceHelper.max_workspace_size_auto(
mapping.tp_size, support_deterministic=False),
)
allreduce_workspaces[mapping] = (ipc_buffers, workspace)
return allreduce_workspaces[mapping][1]

View File

@ -679,8 +679,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
0) > 128
min_latency_mode = self._enable_latency_mode(hidden_states.size(0))
min_latency_mode = self._enable_latency_mode(
hidden_states.size(0)) and not using_prev_fusion
hidden_states_fp4 = None
if self.fusion_config.PRE_MOE_FUSION:
# Custom AR Fusion for DeepseekV3
if using_prev_fusion:

View File

@ -295,7 +295,6 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
layer for layer in self.layers[:config.num_hidden_layers]
if not layer.is_missing()
]
print(f"{self._local_layers=}, {self.pp_layer_list=}")
# add create_pipeline_interface method
pp_interface_keys = ["hidden_states", "residual"]

View File

@ -71,7 +71,11 @@ class PipelineInterface:
def send(self):
"""Send tensors to next rank."""
# pp_comm.send returns after nccl send kernel is enqueued. Event sync waits till prev kernel
# finishes and avoids earlier PP rank executing multiple microbatches ahead of later rank.
self._pp_comm.send_event.synchronize()
if self.hidden_states is not None:
self._pp_comm.send(self.hidden_states, tag=self.tag)
if self.residual is not None:
self._pp_comm.send(self.residual, tag=self.tag)
self._pp_comm.send_event.record()

View File

@ -7,8 +7,7 @@ import torch
import tensorrt_llm
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._utils import (mpi_broadcast, str_dtype_to_binding,
torch_dtype_to_str)
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, load_torch_hf_lora
@ -205,11 +204,13 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
py_executor.decoder) == TRTLLMDecoder else origin_seq_len
req = create_dummy_context_requests(max_num_tokens, seq_len, vocab_size)
req_ids = py_executor.enqueue_requests(req)
req_ids = mpi_broadcast(req_ids, root=0)
req_ids = py_executor.dist.broadcast(req_ids, root=0)
py_executor.is_warmup = True
py_executor.start_worker()
py_executor.await_responses(req_ids)
# TODO check why call mpi_barrier() here will hang-on, but call mpi_allgather(0) is fine.
# sync all ranks after processing dummy requests. mpi barrier causes hang, so allgather is used.
py_executor.dist.allgather(0)
torch_peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
@ -255,6 +256,8 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
if py_executor.dist.mapping.rank == 0:
py_executor.shutdown()
py_executor.dist.allgather(0)
return kv_cache_max_tokens

View File

@ -755,7 +755,7 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata = None
pipeline_interface = None
if self.mapping.pp_rank > 0:
if not self.mapping.is_first_pp_rank():
pipeline_interface = self.model.create_pipeline_interface(
batch_size)
self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner(
@ -1229,10 +1229,9 @@ class PyTorchModelEngine(ModelEngine):
if self.mapping.has_pp():
pipeline_interface = None
if self.mapping.pp_rank > 0:
if not self.mapping.is_first_pp_rank():
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface
num_generation_tokens = len(generation_requests) + len(
@ -1369,7 +1368,6 @@ class PyTorchModelEngine(ModelEngine):
if self.mapping.pp_rank > 0:
pipeline_interface = self.model.create_pipeline_interface(
inputs['input_ids'].shape[0])
pipeline_interface.recv()
inputs['pipeline_interface'] = pipeline_interface
return inputs, None
@ -1770,7 +1768,9 @@ class PyTorchModelEngine(ModelEngine):
inputs.update(extra_model_inputs)
self.last_spec_metadata = spec_metadata
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
if not self.mapping.is_first_pp_rank():
inputs['pipeline_interface'].recv()
if not self.mapping.is_last_pp_rank():
pp_interface = self._forward_step_intermediate(inputs)
pp_interface.send()
return self._post_forward_intermediate(inputs, pp_interface,
@ -1803,7 +1803,9 @@ class PyTorchModelEngine(ModelEngine):
self.iter_counter += 1
if maybe_graph is None:
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
if not self.mapping.is_first_pp_rank():
inputs['pipeline_interface'].recv()
if not self.mapping.is_last_pp_rank():
pp_interface = self._forward_step_intermediate(inputs)
pp_interface.send()
outputs = self._post_forward_intermediate(
@ -1824,6 +1826,8 @@ class PyTorchModelEngine(ModelEngine):
extra_model_inputs)
self._cuda_graph_mem_pool = pool
if not self.mapping.is_first_pp_rank():
inputs['pipeline_interface'].recv()
outputs = maybe_graph.run(inputs, extra_model_inputs)
if not self.mapping.is_last_pp_rank():
pp_interface = PipelineInterface(*outputs)

View File

@ -233,8 +233,6 @@ class PyExecutor:
self.micro_batches: List[BatchStatePP
| None] = [None] * self.num_micro_batches
self.send_handles = [None] * self.num_micro_batches
# one handle each for metadata and serialized new_reqs buffer
self.send_new_reqs_handle = [None] * 2
self.inflight_req_ids = ReqIdsSet()
self.canceled_req_ids = ReqIdsSet()
@ -1192,10 +1190,7 @@ class PyExecutor:
self.dist.recv(metadata_arr, self.dist.prev_pp_rank, tag)
if not self.dist.is_last_pp_rank:
if self.send_new_reqs_handle[0] is not None:
self.send_new_reqs_handle[0].Wait()
self.send_new_reqs_handle[0] = self.dist.isend(
metadata_arr, self.dist.next_pp_rank, tag)
self.dist.send(metadata_arr, self.dist.next_pp_rank, tag)
# 2. send serialized buffer when new requests is not empty
num_new_requests = metadata_arr[0]
@ -1206,10 +1201,7 @@ class PyExecutor:
self.dist.recv(buf, self.dist.prev_pp_rank, tag)
if not self.dist.is_last_pp_rank:
if self.send_new_reqs_handle[1] is not None:
self.send_new_reqs_handle[1].Wait()
self.send_new_reqs_handle[1] = self.dist.isend(
buf, self.dist.next_pp_rank, tag)
self.dist.send(buf, self.dist.next_pp_rank, tag)
if not self.dist.is_first_pp_rank:
new_requests = dill.loads(buf.tobytes()) # nosec B301

View File

@ -197,7 +197,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
num_token_buckets.append(m)
m //= 2
return num_token_buckets
return tuple(num_token_buckets)
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:

View File

@ -531,6 +531,14 @@ def mpi_isend(buf, dest, tag=0):
return None
def mpi_send(buf, dest, tag=0):
# send in buf-like objects (e.g. numpy array)
# return request handle if ENABLE_MULTI_DEVICE
if ENABLE_MULTI_DEVICE:
mpi_comm().Send(buf, dest, tag=tag)
return None
def mpi_recv(buf, source, tag):
# recv in buf-like object (e.g. numpy array)
if ENABLE_MULTI_DEVICE:

View File

@ -527,6 +527,10 @@ class LLM:
if self.args.kv_cache_config is not None:
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
self.args.kv_cache_config)
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
executor_config.kv_cache_config.enable_block_reuse = False
executor_config.kv_cache_config.enable_partial_reuse = False
if self.args.peft_cache_config is not None:
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
self.args.peft_cache_config)

View File

@ -27,6 +27,8 @@ import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer
from transformers.pytorch_utils import Conv1D
from ..._utils import pad_vocab_size, str_dtype_to_torch
@ -101,9 +103,9 @@ def smooth_qwen_model(model, scales, alpha, qwen_qkv_para, qwen_smoother):
@torch.no_grad()
def smooth_qwen2_model(model, scales, alpha, qwen_qkv_para, qwen_smoother):
# Smooth the activation and weights with smoother = $\diag{s}$
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
for name, module in model.named_modules():
if not isinstance(module, Qwen2DecoderLayer):
if not isinstance(module, Qwen2DecoderLayer) and not isinstance(
module, Qwen2VLDecoderLayer):
continue
# qkv_proj
layer_name_q = name + ".self_attn.q_proj"

View File

@ -704,8 +704,9 @@ class CustomAllReduceHelper:
)
@staticmethod
def max_workspace_size_auto(tp_size: int) -> int:
if force_all_reduce_deterministic():
def max_workspace_size_auto(tp_size: int,
support_deterministic=True) -> int:
if force_all_reduce_deterministic() and support_deterministic:
workspace_size = os.getenv("FORCE_ALLREDUCE_KERNEL_WORKSPACE_SIZE",
"1000000000")
return int(workspace_size)
@ -746,7 +747,7 @@ class CustomAllReduceHelper:
lamport_buffers_0.local_ptr,
lamport_buffers_1.local_ptr,
lamport_buffers_2.local_ptr,
size * mapping.tp_size,
lamport_buffers_size,
)
buffers = [
ipc_buffers_ping, ipc_buffers_pong, ipc_barriers_in,

View File

@ -28,14 +28,16 @@ def quantize_layers(
quant_map,
preprocess_init_params=None,
):
exclude_modules = quant_config.exclude_modules or [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
exclude_modules = quant_config.exclude_modules
if exclude_modules is None:
exclude_modules = [
'*lm_head',
'*router',
'*vocab_embedding',
'*position_embedding',
'*block_embedding',
'*shared_expert_gate',
]
for name, module, parent in model.named_modules_with_parent():
module_name = name.rsplit('.', 1)[-1]
@ -244,9 +246,12 @@ def fp8_rowwise_quantize(model, quant_config: QuantConfig):
Attention: Fp8RowwiseAttention,
}
exclude_modules = quant_config.exclude_modules
if exclude_modules is None:
exclude_modules = []
# Always exclude these modules for FP8 rowwise
exclude_modules = list(
set((quant_config.exclude_modules or []) +
['*ln_f', '*ln_embed', '*lm_head']))
set(exclude_modules + ['*ln_f', '*ln_embed', '*lm_head']))
def extract_layer_idx(name):
ss = name.split('.')

View File

@ -19,7 +19,8 @@ from tensorrt_llm.llmapi import (EagleDecodingConfig, LookaheadDecodingConfig,
from tensorrt_llm.quantization import QuantAlgo
from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_nvls,
skip_pre_ada, skip_pre_blackwell, skip_pre_hopper)
skip_post_blackwell, skip_pre_ada, skip_pre_blackwell,
skip_pre_hopper)
from .accuracy_core import (MMLU, CliFlowAccuracyTestHarness, CnnDailymail,
Humaneval, PassKeyRetrieval64k,
PassKeyRetrieval128k, SlimPajama6B, ZeroScrolls)
@ -57,6 +58,7 @@ class TestGpt2(CliFlowAccuracyTestHarness):
def test_int8_kv_cache(self):
self.run(kv_cache_quant_algo=QuantAlgo.INT8)
@skip_post_blackwell
@parametrize_with_ids("per_token,per_channel", [(False, False),
(True, True)])
def test_smooth_quant(self, per_token: bool, per_channel: bool):
@ -142,6 +144,7 @@ class TestStarcoder2_15B(CliFlowAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/starcoder2-model"
EXAMPLE_FOLDER = "models/core/gpt"
@skip_post_blackwell
def test_smooth_quant_ootb(self):
self.run(tasks=[Humaneval(self.MODEL_NAME)],
quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL)
@ -194,9 +197,11 @@ class TestPhi2(CliFlowAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/phi-2"
EXAMPLE_FOLDER = "models/core/phi"
@skip_post_blackwell
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
@pytest.mark.skip_less_device(2)
def test_tp2(self):
self.run(tp_size=2)
@ -316,6 +321,7 @@ class TestVicuna7B(CliFlowAccuracyTestHarness):
extra_build_args=["--speculative_decoding_mode=medusa"],
extra_summarize_args=extra_summarize_args)
@skip_post_blackwell
@parametrize_with_ids("cuda_graph,chunked_context,typical_acceptance",
[(False, False, False), (True, False, False),
(True, True, False), (True, False, True)])
@ -360,6 +366,7 @@ class TestLlama7B(CliFlowAccuracyTestHarness):
extra_build_args=["--max_beam_width=5"],
extra_summarize_args=["--num_beams=5"])
@skip_post_blackwell
def test_int4_gptq(self):
self.run(
quant_algo=QuantAlgo.W4A16_GPTQ,
@ -386,6 +393,7 @@ class TestLlama2_7B(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
def test_smooth_quant(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN)
@ -433,14 +441,17 @@ class TestLlama2_7B(CliFlowAccuracyTestHarness):
extra_build_args=["--low_latency_gemm_plugin=fp8"])
@pytest.mark.skip_less_device(2)
@skip_post_blackwell
def test_smooth_quant_ootb_tp2(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL, tp_size=2)
@pytest.mark.skip_less_device(2)
@skip_post_blackwell
def test_int4_awq_tp2(self):
self.run(quant_algo=QuantAlgo.W4A16_AWQ, tp_size=2)
@pytest.mark.skip_less_device(2)
@skip_post_blackwell
def test_int4_awq_prequantized_tp2(self, mocker):
mocker.patch.object(
self.__class__, "MODEL_PATH",
@ -448,6 +459,7 @@ class TestLlama2_7B(CliFlowAccuracyTestHarness):
self.run(quant_algo=QuantAlgo.W4A16_AWQ, tp_size=2)
@pytest.mark.skip_less_device(2)
@skip_post_blackwell
def test_int4_gptq_prequantized_tp2(self, mocker):
mocker.patch.object(
self.__class__, "MODEL_PATH",
@ -469,16 +481,19 @@ class TestTinyLlama1_1BChat(CliFlowAccuracyTestHarness):
def test_float32(self):
self.run(dtype='float32')
@skip_post_blackwell
@pytest.mark.parametrize("precision", ["int8", "int4"])
def test_weight_only(self, precision: str):
quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16
self.run(quant_algo=quant_algo)
@skip_post_blackwell
@pytest.mark.parametrize("precision", ["int8", "int4"])
def test_weight_only_int8_kv_cache(self, precision: str):
quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16
self.run(quant_algo=quant_algo, kv_cache_quant_algo=QuantAlgo.INT8)
@skip_post_blackwell
@pytest.mark.parametrize("precision", ["int8", "int4"])
def test_weight_only_manage_weights(self, precision: str):
quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16
@ -567,6 +582,7 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
def test_smooth_quant(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN)
@ -575,12 +591,14 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
self.run(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
@skip_pre_ada
@skip_post_blackwell
def test_fp8_rowwise(self):
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
MMLU(self.MODEL_NAME)],
quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN)
@skip_pre_ada
@skip_post_blackwell
def test_fp8_rowwise_meta_recipe(self):
self.run(quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN,
extra_acc_spec="meta_recipe",
@ -601,6 +619,7 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
extra_build_args=extra_build_args)
@skip_pre_ada
@skip_post_blackwell
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"gemm_allreduce", [False, pytest.param(True, marks=skip_no_nvls)],
@ -646,6 +665,7 @@ class TestLlama3_1_8BInstruct(CliFlowAccuracyTestHarness):
self.run(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
@skip_pre_ada
@skip_post_blackwell
def test_medusa_fp8_prequantized(self, mocker):
# nvidia/Llama-3.1-8B-Medusa-FP8
mocker.patch.object(self.__class__, "MODEL_PATH",
@ -670,23 +690,29 @@ class TestLlama3_2_1B(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
def test_smooth_quant(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN)
@skip_post_blackwell
def test_smooth_quant_ootb(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL)
@skip_post_blackwell
def test_smooth_quant_ootb_manage_weights(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL,
extra_build_args=["--fast_build"])
@skip_post_blackwell
def test_int4_awq(self):
self.run(quant_algo=QuantAlgo.W4A16_AWQ)
@skip_post_blackwell
def test_int4_awq_int8_kv_cache(self):
self.run(quant_algo=QuantAlgo.W4A16_AWQ,
kv_cache_quant_algo=QuantAlgo.INT8)
@skip_post_blackwell
def test_int4_awq_manage_weights(self):
self.run(quant_algo=QuantAlgo.W4A16_AWQ,
extra_build_args=["--fast_build"])
@ -733,10 +759,12 @@ class TestLlama3_2_1B(CliFlowAccuracyTestHarness):
pp_size=2)
@skip_pre_ada
@skip_post_blackwell
def test_fp8_rowwise(self):
self.run(quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN)
@skip_pre_ada
@skip_post_blackwell
def test_fp8_rowwise_meta_recipe(self):
self.run(quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN,
extra_acc_spec="meta_recipe",
@ -830,6 +858,7 @@ class TestGemma2B(CliFlowAccuracyTestHarness):
quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16
self.run(quant_algo=quant_algo, extra_convert_args=["--ckpt-type=hf"])
@skip_post_blackwell
def test_smooth_quant(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN,
extra_convert_args=[
@ -841,6 +870,7 @@ class TestGemma2B(CliFlowAccuracyTestHarness):
def test_fp8(self):
self.run(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
@skip_post_blackwell
def test_int4_awq(self):
self.run(quant_algo=QuantAlgo.W4A16_AWQ)
@ -859,6 +889,7 @@ class TestGemma7B(CliFlowAccuracyTestHarness):
quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16
self.run(quant_algo=quant_algo, extra_convert_args=["--ckpt-type=hf"])
@skip_post_blackwell
@pytest.mark.skip_less_device_memory(50000)
def test_smooth_quant(self):
self.run(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN,
@ -871,6 +902,7 @@ class TestGemma7B(CliFlowAccuracyTestHarness):
def test_fp8(self):
self.run(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
@skip_post_blackwell
def test_int4_awq(self):
self.run(quant_algo=QuantAlgo.W4A16_AWQ)
@ -887,6 +919,7 @@ class TestGemma2_9BIt(CliFlowAccuracyTestHarness):
dtype='auto',
extra_convert_args=["--ckpt-type=hf"])
@skip_post_blackwell
@pytest.mark.parametrize("precision", ["int8", "int4"])
def test_weight_only(self, precision: str):
quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16
@ -910,6 +943,7 @@ class TestQwen7BChat(CliFlowAccuracyTestHarness):
def test_weight_only(self):
self.run(quant_algo=QuantAlgo.W8A16)
@skip_post_blackwell
def test_int4_gptq_prequantized(self, mocker):
mocker.patch.object(self.__class__, "MODEL_PATH",
f"{llm_models_root()}/Qwen-7B-Chat-Int4")
@ -938,6 +972,7 @@ class TestQwen2_0_5BInstruct(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
def test_weight_only(self):
self.run(quant_algo=QuantAlgo.W8A16)
@ -956,9 +991,11 @@ class TestQwen2_7BInstruct(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
def test_weight_only(self):
self.run(quant_algo=QuantAlgo.W8A16)
@skip_post_blackwell
def test_int4_awq_prequantized(self, mocker):
mocker.patch.object(self.__class__, "MODEL_PATH",
f"{llm_models_root()}/Qwen2-7B-Instruct-AWQ")
@ -990,6 +1027,7 @@ class TestQwen2_5_1_5BInstruct(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(dtype='auto')
@skip_post_blackwell
def test_weight_only(self):
self.run(quant_algo=QuantAlgo.W8A16)

View File

@ -18,7 +18,7 @@ from tensorrt_llm.llmapi import LLM
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
from ..conftest import llm_models_root, skip_pre_ada
from ..conftest import llm_models_root, skip_post_blackwell, skip_pre_ada
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness
@ -27,6 +27,7 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
@skip_pre_ada
@skip_post_blackwell
def test_fp8_rowwise(self):
quant_config = QuantConfig(QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN)
@ -65,6 +66,7 @@ class TestQwen2_7BInstruct(LlmapiAccuracyTestHarness):
task.evaluate(llm,
extra_evaluator_kwargs=self.EXTRA_EVALUATOR_KWARGS)
@skip_post_blackwell
def test_weight_only(self):
quant_config = QuantConfig(QuantAlgo.W8A16)
with LLM(self.MODEL_PATH, quant_config=quant_config) as llm:
@ -133,6 +135,7 @@ class TestQwen2_5_1_5BInstruct(LlmapiAccuracyTestHarness):
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
@skip_post_blackwell
def test_weight_only(self):
quant_config = QuantConfig(QuantAlgo.W8A16)
with LLM(self.MODEL_PATH, quant_config=quant_config) as llm:

View File

@ -73,8 +73,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2)],
ids=["tp4", "tp2pp2"])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2), (1, 4)],
ids=["tp4", "tp2pp2", "pp4"])
def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
torch_compile):
if torch_compile and pp_size > 1:
@ -130,8 +130,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
@parametrize_with_ids("fp8kv", [False, True])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2)],
ids=["tp4", "tp2pp2"])
@pytest.mark.parametrize("tp_size,pp_size", [(4, 1), (2, 2), (1, 4)],
ids=["tp4", "tp2pp2", "pp4"])
def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend,
torch_compile):
if pp_size > 1:

View File

@ -1921,6 +1921,9 @@ skip_post_blackwell = pytest.mark.skipif(
skip_no_nvls = pytest.mark.skipif(not ipc_nvls_supported(),
reason="NVLS is not supported")
skip_no_hopper = pytest.mark.skipif(
get_sm_version() != 90,
reason="This test is only supported in Hopper architecture")
def skip_fp8_pre_ada(use_fp8):

View File

@ -17,6 +17,7 @@ import os
import subprocess
import pytest
from defs.conftest import skip_no_hopper
def kill_disaggregated_processes():
@ -353,6 +354,7 @@ def test_disaggregated_load_balance(disaggregated_test_root, llm_venv,
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8(disaggregated_test_root,
@ -373,6 +375,7 @@ def test_disaggregated_deepseek_v3_lite_fp8(disaggregated_test_root,
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu(
@ -393,6 +396,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp(
@ -413,12 +417,14 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_ucx(disaggregated_test_root,
disaggregated_example_root,
llm_venv,
deepseek_v3_model_root):
src_dst_dict = {
deepseek_v3_model_root:
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/fp8",
@ -436,6 +442,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_ucx(disaggregated_test_root,
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_ucx_tp1_single_gpu(
@ -459,6 +466,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_ucx_tp1_single_gpu(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_attention_dp(
@ -480,6 +488,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_attention_dp(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap(
@ -500,6 +509,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap_cuda_graph(
@ -522,6 +532,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_overlap_cuda_graph(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph(
@ -543,6 +554,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one(
@ -564,6 +576,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp(
@ -585,11 +598,13 @@ def test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp(
cwd=llm_venv.get_working_directory())
@skip_no_hopper
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
indirect=True)
def test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp(
disaggregated_test_root, disaggregated_example_root, llm_venv,
deepseek_v3_model_root):
src_dst_dict = {
deepseek_v3_model_root:
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/fp8",

View File

@ -5,6 +5,7 @@ import sys
import cloudpickle
import pytest
from defs.conftest import skip_no_hopper
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
@ -199,6 +200,7 @@ def test_disaggregated_simple_llama(model, generation_overlap,
])
@skip_no_hopper
@pytest.mark.parametrize("model", ["DeepSeek-V3-Lite-fp8/fp8"])
@pytest.mark.parametrize("generation_overlap", [False, True])
@pytest.mark.parametrize("enable_cuda_graph", [False, True])

View File

@ -22,6 +22,7 @@ from defs.conftest import skip_post_blackwell, skip_pre_ada
from defs.trt_test_alternative import check_call
@skip_post_blackwell
@pytest.mark.parametrize("use_dynamic_tree", [False, True],
ids=['eagle1', 'eagle2'])
@pytest.mark.parametrize("batch_size", [1, 8], ids=['bs1', 'bs8'])

View File

@ -17,6 +17,7 @@
import pytest
from defs.common import (convert_weights, generate_summary_cmd, venv_check_call,
venv_mpi_check_call)
from defs.conftest import skip_post_blackwell
from defs.trt_test_alternative import check_call
@ -26,7 +27,8 @@ from defs.trt_test_alternative import check_call
@pytest.mark.parametrize("llm_exaone_model_root",
['exaone_3.0_7.8b_instruct', 'exaone_deep_2.4b'],
indirect=True)
@pytest.mark.parametrize("use_weight_only", [True, False],
@pytest.mark.parametrize("use_weight_only",
[pytest.param(True, marks=skip_post_blackwell), False],
ids=["enable_weight_only", "disable_weight_only"])
def test_llm_exaone_1gpu(data_type, exaone_example_root, llm_exaone_model_root,
llama_example_root, llm_datasets_root, llm_rouge_root,

View File

@ -4656,6 +4656,7 @@ def test_llm_llama_lookahead_xqa_fp8_1gpu(llama_example_root, llama_model_root,
venv_check_call(llm_venv, summary_cmd)
@skip_pre_ada
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize("code_llama_model_root", ['CodeLlama-7b-Instruct'],
indirect=True)

View File

@ -355,6 +355,7 @@ def test_llm_mistral_v1_smooth_quant_4gpus(llama_example_root,
summary_cmd)
@skip_pre_ada
@pytest.mark.parametrize("run_type", ['inference', 'summarization'])
@pytest.mark.parametrize("mistral_nemo_model_root", ['Mistral-Nemo-12b-Base'],
indirect=True)

View File

@ -22,7 +22,8 @@ from defs.common import (convert_weights, generate_summary_cmd, quantize_data,
venv_check_call, venv_mpi_check_call)
from defs.conftest import (evaltool_mmlu_post_process,
evaltool_wikilingua_post_process, llm_models_root,
skip_pre_ada, skip_pre_blackwell)
skip_post_blackwell, skip_pre_ada,
skip_pre_blackwell)
from defs.trt_test_alternative import check_call
from evaltool.constants import (EVALTOOL_INFERENCE_SERVER_STARTUP_SCRIPT,
EVALTOOL_INFERENCE_SERVER_STOP_SCRIPT,
@ -876,6 +877,7 @@ def test_llm_mixtral_1gpu_fp4_llmapi(
check_call(" ".join(mmlu_cmd), shell=True, env=llm_venv._new_env)
@skip_post_blackwell
@pytest.mark.parametrize("model_name", ['mixtral-8x7b-v0.1-AWQ'])
def test_llm_mixtral_int4_awq_1gpu_summary(llama_example_root,
llm_datasets_root, model_name,
@ -916,6 +918,7 @@ def test_llm_mixtral_int4_awq_1gpu_summary(llama_example_root,
venv_check_call(llm_venv, summary_cmd)
@skip_post_blackwell
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize(

View File

@ -17,7 +17,7 @@ import os
import pytest
from defs.common import convert_weights, venv_check_call, venv_mpi_check_call
from defs.conftest import get_device_memory, skip_pre_ada
from defs.conftest import get_device_memory, skip_post_blackwell, skip_pre_ada
from defs.trt_test_alternative import check_call
@ -617,7 +617,7 @@ def _test_llm_multimodal_general(llm_venv,
'neva-22b',
'kosmos-2',
'video-neva',
'Phi-3-vision-128k-instruct',
pytest.param('Phi-3-vision-128k-instruct', marks=skip_post_blackwell),
'Phi-3.5-vision-instruct',
'Phi-4-multimodal-instruct',
'Llama-3.2-11B-Vision',

View File

@ -24,7 +24,8 @@ from defs.conftest import (LLM_GATE_WAY_CLIENT_ID, LLM_GATE_WAY_TOKEN,
evaltool_mmlu_post_process,
evaltool_mtbench_post_process,
evaltool_wikilingua_post_process, get_device_memory,
skip_fp8_pre_ada, skip_pre_ada)
get_sm_version, skip_fp8_pre_ada,
skip_post_blackwell, skip_pre_ada)
from defs.trt_test_alternative import check_call
from evaltool.constants import (EVALTOOL_INFERENCE_SERVER_STARTUP_SCRIPT,
EVALTOOL_INFERENCE_SERVER_STOP_SCRIPT,
@ -421,6 +422,8 @@ def test_llm_phi_lora_1gpu(data_type, lora_data_type, phi_example_root,
model_name = 'phi-3-lora'
if data_type == 'fp8':
skip_fp8_pre_ada(use_fp8=True)
if get_sm_version() >= 100:
pytest.skip("FP8 is not supported on post-Blackwell architectures")
model_dir = quantize_data(
llm_venv,
phi_example_root,
@ -570,6 +573,7 @@ def test_llm_phi_quantization_1gpu(data_type, llm_phi_model_root, llm_venv,
@skip_pre_ada
@skip_post_blackwell
@pytest.mark.parametrize("llm_phi_model_root", [
"phi-2", "Phi-3-mini-128k-instruct", "Phi-3-small-128k-instruct",
"Phi-3.5-mini-instruct", "Phi-3.5-MoE-instruct", "Phi-4-mini-instruct"

View File

@ -56,7 +56,7 @@ from defs.trt_test_alternative import (Popen, cleanup_process_tree, print_info,
# [sys.executable, "-m", "pip", "install", "-r", requirements_file])
# Define a constant for process termination timeouts
GRACEFUL_TERMINATION_TIMEOUT = 10 # seconds - set longer when stress large model
GRACEFUL_TERMINATION_TIMEOUT = 300 # seconds - set longer when stress large model
@dataclass(frozen=True)
@ -384,7 +384,34 @@ def stress_test(config, test_mode, server_config=None):
)
# Define test configurations
performance_config = PerformanceParams() if run_performance else None
performance_config = None
if run_performance:
performance_config = PerformanceParams()
# For ds v3 specific parameters
if "DeepSeek-V3" in config.model_dir:
performance_config = PerformanceParams(
test_timeout=
36000 # 10 hours for ds v3, change this value if needed
)
# For ds v3 specific server parameters
if "DeepSeek-V3" in config.model_dir:
test_server_config = ServerConfig(
port=test_server_config.port,
host=test_server_config.host,
pp_size=test_server_config.pp_size,
ep_size=8, # DeepSeek-V3 specific ep_size
max_batch_size=161, # DeepSeek-V3 specific max_batch_size
max_num_tokens=1160, # DeepSeek-V3 specific max_num_tokens
kv_cache_free_gpu_memory_fraction=
0.7, # DeepSeek-V3 specific kv_cache fraction
capacity_scheduler_policy=test_server_config.
capacity_scheduler_policy,
wait_interval=test_server_config.wait_interval,
max_wait_seconds=7200, # DeepSeek-V3 specific wait time (2 hours)
health_check_timeout=test_server_config.health_check_timeout)
stress_config = StressTestConfig(
model_config=config,
server_config=test_server_config) if run_stress else None
@ -405,7 +432,7 @@ def stress_test(config, test_mode, server_config=None):
if not os.path.exists(model_path):
raise RuntimeError(f"Model path does not exist: {model_path}")
# Create a temporary YAML file for 'capacity_scheduler_policy'
# Create a temporary YAML file for extra_llm_options
extra_llm_options = {
"scheduler_config": {
"capacity_scheduler_policy":
@ -413,6 +440,21 @@ def stress_test(config, test_mode, server_config=None):
}
}
# Add DeepSeek-V3 specific configuration
if "DeepSeek-V3" in config.model_dir:
extra_llm_options["enable_attention_dp"] = True
if config.backend == "pytorch":
extra_llm_options["pytorch_backend_config"] = {
"use_cuda_graph": True,
"cuda_graph_padding_enabled": True,
"cuda_graph_batch_sizes":
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384],
"print_iter_log": True,
"enable_overlap_scheduler": True
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
delete=False) as temp_file:
yaml.dump(extra_llm_options, temp_file)

View File

@ -1598,6 +1598,31 @@ def test_ptq_quickstart_advanced_mtp(llm_root, llm_venv, model_name,
])
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("model_name,model_path", [
pytest.param('DeepSeek-V3', 'DeepSeek-V3', marks=skip_pre_hopper),
])
def test_ptp_quickstart_advanced_deepseek_v3_2nodes_8gpus(
llm_root, llm_venv, model_name, model_path):
# "RCCA https://nvbugs/5163844"
print(f"Testing {model_name}.")
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
llm_venv.run_cmd([
str(example_root / "quickstart_advanced.py"),
"--enable_overlap_scheduler",
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--moe_ep_size=8",
"--tp_size=16",
"--use_cuda_graph",
"--kv_cache_fraction=0.5",
"--max_batch_size=32",
"--max_num_tokens=2048",
"--kv_cache_enable_block_reuse",
])
@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [
("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct",
"EAGLE3-LLaMA3.1-Instruct-8B"),
@ -1661,9 +1686,11 @@ def test_ptp_quickstart_advanced_deepseek_r1_8gpus(llm_root, llm_venv,
pytest.param('Mixtral-8x7B-NVFP4',
'nvfp4-quantized/Mixtral-8x7B-Instruct-v0.1',
marks=skip_pre_blackwell),
pytest.param('Nemotron-Ultra-253B',
'nemotron-nas/Llama-3_1-Nemotron-Ultra-253B-v1',
marks=skip_pre_hopper),
pytest.param(
'Nemotron-Ultra-253B',
'nemotron-nas/Llama-3_1-Nemotron-Ultra-253B-v1',
marks=[skip_pre_hopper,
pytest.mark.skip_less_device_memory(140000)]),
])
def test_ptp_quickstart_advanced_8gpus(llm_root, llm_venv, model_name,
model_path):

View File

@ -421,12 +421,12 @@ accuracy/test_llm_api.py::TestQwen2_5_7BInstruct::test_fp8
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_fp8_tp2
accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[]
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[]
accuracy/test_llm_api_pytorch.py::TestMistral7B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestMinitron4BBaseInstruct::test_fp8_prequantized
accuracy/test_llm_api_pytorch.py::TestNemotronNas::test_auto_dtype_tp8
accuracy/test_llm_api_pytorch.py::TestQwen2_7BInstruct::test_auto_dtype

View File

@ -12,5 +12,6 @@ examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-405b-fp8-disa
examples/test_llama.py::test_llm_llama_v3_1_2nodes_8gpus[llama-3.1-405b-fp8-disable_fp8-tp8pp2-infer]
examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-build]
examples/test_mixtral.py::test_llm_mixtral_2nodes_8gpus[Mixtral-8x22B-v0.1-plugin-renormalize-tensor_parallel-infer]
test_e2e.py::test_ptp_quickstart_advanced_deepseek_v3_2nodes_8gpus[DeepSeek-V3-DeepSeek-V3]
test_e2e.py::test_openai_multinodes_chat_tp16pp1
test_e2e.py::test_openai_multinodes_chat_tp8pp2

View File

@ -19,7 +19,7 @@ l0_a10:
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-MAX_UTILIZATION-pytorch-stress-test]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-GUARANTEED_NO_EVICT-pytorch-stress-stage-alone]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-GUARANTEED_NO_EVICT-pytorch-stress-test]
- condition:
ranges:
system_gpu_count:
@ -113,7 +113,7 @@ l0_a10:
- examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-enable_gemm_plugin]
- examples/test_mamba.py::test_llm_mamba_1gpu[mamba-codestral-7B-v0.1-float16-enable_gemm_plugin] # 3 mins
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-MAX_UTILIZATION-trt-stress-test]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-GUARANTEED_NO_EVICT-trt-stress-stage-alone]
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-GUARANTEED_NO_EVICT-trt-stress-test]
- condition:
ranges:
system_gpu_count:

View File

@ -29,6 +29,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-attn_backend=TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv-attn_backend=TRTLLM]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=TRTLLM-torch_compile]
- disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
@ -207,6 +208,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv-attn_backend=FLASHINFER]
@ -215,6 +217,7 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=FLASHINFER]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp2pp2-fp8kv-attn_backend=FLASHINFER-torch_compile]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv-attn_backend=FLASHINFER]
- condition:
ranges:
system_gpu_count:

View File

@ -95,7 +95,7 @@ l0_l40s:
- accuracy/test_cli_flow.py::TestQwen1_5MoeA2_7BChat::test_weight_only
- examples/test_gpt.py::test_llm_gpt2_next_prompt_tuning[use_cpp_session-tp1] # 10 mins
- examples/test_gpt.py::test_llm_gpt2_next_prompt_tuning[use_py_session-tp1]
# - examples/test_llama.py::test_llm_llama_1gpu_fp8_kv_cache[llama-v2-7b-hf-bfloat16] #4 mins
- examples/test_llama.py::test_llm_llama_1gpu_fp8_kv_cache[llama-v2-7b-hf-bfloat16] #4 mins
- examples/test_multimodal.py::test_llm_multimodal_general[neva-22b-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
- examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
- examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:1-cpp_e2e:True-nb:1]

View File

@ -346,7 +346,6 @@ examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_vl_7b_instruct-ena
examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_vl_7b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha_fp32_acc] SKIP (https://nvbugs/5141290)
examples/test_qwen.py::test_llm_qwen_awq_single_gpu_summary[qwen2_vl_7b_instruct-nb:4] SKIP (https://nvbugs/5141290)
examples/test_qwen.py::test_llm_hf_qwen_quantization_1gpu[qwen2_vl_7b_instruct-fp8-bfloat16] SKIP (https://nvbugs/5141290)
examples/test_qwen.py::test_llm_qwen_smooth_quant_single_gpu_summary[qwen2_vl_7b_instruct-enable_ptpc-nb:4] SKIP (https://nvbugs/5141291)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (https://nvbugs/5141400)
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5141400)
unittest/_torch/auto_deploy/integration/test_lm_eval.py SKIP (https://nvbugs/5144854)
@ -362,14 +361,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpu
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-cuda_graph] SKIP (https://nvbugs/5170160)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-overlap_scheduler] SKIP (https://nvbugs/5170160)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp2pp2-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5170160)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5201530)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5201530)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5201530)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5201530)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-cuda_graph] SKIP (https://nvbugs/5181511)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5181511)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-cuda_graph] SKIP (https://nvbugs/5181511)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5181511)
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype SKIP (https://nvbugs/5176851)
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int8] SKIP (https://nvbugs/5176851)
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int4] SKIP (https://nvbugs/5176851)
@ -418,11 +409,9 @@ examples/test_multimodal.py::test_llm_multimodal_general[neva-22b-pp:1-tp:1-bflo
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_py_session-recurrentgemma-2b-no_paged_cache-disable_quant-float16-disable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5214221)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_py_session-recurrentgemma-2b-no_paged_cache-disable_quant-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5214221)
examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_py_session-recurrentgemma-2b-use_paged_cache-disable_quant-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (https://nvbugs/5214221)
accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8_lm_head SKIP (https://nvbugs/5214229)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5214239)
examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-scienceqa-Llama-3.2-11B-Vision-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (https://nvbugs/5222697)
examples/test_gpt.py::test_llm_gpt2_santacoder_1node_4gpus[parallel_build-enable_fmha-enable_gemm_plugin-enable_attention_plugin] SKIP (https://nvbugs/5219531)
examples/test_llama.py::test_llm_llama_v3_1_1node_multi_gpus[enable_gemm_allreduce_plugin-llama-3.1-70b-disable_fp8] SKIP (https://nvbugs/5219533)
examples/test_medusa.py::test_llama_medusa_1gpu[llama-v2-7b-hf] SKIP (https://nvbugs/5219534)
examples/test_medusa.py::test_llama_medusa_1gpu[llama-3.2-1b] SKIP (https://nvbugs/5219534)
examples/test_medusa.py::test_llama_medusa_1gpu[llama-3.1-8b] SKIP (https://nvbugs/5219534)
@ -436,7 +425,6 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle1] SKIP (https:/
examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle1] SKIP (https://nvbugs/5219535)
examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle2] SKIP (https://nvbugs/5219535)
examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle2] SKIP (https://nvbugs/5219535)
examples/test_mixtral.py::test_llm_mixtral_fp8_4gpus_summary[Mixtral-8x22B-v0.1-nb:1] SKIP (https://nvbugs/5220758)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5214239)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:1-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5214239)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5214239)
@ -486,6 +474,11 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus[Nemotron-Ultra-253B-nemotron-nas
test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image] SKIP (https://nvbugs/5233423)
accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_medusa_fp8_prequantized SKIP (https://nvbugs/5238599)
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4 SKIP (https://nvbugs/5238602)
accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8_lm_head SKIP (https://nvbugs/5214229)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5239087)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[ep4-mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] SKIP (https://nvbugs/5239087)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[tp8-trtllm-scout] SKIP (https://nvbugs/5244009)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM] SKIP (https://nvbugs/5241627)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=FLASHINFER] SKIP (https://nvbugs/5241627)

View File

@ -285,7 +285,6 @@ def run_command(command: str):
@skip_single_gpu
def test_llm_multi_node(engine_from_checkpoint: tempfile.TemporaryDirectory):
# TODO[chunweiy]: reactivate this later
nworkers = 2
test_case_file = os.path.join(os.path.dirname(__file__), "run_llm.py")
os.path.join(os.path.dirname(__file__), "launch.py")