[TRTLLM-5530][BREAKING CHANGE] refactor: unify KvCacheConfig in LLM class for pytorch backend (#5752)

Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Yan Chunwei 2025-07-16 16:42:59 +08:00 committed by GitHub
parent 10349b54df
commit a02606a9e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 108 additions and 158 deletions

View File

@ -149,6 +149,7 @@ def setup_llm(args):
kv_cache_config = KvCacheConfig(
enable_block_reuse=not args.disable_kv_cache_reuse,
free_gpu_memory_fraction=args.kv_cache_fraction,
dtype=args.kv_cache_dtype,
)
spec_decode_algo = args.spec_decode_algo.upper(
@ -194,7 +195,6 @@ def setup_llm(args):
model=args.model_dir,
backend='pytorch',
disable_overlap_scheduler=args.disable_overlap_scheduler,
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
cuda_graph_config=cuda_graph_config,

View File

@ -88,12 +88,14 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
enable_chunked_prefill = params.get("enable_chunked_prefill", False)
kv_cache_dtype = "auto"
kv_cache_config = {}
if extra_llm_api_options:
with open(extra_llm_api_options, 'r') as f:
llm_args_dict = yaml.safe_load(f)
if "kv_cache_dtype" in llm_args_dict:
kv_cache_dtype = llm_args_dict["kv_cache_dtype"]
kv_cache_config = llm_args_dict.get("kv_cache_config", {
"dtype": "auto",
})
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
enable_chunked_prefill)
@ -158,9 +160,11 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"max_batch_size": max_batch_size
}
kv_cache_config["dtype"] = kv_cache_dtype
pyt_options = {
"cuda_graph_config": cuda_graph_config,
"kv_cache_dtype": kv_cache_dtype,
"kv_cache_config": kv_cache_config,
}
backend = params.get("backend", "pytorch")

View File

@ -112,7 +112,6 @@ class PerformanceOptions:
def get_autodeploy_perf_config(self) -> Dict:
AutoDeployPerfConfig = dict
ad_config = AutoDeployPerfConfig()
ad_config["kv_cache_dtype"] = "auto"
ad_config["attn_backend"] = "flashinfer"
return ad_config

View File

