mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-5530][BREAKING CHANGE] refactor: unify KvCacheConfig in LLM class for pytorch backend (#5752)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
10349b54df
commit
a02606a9e2
@ -149,6 +149,7 @@ def setup_llm(args):
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=not args.disable_kv_cache_reuse,
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
dtype=args.kv_cache_dtype,
|
||||
)
|
||||
|
||||
spec_decode_algo = args.spec_decode_algo.upper(
|
||||
@ -194,7 +195,6 @@ def setup_llm(args):
|
||||
model=args.model_dir,
|
||||
backend='pytorch',
|
||||
disable_overlap_scheduler=args.disable_overlap_scheduler,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
kv_cache_config=kv_cache_config,
|
||||
attn_backend=args.attention_backend,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
|
||||
@ -88,12 +88,14 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
enable_chunked_prefill = params.get("enable_chunked_prefill", False)
|
||||
|
||||
kv_cache_dtype = "auto"
|
||||
kv_cache_config = {}
|
||||
if extra_llm_api_options:
|
||||
with open(extra_llm_api_options, 'r') as f:
|
||||
llm_args_dict = yaml.safe_load(f)
|
||||
|
||||
if "kv_cache_dtype" in llm_args_dict:
|
||||
kv_cache_dtype = llm_args_dict["kv_cache_dtype"]
|
||||
kv_cache_config = llm_args_dict.get("kv_cache_config", {
|
||||
"dtype": "auto",
|
||||
})
|
||||
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
|
||||
|
||||
enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
|
||||
enable_chunked_prefill)
|
||||
@ -158,9 +160,11 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
|
||||
"max_batch_size": max_batch_size
|
||||
}
|
||||
|
||||
kv_cache_config["dtype"] = kv_cache_dtype
|
||||
|
||||
pyt_options = {
|
||||
"cuda_graph_config": cuda_graph_config,
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
}
|
||||
|
||||
backend = params.get("backend", "pytorch")
|
||||
|
||||
@ -112,7 +112,6 @@ class PerformanceOptions:
|
||||
def get_autodeploy_perf_config(self) -> Dict:
|
||||
AutoDeployPerfConfig = dict
|
||||
ad_config = AutoDeployPerfConfig()
|
||||
ad_config["kv_cache_dtype"] = "auto"
|
||||
ad_config["attn_backend"] = "flashinfer"
|
||||
return ad_config
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from tensorrt_llm.bench.dataclasses.general import DatasetMetadata
|
||||
from tensorrt_llm.bench.dataclasses.statistics import (BenchmarkStatistics,
|
||||
PercentileStats,
|
||||
RequestRecord)
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.logger import Logger
|
||||
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
||||
|
||||
@ -275,8 +276,17 @@ class ReportUtility:
|
||||
model = self.rt_cfg.model_path or self.rt_cfg.model
|
||||
model_config = ModelConfig.from_pretrained(model,
|
||||
trust_remote_code=True)
|
||||
validate_and_set_kv_cache_quant(model_config,
|
||||
self.kwargs["kv_cache_dtype"])
|
||||
kv_cache_config = self.kwargs.get("kv_cache_config",
|
||||
KvCacheConfig())
|
||||
if isinstance(kv_cache_config, KvCacheConfig):
|
||||
kv_cache_dtype = kv_cache_config.dtype
|
||||
elif isinstance(kv_cache_config, dict):
|
||||
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid kv_cache_config type: {type(kv_cache_config)}.")
|
||||
|
||||
validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)
|
||||
|
||||
stats_dict["engine"] |= {
|
||||
"backend":
|
||||
|
||||
@ -821,6 +821,10 @@ class KvCacheConfig(BaseModel, PybindMirror):
|
||||
use_uvm: bool = Field(default=False,
|
||||
description="Whether to use UVM for the KV cache.")
|
||||
|
||||
# This is a pure python field, not a pybind field. It is only for the Pytorch backend.
|
||||
dtype: str = Field(default="auto",
|
||||
description="The data type to use for the KV cache.")
|
||||
|
||||
def _to_pybind(self):
|
||||
return _KvCacheConfig(
|
||||
enable_block_reuse=self.enable_block_reuse,
|
||||
@ -1024,10 +1028,6 @@ class BaseLlmArgs(BaseModel):
|
||||
lora_config: Optional[LoraConfig] = Field(
|
||||
default=None, description="LoRA configuration for the model.")
|
||||
|
||||
# Quantization and calibration configurations
|
||||
quant_config: Optional[QuantConfig] = Field(
|
||||
default=None, description="Quantization config.", validate_default=True)
|
||||
|
||||
# Several options from ExecutorConfig, expanded here for less hierarchy
|
||||
kv_cache_config: KvCacheConfig = Field(default_factory=KvCacheConfig,
|
||||
description="KV cache config.")
|
||||
@ -1208,13 +1208,6 @@ class BaseLlmArgs(BaseModel):
|
||||
raise RuntimeError("Pre SM 80 GPUs do not support bfloat16")
|
||||
return v
|
||||
|
||||
@field_validator("quant_config", mode='before')
|
||||
@classmethod
|
||||
def validate_quant_config(cls, v, info):
|
||||
if v is None:
|
||||
v = QuantConfig()
|
||||
return v
|
||||
|
||||
@field_validator("gpus_per_node", mode='before')
|
||||
@classmethod
|
||||
def validate_gpus_per_node(cls, v, info):
|
||||
@ -1657,6 +1650,10 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
calib_config: Optional[CalibConfig] = Field(
|
||||
default=None, description="Calibration config.", validate_default=True)
|
||||
|
||||
# Quantization and calibration configurations
|
||||
quant_config: Optional[QuantConfig] = Field(
|
||||
default=None, description="Quantization config.", validate_default=True)
|
||||
|
||||
embedding_parallel_mode: str = Field(
|
||||
default='SHARDING_ALONG_VOCAB',
|
||||
description="The embedding parallel mode.")
|
||||
@ -1694,6 +1691,13 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
return CalibConfig()
|
||||
return v
|
||||
|
||||
@field_validator("quant_config", mode='before')
|
||||
@classmethod
|
||||
def validate_quant_config(cls, v, info):
|
||||
if v is None:
|
||||
v = QuantConfig()
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_embedding_parallel_mode(self):
|
||||
if self.embedding_parallel_mode == 'NONE':
|
||||
@ -1738,6 +1742,11 @@ class TrtLlmArgs(BaseLlmArgs):
|
||||
f"Invalid build_cache_config: {self.enable_build_cache}")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_kv_cache_dtype(self):
|
||||
assert self.kv_cache_config.dtype == "auto", "KvCacheConfig.dtype is not supported by the TensorRT backend."
|
||||
return self
|
||||
|
||||
|
||||
class LoadFormat(Enum):
|
||||
AUTO = 0
|
||||
@ -1811,9 +1820,6 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
"If true, will use the TRTLLM sampler instead of the PyTorch sampler. The TRTLLM sampler has a wide coverage of sampling strategies."
|
||||
)
|
||||
|
||||
kv_cache_dtype: str = Field(default="auto",
|
||||
description="Data type for KV cache.")
|
||||
|
||||
enable_iter_perf_stats: bool = Field(
|
||||
default=False, description="Enable iteration performance statistics.")
|
||||
|
||||
@ -1867,6 +1873,19 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
'MNNVL']] = Field(default='AUTO',
|
||||
description="Allreduce strategy to use.")
|
||||
|
||||
# PrivateVars
|
||||
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def quant_config(self) -> QuantConfig:
|
||||
if self._quant_config is None:
|
||||
self._quant_config = QuantConfig()
|
||||
return self._quant_config
|
||||
|
||||
@quant_config.setter
|
||||
def quant_config(self, value: QuantConfig):
|
||||
self._quant_config = value
|
||||
|
||||
# TODO: remove backend later
|
||||
@field_validator('backend', mode='before')
|
||||
def init_backend(cls, v):
|
||||
@ -1994,6 +2013,22 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode='after')
|
||||
def sync_quant_config_with_kv_cache_config_dtype(self) -> 'TorchLlmArgs':
|
||||
if self.kv_cache_config is None:
|
||||
return self
|
||||
|
||||
assert self.quant_config is not None
|
||||
if self.kv_cache_config.dtype == "auto":
|
||||
return self
|
||||
elif self.kv_cache_config.dtype == 'fp8':
|
||||
self.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cannot sync quant_config.kv_cache_quant_algo with kv_cache_config.dtype of {self.kv_cache_config.dtype}, "
|
||||
"please update the validator")
|
||||
return self
|
||||
|
||||
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
|
||||
def get_pytorch_backend_config(self) -> "PyTorchConfig":
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
@ -2017,7 +2052,7 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
moe_backend=self.moe_config.backend,
|
||||
enable_mixed_sampler=self.enable_mixed_sampler,
|
||||
enable_trtllm_sampler=self.enable_trtllm_sampler,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
kv_cache_dtype=self.kv_cache_config.dtype,
|
||||
enable_iter_perf_stats=self.enable_iter_perf_stats,
|
||||
enable_iter_req_stats=self.enable_iter_req_stats,
|
||||
print_iter_log=self.print_iter_log,
|
||||
|
||||
@ -401,6 +401,9 @@ class ModelLoader:
|
||||
logger.info(f"Setting {key}={value} from HF quant config.")
|
||||
setattr(quant_config, key, value)
|
||||
|
||||
# Update the quant_config in llm_args for pytorch
|
||||
self.llm_args.quant_config = quant_config
|
||||
|
||||
return True
|
||||
|
||||
hf_config_path = f"{self._model_dir}/config.json"
|
||||
|
||||
@ -23,7 +23,6 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
KvCacheConfig, MoeConfig, MTPDecodingConfig,
|
||||
NGramDecodingConfig, SamplingParams,
|
||||
TorchCompileConfig)
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import (llm_models_root, parametrize_with_ids,
|
||||
@ -51,7 +50,6 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
|
||||
with LLM(model_path) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -67,7 +65,6 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
with LLM(f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B",
|
||||
stream_interval=stream_interval) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.stream_interval == stream_interval
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm, streaming=True)
|
||||
@ -143,7 +140,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
@parametrize_with_ids("attn_backend", ["TRTLLM", "FLASHINFER"])
|
||||
@parametrize_with_ids("fp8kv", [False, True])
|
||||
def test_fp8(self, fp8kv, attn_backend, torch_compile):
|
||||
quant_config = QuantConfig(QuantAlgo.FP8)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
@ -154,15 +150,11 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
disable_overlap_scheduler=torch_compile,
|
||||
)
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8")
|
||||
with LLM(
|
||||
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
|
||||
quant_config=quant_config,
|
||||
**pytorch_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -181,7 +173,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
"Pipeline parallel with torch.compile is not supported yet.\n"
|
||||
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
|
||||
"discarded at graph breaks.")
|
||||
quant_config = QuantConfig(QuantAlgo.FP8)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
@ -192,17 +183,13 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
disable_overlap_scheduler=torch_compile,
|
||||
)
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8")
|
||||
with LLM(
|
||||
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
quant_config=quant_config,
|
||||
**pytorch_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -336,7 +323,6 @@ class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B-FP8"
|
||||
with LLM(model_path) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -358,7 +344,6 @@ class TestLlama3_2_3B(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-3B-Instruct-FP8"
|
||||
with LLM(model_path) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -401,7 +386,6 @@ class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4"
|
||||
with LLM(model_path, tensor_parallel_size=4) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
@ -583,7 +567,6 @@ class TestMixtral8x7B(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp8"
|
||||
with LLM(model_path, tensor_parallel_size=2) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -595,7 +578,6 @@ class TestMixtral8x7B(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp4"
|
||||
with LLM(model_path, tensor_parallel_size=2) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
@ -716,11 +698,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
torch_compile_config=torch_compile_config,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
mtp_nextn = 2
|
||||
@ -733,13 +712,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -775,11 +751,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
moe_config=MoeConfig(backend="CUTEDSL"),
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -789,14 +762,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config,
|
||||
) as llm:
|
||||
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -837,14 +807,11 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
disable_overlap_scheduler=False,
|
||||
cuda_graph_config=CudaGraphConfig(enable_padding=True),
|
||||
)
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
|
||||
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
|
||||
tensor_parallel_size=4,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
@ -886,11 +853,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
torch_compile_config=torch_compile_config,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -902,13 +866,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -955,11 +916,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
moe_config=MoeConfig(backend="CUTEDSL"),
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -972,13 +930,10 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config,
|
||||
) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1045,23 +1000,20 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
num_slots = 80
|
||||
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
|
||||
layer_updates_per_iter=2)
|
||||
pytorch_backend_options = dict(cuda_graph_config=CudaGraphConfig(),
|
||||
moe_config=MoeConfig(
|
||||
backend="WIDEEP",
|
||||
load_balancer=eplb_config))
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
pytorch_config = dict(cuda_graph_config=CudaGraphConfig(),
|
||||
moe_config=MoeConfig(backend="WIDEEP",
|
||||
load_balancer=eplb_config))
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_backend_options["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
|
||||
tensor_parallel_size=4,
|
||||
moe_expert_parallel_size=4,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_backend_options,
|
||||
enable_attention_dp=True,
|
||||
quant_config=quant_config) as llm:
|
||||
with LLM(
|
||||
f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
|
||||
tensor_parallel_size=4,
|
||||
moe_expert_parallel_size=4,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=True,
|
||||
) as llm:
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -1095,21 +1047,15 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1157,11 +1103,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
if mtp_nextn > 0:
|
||||
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
|
||||
tensor_parallel_size=tp_size,
|
||||
@ -1169,12 +1112,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1215,21 +1155,13 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
|
||||
if quant_dtype == "none":
|
||||
assert not fp8kv
|
||||
quant_config = None
|
||||
else:
|
||||
quant_config = QuantConfig()
|
||||
if quant_dtype == "fp8":
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
elif quant_dtype == "nvfp4":
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
with LLM(model_path,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
if quant_dtype == "fp8":
|
||||
@ -1237,8 +1169,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
elif quant_dtype == "nvfp4":
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -1275,23 +1205,15 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
|
||||
if quant_dtype == "none":
|
||||
assert not fp8kv
|
||||
quant_config = None
|
||||
else:
|
||||
quant_config = QuantConfig()
|
||||
if quant_dtype == "fp8":
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
elif quant_dtype == "nvfp4":
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
with LLM(model_path,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=512,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=True,
|
||||
speculative_config=mtp_config) as llm:
|
||||
|
||||
@ -1300,9 +1222,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
elif quant_dtype == "nvfp4":
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -1388,11 +1307,8 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -1404,14 +1320,11 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
|
||||
assert llm.args.moe_backend == moe_backend
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1438,11 +1351,8 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
)
|
||||
|
||||
quant_config = QuantConfig()
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
|
||||
pytorch_config["kv_cache_dtype"] = "fp8"
|
||||
kv_cache_config.dtype = "fp8"
|
||||
|
||||
mtp_config = None
|
||||
if mtp_nextn > 0:
|
||||
@ -1454,12 +1364,9 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
|
||||
moe_expert_parallel_size=ep_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
**pytorch_config,
|
||||
quant_config=quant_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
speculative_config=mtp_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
if fp8kv:
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1549,7 +1456,6 @@ class TestLlama3_1NemotronNano8Bv1(LlmapiAccuracyTestHarness):
|
||||
model_path = f"{llm_models_root()}/Llama-3.1-Nemotron-Nano-8B-v1-FP8"
|
||||
with LLM(model_path) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
@ -1601,7 +1507,6 @@ class TestNemotronUltra(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config=KvCacheConfig(
|
||||
free_gpu_memory_fraction=0.85)) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
@ -1634,7 +1539,6 @@ class TestNemotronH(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=256) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
|
||||
@ -110,14 +110,12 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt,
|
||||
worker_pytorch_configs.append(
|
||||
dict(
|
||||
disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
|
||||
|
||||
# Generation worker
|
||||
worker_pytorch_configs.append(
|
||||
dict(
|
||||
disable_overlap_scheduler=not generation_overlap,
|
||||
kv_cache_dtype="auto",
|
||||
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
|
||||
|
||||
kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)]
|
||||
@ -233,18 +231,16 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
|
||||
worker_pytorch_configs.append(
|
||||
dict(
|
||||
disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
|
||||
|
||||
# Generation worker
|
||||
worker_pytorch_configs.append(
|
||||
dict(
|
||||
disable_overlap_scheduler=not generation_overlap,
|
||||
kv_cache_dtype="auto",
|
||||
cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None))
|
||||
|
||||
kv_cache_configs = [
|
||||
KvCacheConfig(max_tokens=128, enable_block_reuse=False)
|
||||
KvCacheConfig(max_tokens=128, enable_block_reuse=False, dtype="auto")
|
||||
for _ in range(2)
|
||||
]
|
||||
model_names = [model_path(model) for _ in range(2)]
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
Model pytorch yaml config for trtllm-bench perf tests
|
||||
"""
|
||||
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
|
||||
def recursive_update(d, u):
|
||||
for k, v in u.items():
|
||||
@ -186,4 +188,9 @@ def get_model_yaml_config(model_label: str,
|
||||
}
|
||||
base_config.update(lora_config)
|
||||
|
||||
kv_cache_config = base_config.get('kv_cache_config', KvCacheConfig())
|
||||
if 'kv_cache_dtype' in base_config:
|
||||
kv_cache_config.dtype = base_config.pop('kv_cache_dtype', 'auto')
|
||||
base_config.update({'kv_cache_config': kv_cache_config})
|
||||
|
||||
return base_config
|
||||
|
||||
@ -68,7 +68,6 @@ def test_deepseek_trtllmgen(model_name):
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
attn_backend="TRTLLM",
|
||||
load_format="dummy",
|
||||
moe_config=MoeConfig(backend="TRTLLM"),
|
||||
@ -89,7 +88,8 @@ def test_deepseek_trtllmgen(model_name):
|
||||
moe_tensor_parallel_size=-1,
|
||||
enable_attention_dp=False,
|
||||
speculative_config=spec_config,
|
||||
kv_cache_config=KvCacheConfig(enable_block_reuse=False,
|
||||
kv_cache_config=KvCacheConfig(dtype="auto",
|
||||
enable_block_reuse=False,
|
||||
free_gpu_memory_fraction=0.4))
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=20)
|
||||
|
||||
@ -63,7 +63,6 @@ def test_deepseek_streaming(model_name, backend, quant, tp_size):
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=True,
|
||||
kv_cache_dtype="auto",
|
||||
attn_backend=backend,
|
||||
)
|
||||
moe_config = MoeConfig(max_num_tokens=moe_max_num_tokens)
|
||||
|
||||
@ -57,13 +57,6 @@ methods:
|
||||
guided_decoding_backend:
|
||||
annotation: Optional[Literal["xgrammar", "llguidance"]]
|
||||
default: null
|
||||
# Quantization and calibration
|
||||
quant_config:
|
||||
annotation: Optional[tensorrt_llm.models.modeling_utils.QuantConfig]
|
||||
default: null
|
||||
calib_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_utils.CalibConfig]
|
||||
default: null
|
||||
# Speculative decoding
|
||||
speculative_config:
|
||||
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, NoneType]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user