diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py index 4ba48d8905..63f812a139 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py @@ -2,6 +2,7 @@ import torch from torch import nn from tensorrt_llm._torch.models.modeling_utils import register_mapper +from tensorrt_llm._torch.modules.linear import W4A16_AWQ_LinearMethod from ..base_weight_mapper import BaseWeightMapper @@ -60,7 +61,10 @@ class HfWeightMapper(BaseWeightMapper): weights: dict): if new_name in ['k_proj', 'v_proj']: # k_proj and v_proj shape is [num_kv_heads*head_dim, hidden_dim] - num_kv_heads = weights['weight'].shape[0] // self._head_dim + if isinstance(module.quant_method, W4A16_AWQ_LinearMethod): + num_kv_heads = weights['weight'].shape[0] * 2 // self._head_dim + else: + num_kv_heads = weights['weight'].shape[0] // self._head_dim processed_weights = { k: self._duplicate_kv(weight=v[:],