From 4e35fb9da8aa9de05704ef146656584f9e916a69 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Fri, 17 Oct 2025 00:09:32 +0800 Subject: [PATCH] [fix] update model --- eval_llm.py => eval_model.py | 2 +- model/model_minimind.py | 54 ++++++++++++++++++++++++++---------- model/tokenizer_config.json | 2 +- 3 files changed, 41 insertions(+), 17 deletions(-) rename eval_llm.py => eval_model.py (99%) mode change 100644 => 100755 model/model_minimind.py diff --git a/eval_llm.py b/eval_model.py similarity index 99% rename from eval_llm.py rename to eval_model.py index a41b331..079bbb0 100755 --- a/eval_llm.py +++ b/eval_model.py @@ -29,7 +29,7 @@ def init_model(args): apply_lora(model) load_lora(model, f'./{args.out_dir}/lora/{args.lora_name}_{args.hidden_size}.pth') else: - transformers_model_path = './MiniMind2-MoE' + transformers_model_path = './MiniMind/MiniMind2' tokenizer = AutoTokenizer.from_pretrained(transformers_model_path) model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True) print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)') diff --git a/model/model_minimind.py b/model/model_minimind.py old mode 100644 new mode 100755 index 4b43610..ecd99b6 --- a/model/model_minimind.py +++ b/model/model_minimind.py @@ -23,6 +23,7 @@ class MiniMindConfig(PretrainedConfig): vocab_size: int = 6400, rms_norm_eps: float = 1e-05, rope_theta: int = 1000000.0, + inference_rope_scaling: bool = False, flash_attn: bool = True, #################################################### # Here are the specific configurations of MOE @@ -52,6 +53,15 @@ class MiniMindConfig(PretrainedConfig): self.vocab_size = vocab_size self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta + self.inference_rope_scaling = inference_rope_scaling + # 外推长度 = factor * original_max_position_embeddings + self.rope_scaling = { + "beta_fast": 4, + "beta_slow": 1, + "factor": 4, + "original_max_position_embeddings": 2048, + "type": "yarn" + } if self.inference_rope_scaling else None self.flash_attn = flash_attn #################################################### # Here are the specific configurations of MOE @@ -73,10 +83,11 @@ class MiniMindConfig(PretrainedConfig): import math import torch +import torch.nn.init as init +import torch.nn.functional as F from torch import nn from transformers.activations import ACT2FN from typing import Optional, Tuple, List, Union -import torch.nn.functional as F from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast @@ -94,8 +105,22 @@ class RMSNorm(torch.nn.Module): return self.weight * self._norm(x.float()).type_as(x) -def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) +def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, + rope_scaling: Optional[dict] = None): + freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + if rope_scaling is not None: + orig_max, factor, beta_fast, beta_slow = ( + rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4), + rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0) + ) + if end / orig_max > 1.0: + corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2) + power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1) + beta = beta_slow + (beta_fast - beta_slow) * power + # λ = (β·α - β + 1)/(β·α) YaRN标准公式 + scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor) + freqs = freqs * scale + t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) @@ -118,9 +143,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: if n_rep == 1: return x return ( - x[:, :, :, None, :] - .expand(bs, slen, num_key_value_heads, n_rep, head_dim) - .reshape(bs, slen, num_key_value_heads * n_rep, head_dim) + x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim) ) @@ -170,14 +193,14 @@ class Attention(nn.Module): repeat_kv(xv, self.n_rep).transpose(1, 2) ) - if self.flash and seq_len != 1: - dropout_p = self.dropout if self.training else 0.0 - attn_mask = None - if attention_mask is not None: - attn_mask = attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1) - attn_mask = attn_mask.bool() if attention_mask is not None else None + if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)): + attn_mask = ( + None + if attention_mask is None + else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool() + ) - output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True) + output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores = scores + torch.triu( @@ -232,7 +255,6 @@ class MoEGate(nn.Module): self.reset_parameters() def reset_parameters(self) -> None: - import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): @@ -369,7 +391,8 @@ class MiniMindModel(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads, - end=config.max_position_embeddings, theta=config.rope_theta) + end=config.max_position_embeddings, rope_base=config.rope_theta, + rope_scaling=config.rope_scaling) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) @@ -380,6 +403,7 @@ class MiniMindModel(nn.Module): use_cache: bool = False, **kwargs): batch_size, seq_length = input_ids.shape + if hasattr(past_key_values, 'layers'): past_key_values = None past_key_values = past_key_values or [None] * len(self.layers) start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 diff --git a/model/tokenizer_config.json b/model/tokenizer_config.json index 1509bad..fc4e726 100644 --- a/model/tokenizer_config.json +++ b/model/tokenizer_config.json @@ -39,5 +39,5 @@ "spaces_between_special_tokens": false, "tokenizer_class": "PreTrainedTokenizerFast", "unk_token": "<|endoftext|>", - "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}" + "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}" } \ No newline at end of file