From 2532eb5adc8af6e4c5a62fe25680ab233b7d25a0 Mon Sep 17 00:00:00 2001 From: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Date: Tue, 3 Feb 2026 05:03:42 -0800 Subject: [PATCH] [None][fix] Align kv_scales with modelopt HF checkpoint (#10745) Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> --- tensorrt_llm/_torch/modules/linear.py | 16 +++++----------- .../defs/accuracy/test_llm_api_pytorch.py | 2 +- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 6f601455f4..f697ea5e15 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -761,12 +761,10 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod): # to avoid overflow when dequantizing NVFP4 in attention kernels. copy_weight( module.kv_scales, - torch.tensor([ - 1.0, - max(k_scales).item() * 6.0, - max(v_scales).item() * 6.0 - ], - dtype=torch.float32)) + torch.tensor( + [1.0, max(k_scales).item(), + max(v_scales).item()], + dtype=torch.float32)) module.inv_kv_scales.data = 1.0 / module.kv_scales # Clean up temporary attributes @@ -1367,14 +1365,10 @@ class NVFP4LinearMethod(LinearMethodBase): if os.environ.get("TRTLLM_LOAD_KV_SCALES", "1") == "1": if len(k_scale) != 0: assert len(v_scale) != 0 - # The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448, - # to avoid overflow when dequantizing NVFP4 in attention kernels using FP8 math. copy_weight( module.kv_scales, torch.tensor( - [1.0, max(k_scale) * 6.0, - max(v_scale) * 6.0], - dtype=torch.float32)) + [1.0, max(k_scale), max(v_scale)], dtype=torch.float32)) module.inv_kv_scales.data = 1.0 / module.kv_scales def load_weights_fused_gate_up_linear(self, module: Linear, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 969e26e4e1..cef8dbdd27 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -345,7 +345,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): disable_overlap_scheduler=torch_compile, ) pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="nvfp4") - with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_nvfp4_fp8_hf", + with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_fp8_kv_nvfp4", **pytorch_config) as llm: assert llm.args.quant_config.quant_algo == QuantAlgo.FP8 assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.NVFP4