mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
chore [BREAKING CHANGE]: Flatten PyTorchConfig knobs into TorchLlmArgs (#4603)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
fbe4db207d
commit
5506f60037
@ -134,9 +134,8 @@ To do the benchmark, run the following command:
|
||||
YOUR_DATA_PATH=<your dataset file following the format>
|
||||
|
||||
cat >./extra-llm-api-config.yml<<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
moe_backend: TRTLLM
|
||||
use_cuda_graph: true
|
||||
moe_backend: TRTLLM
|
||||
speculative_config:
|
||||
decoding_type: MTP
|
||||
num_nextn_predict_layers: 3
|
||||
@ -202,21 +201,20 @@ python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \
|
||||
YOUR_DATA_PATH=./dataset.txt
|
||||
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
|
||||
@ -257,8 +255,7 @@ To do the benchmark, run the following command:
|
||||
YOUR_DATA_PATH=<your dataset file following the format>
|
||||
|
||||
cat >./extra-llm-api-config.yml<<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
use_cuda_graph: true
|
||||
speculative_config:
|
||||
decoding_type: MTP
|
||||
num_nextn_predict_layers: 3
|
||||
@ -307,10 +304,9 @@ python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \
|
||||
YOUR_DATA_PATH=./dataset.txt
|
||||
|
||||
cat >./extra-llm-api-config.yml<<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 128
|
||||
use_cuda_graph: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 128
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
|
||||
|
||||
@ -121,9 +121,8 @@ To benchmark min-latency performance with MTP, you need to follow [this document
|
||||
YOUR_DATA_PATH=<your dataset file following the format>
|
||||
|
||||
cat >./extra-llm-api-config.yml<<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
moe_backend: TRTLLM
|
||||
use_cuda_graph: true
|
||||
moe_backend: TRTLLM
|
||||
speculative_config:
|
||||
decoding_type: MTP
|
||||
num_nextn_predict_layers: 3
|
||||
@ -177,9 +176,8 @@ To benchmark min-latency performance with MTP Relaxed Acceptance, you need to fo
|
||||
YOUR_DATA_PATH=<your dataset file following the format>
|
||||
|
||||
cat >./extra-llm-api-config.yml<<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
moe_backend: TRTLLM
|
||||
use_cuda_graph: true
|
||||
moe_backend: TRTLLM
|
||||
speculative_config:
|
||||
decoding_type: MTP
|
||||
num_nextn_predict_layers: 3
|
||||
|
||||
@ -628,8 +628,7 @@ If you would like to force the KV cache quantizaton, you can specify the followi
|
||||
when the checkpoint precision is `null`:
|
||||
|
||||
```yaml
|
||||
pytorch_backend_config:
|
||||
kv_cache_dtype: "fp8"
|
||||
kv_cache_dtype: "fp8"
|
||||
```
|
||||
|
||||
```{tip}
|
||||
|
||||
@ -200,11 +200,9 @@ trtllm-bench --model $model_name throughput --dataset $dataset_file --backend py
|
||||
|
||||
`llm_options.yml`
|
||||
```yaml
|
||||
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
|
||||
@ -16,7 +16,7 @@ The following sections explain how to use these implementations and provide a br
|
||||
|
||||
|
||||
There are currently three available attention backends: the vanilla backend, the TRT-LLM backend, and the Flashinfer backend.
|
||||
You can specify the desired attention backend using `PyTorchConfig.attn_backend`. For instance, to utilize the Flashinfer backend, you can create a `PyTorchConfig` with `attn_backend = "flashinfer"` and then pass it to the `LLM` constructor as follows: `LLM(pytorch_backend_config=pytorch_config)`. This will enable the use of the Flashinfer backend for your model.
|
||||
You can specify the desired attention backend using `PyTorchConfig.attn_backend`. For instance, to utilize the Flashinfer backend, you can pass `attn_backend="flashinfer"` to the `LLM` constructor as follows: `LLM(attn_backend="flashinfer")`. This will enable the use of the Flashinfer backend for your model.
|
||||
|
||||
The vanilla backend, `VanillaAttention`, is a reference implementation designed primarily for inflight batching and linear KV cache support. While it serves as a useful baseline, it is not recommended for production use due to its limited optimizations.
|
||||
|
||||
|
||||
@ -265,7 +265,7 @@ llm = LLM(
|
||||
model=<HF_MODEL_CARD_OR_DIR>,
|
||||
backend="autodeploy",
|
||||
build_config=build_config,
|
||||
pytorch_backend_config=ad_config,
|
||||
auto_deploy_config=ad_config,
|
||||
tensor_parallel_size=<NUM_WORLD_RANK>,
|
||||
)
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ def build_llm_from_config(config: SimpleConfig) -> LLM:
|
||||
model=factory.model,
|
||||
backend="autodeploy",
|
||||
build_config=build_config,
|
||||
pytorch_backend_config=ad_config,
|
||||
auto_deploy_config=ad_config,
|
||||
tensor_parallel_size=config.world_size,
|
||||
tokenizer=factory.init_tokenizer() if config.customize_tokenizer else None,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ You can use multiple `trtllm-serve` commands to launch the context and generatio
|
||||
for disaggregated serving. For example, you could launch two context servers and one generation servers as follows:
|
||||
|
||||
```
|
||||
echo -e "pytorch_backend_config:\n disable_overlap_scheduler: True\ncache_transceiver_config:\n max_num_tokens: 2048" > context_extra-llm-api-config.yml
|
||||
echo -e "disable_overlap_scheduler: True\ncache_transceiver_config:\nmax_num_tokens: 2048" > context_extra-llm-api-config.yml
|
||||
echo -e "cache_transceiver_config:\n max_num_tokens: 2048" > gen_extra-llm-api-config.yml
|
||||
|
||||
export TRTLLM_USE_UCX_KVCACHE=1
|
||||
@ -63,9 +63,8 @@ hostname: localhost
|
||||
port: 8000
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
### Get KV Cache Events
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
|
||||
def main():
|
||||
pytorch_config = PyTorchConfig(autotuner_enabled=False,
|
||||
kv_cache_dtype='auto')
|
||||
|
||||
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
tensor_parallel_size=2,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
autotuner_enabled=False,
|
||||
kv_cache_dtype='auto',
|
||||
kv_cache_config=KvCacheConfig(enable_block_reuse=True,
|
||||
event_buffer_max_size=1024),
|
||||
backend="pytorch")
|
||||
|
||||
@ -74,10 +74,9 @@ srun -l \
|
||||
|
||||
# This is optional
|
||||
cat > /tmp/pytorch_extra_args.txt << EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: false
|
||||
cuda_graph_padding_enabled: false
|
||||
print_iter_log: true
|
||||
use_cuda_graph: false
|
||||
cuda_graph_padding_enabled: false
|
||||
print_iter_log: true
|
||||
enable_attention_dp: false
|
||||
EOF
|
||||
|
||||
|
||||
@ -100,7 +100,6 @@ class TRTLLMEvalBase(TemplateLM):
|
||||
if hasattr(PyTorchConfig, "moe_backend"):
|
||||
pytorch_config_params["moe_backend"] = self.moe_backend
|
||||
print(f"Info: moe_backend is set to {self.moe_backend}")
|
||||
pytorch_config = PyTorchConfig(**pytorch_config_params)
|
||||
|
||||
# stop words not currently supported by torch backend
|
||||
self.use_stop_words = False
|
||||
@ -110,7 +109,7 @@ class TRTLLMEvalBase(TemplateLM):
|
||||
tensor_parallel_size=tp,
|
||||
trust_remote_code=trust_remote_code,
|
||||
enable_chunked_prefill=False,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config_params,
|
||||
tokenizer=self.tokenizer,
|
||||
kv_cache_config=trt_kv_cache_config,
|
||||
moe_expert_parallel_size=self.moe_expert_parallel_size,
|
||||
|
||||
@ -140,10 +140,9 @@ python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
|
||||
--num-requests 24 > /tmp/benchmarking_64k.txt
|
||||
|
||||
cat <<EOF > /tmp/extra-llm-api-config.yml
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes: [1, 4, 8, 12]
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes: [1, 4, 8, 12]
|
||||
EOF
|
||||
|
||||
trtllm-bench -m deepseek-ai/DeepSeek-R1 --model_path ${DS_R1_NVFP4_MODEL_PATH} throughput \
|
||||
@ -168,11 +167,10 @@ python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
|
||||
--num-requests 4 > /tmp/benchmarking_128k.txt
|
||||
|
||||
cat <<EOF > /tmp/extra-llm-api-config.yml
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes: [1, 2]
|
||||
moe_max_num_tokens: 16384
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes: [1, 2]
|
||||
moe_max_num_tokens: 16384
|
||||
EOF
|
||||
|
||||
trtllm-bench -m deepseek-ai/DeepSeek-R1 --model_path ${DS_R1_NVFP4_MODEL_PATH} throughput \
|
||||
@ -193,8 +191,7 @@ Evaluate the model accuracy using `trtllm-eval`.
|
||||
1. (Optional) Prepare an advanced configuration file:
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
use_cuda_graph: true
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
```
|
||||
@ -236,21 +233,20 @@ To serve the model using `trtllm-serve`:
|
||||
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
|
||||
@ -427,21 +423,20 @@ python3 /path/to/TensorRT-LLM/benchmarks/cpp/prepare_dataset.py \
|
||||
--input-mean=1024 --output-mean=2048 --input-stdev=0 --output-stdev=0 > /tmp/dataset.txt
|
||||
|
||||
cat >/path/to/TensorRT-LLM/extra-llm-api-config.yml <<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
```
|
||||
@ -605,9 +600,8 @@ To enable FP8 MLA, modify the `kv_cache_quant_algo` property. The following show
|
||||
Alternatively, configure FP8 MLA through the `kv_cache_dtype` of the PyTorch backend config. An example is to use `--kv_cache_dtype` of `quickstart_advanced.py`. Also, you can edit `extra-llm-api-config.yml` consumed by `--extra_llm_api_options` of `trtllm-serve`, `trtllm-bench` and so on:
|
||||
```yaml
|
||||
# ...
|
||||
pytorch_backend_config:
|
||||
kv_cache_dtype: fp8
|
||||
# ...
|
||||
kv_cache_dtype: fp8
|
||||
# ...
|
||||
```
|
||||
|
||||
### W4AFP8
|
||||
|
||||
@ -653,21 +653,20 @@ To serve the model using `trtllm-serve`:
|
||||
|
||||
```bash
|
||||
cat >./extra-llm-api-config.yml <<EOF
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
use_cuda_graph: true
|
||||
cuda_graph_padding_enabled: true
|
||||
cuda_graph_batch_sizes:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
- 128
|
||||
- 256
|
||||
- 384
|
||||
print_iter_log: true
|
||||
enable_attention_dp: true
|
||||
EOF
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ import argparse
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig)
|
||||
|
||||
@ -124,19 +123,6 @@ def parse_arguments():
|
||||
|
||||
|
||||
def setup_llm(args):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=args.disable_overlap_scheduler,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
attn_backend=args.attention_backend,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
load_format=args.load_format,
|
||||
print_iter_log=args.print_iter_log,
|
||||
enable_iter_perf_stats=args.print_iter_log,
|
||||
torch_compile_enabled=args.use_torch_compile,
|
||||
torch_compile_piecewise_cuda_graph=args.use_piecewise_cuda_graph,
|
||||
moe_backend=args.moe_backend,
|
||||
enable_trtllm_sampler=args.enable_trtllm_sampler)
|
||||
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=not args.disable_kv_cache_reuse,
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
@ -168,13 +154,22 @@ def setup_llm(args):
|
||||
spec_config = None
|
||||
|
||||
llm = LLM(model=args.model_dir,
|
||||
backend='pytorch',
|
||||
disable_overlap_scheduler=args.disable_overlap_scheduler,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
kv_cache_config=kv_cache_config,
|
||||
attn_backend=args.attention_backend,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
load_format=args.load_format,
|
||||
print_iter_log=args.print_iter_log,
|
||||
enable_iter_perf_stats=args.print_iter_log,
|
||||
torch_compile_enabled=args.use_torch_compile,
|
||||
torch_compile_piecewise_cuda_graph=args.use_piecewise_cuda_graph,
|
||||
moe_backend=args.moe_backend,
|
||||
enable_trtllm_sampler=args.enable_trtllm_sampler,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
tensor_parallel_size=args.tp_size,
|
||||
pipeline_parallel_size=args.pp_size,
|
||||
enable_attention_dp=args.enable_attention_dp,
|
||||
moe_expert_parallel_size=args.moe_ep_size,
|
||||
moe_tensor_parallel_size=args.moe_tp_size,
|
||||
|
||||
@ -8,7 +8,6 @@ import torch
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
|
||||
@ -66,8 +65,6 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
|
||||
"block_size": args.sa_block_size
|
||||
}
|
||||
|
||||
pytorch_backend_config = PyTorchConfig(
|
||||
attn_backend='FLASHINFER_STAR_ATTENTION')
|
||||
llm = LLM(model=args.model_path,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
@ -77,7 +74,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
|
||||
tensor_parallel_size=1,
|
||||
context_parallel_size=args.num_procs,
|
||||
cp_config=cp_config,
|
||||
pytorch_backend_config=pytorch_backend_config)
|
||||
attn_backend='FLASHINFER_STAR_ATTENTION')
|
||||
|
||||
sampling_params = SamplingParams(add_special_tokens=False,
|
||||
max_tokens=args.max_new_tokens)
|
||||
|
||||
@ -352,7 +352,7 @@ class DemoLLM(LLM):
|
||||
**kwargs: Any,
|
||||
):
|
||||
try:
|
||||
self.pytorch_backend_config = kwargs.pop("pytorch_backend_config", None)
|
||||
self.pytorch_backend_config = kwargs.pop("auto_deploy_config", None)
|
||||
self.args = LlmArgs.from_kwargs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@ -106,8 +106,6 @@ class AutoDeployConfig(PyTorchConfig):
|
||||
free_mem_ratio: float = 0.8
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# we don't want to loose the default values for model_kwargs unless explicitly set by the
|
||||
# user. They are not preserved by the standard initialization process since they whole dict
|
||||
# gets replaced by the user provided one. We don't want that though.
|
||||
|
||||
@ -1,15 +1,10 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm.bindings.executor import ExecutorConfig
|
||||
|
||||
from ...builder import BuildConfig
|
||||
from ...llmapi.llm_args import LoadFormat
|
||||
from ...logger import logger
|
||||
from ...mapping import Mapping
|
||||
from ..model_config import MoeLoadBalancerConfig
|
||||
@ -17,12 +12,6 @@ from ..speculative import SpecConfig
|
||||
from .resource_manager import BaseResourceManager
|
||||
|
||||
|
||||
class LoadFormat(Enum):
|
||||
AUTO = 0
|
||||
# Initialize all weights randomly.
|
||||
DUMMY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyTorchConfig:
|
||||
"""
|
||||
@ -95,62 +84,6 @@ class PyTorchConfig:
|
||||
# from the model checkpoint.
|
||||
load_format: Union[str, LoadFormat] = 'auto'
|
||||
|
||||
def _convert_load_format(self) -> None:
|
||||
if isinstance(self.load_format, LoadFormat):
|
||||
return
|
||||
load_format = self.load_format.upper()
|
||||
if load_format not in LoadFormat.__members__:
|
||||
raise NotImplementedError(f"Invalid LoadFormat: {self.load_format}")
|
||||
self.load_format = LoadFormat[load_format]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.torch_compile_enabled and self.torch_compile_piecewise_cuda_graph:
|
||||
assert self.torch_compile_fullgraph, "Fullgraph must be enabled for piecewise CUDA graph."
|
||||
|
||||
if self.cuda_graph_batch_sizes is not None:
|
||||
assert self.cuda_graph_max_batch_size == 0, (
|
||||
"Please don't set both cuda_graph_batch_sizes "
|
||||
"and cuda_graph_max_batch_size.")
|
||||
self.cuda_graph_batch_sizes = sorted(self.cuda_graph_batch_sizes)
|
||||
else:
|
||||
self.cuda_graph_max_batch_size = self.cuda_graph_max_batch_size or 128
|
||||
if self.cuda_graph_padding_enabled:
|
||||
self.cuda_graph_batch_sizes = [1, 2, 4] + [
|
||||
i * 8 for i in range(1, 17)
|
||||
]
|
||||
else:
|
||||
self.cuda_graph_batch_sizes = list(range(1, 32)) + [32, 64, 128]
|
||||
self.cuda_graph_batch_sizes += [
|
||||
2**i for i in range(
|
||||
8, math.floor(math.log(self.cuda_graph_max_batch_size, 2)))
|
||||
]
|
||||
self.cuda_graph_batch_sizes = [
|
||||
size for size in self.cuda_graph_batch_sizes
|
||||
if size <= self.cuda_graph_max_batch_size
|
||||
]
|
||||
if self.cuda_graph_max_batch_size != self.cuda_graph_batch_sizes[
|
||||
-1]:
|
||||
self.cuda_graph_batch_sizes.append(
|
||||
self.cuda_graph_max_batch_size)
|
||||
|
||||
if isinstance(self.moe_load_balancer, str):
|
||||
assert os.path.exists(self.moe_load_balancer)
|
||||
if self.moe_load_balancer.endswith(".json"):
|
||||
with open(self.moe_load_balancer) as f:
|
||||
self.moe_load_balancer = json.load(f)
|
||||
elif self.moe_load_balancer.endswith((".yaml", ".yml")):
|
||||
with open(self.moe_load_balancer) as f:
|
||||
self.moe_load_balancer = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported moe load balancer config file: {self.moe_load_balancer}"
|
||||
)
|
||||
if isinstance(self.moe_load_balancer, dict):
|
||||
self.moe_load_balancer = MoeLoadBalancerConfig(
|
||||
**self.moe_load_balancer)
|
||||
|
||||
self._convert_load_format()
|
||||
|
||||
|
||||
EXETENDED_EXECUTOR_CONFIG_FIELDS = [
|
||||
'backend',
|
||||
|
||||
@ -310,11 +310,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
)
|
||||
|
||||
attn_backend = pytorch_backend_config.attn_backend
|
||||
# _convert_load_format should already be called by
|
||||
# __post_init__, but call it again just in case.
|
||||
# The config object is not a frozen data class, so it's
|
||||
# possible the user changed it after initialization.
|
||||
pytorch_backend_config._convert_load_format()
|
||||
self.model = self._load_model(
|
||||
model_path,
|
||||
mapping=self.mapping,
|
||||
|
||||
@ -340,8 +340,8 @@ def throughput_command(
|
||||
kwargs = kwargs | runtime_config.get_llm_args()
|
||||
kwargs['backend'] = backend
|
||||
|
||||
if "pytorch_backend_config" in kwargs and iteration_log is not None:
|
||||
kwargs["pytorch_backend_config"].enable_iter_perf_stats = True
|
||||
if backend == "pytorch":
|
||||
kwargs["enable_iter_perf_stats"] = True
|
||||
|
||||
if runtime_config.backend == 'pytorch':
|
||||
llm = PyTorchLLM(**kwargs)
|
||||
|
||||
@ -89,10 +89,8 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
if extra_llm_api_options:
|
||||
with open(extra_llm_api_options, 'r') as f:
|
||||
llm_args_dict = yaml.safe_load(f)
|
||||
if "pytorch_backend_config" in llm_args_dict:
|
||||
if "kv_cache_dtype" in llm_args_dict["pytorch_backend_config"]:
|
||||
kv_cache_dtype = llm_args_dict["pytorch_backend_config"][
|
||||
"kv_cache_dtype"]
|
||||
if "kv_cache_dtype" in llm_args_dict:
|
||||
kv_cache_dtype = llm_args_dict["kv_cache_dtype"]
|
||||
|
||||
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
|
||||
enable_chunked_prefill)
|
||||
|
||||
@ -81,8 +81,7 @@ class RuntimeConfig(BaseModel):
|
||||
}
|
||||
|
||||
if self.backend in backend_config_map:
|
||||
llm_args["pytorch_backend_config"] = backend_config_map[
|
||||
self.backend]()
|
||||
llm_args.update(backend_config_map[self.backend]())
|
||||
|
||||
return update_llm_args_with_extra_options(llm_args,
|
||||
self.extra_llm_api_options)
|
||||
@ -109,7 +108,7 @@ class PerformanceOptions:
|
||||
return config
|
||||
|
||||
def get_pytorch_perf_config(self) -> PyTorchConfig:
|
||||
return PyTorchConfig(**self.pytorch_config)
|
||||
return self.pytorch_config
|
||||
|
||||
def get_autodeploy_perf_config(self) -> AutoDeployConfig:
|
||||
ad_config = AutoDeployConfig(**self.pytorch_config)
|
||||
|
||||
@ -264,9 +264,8 @@ class ReportUtility:
|
||||
model = self.rt_cfg.model_path or self.rt_cfg.model
|
||||
model_config = ModelConfig.from_pretrained(model,
|
||||
trust_remote_code=True)
|
||||
validate_and_set_kv_cache_quant(
|
||||
model_config,
|
||||
self.kwargs["pytorch_backend_config"].kv_cache_dtype)
|
||||
validate_and_set_kv_cache_quant(model_config,
|
||||
self.kwargs["kv_cache_dtype"])
|
||||
|
||||
stats_dict["engine"] |= {
|
||||
"backend":
|
||||
|
||||
@ -19,7 +19,6 @@ import click
|
||||
import tensorrt_llm.profiler as profiler
|
||||
|
||||
from .._torch.llm import LLM as PyTorchLLM
|
||||
from .._torch.pyexecutor.config import PyTorchConfig
|
||||
from ..evaluate import (GSM8K, MMLU, CnnDailymail, GPQADiamond, GPQAExtended,
|
||||
GPQAMain)
|
||||
from ..llmapi import LLM, BuildConfig, KvCacheConfig
|
||||
@ -113,9 +112,6 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
|
||||
if backend == "tensorrt":
|
||||
backend = None
|
||||
pytorch_backend_config = None
|
||||
if backend == "pytorch":
|
||||
pytorch_backend_config = PyTorchConfig()
|
||||
|
||||
llm_args = {
|
||||
"model": model,
|
||||
@ -128,7 +124,6 @@ def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
"build_config": build_config,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"backend": backend,
|
||||
"pytorch_backend_config": pytorch_backend_config,
|
||||
}
|
||||
|
||||
if extra_llm_api_options is not None:
|
||||
|
||||
@ -8,7 +8,6 @@ import yaml
|
||||
from torch.cuda import device_count
|
||||
|
||||
from tensorrt_llm._torch.llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import (LLM, BuildConfig, CapacitySchedulerPolicy,
|
||||
DynamicBatchConfig, KvCacheConfig,
|
||||
SchedulerConfig)
|
||||
@ -48,7 +47,6 @@ def get_llm_args(model: str,
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction)
|
||||
|
||||
pytorch_backend_config = PyTorchConfig() if backend == "pytorch" else None
|
||||
dynamic_batch_config = DynamicBatchConfig(
|
||||
enable_batch_size_tuning=True,
|
||||
enable_max_num_tokens_tuning=False,
|
||||
@ -74,7 +72,6 @@ def get_llm_args(model: str,
|
||||
"max_seq_len": max_seq_len,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"backend": backend if backend == "pytorch" else None,
|
||||
"pytorch_backend_config": pytorch_backend_config,
|
||||
"_num_postprocess_workers": num_postprocess_workers,
|
||||
"_postprocess_tokenizer_dir": tokenizer or model,
|
||||
"_reasoning_parser": reasoning_parser,
|
||||
|
||||
@ -8,8 +8,10 @@ import pickle # nosec B403
|
||||
# it is only needed in a single instance the class can be added at runtime
|
||||
# using register_approved_ipc_class.
|
||||
BASE_ZMQ_CLASSES = {
|
||||
"builtins": ["Exception", "ValueError"
|
||||
], # each Exception Error class needs to be added explicitly
|
||||
"builtins": [
|
||||
"Exception", "ValueError", "NotImplementedError", "AttributeError",
|
||||
"AssertionError"
|
||||
], # each Exception Error class needs to be added explicitly
|
||||
"collections": ["OrderedDict"],
|
||||
"datetime": ["timedelta"],
|
||||
"pathlib": ["PosixPath"],
|
||||
@ -57,6 +59,7 @@ BASE_ZMQ_CLASSES = {
|
||||
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
|
||||
"SchedulerConfig", "DynamicBatchConfig"
|
||||
],
|
||||
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"],
|
||||
"tensorrt_llm.builder": ["BuildConfig"],
|
||||
"tensorrt_llm.disaggregated_params": ["DisaggregatedParams"],
|
||||
"tensorrt_llm.executor.postproc_worker": [
|
||||
@ -77,7 +80,7 @@ BASE_ZMQ_CLASSES = {
|
||||
"tensorrt_llm.llmapi.llm_args": [
|
||||
"_ModelFormatKind", "_ParallelConfig", "CalibConfig",
|
||||
"CapacitySchedulerPolicy", "KvCacheConfig", "LookaheadDecodingConfig",
|
||||
"TrtLlmArgs", "SchedulerConfig"
|
||||
"TrtLlmArgs", "SchedulerConfig", "LoadFormat"
|
||||
],
|
||||
"tensorrt_llm.llmapi.mpi_session": ["RemoteTask"],
|
||||
"tensorrt_llm.llmapi.llm_utils":
|
||||
|
||||
@ -112,9 +112,6 @@ class LLM:
|
||||
self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor)
|
||||
|
||||
try:
|
||||
self.pytorch_backend_config = kwargs.pop('pytorch_backend_config',
|
||||
None)
|
||||
|
||||
llm_args_cls = TorchLlmArgs if kwargs.get(
|
||||
'backend', None) == 'pytorch' else TrtLlmArgs
|
||||
|
||||
@ -625,7 +622,8 @@ class LLM:
|
||||
update_executor_config(
|
||||
executor_config,
|
||||
backend=self.args.backend,
|
||||
pytorch_backend_config=self.pytorch_backend_config,
|
||||
pytorch_backend_config=self.args.get_pytorch_backend_config()
|
||||
if self.args.backend == "pytorch" else None,
|
||||
mapping=self.args.parallel_config.to_mapping(),
|
||||
build_config=self.args.build_config
|
||||
if self._on_trt_backend else None,
|
||||
|
||||
@ -1,14 +1,17 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, EnumMeta
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, validator
|
||||
from pydantic import (BaseModel, Field, PrivateAttr, field_validator,
|
||||
model_validator)
|
||||
from strenum import StrEnum
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -18,6 +21,9 @@ from tensorrt_llm.lora_manager import (LoraConfig,
|
||||
from .._utils import mpi_rank
|
||||
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
|
||||
# yapf: disable
|
||||
# isort: off
|
||||
from ..bindings.executor import (
|
||||
@ -36,6 +42,8 @@ from ..bindings.executor import (
|
||||
PeftCacheConfig as _PeftCacheConfig,
|
||||
SchedulerConfig as _SchedulerConfig) # isort: skip
|
||||
# isort: on
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
# yapf: enable
|
||||
from ..builder import BuildConfig, EngineConfig
|
||||
from ..logger import logger
|
||||
@ -549,7 +557,9 @@ class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror):
|
||||
get_default_lookahead_decoding_verification_set(),
|
||||
description="Number of NGrams in verification branch per step.")
|
||||
|
||||
@validator('max_window_size', 'max_ngram_size', 'max_verification_set_size')
|
||||
@field_validator('max_window_size', 'max_ngram_size',
|
||||
'max_verification_set_size')
|
||||
@classmethod
|
||||
def validate_positive_values(cls, v):
|
||||
if v <= 0:
|
||||
raise ValueError(f"Value must be positive, got {v}")
|
||||
@ -848,8 +858,8 @@ class BaseLlmArgs(BaseModel):
|
||||
default=None, description="Quantization config.")
|
||||
|
||||
# Several options from ExecutorConfig, expanded here for less hierarchy
|
||||
kv_cache_config: Optional[KvCacheConfig] = Field(
|
||||
default=None, description="KV cache config.")
|
||||
kv_cache_config: KvCacheConfig = Field(default_factory=KvCacheConfig,
|
||||
description="KV cache config.")
|
||||
|
||||
enable_chunked_prefill: bool = Field(default=False,
|
||||
description="Enable chunked prefill.")
|
||||
@ -876,8 +886,8 @@ class BaseLlmArgs(BaseModel):
|
||||
peft_cache_config: Optional[PeftCacheConfig] = Field(
|
||||
default=None, description="PEFT cache config.")
|
||||
|
||||
scheduler_config: Optional[SchedulerConfig] = Field(
|
||||
default=None, description="Scheduler config.")
|
||||
scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig,
|
||||
description="Scheduler config.")
|
||||
|
||||
cache_transceiver_config: Optional[CacheTransceiverConfig] = Field(
|
||||
default=None, description="Cache transceiver config.")
|
||||
@ -991,10 +1001,6 @@ class BaseLlmArgs(BaseModel):
|
||||
enable_attention_dp=self.enable_attention_dp,
|
||||
cp_config=self.cp_config)
|
||||
|
||||
self.kv_cache_config = self.kv_cache_config or KvCacheConfig()
|
||||
|
||||
self.scheduler_config = self.scheduler_config or SchedulerConfig()
|
||||
|
||||
@classmethod
|
||||
def from_kwargs(cls, **kwargs: Any) -> "BaseLlmArgs":
|
||||
"""Create `LlmArgs` instance from kwargs.
|
||||
@ -1016,8 +1022,7 @@ class BaseLlmArgs(BaseModel):
|
||||
Returns:
|
||||
dict: The dict that contains all fields of the `LlmArgs` instance.
|
||||
"""
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
return self.model_dump()
|
||||
|
||||
@staticmethod
|
||||
def _maybe_update_config_for_consistency(
|
||||
@ -1444,6 +1449,12 @@ LLMARGS_EXPLICIT_DOCSTRING = generate_api_docs_as_docstring(LlmArgs,
|
||||
indent=' ' * 4)
|
||||
|
||||
|
||||
class LoadFormat(Enum):
|
||||
AUTO = 0
|
||||
# Initialize all weights randomly.
|
||||
DUMMY = 1
|
||||
|
||||
|
||||
class TorchLlmArgs(BaseLlmArgs):
|
||||
|
||||
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
|
||||
@ -1453,12 +1464,275 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
exclude_from_json=True,
|
||||
json_schema_extra={"type": f"Optional[{get_type_repr(BuildConfig)}]"})
|
||||
|
||||
# PyTorch backend specific configurations
|
||||
|
||||
use_cuda_graph: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"If true, use CUDA graphs for decoding. CUDA graphs are only created for the batch sizes in cuda_graph_batch_sizes, and are enabled for batches that consist of decoding requests *only* (the reason is that it's hard to capture a single graph with prefill requests since the input shapes are a function of the sequence lengths). Note that each CUDA graph can use up to 200 MB of extra memory."
|
||||
)
|
||||
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description="List of batch sizes to create CUDA graphs for.")
|
||||
|
||||
cuda_graph_max_batch_size: int = Field(
|
||||
default=0, description="Maximum batch size for CUDA graphs.")
|
||||
|
||||
cuda_graph_padding_enabled: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"If true, batches are rounded up to the nearest cuda_graph_batch_size. This is usually a net win for performance."
|
||||
)
|
||||
|
||||
disable_overlap_scheduler: bool = Field(
|
||||
default=False, description="Disable the overlap scheduler.")
|
||||
|
||||
moe_max_num_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description=
|
||||
"If set, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used."
|
||||
)
|
||||
|
||||
moe_load_balancer: Optional[Union[object, dict, str]] = Field(
|
||||
default=None,
|
||||
description="Configuration for MoE load balancing.",
|
||||
json_schema_extra={"type": f"Union[MoeLoadBalancerConfig, dict, str]"})
|
||||
|
||||
attn_backend: str = Field(default='TRTLLM',
|
||||
description="Attention backend to use.")
|
||||
|
||||
moe_backend: str = Field(default='CUTLASS',
|
||||
description="MoE backend to use.")
|
||||
|
||||
mixed_sampler: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc."
|
||||
)
|
||||
|
||||
enable_trtllm_sampler: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"If true, will use the TRTLLM sampler instead of the PyTorch sampler. The TRTLLM sampler has a wide coverage of sampling strategies."
|
||||
)
|
||||
|
||||
kv_cache_dtype: str = Field(default="auto",
|
||||
description="Data type for KV cache.")
|
||||
|
||||
use_kv_cache: bool = Field(default=True,
|
||||
description="Whether to use KV cache.")
|
||||
|
||||
enable_iter_perf_stats: bool = Field(
|
||||
default=False, description="Enable iteration performance statistics.")
|
||||
|
||||
enable_iter_req_stats: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
"If true, enables per request stats per iteration. Must also set enable_iter_perf_stats to true to get request stats."
|
||||
)
|
||||
|
||||
print_iter_log: bool = Field(default=False,
|
||||
description="Print iteration logs.")
|
||||
|
||||
torch_compile_enabled: bool = Field(
|
||||
default=False, description="Enable torch.compile optimization.")
|
||||
|
||||
torch_compile_fullgraph: bool = Field(
|
||||
default=True,
|
||||
description="Enable full graph compilation in torch.compile.")
|
||||
|
||||
torch_compile_inductor_enabled: bool = Field(
|
||||
default=False, description="Enable inductor backend in torch.compile.")
|
||||
|
||||
torch_compile_piecewise_cuda_graph: bool = Field(
|
||||
default=False,
|
||||
description="Enable piecewise CUDA graph in torch.compile.")
|
||||
|
||||
torch_compile_enable_userbuffers: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
"When torch compile is enabled, userbuffers is enabled by default.")
|
||||
|
||||
autotuner_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable autotuner only when torch compile is enabled.")
|
||||
|
||||
enable_layerwise_nvtx_marker: bool = Field(
|
||||
default=False, description="If true, enable layerwise nvtx marker.")
|
||||
|
||||
auto_deploy_config: Optional[object] = Field(
|
||||
default=None,
|
||||
description="Auto deploy config.",
|
||||
exclude_from_json=True,
|
||||
json_schema_extra={"type": f"Optional[AutoDeployConfig]"})
|
||||
|
||||
load_format: Union[str, LoadFormat] = Field(
|
||||
default=LoadFormat.AUTO,
|
||||
description=
|
||||
"How to load the model weights. By default, detect the weight type from the model checkpoint."
|
||||
)
|
||||
|
||||
@field_validator('load_format', mode='before')
|
||||
@classmethod
|
||||
def convert_load_format(cls, v):
|
||||
if isinstance(v, LoadFormat):
|
||||
return v
|
||||
load_format = v.upper()
|
||||
if load_format not in LoadFormat.__members__:
|
||||
raise ValueError(f"Invalid LoadFormat: {v}")
|
||||
return LoadFormat[load_format]
|
||||
|
||||
# Extra resource managers to use in addition to the KV cache manager.
|
||||
# Each manager's prepare_resources method is called before the forward pass,
|
||||
# and update_resources() is called after the pass finishes. free_resources()
|
||||
# is called when a request finishes. The KV cache manager is guaranteed to
|
||||
# be invoked after all of these extra managers in all stages.
|
||||
_extra_resource_managers: Dict[str,
|
||||
object] = PrivateAttr(default_factory=dict, )
|
||||
|
||||
@property
|
||||
def extra_resource_managers(self) -> Dict[str, object]:
|
||||
return self._extra_resource_managers
|
||||
|
||||
@extra_resource_managers.setter
|
||||
def extra_resource_managers(self, value: Dict[str, object]) -> None:
|
||||
self._extra_resource_managers = value
|
||||
|
||||
@print_traceback_on_error
|
||||
def model_post_init(self, __context):
|
||||
super().model_post_init(__context)
|
||||
from .._torch.model_config import MoeLoadBalancerConfig
|
||||
|
||||
super().model_post_init(__context)
|
||||
self.model_format = _ModelFormatKind.HF
|
||||
|
||||
if isinstance(self.moe_load_balancer, str):
|
||||
assert os.path.exists(self.moe_load_balancer)
|
||||
if self.moe_load_balancer.endswith(".json"):
|
||||
with open(self.moe_load_balancer) as f:
|
||||
self.moe_load_balancer = json.load(f)
|
||||
elif self.moe_load_balancer.endswith((".yaml", ".yml")):
|
||||
with open(self.moe_load_balancer) as f:
|
||||
self.moe_load_balancer = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported moe load balancer config file: {self.moe_load_balancer}"
|
||||
)
|
||||
if isinstance(self.moe_load_balancer, dict):
|
||||
self.moe_load_balancer = MoeLoadBalancerConfig(
|
||||
**self.moe_load_balancer)
|
||||
|
||||
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
|
||||
def get_pytorch_backend_config(self) -> "PyTorchConfig":
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
|
||||
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
|
||||
# Just a WAR to support the auto_deploy
|
||||
if self.auto_deploy_config is not None:
|
||||
return self.auto_deploy_config
|
||||
|
||||
return PyTorchConfig(
|
||||
extra_resource_managers=self.extra_resource_managers,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
cuda_graph_batch_sizes=self.cuda_graph_batch_sizes,
|
||||
cuda_graph_max_batch_size=self.cuda_graph_max_batch_size,
|
||||
cuda_graph_padding_enabled=self.cuda_graph_padding_enabled,
|
||||
disable_overlap_scheduler=self.disable_overlap_scheduler,
|
||||
moe_max_num_tokens=self.moe_max_num_tokens,
|
||||
moe_load_balancer=self.moe_load_balancer,
|
||||
attn_backend=self.attn_backend,
|
||||
moe_backend=self.moe_backend,
|
||||
mixed_sampler=self.mixed_sampler,
|
||||
enable_trtllm_sampler=self.enable_trtllm_sampler,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
use_kv_cache=self.use_kv_cache,
|
||||
enable_iter_perf_stats=self.enable_iter_perf_stats,
|
||||
enable_iter_req_stats=self.enable_iter_req_stats,
|
||||
print_iter_log=self.print_iter_log,
|
||||
torch_compile_enabled=self.torch_compile_enabled,
|
||||
torch_compile_fullgraph=self.torch_compile_fullgraph,
|
||||
torch_compile_inductor_enabled=self.torch_compile_inductor_enabled,
|
||||
torch_compile_piecewise_cuda_graph=self.
|
||||
torch_compile_piecewise_cuda_graph,
|
||||
torch_compile_enable_userbuffers=self.
|
||||
torch_compile_enable_userbuffers,
|
||||
autotuner_enabled=self.autotuner_enabled,
|
||||
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
|
||||
load_format=self.load_format)
|
||||
|
||||
@field_validator('cuda_graph_max_batch_size')
|
||||
@classmethod
|
||||
def validate_cuda_graph_max_batch_size(cls, v):
|
||||
"""Validate cuda_graph_max_batch_size is non-negative."""
|
||||
if v < 0:
|
||||
raise ValueError("cuda_graph_max_batch_size must be non-negative")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
|
||||
padding_enabled: bool) -> List[int]:
|
||||
"""Generate a list of batch sizes for CUDA graphs.
|
||||
|
||||
Args:
|
||||
max_batch_size: Maximum batch size to generate up to
|
||||
padding_enabled: Whether padding is enabled, which affects the batch size distribution
|
||||
|
||||
Returns:
|
||||
List of batch sizes to create CUDA graphs for
|
||||
"""
|
||||
if padding_enabled:
|
||||
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
|
||||
else:
|
||||
batch_sizes = list(range(1, 32)) + [32, 64, 128]
|
||||
|
||||
# Add powers of 2 up to max_batch_size
|
||||
batch_sizes += [
|
||||
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
|
||||
]
|
||||
|
||||
# Filter and sort batch sizes
|
||||
batch_sizes = sorted(
|
||||
[size for size in batch_sizes if size <= max_batch_size])
|
||||
|
||||
# Add max_batch_size if not already included
|
||||
if max_batch_size != batch_sizes[-1]:
|
||||
batch_sizes.append(max_batch_size)
|
||||
|
||||
return batch_sizes
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
|
||||
"""Validate CUDA graph configuration.
|
||||
|
||||
Ensures that:
|
||||
1. If cuda_graph_batch_sizes is provided, cuda_graph_max_batch_size must be 0
|
||||
2. If cuda_graph_batch_sizes is not provided, it is generated based on cuda_graph_max_batch_size
|
||||
3. If both are provided, cuda_graph_batch_sizes must match the generated values
|
||||
"""
|
||||
if self.cuda_graph_batch_sizes is not None:
|
||||
self.cuda_graph_batch_sizes = sorted(self.cuda_graph_batch_sizes)
|
||||
if self.cuda_graph_max_batch_size != 0:
|
||||
if self.cuda_graph_batch_sizes != self._generate_cuda_graph_batch_sizes(
|
||||
self.cuda_graph_max_batch_size,
|
||||
self.cuda_graph_padding_enabled):
|
||||
raise ValueError(
|
||||
"Please don't set both cuda_graph_batch_sizes "
|
||||
"and cuda_graph_max_batch_size.\n"
|
||||
f"cuda_graph_batch_sizes: {self.cuda_graph_batch_sizes}, "
|
||||
f"cuda_graph_max_batch_size: {self.cuda_graph_max_batch_size}"
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_max_batch_size = max(
|
||||
self.cuda_graph_batch_sizes)
|
||||
else:
|
||||
max_batch_size = self.cuda_graph_max_batch_size or 128
|
||||
generated_sizes = self._generate_cuda_graph_batch_sizes(
|
||||
max_batch_size, self.cuda_graph_padding_enabled)
|
||||
self.cuda_graph_batch_sizes = generated_sizes
|
||||
self.cuda_graph_max_batch_size = max_batch_size
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def update_llm_args_with_extra_dict(
|
||||
llm_args: Dict,
|
||||
|
||||
@ -507,6 +507,9 @@ def generate_api_docs_as_docstring(model: Type[BaseModel],
|
||||
elif field_name in type_hints:
|
||||
type_str = str(type_hints[field_name])
|
||||
type_str = type_str.replace("typing.", "")
|
||||
# Extract just the class name from full class path
|
||||
if "<class '" in type_str:
|
||||
type_str = type_str[8:-2]
|
||||
else:
|
||||
type_str = field_type or 'Any'
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Callable
|
||||
import openai
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.executor import GenerationExecutor
|
||||
from tensorrt_llm.llmapi.llm import LLM
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
@ -138,10 +137,6 @@ class TRTLLMWorker(Worker):
|
||||
kv_cache_free_gpu_memory_fraction: float = 0.9,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
):
|
||||
pytorch_backend_config = PyTorchConfig(
|
||||
mixed_sampler=True,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, )
|
||||
|
||||
@ -157,7 +152,8 @@ class TRTLLMWorker(Worker):
|
||||
llm = LLM(model_dir,
|
||||
backend=backend,
|
||||
tokenizer=tokenizer,
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
mixed_sampler=True,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens)
|
||||
|
||||
@ -181,15 +181,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_device_not_contain(["H100", "H200"])
|
||||
@pytest.mark.parametrize("disable_overlap_scheduler", [False, True])
|
||||
def test_auto_dtype(self, disable_overlap_scheduler):
|
||||
ctx_server_config = {
|
||||
"pytorch_backend_config": {
|
||||
"disable_overlap_scheduler": True
|
||||
}
|
||||
}
|
||||
ctx_server_config = {"disable_overlap_scheduler": True}
|
||||
gen_server_config = {
|
||||
"pytorch_backend_config": {
|
||||
"disable_overlap_scheduler": disable_overlap_scheduler
|
||||
}
|
||||
"disable_overlap_scheduler": disable_overlap_scheduler
|
||||
}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
@ -220,16 +214,8 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||
def test_auto_dtype(self, overlap_scheduler):
|
||||
pytest.skip("https://nvbugs/5297821")
|
||||
ctx_server_config = {
|
||||
"pytorch_backend_config": {
|
||||
"disable_overlap_scheduler": True
|
||||
}
|
||||
}
|
||||
gen_server_config = {
|
||||
"pytorch_backend_config": {
|
||||
"disable_overlap_scheduler": overlap_scheduler
|
||||
}
|
||||
}
|
||||
ctx_server_config = {"disable_overlap_scheduler": True}
|
||||
gen_server_config = {"disable_overlap_scheduler": overlap_scheduler}
|
||||
disaggregated_server_config = {
|
||||
"hostname": "localhost",
|
||||
"port": 8000,
|
||||
|
||||
@ -15,8 +15,7 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import (MoeLoadBalancerConfig,
|
||||
PyTorchConfig)
|
||||
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, MTPDecodingConfig, SamplingParams
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
@ -59,11 +58,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(32000)
|
||||
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
|
||||
def test_chunked_prefill(self, attn_backend):
|
||||
pytorch_config = PyTorchConfig(attn_backend=attn_backend, )
|
||||
pytorch_config = dict(attn_backend=attn_backend, )
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=64,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
**pytorch_config)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -74,7 +73,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("torch_compile", [False, True])
|
||||
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
|
||||
def test_bfloat16(self, attn_backend, torch_compile):
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True,
|
||||
cuda_graph_padding_enabled=torch_compile,
|
||||
@ -82,7 +81,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=torch_compile,
|
||||
)
|
||||
llm = LLM(self.MODEL_PATH, pytorch_backend_config=pytorch_config)
|
||||
llm = LLM(self.MODEL_PATH, **pytorch_config)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -100,7 +99,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
"Pipeline parallel with torch.compile is not supported yet.\n"
|
||||
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
|
||||
"discarded at graph breaks.")
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True,
|
||||
cuda_graph_padding_enabled=torch_compile,
|
||||
@ -111,7 +110,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
**pytorch_config)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -124,7 +123,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("fp8kv", [False, True])
|
||||
def test_fp8(self, fp8kv, attn_backend, torch_compile):
|
||||
quant_config = QuantConfig(QuantAlgo.FP8)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True,
|
||||
cuda_graph_padding_enabled=torch_compile,
|
||||
@ -134,11 +133,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
)
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
llm = LLM(
|
||||
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
|
||||
quant_config=quant_config,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
**pytorch_config)
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
@ -162,7 +161,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
|
||||
"discarded at graph breaks.")
|
||||
quant_config = QuantConfig(QuantAlgo.FP8)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True,
|
||||
cuda_graph_padding_enabled=torch_compile,
|
||||
@ -172,13 +171,13 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
)
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
llm = LLM(
|
||||
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
quant_config=quant_config,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
**pytorch_config)
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
@ -192,8 +191,8 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@skip_pre_hopper
|
||||
def test_fp8_llm_sampler(self):
|
||||
model_path = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
|
||||
pytorch_config = PyTorchConfig(enable_trtllm_sampler=True)
|
||||
llm = LLM(model_path, pytorch_backend_config=pytorch_config)
|
||||
pytorch_config = dict(enable_trtllm_sampler=True)
|
||||
llm = LLM(model_path, **pytorch_config)
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -453,17 +452,18 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
# OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True)
|
||||
torch_compile_fullgraph=True,
|
||||
)
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
with llm:
|
||||
@ -497,11 +497,12 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
# OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True)
|
||||
torch_compile_fullgraph=True,
|
||||
)
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
@ -510,7 +511,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
with llm:
|
||||
@ -539,17 +540,18 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
# OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True)
|
||||
torch_compile_fullgraph=True,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -557,7 +559,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
@ -578,13 +580,15 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
def test_fp8_block_scales_cuda_graph_padding(self):
|
||||
# OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
pytorch_config = PyTorchConfig(disable_overlap_scheduler=False,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_max_batch_size=512,
|
||||
cuda_graph_padding_enabled=True)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=False,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_max_batch_size=512,
|
||||
cuda_graph_padding_enabled=True,
|
||||
)
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config)
|
||||
**pytorch_config)
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -620,17 +624,18 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
# OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True)
|
||||
torch_compile_fullgraph=True,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -641,7 +646,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
@ -676,13 +681,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
num_slots=num_slots,
|
||||
initial_global_assignments=initial_global_assignments,
|
||||
layer_updates_per_iter=0)
|
||||
pytorch_config = PyTorchConfig(use_cuda_graph=True,
|
||||
pytorch_backend_options = dict(use_cuda_graph=True,
|
||||
moe_load_balancer=eplb_config)
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
tensor_parallel_size=4,
|
||||
moe_expert_parallel_size=4,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_backend_options,
|
||||
enable_attention_dp=True)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -705,21 +710,22 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True)
|
||||
torch_compile_fullgraph=True,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
|
||||
@ -756,24 +762,25 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
if torch_compile and pp_size > 1:
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
torch_compile_enabled=torch_compile,
|
||||
torch_compile_fullgraph=True)
|
||||
torch_compile_fullgraph=True,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
|
||||
@ -815,9 +822,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6,
|
||||
enable_block_reuse=False)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
use_cuda_graph=cuda_graph,
|
||||
)
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
@ -833,11 +841,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
llm = LLM(model_path,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
@ -883,16 +891,15 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("https://nvbugs/5302441")
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
moe_backend=moe_backend)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
moe_backend=moe_backend)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -903,12 +910,12 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
|
||||
assert llm.pytorch_backend_config.moe_backend == moe_backend
|
||||
assert llm.args.moe_backend == moe_backend
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
@ -933,15 +940,16 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
attention_dp, cuda_graph, overlap_scheduler,
|
||||
max_batch_size):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
use_cuda_graph=cuda_graph,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config.kv_cache_dtype = "fp8"
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -952,7 +960,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config)
|
||||
@ -986,12 +994,12 @@ class TestNemotronNas(LlmapiAccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device(8)
|
||||
def test_auto_dtype_tp8(self):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
pytorch_config = PyTorchConfig()
|
||||
pytorch_config = dict()
|
||||
|
||||
with LLM(self.MODEL_PATH,
|
||||
tensor_parallel_size=8,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pytorch_backend_config=pytorch_config) as llm:
|
||||
**pytorch_config) as llm:
|
||||
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1134,15 +1142,14 @@ class TestQwen3_8B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency"])
|
||||
def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
|
||||
cuda_graph, overlap_scheduler):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
|
||||
llm = LLM(f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
@ -1162,15 +1169,14 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency"])
|
||||
def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
|
||||
cuda_graph, overlap_scheduler):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
|
||||
llm = LLM(f"{llm_models_root()}/Qwen3/Qwen3-30B-A3B-FP8",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -1185,16 +1191,15 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency"])
|
||||
def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
|
||||
overlap_scheduler):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
|
||||
llm = LLM(
|
||||
f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -1225,7 +1230,7 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
overlap_scheduler,
|
||||
moe_backend,
|
||||
):
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
moe_backend=moe_backend,
|
||||
@ -1236,7 +1241,7 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -1255,15 +1260,14 @@ class TestQwen3_32B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency"])
|
||||
def test_fp8_block_scales(self, tp_size, pp_size, ep_size, attention_dp,
|
||||
cuda_graph, overlap_scheduler):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
|
||||
llm = LLM(f"{llm_models_root()}/Qwen3/Qwen3-32B-FP8",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
with llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
@ -1282,9 +1286,8 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency", "throughput_latency"])
|
||||
def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
|
||||
overlap_scheduler):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
llm = LLM(
|
||||
@ -1292,7 +1295,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
kv_cache_config=kv_cache_config)
|
||||
with llm:
|
||||
@ -1308,16 +1311,15 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
ids=["latency", "throughput_latency"])
|
||||
def test_nvfp4(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
|
||||
overlap_scheduler):
|
||||
pytorch_config = PyTorchConfig(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph)
|
||||
|
||||
llm = LLM(
|
||||
f"{llm_models_root()}/Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp)
|
||||
with llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
|
||||
@ -3,10 +3,9 @@ hostname: localhost
|
||||
port: 8000
|
||||
backend: "pytorch"
|
||||
free_gpu_memory_fraction: 0.1
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
autotuner_enabled: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
autotuner_enabled: False
|
||||
context_servers:
|
||||
num_instances: 2
|
||||
router:
|
||||
|
||||
@ -3,10 +3,9 @@ port: 8000
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
free_gpu_memory_fraction: 0.15
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
autotuner_enabled: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
autotuner_enabled: False
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -5,10 +5,9 @@ backend: "pytorch"
|
||||
free_gpu_memory_fraction: 0.15
|
||||
conditional_disagg_config:
|
||||
max_local_prefill_length: 100
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
autotuner_enabled: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
autotuner_enabled: False
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: DeepSeek-V3-Lite/fp8
|
||||
free_gpu_memory_fraction: 0.1
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: DeepSeek-V3-Lite/fp8
|
||||
free_gpu_memory_fraction: 0.1
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
speculative_config:
|
||||
decoding_type: MTP
|
||||
num_nextn_predict_layers: 1
|
||||
|
||||
@ -11,9 +11,8 @@ context_servers:
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
enable_attention_dp: true
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
@ -21,8 +20,7 @@ generation_servers:
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
enable_attention_dp: true
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: DeepSeek-V3-Lite/fp8
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: DeepSeek-V3-Lite/fp8
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: DeepSeek-V3-Lite/fp8
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: DeepSeek-V3-Lite/fp8
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
speculative_config:
|
||||
decoding_type: MTP
|
||||
num_nextn_predict_layers: 1
|
||||
|
||||
@ -8,9 +8,8 @@ context_servers:
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
enable_attention_dp: True
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
@ -18,8 +17,7 @@ generation_servers:
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
enable_attention_dp: True
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
|
||||
@ -8,9 +8,8 @@ context_servers:
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
enable_attention_dp: true
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
@ -18,8 +17,7 @@ generation_servers:
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
enable_attention_dp: true
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: True
|
||||
disable_overlap_scheduler: False
|
||||
use_cuda_graph: True
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
|
||||
@ -7,17 +7,15 @@ context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 2
|
||||
pipeline_parallel_size: 1
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: True
|
||||
disable_overlap_scheduler: False
|
||||
use_cuda_graph: True
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
|
||||
@ -12,10 +12,9 @@ context_servers:
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: True
|
||||
cuda_graph_batch_sizes: [1,3000]
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: True
|
||||
cuda_graph_batch_sizes: [1,3000]
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
@ -28,10 +27,9 @@ generation_servers:
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: True
|
||||
disable_overlap_scheduler: True
|
||||
cuda_graph_padding_enabled: True
|
||||
cuda_graph_batch_sizes: [1,4,8,16,24,32]
|
||||
use_cuda_graph: True
|
||||
disable_overlap_scheduler: True
|
||||
cuda_graph_padding_enabled: True
|
||||
cuda_graph_batch_sizes: [1,4,8,16,24,32]
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
|
||||
@ -12,8 +12,7 @@ generation_servers:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_block_reuse: False
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
print_iter_log: True
|
||||
print_iter_log: True
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
- "localhost:8003"
|
||||
|
||||
@ -16,9 +16,8 @@ context_servers:
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.15
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
- "localhost:8002"
|
||||
@ -35,9 +34,8 @@ generation_servers:
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.15
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8003"
|
||||
- "localhost:8004"
|
||||
|
||||
@ -3,9 +3,8 @@ port: 8000
|
||||
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
free_gpu_memory_fraction: 0.25
|
||||
backend: "pytorch"
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: 1
|
||||
|
||||
@ -13,9 +13,8 @@ context_servers:
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: True
|
||||
urls:
|
||||
- "localhost:8001"
|
||||
generation_servers:
|
||||
@ -28,8 +27,7 @@ generation_servers:
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
enable_partial_reuse: False
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
use_cuda_graph: False
|
||||
disable_overlap_scheduler: False
|
||||
urls:
|
||||
- "localhost:8002"
|
||||
|
||||
@ -11,7 +11,6 @@ from mpi4py.futures import MPIPoolExecutor
|
||||
|
||||
from tensorrt_llm import DisaggregatedParams, SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm._utils import set_mpi_comm
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, MpiCommSession
|
||||
|
||||
@ -40,6 +39,7 @@ def model_path(model_name):
|
||||
|
||||
|
||||
async def run_worker(kv_cache_config, pytorch_config, model_name, rank):
|
||||
assert isinstance(pytorch_config, dict)
|
||||
print(f"Running worker {rank}")
|
||||
port_name = MPI.Lookup_name('my_port')
|
||||
intercomm = MPI.COMM_WORLD.Connect(port_name)
|
||||
@ -53,7 +53,7 @@ async def run_worker(kv_cache_config, pytorch_config, model_name, rank):
|
||||
auto_parallel=False,
|
||||
model=model_name,
|
||||
enable_chunked_prefill=False,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
_mpi_session=mpi_session,
|
||||
kv_cache_config=kv_cache_config)
|
||||
print(f"LLM created")
|
||||
@ -110,15 +110,15 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
|
||||
|
||||
# Context worker
|
||||
worker_pytorch_configs.append(
|
||||
PyTorchConfig(disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
dict(disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
|
||||
# Generation worker
|
||||
worker_pytorch_configs.append(
|
||||
PyTorchConfig(disable_overlap_scheduler=not generation_overlap,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
dict(disable_overlap_scheduler=not generation_overlap,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
|
||||
kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)]
|
||||
model_names = [model_path(model) for _ in range(2)]
|
||||
@ -231,15 +231,15 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
|
||||
|
||||
# Context worker
|
||||
worker_pytorch_configs.append(
|
||||
PyTorchConfig(disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
dict(disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
|
||||
# Generation worker
|
||||
worker_pytorch_configs.append(
|
||||
PyTorchConfig(disable_overlap_scheduler=not generation_overlap,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
dict(disable_overlap_scheduler=not generation_overlap,
|
||||
kv_cache_dtype="auto",
|
||||
use_cuda_graph=enable_cuda_graph))
|
||||
|
||||
kv_cache_configs = [
|
||||
KvCacheConfig(max_tokens=128, enable_block_reuse=False)
|
||||
|
||||
@ -28,18 +28,14 @@ def get_model_yaml_config(model_label: str, input_lens: list[int]) -> dict:
|
||||
"""
|
||||
base_config = {
|
||||
'enable_attention_dp': True,
|
||||
'pytorch_backend_config': {
|
||||
'print_iter_log': True,
|
||||
'use_cuda_graph': True,
|
||||
'cuda_graph_padding_enabled': True,
|
||||
}
|
||||
'print_iter_log': True,
|
||||
'use_cuda_graph': True,
|
||||
'cuda_graph_padding_enabled': True,
|
||||
}
|
||||
model_configs = {
|
||||
'deepseek_r1-bench-pytorch-float16-maxbs:1-maxnt:8192-input_output_len:1000,2000-quant:fp8-reqs:10-ep:4-gpus:8':
|
||||
{
|
||||
'pytorch_backend_config': {
|
||||
'use_cuda_graph': True,
|
||||
},
|
||||
'use_cuda_graph': True,
|
||||
'speculative_config': {
|
||||
'decoding_type': 'MTP',
|
||||
'num_nextn_predict_layers': 3
|
||||
@ -47,9 +43,7 @@ def get_model_yaml_config(model_label: str, input_lens: list[int]) -> dict:
|
||||
},
|
||||
'deepseek_r1_nvfp4-bench-pytorch-float16-maxbs:1-maxnt:8192-input_output_len:1000,2000-quant:nvfp4-reqs:10-ep:4-tp:8-gpus:8':
|
||||
{
|
||||
'pytorch_backend_config': {
|
||||
'use_cuda_graph': True,
|
||||
},
|
||||
'use_cuda_graph': True,
|
||||
'speculative_config': {
|
||||
'decoding_type': 'MTP',
|
||||
'num_nextn_predict_layers': 3
|
||||
@ -57,25 +51,17 @@ def get_model_yaml_config(model_label: str, input_lens: list[int]) -> dict:
|
||||
},
|
||||
'deepseek_r1-bench-pytorch-float16-maxbs:128-maxnt:1127-input_output_len:1000,2000-quant:fp8-reqs:5120-con:1024-ep:8-gpus:8':
|
||||
{
|
||||
'pytorch_backend_config': {
|
||||
'cuda_graph_batch_sizes': [128]
|
||||
},
|
||||
'cuda_graph_batch_sizes': [128]
|
||||
},
|
||||
'deepseek_r1-bench-pytorch-float16-maxbs:384-maxnt:1536-input_output_len:1000,2000-quant:nvfp4-reqs:49152-con:3072-ep:8-gpus:8':
|
||||
{
|
||||
'pytorch_backend_config': {
|
||||
'cuda_graph_padding_enabled': True,
|
||||
'cuda_graph_batch_sizes':
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
|
||||
},
|
||||
'cuda_graph_padding_enabled': True,
|
||||
'cuda_graph_batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
|
||||
},
|
||||
'deepseek_r1_nvfp4-bench-pytorch-float16-maxbs:384-maxnt:1536-input_output_len:1000,2000-quant:nvfp4-reqs:49152-con:3072-ep:8-gpus:8':
|
||||
{
|
||||
'pytorch_backend_config': {
|
||||
'cuda_graph_padding_enabled': True,
|
||||
'cuda_graph_batch_sizes':
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
|
||||
},
|
||||
'cuda_graph_padding_enabled': True,
|
||||
'cuda_graph_batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
|
||||
}
|
||||
}
|
||||
# get model name from model_label
|
||||
|
||||
@ -510,13 +510,16 @@ def stress_test(config,
|
||||
extra_llm_options["enable_attention_dp"] = True
|
||||
|
||||
if config.backend == "pytorch":
|
||||
extra_llm_options["pytorch_backend_config"] = {
|
||||
"use_cuda_graph": True,
|
||||
"cuda_graph_padding_enabled": True,
|
||||
extra_llm_options.update({
|
||||
"use_cuda_graph":
|
||||
True,
|
||||
"cuda_graph_padding_enabled":
|
||||
True,
|
||||
"cuda_graph_batch_sizes":
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384],
|
||||
"print_iter_log": True,
|
||||
}
|
||||
"print_iter_log":
|
||||
True,
|
||||
})
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
|
||||
delete=False) as temp_file:
|
||||
|
||||
@ -625,11 +625,17 @@ def temp_extra_llm_api_options_file(request):
|
||||
}
|
||||
}
|
||||
|
||||
pytorch_backend_config = {}
|
||||
if request.node.callspec.params['pytorch_backend_config']:
|
||||
extra_llm_api_options_dict["pytorch_backend_config"] = {
|
||||
pytorch_backend_config = {
|
||||
"use_cuda_graph": True,
|
||||
"cuda_graph_batch_sizes": [1, 2, 3],
|
||||
# trtllm-bench will set cuda_max_batch_size to
|
||||
# max_batch_size, so the cuda_graph_batch_sizes is not
|
||||
# needed.
|
||||
# "cuda_graph_batch_sizes": [1, 2, 3],
|
||||
}
|
||||
# Flatten the pytorch_backend_config
|
||||
extra_llm_api_options_dict.update(pytorch_backend_config)
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
@ -1981,7 +1987,6 @@ def test_ptp_quickstart_bert(llm_root, llm_venv, model_name, model_path,
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -1994,8 +1999,8 @@ def test_ptp_quickstart_bert(llm_root, llm_venv, model_name, model_path,
|
||||
sampling_param = SamplingParams(max_tokens=32, return_context_logits=True)
|
||||
with LLM(
|
||||
model=model_dir,
|
||||
pytorch_backend_config=PyTorchConfig(
|
||||
attn_backend=backend, disable_overlap_scheduler=True),
|
||||
attn_backend=backend,
|
||||
disable_overlap_scheduler=True,
|
||||
) as llm:
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_param)
|
||||
|
||||
@ -9,7 +9,6 @@ from utils.util import getSMVersion
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
|
||||
@ -56,7 +55,7 @@ def test_deepseek_trtllmgen(model_name):
|
||||
"The president of the United States is",
|
||||
] * 4
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
use_cuda_graph=False,
|
||||
kv_cache_dtype="auto",
|
||||
@ -73,7 +72,7 @@ def test_deepseek_trtllmgen(model_name):
|
||||
llm = LLM(model=tmp_model_dir,
|
||||
tensor_parallel_size=1,
|
||||
enable_chunked_prefill=False,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
moe_expert_parallel_size=-1,
|
||||
moe_tensor_parallel_size=-1,
|
||||
enable_attention_dp=False,
|
||||
|
||||
@ -3,7 +3,6 @@ import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
@ -42,8 +41,7 @@ class TestOutOfTree(unittest.TestCase):
|
||||
llm = LLM(model=model_dir,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=2048,
|
||||
pytorch_backend_config=PyTorchConfig(
|
||||
disable_overlap_scheduler=True))
|
||||
disable_overlap_scheduler=True)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@ -7,7 +7,6 @@ from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
@ -61,9 +60,8 @@ def test_model(backend, model_name, quant, sp_size, sa_block_size,
|
||||
max_batch_size = 20
|
||||
max_output_tokens = 128
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
|
||||
pytorch_backend_config = PyTorchConfig(
|
||||
attn_backend='FLASHINFER_STAR_ATTENTION',
|
||||
disable_overlap_scheduler=True)
|
||||
pytorch_backend_options = dict(attn_backend='FLASHINFER_STAR_ATTENTION',
|
||||
disable_overlap_scheduler=True)
|
||||
|
||||
llm = LLM(model=model_dir,
|
||||
backend=backend,
|
||||
@ -72,7 +70,7 @@ def test_model(backend, model_name, quant, sp_size, sa_block_size,
|
||||
quant_config=quant_config,
|
||||
context_parallel_size=sp_size,
|
||||
cp_config=cp_config,
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
**pytorch_backend_options,
|
||||
max_batch_size=max_batch_size,
|
||||
max_input_len=MAX_SEQ_LEN - max_output_tokens,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
|
||||
@ -9,7 +9,6 @@ from utils.util import getSMVersion
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
||||
|
||||
@ -63,7 +62,7 @@ def test_deepseek_streaming(model_name, backend, quant, tp_size):
|
||||
" the head of state and head of government of the",
|
||||
] * 32
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
use_cuda_graph=False,
|
||||
kv_cache_dtype="auto",
|
||||
@ -78,7 +77,7 @@ def test_deepseek_streaming(model_name, backend, quant, tp_size):
|
||||
llm = LLM(model=model_dir,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_chunked_prefill=False,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
moe_expert_parallel_size=-1,
|
||||
moe_tensor_parallel_size=-1,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
|
||||
@ -7,7 +7,6 @@ from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -48,8 +47,7 @@ def test_llama4(model_name, backend, tp_size, use_cuda_graph,
|
||||
" the head of state and head of government of the", " solid white"
|
||||
]
|
||||
|
||||
pytorch_config = PyTorchConfig(attn_backend=backend,
|
||||
use_cuda_graph=use_cuda_graph)
|
||||
pytorch_config = dict(attn_backend=backend, use_cuda_graph=use_cuda_graph)
|
||||
model_dir = str(llm_models_root() / "llama4-models" / model_name)
|
||||
|
||||
llm = LLM(
|
||||
@ -57,7 +55,7 @@ def test_llama4(model_name, backend, tp_size, use_cuda_graph,
|
||||
tensor_parallel_size=tp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
moe_tensor_parallel_size=tp_size // ep_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
pipeline_parallel_size=pp_size,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
)
|
||||
|
||||
@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import BuildConfig, EagleDecodingConfig, KvCacheConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -24,7 +23,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
|
||||
|
||||
models_path = llm_models_root()
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
# Only create a single CUDA graph to prevent OOM in CI
|
||||
@ -49,7 +48,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
|
||||
build_config = BuildConfig(max_seq_len=2048)
|
||||
|
||||
llm_spec = LLM(model=target_model_dir,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config,
|
||||
build_config=build_config)
|
||||
@ -89,7 +88,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str):
|
||||
llm_spec.shutdown()
|
||||
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
build_config=build_config)
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, NGramDecodingConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -26,7 +25,7 @@ def test_llama_ngram(use_cuda_graph: bool, attn_backend: str):
|
||||
|
||||
models_path = llm_models_root()
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
pytorch_config = dict(
|
||||
enable_overlap_scheduler=False,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
# Only create a single CUDA graph to prevent OOM in CI
|
||||
@ -54,7 +53,7 @@ def test_llama_ngram(use_cuda_graph: bool, attn_backend: str):
|
||||
)
|
||||
llm_spec = LLM(model=target_model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config)
|
||||
|
||||
@ -67,7 +66,7 @@ def test_llama_ngram(use_cuda_graph: bool, attn_backend: str):
|
||||
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config)
|
||||
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
|
||||
@ -6,7 +6,6 @@ from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig as TRT_KvCacheConfig
|
||||
|
||||
|
||||
@ -24,10 +23,9 @@ def model_path():
|
||||
|
||||
def create_llm(model_dir, disable_overlap_scheduler, enable_trtllm_sampler):
|
||||
"""Create LLM with specific overlap scheduler setting"""
|
||||
pytorch_config = PyTorchConfig(
|
||||
use_cuda_graph=True,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
enable_trtllm_sampler=enable_trtllm_sampler)
|
||||
pytorch_config = dict(use_cuda_graph=True,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
enable_trtllm_sampler=enable_trtllm_sampler)
|
||||
|
||||
trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)
|
||||
|
||||
@ -36,7 +34,7 @@ def create_llm(model_dir, disable_overlap_scheduler, enable_trtllm_sampler):
|
||||
tensor_parallel_size=1,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
kv_cache_config=trt_kv_cache_config,
|
||||
max_num_tokens=
|
||||
128 # Only one request longer than max_num_tokens is required to test chunked prefill
|
||||
|
||||
@ -7,7 +7,6 @@ from utils.util import similar
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import KvCacheConfig as TRT_KvCacheConfig
|
||||
|
||||
|
||||
@ -25,8 +24,7 @@ def model_path():
|
||||
|
||||
def create_llm(model_dir):
|
||||
"""Create LLM with specific overlap scheduler setting"""
|
||||
pytorch_config = PyTorchConfig(use_cuda_graph=True,
|
||||
enable_trtllm_sampler=True)
|
||||
pytorch_config = dict(use_cuda_graph=True, enable_trtllm_sampler=True)
|
||||
|
||||
trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)
|
||||
|
||||
@ -35,7 +33,7 @@ def create_llm(model_dir):
|
||||
tensor_parallel_size=1,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
**pytorch_config,
|
||||
kv_cache_config=trt_kv_cache_config,
|
||||
max_num_tokens=
|
||||
128 # Only one request longer than max_num_tokens is required to test chunked prefill
|
||||
|
||||
@ -35,7 +35,7 @@ methods:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.PeftCacheConfig]
|
||||
default: null
|
||||
scheduler_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.SchedulerConfig]
|
||||
annotation: tensorrt_llm.llmapi.llm_args.SchedulerConfig
|
||||
default: null
|
||||
extended_runtime_perf_knob_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.ExtendedRuntimePerfKnobConfig]
|
||||
|
||||
@ -103,7 +103,7 @@ methods:
|
||||
annotation: bool
|
||||
default: false
|
||||
kv_cache_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.KvCacheConfig]
|
||||
annotation: tensorrt_llm.llmapi.llm_args.KvCacheConfig
|
||||
default: null
|
||||
return_annotation: None
|
||||
generate:
|
||||
|
||||
@ -25,9 +25,7 @@ def temp_extra_llm_api_options_file(request):
|
||||
try:
|
||||
extra_llm_api_options_dict = {
|
||||
"guided_decoding_backend": "xgrammar",
|
||||
"pytorch_backend_config": {
|
||||
"disable_overlap_scheduler": True,
|
||||
}
|
||||
"disable_overlap_scheduler": True,
|
||||
}
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
|
||||
@ -5,7 +5,6 @@ from fastapi.testclient import TestClient
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tensorrt_llm._torch.llm import LLM as PyTorchLLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig
|
||||
from tensorrt_llm.serve.openai_server import OpenAIServer
|
||||
|
||||
@ -23,8 +22,7 @@ def client():
|
||||
build_config=build_config,
|
||||
kv_cache_config=KvCacheConfig(),
|
||||
backend="pytorch",
|
||||
pytorch_backend_config=PyTorchConfig(
|
||||
enable_iter_perf_stats=True, ))
|
||||
enable_iter_perf_stats=True)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(llama_model_path)
|
||||
|
||||
app_instance = OpenAIServer(llm,
|
||||
|
||||
@ -1879,11 +1879,10 @@ def llm_get_stats_test_harness(tp_size: int = 1,
|
||||
|
||||
if pytorch_backend:
|
||||
from tensorrt_llm._torch import LLM as LLM_torch
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
llm_args_extra["pytorch_backend_config"] = PyTorchConfig(
|
||||
enable_iter_perf_stats=True,
|
||||
enable_iter_req_stats=enable_iter_req_stats,
|
||||
disable_overlap_scheduler=not use_overlap)
|
||||
llm_args_extra.update(
|
||||
dict(enable_iter_perf_stats=True,
|
||||
enable_iter_req_stats=enable_iter_req_stats,
|
||||
disable_overlap_scheduler=not use_overlap))
|
||||
LLM_CLASS = LLM_torch
|
||||
else:
|
||||
LLM_CLASS = LLM
|
||||
@ -1949,11 +1948,10 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
|
||||
|
||||
if pytorch_backend:
|
||||
from tensorrt_llm._torch import LLM as LLM_torch
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
llm_args_extra["pytorch_backend_config"] = PyTorchConfig(
|
||||
enable_iter_perf_stats=True,
|
||||
enable_iter_req_stats=enable_iter_req_stats,
|
||||
disable_overlap_scheduler=not use_overlap)
|
||||
llm_args_extra.update(
|
||||
dict(enable_iter_perf_stats=True,
|
||||
enable_iter_req_stats=enable_iter_req_stats,
|
||||
disable_overlap_scheduler=not use_overlap))
|
||||
LLM_CLASS = LLM_torch
|
||||
else:
|
||||
LLM_CLASS = LLM
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
import tensorrt_llm.bindings.executor as tle
|
||||
from tensorrt_llm.llmapi.llm_args import *
|
||||
from tensorrt_llm.llmapi.llm_utils import *
|
||||
|
||||
from .test_llm import llama_model_path
|
||||
|
||||
@ -50,7 +50,7 @@ speculative_config:
|
||||
f.seek(0)
|
||||
dict_content = yaml.safe_load(f)
|
||||
|
||||
llm_args = LlmArgs(llama_model_path)
|
||||
llm_args = LlmArgs(model=llama_model_path)
|
||||
llm_args_dict = update_llm_args_with_extra_dict(llm_args.to_dict(),
|
||||
dict_content)
|
||||
llm_args = LlmArgs(**llm_args_dict)
|
||||
@ -173,3 +173,46 @@ def test_PeftCacheConfig_declaration():
|
||||
assert pybind_config.device_cache_percent == 0.5
|
||||
assert pybind_config.host_cache_size == 1024
|
||||
assert pybind_config.lora_prefetch_dir == "."
|
||||
|
||||
|
||||
class TestTorchLlmArgsCudaGraphSettings:
|
||||
|
||||
def test_cuda_graph_batch_sizes_case_0(self):
|
||||
# set both cuda_graph_batch_sizes and cuda_graph_max_batch_size, and
|
||||
# cuda_graph_batch_sizes is not equal to generated
|
||||
with pytest.raises(ValueError):
|
||||
TorchLlmArgs(model=llama_model_path,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_batch_sizes=[1, 2, 3],
|
||||
cuda_graph_max_batch_size=128)
|
||||
|
||||
def test_cuda_graph_batch_sizes_case_0_1(self):
|
||||
# set both cuda_graph_batch_sizes and cuda_graph_max_batch_size, and
|
||||
# cuda_graph_batch_sizes is equal to generated
|
||||
args = TorchLlmArgs(model=llama_model_path,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_padding_enabled=True,
|
||||
cuda_graph_batch_sizes=TorchLlmArgs.
|
||||
_generate_cuda_graph_batch_sizes(128, True),
|
||||
cuda_graph_max_batch_size=128)
|
||||
assert args.cuda_graph_batch_sizes == TorchLlmArgs._generate_cuda_graph_batch_sizes(
|
||||
128, True)
|
||||
assert args.cuda_graph_max_batch_size == 128
|
||||
|
||||
def test_cuda_graph_batch_sizes_case_1(self):
|
||||
# set cuda_graph_batch_sizes only
|
||||
args = TorchLlmArgs(model=llama_model_path,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_padding_enabled=True,
|
||||
cuda_graph_batch_sizes=[1, 2, 4])
|
||||
assert args.cuda_graph_batch_sizes == [1, 2, 4]
|
||||
|
||||
def test_cuda_graph_batch_sizes_case_2(self):
|
||||
# set cuda_graph_max_batch_size only
|
||||
args = TorchLlmArgs(model=llama_model_path,
|
||||
use_cuda_graph=True,
|
||||
cuda_graph_padding_enabled=True,
|
||||
cuda_graph_max_batch_size=128)
|
||||
assert args.cuda_graph_batch_sizes == TorchLlmArgs._generate_cuda_graph_batch_sizes(
|
||||
128, True)
|
||||
assert args.cuda_graph_max_batch_size == 128
|
||||
|
||||
@ -2,7 +2,6 @@ import asyncio
|
||||
import time
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._utils import KVCacheEventSerializer
|
||||
@ -48,7 +47,7 @@ def create_llm(tensor_parallel_size=1):
|
||||
return LLM(model=llama_model_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
pytorch_backend_config=PyTorchConfig(autotuner_enabled=False),
|
||||
autotuner_enabled=False,
|
||||
backend="pytorch")
|
||||
|
||||
|
||||
|
||||
@ -90,10 +90,9 @@ def test_llm_reward_model():
|
||||
tokenized_input = tokenizer(prompts, return_tensors="pt")["input_ids"]
|
||||
|
||||
from tensorrt_llm._torch import LLM as LLM_torch
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
llm = LLM_torch(model=rm_model_path,
|
||||
pytorch_backend_config=PyTorchConfig(
|
||||
attn_backend="VANILLA", disable_overlap_scheduler=True))
|
||||
attn_backend="VANILLA",
|
||||
disable_overlap_scheduler=True)
|
||||
|
||||
sampling_params = SamplingParams(return_context_logits=True)
|
||||
|
||||
|
||||
@ -8,8 +8,7 @@ backend: "pytorch"
|
||||
tensor_parallel_size: 1
|
||||
pipeline_parallel_size: 1
|
||||
|
||||
pytorch_backend_config:
|
||||
use_cuda_graph: False
|
||||
use_cuda_graph: False
|
||||
|
||||
# ======= Triton Server Configurations =======
|
||||
# Triton Configurations to override the default values in config.pbtxt
|
||||
|
||||
Loading…
Reference in New Issue
Block a user