@ -11,6 +11,7 @@ from tensorrt_llm.bench.dataclasses.general import DatasetMetadata
from tensorrt_llm.bench.dataclasses.statistics import (BenchmarkStatistics,
PercentileStats,
RequestRecord)
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.logger import Logger
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
@ -275,8 +276,17 @@ class ReportUtility:
model = self.rt_cfg.model_path or self.rt_cfg.model
model_config = ModelConfig.from_pretrained(model,
trust_remote_code=True)
validate_and_set_kv_cache_quant(model_config,
self.kwargs["kv_cache_dtype"])
kv_cache_config = self.kwargs.get("kv_cache_config",
KvCacheConfig())
if isinstance(kv_cache_config, KvCacheConfig):
kv_cache_dtype = kv_cache_config.dtype
elif isinstance(kv_cache_config, dict):
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
else:
raise ValueError(
f"Invalid kv_cache_config type: {type(kv_cache_config)}.")
validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)
stats_dict["engine"] |= {
"backend":

View File

@ -821,6 +821,10 @@ class KvCacheConfig(BaseModel, PybindMirror):
use_uvm: bool = Field(default=False,
description="Whether to use UVM for the KV cache.")
# This is a pure python field, not a pybind field. It is only for the Pytorch backend.
dtype: str = Field(default="auto",
description="The data type to use for the KV cache.")
def _to_pybind(self):
return _KvCacheConfig(
enable_block_reuse=self.enable_block_reuse,
@ -1024,10 +1028,6 @@ class BaseLlmArgs(BaseModel):
lora_config: Optional[LoraConfig] = Field(
default=None, description="LoRA configuration for the model.")
# Quantization and calibration configurations
quant_config: Optional[QuantConfig] = Field(
default=None, description="Quantization config.", validate_default=True)
# Several options from ExecutorConfig, expanded here for less hierarchy
kv_cache_config: KvCacheConfig = Field(default_factory=KvCacheConfig,
description="KV cache config.")
@ -1208,13 +1208,6 @@ class BaseLlmArgs(BaseModel):
raise RuntimeError("Pre SM 80 GPUs do not support bfloat16")
return v
@field_validator("quant_config", mode='before')
@classmethod
def validate_quant_config(cls, v, info):
if v is None:
v = QuantConfig()
return v
@field_validator("gpus_per_node", mode='before')
@classmethod
def validate_gpus_per_node(cls, v, info):
@ -1657,6 +1650,10 @@ class TrtLlmArgs(BaseLlmArgs):
calib_config: Optional[CalibConfig] = Field(
default=None, description="Calibration config.", validate_default=True)
# Quantization and calibration configurations
quant_config: Optional[QuantConfig] = Field(
default=None, description="Quantization config.", validate_default=True)
embedding_parallel_mode: str = Field(
default='SHARDING_ALONG_VOCAB',
description="The embedding parallel mode.")
@ -1694,6 +1691,13 @@ class TrtLlmArgs(BaseLlmArgs):
return CalibConfig()
return v
@field_validator("quant_config", mode='before')
@classmethod
def validate_quant_config(cls, v, info):
if v is None:
v = QuantConfig()
return v
@model_validator(mode="after")
def setup_embedding_parallel_mode(self):
if self.embedding_parallel_mode == 'NONE':
@ -1738,6 +1742,11 @@ class TrtLlmArgs(BaseLlmArgs):
f"Invalid build_cache_config: {self.enable_build_cache}")
return self
@model_validator(mode="after")
def validate_kv_cache_dtype(self):
assert self.kv_cache_config.dtype == "auto", "KvCacheConfig.dtype is not supported by the TensorRT backend."
return self
class LoadFormat(Enum):
AUTO = 0
@ -1811,9 +1820,6 @@ class TorchLlmArgs(BaseLlmArgs):
"If true, will use the TRTLLM sampler instead of the PyTorch sampler. The TRTLLM sampler has a wide coverage of sampling strategies."
)
kv_cache_dtype: str = Field(default="auto",
description="Data type for KV cache.")
enable_iter_perf_stats: bool = Field(
default=False, description="Enable iteration performance statistics.")
@ -1867,6 +1873,19 @@ class TorchLlmArgs(BaseLlmArgs):
'MNNVL']] = Field(default='AUTO',
description="Allreduce strategy to use.")
# PrivateVars
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
@property
def quant_config(self) -> QuantConfig:
if self._quant_config is None:
self._quant_config = QuantConfig()
return self._quant_config
@quant_config.setter
def quant_config(self, value: QuantConfig):
self._quant_config = value
# TODO: remove backend later
@field_validator('backend', mode='before')
def init_backend(cls, v):
@ -1994,6 +2013,22 @@ class TorchLlmArgs(BaseLlmArgs):
return self
@model_validator(mode='after')
def sync_quant_config_with_kv_cache_config_dtype(self) -> 'TorchLlmArgs':
if self.kv_cache_config is None:
return self
assert self.quant_config is not None
if self.kv_cache_config.dtype == "auto":
return self
elif self.kv_cache_config.dtype == 'fp8':
self.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
else:
logger.warning(
f"Cannot sync quant_config.kv_cache_quant_algo with kv_cache_config.dtype of {self.kv_cache_config.dtype}, "
"please update the validator")
return self
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "PyTorchConfig":
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@ -2017,7 +2052,7 @@ class TorchLlmArgs(BaseLlmArgs):
moe_backend=self.moe_config.backend,
enable_mixed_sampler=self.enable_mixed_sampler,
enable_trtllm_sampler=self.enable_trtllm_sampler,
kv_cache_dtype=self.kv_cache_dtype,
kv_cache_dtype=self.kv_cache_config.dtype,
enable_iter_perf_stats=self.enable_iter_perf_stats,
enable_iter_req_stats=self.enable_iter_req_stats,
print_iter_log=self.print_iter_log,

View File

@ -401,6 +401,9 @@ class ModelLoader:
logger.info(f"Setting {key}={value} from HF quant config.")
setattr(quant_config, key, value)
# Update the quant_config in llm_args for pytorch
self.llm_args.quant_config = quant_config
return True
hf_config_path = f"{self._model_dir}/config.json"

