[None][fix] Align kv_scales with modelopt HF checkpoint (#10745)

Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com>
This commit is contained in:
Chenjie Luo 2026-02-03 05:03:42 -08:00 committed by GitHub
parent 20946554f6
commit 2532eb5adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 12 deletions

View File

@ -761,12 +761,10 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
# to avoid overflow when dequantizing NVFP4 in attention kernels.
copy_weight(
module.kv_scales,
torch.tensor([
1.0,
max(k_scales).item() * 6.0,
max(v_scales).item() * 6.0
],
dtype=torch.float32))
torch.tensor(
[1.0, max(k_scales).item(),
max(v_scales).item()],
dtype=torch.float32))
module.inv_kv_scales.data = 1.0 / module.kv_scales
# Clean up temporary attributes
@ -1367,14 +1365,10 @@ class NVFP4LinearMethod(LinearMethodBase):
if os.environ.get("TRTLLM_LOAD_KV_SCALES", "1") == "1":
if len(k_scale) != 0:
assert len(v_scale) != 0
# The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448,
# to avoid overflow when dequantizing NVFP4 in attention kernels using FP8 math.
copy_weight(
module.kv_scales,
torch.tensor(
[1.0, max(k_scale) * 6.0,
max(v_scale) * 6.0],
dtype=torch.float32))
[1.0, max(k_scale), max(v_scale)], dtype=torch.float32))
module.inv_kv_scales.data = 1.0 / module.kv_scales
def load_weights_fused_gate_up_linear(self, module: Linear,

View File

@ -345,7 +345,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=torch_compile,
)
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="nvfp4")
with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_nvfp4_fp8_hf",
with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_fp8_kv_nvfp4",
**pytorch_config) as llm:
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.NVFP4