mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[fix] update model
This commit is contained in:
parent
36159fb2ab
commit
4e35fb9da8
@ -29,7 +29,7 @@ def init_model(args):
|
||||
apply_lora(model)
|
||||
load_lora(model, f'./{args.out_dir}/lora/{args.lora_name}_{args.hidden_size}.pth')
|
||||
else:
|
||||
transformers_model_path = './MiniMind2-MoE'
|
||||
transformers_model_path = './MiniMind/MiniMind2'
|
||||
tokenizer = AutoTokenizer.from_pretrained(transformers_model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True)
|
||||
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
|
||||
54
model/model_minimind.py
Normal file → Executable file
54
model/model_minimind.py
Normal file → Executable file
@ -23,6 +23,7 @@ class MiniMindConfig(PretrainedConfig):
|
||||
vocab_size: int = 6400,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
rope_theta: int = 1000000.0,
|
||||
inference_rope_scaling: bool = False,
|
||||
flash_attn: bool = True,
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
@ -52,6 +53,15 @@ class MiniMindConfig(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.rope_theta = rope_theta
|
||||
self.inference_rope_scaling = inference_rope_scaling
|
||||
# 外推长度 = factor * original_max_position_embeddings
|
||||
self.rope_scaling = {
|
||||
"beta_fast": 4,
|
||||
"beta_slow": 1,
|
||||
"factor": 4,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"type": "yarn"
|
||||
} if self.inference_rope_scaling else None
|
||||
self.flash_attn = flash_attn
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
@ -73,10 +83,11 @@ class MiniMindConfig(PretrainedConfig):
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, Tuple, List, Union
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
@ -94,8 +105,22 @@ class RMSNorm(torch.nn.Module):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
||||
rope_scaling: Optional[dict] = None):
|
||||
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
if rope_scaling is not None:
|
||||
orig_max, factor, beta_fast, beta_slow = (
|
||||
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4),
|
||||
rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0)
|
||||
)
|
||||
if end / orig_max > 1.0:
|
||||
corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
|
||||
power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
|
||||
beta = beta_slow + (beta_fast - beta_slow) * power
|
||||
# λ = (β·α - β + 1)/(β·α) YaRN标准公式
|
||||
scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor)
|
||||
freqs = freqs * scale
|
||||
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
|
||||
@ -118,9 +143,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, num_key_value_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
@ -170,14 +193,14 @@ class Attention(nn.Module):
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
if self.flash and seq_len != 1:
|
||||
dropout_p = self.dropout if self.training else 0.0
|
||||
attn_mask = None
|
||||
if attention_mask is not None:
|
||||
attn_mask = attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1)
|
||||
attn_mask = attn_mask.bool() if attention_mask is not None else None
|
||||
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
attn_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
|
||||
)
|
||||
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True)
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores = scores + torch.triu(
|
||||
@ -232,7 +255,6 @@ class MoEGate(nn.Module):
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -369,7 +391,8 @@ class MiniMindModel(nn.Module):
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
|
||||
end=config.max_position_embeddings, theta=config.rope_theta)
|
||||
end=config.max_position_embeddings, rope_base=config.rope_theta,
|
||||
rope_scaling=config.rope_scaling)
|
||||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||
|
||||
@ -380,6 +403,7 @@ class MiniMindModel(nn.Module):
|
||||
use_cache: bool = False,
|
||||
**kwargs):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
if hasattr(past_key_values, 'layers'): past_key_values = None
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
||||
|
||||
|
||||
@ -39,5 +39,5 @@
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<|endoftext|>",
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user