View File

@ -23,7 +23,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SamplingParams,
TorchCompileConfig)
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
from ..conftest import (llm_models_root, parametrize_with_ids,
@ -51,7 +50,6 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
with LLM(model_path) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
@ -67,7 +65,6 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
with LLM(f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B",
stream_interval=stream_interval) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
assert llm.args.stream_interval == stream_interval
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm, streaming=True)
@ -143,7 +140,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
@parametrize_with_ids("fp8kv", [False, True])
def test_fp8(self, fp8kv, attn_backend, torch_compile):
quant_config = QuantConfig(QuantAlgo.FP8)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
@ -154,15 +150,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=torch_compile,
)
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8")
with LLM(
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
quant_config=quant_config,
**pytorch_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -181,7 +173,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
"Pipeline parallel with torch.compile is not supported yet.\n"
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
"discarded at graph breaks.")
quant_config = QuantConfig(QuantAlgo.FP8)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
@ -192,17 +183,13 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=torch_compile,
)
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8")
with LLM(
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
quant_config=quant_config,
**pytorch_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -336,7 +323,6 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B-FP8"
with LLM(model_path) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
@ -358,7 +344,6 @@ class TestLlama3_2_3B(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-3B-Instruct-FP8"
with LLM(model_path) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
@ -401,7 +386,6 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4"
with LLM(model_path, tensor_parallel_size=4) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
@ -583,7 +567,6 @@ class TestMixtral8x7B(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp8"
with LLM(model_path, tensor_parallel_size=2) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
@ -595,7 +578,6 @@ class TestMixtral8x7B(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp4"
with LLM(model_path, tensor_parallel_size=2) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = CnnDailymail(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
@ -716,11 +698,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
torch_compile_config=torch_compile_config,
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
mtp_config = None
mtp_nextn = 2
@ -733,13 +712,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -775,11 +751,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
moe_config=MoeConfig(backend="CUTEDSL"),
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
mtp_config = None
if mtp_nextn > 0:
@ -789,14 +762,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config,
) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -837,14 +807,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(enable_padding=True),
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
@ -886,11 +853,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
torch_compile_config=torch_compile_config,
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
mtp_config = None
if mtp_nextn > 0:
@ -902,13 +866,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -955,11 +916,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
moe_config=MoeConfig(backend="CUTEDSL"),
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
mtp_config = None
if mtp_nextn > 0:
@ -972,13 +930,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config,
) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -1045,23 +1000,20 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
num_slots = 80
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
layer_updates_per_iter=2)
pytorch_backend_options = dict(cuda_graph_config=CudaGraphConfig(),
moe_config=MoeConfig(
backend="WIDEEP",
load_balancer=eplb_config))
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
pytorch_config = dict(cuda_graph_config=CudaGraphConfig(),
moe_config=MoeConfig(backend="WIDEEP",
load_balancer=eplb_config))
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_backend_options["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
tensor_parallel_size=4,
moe_expert_parallel_size=4,
kv_cache_config=kv_cache_config,
**pytorch_backend_options,
enable_attention_dp=True,
quant_config=quant_config) as llm:
with LLM(
f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
tensor_parallel_size=4,
moe_expert_parallel_size=4,
kv_cache_config=kv_cache_config,
**pytorch_config,
enable_attention_dp=True,
) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -1095,21 +1047,15 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -1157,11 +1103,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
tensor_parallel_size=tp_size,
@ -1169,12 +1112,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -1215,21 +1155,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
if quant_dtype == "none":
assert not fp8kv
quant_config = None
else:
quant_config = QuantConfig()
if quant_dtype == "fp8":
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
elif quant_dtype == "nvfp4":
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
with LLM(model_path,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
if quant_dtype == "fp8":
@ -1237,8 +1169,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
elif quant_dtype == "nvfp4":
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -1275,23 +1205,15 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
if quant_dtype == "none":
assert not fp8kv
quant_config = None
else:
quant_config = QuantConfig()
if quant_dtype == "fp8":
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
elif quant_dtype == "nvfp4":
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
with LLM(model_path,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=True,
max_num_tokens=512,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=True,
speculative_config=mtp_config) as llm:
@ -1300,9 +1222,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
elif quant_dtype == "nvfp4":
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@ -1388,11 +1307,8 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_config=MoeConfig(backend=moe_backend))
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
mtp_config = None
if mtp_nextn > 0:
@ -1404,14 +1320,11 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.moe_backend == moe_backend
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
@ -1438,11 +1351,8 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_config["kv_cache_dtype"] = "fp8"
kv_cache_config.dtype = "fp8"
mtp_config = None
if mtp_nextn > 0:
@ -1454,12 +1364,9 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
if fp8kv:
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
@ -1549,7 +1456,6 @@ class TestLlama3_1NemotronNano8Bv1(LlmapiAccuracyTestHarness):
model_path = f"{llm_models_root()}/Llama-3.1-Nemotron-Nano-8B-v1-FP8"
with LLM(model_path) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
@ -1601,7 +1507,6 @@ class TestNemotronUltra(LlmapiAccuracyTestHarness):
kv_cache_config=KvCacheConfig(
free_gpu_memory_fraction=0.85)) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
@ -1634,7 +1539,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
kv_cache_config=kv_cache_config,
max_batch_size=256) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)

