diff --git a/examples/pytorch/out_of_tree_example/modeling_opt.py b/examples/pytorch/out_of_tree_example/modeling_opt.py index 320a431bc7..58715c44c3 100644 --- a/examples/pytorch/out_of_tree_example/modeling_opt.py +++ b/examples/pytorch/out_of_tree_example/modeling_opt.py @@ -239,7 +239,7 @@ class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]): def load_weights(self, weights: dict): tp_size = self.model_config.mapping.tp_size - head_dim = self.config.hidden_size // self.config.num_attention_heads + num_kv_heads = self.model_config.pretrained_config.num_attention_heads def filter_weights(prefix: str, weights: dict): result = {} @@ -280,7 +280,7 @@ class OPTForCausalLM(DecoderModelForCausalLM[OPTModel, OPTConfig]): k: duplicate_kv_weight( weight=v[:], - head_dim=head_dim, + num_kv_heads=num_kv_heads, tensor_parallel_size=tp_size) if k in ['weight', 'bias'] else v for k, v in fw.items() diff --git a/tensorrt_llm/_torch/models/modeling_gemma3.py b/tensorrt_llm/_torch/models/modeling_gemma3.py index 5e1fca9d69..ea68e18984 100644 --- a/tensorrt_llm/_torch/models/modeling_gemma3.py +++ b/tensorrt_llm/_torch/models/modeling_gemma3.py @@ -328,9 +328,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel, # minor change for Gemma3 RMSNorm. def load_weights(self, weights: Dict): tp_size = self.model_config.mapping.tp_size - head_dim = getattr( - self.config, "head_dim", - self.config.hidden_size // self.config.num_attention_heads) + num_kv_heads = self.config.num_key_value_heads params_map = { 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], @@ -364,7 +362,7 @@ class Gemma3ForCausalLM(DecoderModelForCausalLM[Gemma3TextModel, k: duplicate_kv_weight( weight=v[:], - head_dim=head_dim, + num_kv_heads=num_kv_heads, tensor_parallel_size=tp_size) if k in ["weight", "bias"] else v for k, v in fw.items() diff --git a/tensorrt_llm/_torch/models/modeling_mllama.py b/tensorrt_llm/_torch/models/modeling_mllama.py index ffa5693f17..16ec672539 100644 --- a/tensorrt_llm/_torch/models/modeling_mllama.py +++ b/tensorrt_llm/_torch/models/modeling_mllama.py @@ -333,8 +333,8 @@ class MllamaForConditionalGeneration(nn.Module): tp_size = self.config.mapping.tp_size vision_config = self.config.pretrained_config.vision_config text_config = self.config.pretrained_config.text_config - text_head_dim = text_config.hidden_size // text_config.num_attention_heads - vision_head_dim = vision_config.hidden_size // vision_config.attention_heads + text_config.hidden_size // text_config.num_attention_heads + vision_config.hidden_size // vision_config.attention_heads params_map = { 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], @@ -347,7 +347,7 @@ class MllamaForConditionalGeneration(nn.Module): # skip load weights if tie word embeddings is enabled and layer is lm_head if text_config.tie_word_embeddings and "lm_head" in name: continue - head_dim = vision_head_dim if "vision_model" in name else text_head_dim + num_kv_heads = vision_config.num_key_value_heads if "vision_model" in name else text_config.num_key_value_heads names = name.split('.') if names[-1] in params_map: @@ -360,7 +360,7 @@ class MllamaForConditionalGeneration(nn.Module): k: duplicate_kv_weight( weight=v[:], - head_dim=head_dim, + num_kv_heads=num_kv_heads, tensor_parallel_size=tp_size) if k in ["weight", "bias"] else v for k, v in fw.items() diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 1598a1bb03..258dcee96d 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -394,9 +394,7 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel, tp_size = self.model_config.mapping.tp_size enable_attention_dp = self.model_config.mapping.enable_attention_dp - head_dim = getattr( - self.config, "head_dim", - self.config.hidden_size // self.config.num_attention_heads) + num_kv_heads = self.config.num_key_value_heads params_map = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -419,11 +417,14 @@ class Qwen3MoeForCausalLM(SpecDecOneEngineForCausalLM[Qwen3MoEModel, tensors_need_duplication = ["weight", "bias"] if module.quant_config.quant_mode.has_nvfp4(): tensors_need_duplication.append("weight_scale") + if module.quant_config.quant_mode.has_fp8_block_scales( + ): + tensors_need_duplication.append("weight_scale_inv") if new_name in ["k_proj", "v_proj"]: fw = { k: (duplicate_kv_weight( weight=v[:], - head_dim=head_dim, + num_kv_heads=num_kv_heads, tensor_parallel_size=tp_size if not enable_attention_dp else 1) if k in tensors_need_duplication else v) diff --git a/tensorrt_llm/_torch/models/modeling_qwen_moe.py b/tensorrt_llm/_torch/models/modeling_qwen_moe.py index 7eff895278..ed19a65d3c 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen_moe.py @@ -256,7 +256,7 @@ class Qwen2MoeForCausalLM(DecoderModelForCausalLM[QwenMoeModel, def load_weights(self, weights: Dict): tp_size = self.model_config.mapping.tp_size - head_dim = self.config.hidden_size // self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads params_map = { 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], @@ -281,7 +281,7 @@ class Qwen2MoeForCausalLM(DecoderModelForCausalLM[QwenMoeModel, k: duplicate_kv_weight( weight=v[:], - head_dim=head_dim, + num_kv_heads=num_kv_heads, tensor_parallel_size=tp_size) if k in ["weight", "bias"] else v for k, v in fw.items() diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 1782a374cc..b2dc71a541 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -93,11 +93,9 @@ class MetaInitMode(TorchDispatchMode): return func(*args, **kwargs) -def duplicate_kv_weight(weight: torch.Tensor, head_dim: int, +def duplicate_kv_weight(weight: torch.Tensor, num_kv_heads: int, tensor_parallel_size: int): - num_kv_heads = weight.shape[0] // head_dim - if num_kv_heads >= tensor_parallel_size: assert num_kv_heads % tensor_parallel_size == 0 return weight @@ -109,11 +107,15 @@ def duplicate_kv_weight(weight: torch.Tensor, head_dim: int, if weight.ndim == 1: return weight.repeat_interleave(reps) - # weight - weight = weight.reshape(num_kv_heads, head_dim, + # weight and scale + assert weight.shape[0] % num_kv_heads == 0 + size_per_kv_head = weight.shape[0] // num_kv_heads + weight = weight.reshape(num_kv_heads, size_per_kv_head, -1)[:, None, :, :].expand(num_kv_heads, reps, - head_dim, weight.shape[1]) - return weight.reshape(num_kv_heads * reps * head_dim, -1).clone().detach() + size_per_kv_head, + weight.shape[1]) + return weight.reshape(num_kv_heads * reps * size_per_kv_head, + -1).clone().detach() def iter_modules( @@ -648,9 +650,9 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], logger.info(f"Renamed weights with params_map: {params_map}") tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size - head_dim = getattr( - model.config, "head_dim", - model.config.hidden_size // model.config.num_attention_heads) + num_kv_heads = model.config.num_key_value_heads if hasattr( + model.config, 'num_key_value_heads' + ) and model.config.num_key_value_heads is not None else model.config.num_attention_heads params_map = { 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], @@ -687,13 +689,18 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], fw = filter_weights('.'.join(names[:-1] + [new_name]), weights) if new_name in ['k_proj', 'v_proj']: + num_kv_heads_list = [num_kv_heads + ] * len(fw) if isinstance( + num_kv_heads, + int) else num_kv_heads fw = { k: - duplicate_kv_weight(weight=v[:], - head_dim=head_dim, - tensor_parallel_size=tp_size) + duplicate_kv_weight( + weight=v[:], + num_kv_heads=num_kv_heads_list[i], + tensor_parallel_size=tp_size) if k in ["weight", "bias"] else v - for k, v in fw.items() + for i, (k, v) in enumerate(fw.items()) } module_weights.append(fw)