mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][fix] Align kv_scales with modelopt HF checkpoint (#10745)
Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com>
This commit is contained in:
parent
20946554f6
commit
2532eb5adc
@ -761,12 +761,10 @@ class FP8QDQLinearMethod(UnquantizedLinearMethod):
|
||||
# to avoid overflow when dequantizing NVFP4 in attention kernels.
|
||||
copy_weight(
|
||||
module.kv_scales,
|
||||
torch.tensor([
|
||||
1.0,
|
||||
max(k_scales).item() * 6.0,
|
||||
max(v_scales).item() * 6.0
|
||||
],
|
||||
dtype=torch.float32))
|
||||
torch.tensor(
|
||||
[1.0, max(k_scales).item(),
|
||||
max(v_scales).item()],
|
||||
dtype=torch.float32))
|
||||
module.inv_kv_scales.data = 1.0 / module.kv_scales
|
||||
|
||||
# Clean up temporary attributes
|
||||
@ -1367,14 +1365,10 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
if os.environ.get("TRTLLM_LOAD_KV_SCALES", "1") == "1":
|
||||
if len(k_scale) != 0:
|
||||
assert len(v_scale) != 0
|
||||
# The calibrated KV scales are amax / (6 * 448), but the requested KV scales are amax / 448,
|
||||
# to avoid overflow when dequantizing NVFP4 in attention kernels using FP8 math.
|
||||
copy_weight(
|
||||
module.kv_scales,
|
||||
torch.tensor(
|
||||
[1.0, max(k_scale) * 6.0,
|
||||
max(v_scale) * 6.0],
|
||||
dtype=torch.float32))
|
||||
[1.0, max(k_scale), max(v_scale)], dtype=torch.float32))
|
||||
module.inv_kv_scales.data = 1.0 / module.kv_scales
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
|
||||
@ -345,7 +345,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
disable_overlap_scheduler=torch_compile,
|
||||
)
|
||||
pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="nvfp4")
|
||||
with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_nvfp4_fp8_hf",
|
||||
with LLM(f"{llm_models_root()}/Llama-3_1-8B-Instruct_fp8_kv_nvfp4",
|
||||
**pytorch_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.NVFP4
|
||||
|
||||
Loading…
Reference in New Issue
Block a user