View File

@ -110,14 +110,12 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
worker_pytorch_configs.append(
dict(
disable_overlap_scheduler=True,
kv_cache_dtype="auto",
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
# Generation worker
worker_pytorch_configs.append(
dict(
disable_overlap_scheduler=not generation_overlap,
kv_cache_dtype="auto",
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)]
@ -233,18 +231,16 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
worker_pytorch_configs.append(
dict(
disable_overlap_scheduler=True,
kv_cache_dtype="auto",
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
# Generation worker
worker_pytorch_configs.append(
dict(
disable_overlap_scheduler=not generation_overlap,
kv_cache_dtype="auto",
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
kv_cache_configs = [
KvCacheConfig(max_tokens=128, enable_block_reuse=False)
KvCacheConfig(max_tokens=128, enable_block_reuse=False, dtype="auto")
for _ in range(2)
]
model_names = [model_path(model) for _ in range(2)]

View File

@ -17,6 +17,8 @@
Model pytorch yaml config for trtllm-bench perf tests
"""
from tensorrt_llm.llmapi import KvCacheConfig
def recursive_update(d, u):
for k, v in u.items():
@ -186,4 +188,9 @@ def get_model_yaml_config(model_label: str,
}
base_config.update(lora_config)
kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig())
if 'kv_cache_dtype' in base_config:
kv_cache_config.dtype = base_config.pop('kv_cache_dtype', 'auto')
base_config.update({'kv_cache_config': kv_cache_config})
return base_config

View File

@ -68,7 +68,6 @@ def test_deepseek_trtllmgen(model_name):
pytorch_config = dict(
disable_overlap_scheduler=True,
kv_cache_dtype="auto",
attn_backend="TRTLLM",
load_format="dummy",
moe_config=MoeConfig(backend="TRTLLM"),
@ -89,7 +88,8 @@ def test_deepseek_trtllmgen(model_name):
moe_tensor_parallel_size=-1,
enable_attention_dp=False,
speculative_config=spec_config,
kv_cache_config=KvCacheConfig(enable_block_reuse=False,
kv_cache_config=KvCacheConfig(dtype="auto",
enable_block_reuse=False,
free_gpu_memory_fraction=0.4))
sampling_params = SamplingParams(max_tokens=20)

View File

@ -63,7 +63,6 @@ def test_deepseek_streaming(model_name, backend, quant, tp_size):
pytorch_config = dict(
disable_overlap_scheduler=True,
kv_cache_dtype="auto",
attn_backend=backend,
)
moe_config = MoeConfig(max_num_tokens=moe_max_num_tokens)

View File

@ -57,13 +57,6 @@ methods:
guided_decoding_backend:
annotation: Optional[Literal["xgrammar", "llguidance"]]
default: null
# Quantization and calibration
quant_config:
annotation: Optional[tensorrt_llm.models.modeling_utils.QuantConfig]
default: null
calib_config:
annotation: Optional[tensorrt_llm.llmapi.llm_utils.CalibConfig]
default: null
# Speculative decoding
speculative_config:
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType]