mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[update] minimind-3
This commit is contained in:
+14
-2
@@ -18,7 +18,7 @@ class LoRA(nn.Module):
|
||||
return self.B(self.A(x))
|
||||
|
||||
|
||||
def apply_lora(model, rank=8):
|
||||
def apply_lora(model, rank=16):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear) and module.weight.shape[0] == module.weight.shape[1]:
|
||||
lora = LoRA(module.weight.shape[0], module.weight.shape[1], rank=rank).to(model.device)
|
||||
@@ -48,6 +48,18 @@ def save_lora(model, path):
|
||||
for name, module in raw_model.named_modules():
|
||||
if hasattr(module, 'lora'):
|
||||
clean_name = name[7:] if name.startswith("module.") else name
|
||||
lora_state = {f'{clean_name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
|
||||
lora_state = {f'{clean_name}.lora.{k}': v.cpu().half() for k, v in module.lora.state_dict().items()}
|
||||
state_dict.update(lora_state)
|
||||
torch.save(state_dict, path)
|
||||
|
||||
|
||||
def merge_lora(model, lora_path, save_path):
|
||||
load_lora(model, lora_path)
|
||||
raw_model = getattr(model, '_orig_mod', model)
|
||||
state_dict = {k: v.cpu().half() for k, v in raw_model.state_dict().items() if '.lora.' not in k}
|
||||
for name, module in raw_model.named_modules():
|
||||
if isinstance(module, nn.Linear) and '.lora.' not in name:
|
||||
state_dict[f'{name}.weight'] = module.weight.data.clone().cpu().half()
|
||||
if hasattr(module, 'lora'):
|
||||
state_dict[f'{name}.weight'] += (module.lora.B.weight.data @ module.lora.A.weight.data).cpu().half()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
+135
-319
@@ -1,60 +1,33 @@
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
# MiniMind Config
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import math, torch, torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
|
||||
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||
# MiniMind Config
|
||||
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||
class MiniMindConfig(PretrainedConfig):
|
||||
model_type = "minimind"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float = 0.0,
|
||||
bos_token_id: int = 1,
|
||||
eos_token_id: int = 2,
|
||||
hidden_act: str = 'silu',
|
||||
hidden_size: int = 512,
|
||||
intermediate_size: int = None,
|
||||
max_position_embeddings: int = 32768,
|
||||
num_attention_heads: int = 8,
|
||||
num_hidden_layers: int = 8,
|
||||
num_key_value_heads: int = 2,
|
||||
vocab_size: int = 6400,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
rope_theta: int = 1000000.0,
|
||||
inference_rope_scaling: bool = False,
|
||||
flash_attn: bool = True,
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
# When use_moe is false, the following is invalid
|
||||
####################################################
|
||||
use_moe: bool = False,
|
||||
num_experts_per_tok: int = 2,
|
||||
n_routed_experts: int = 4,
|
||||
n_shared_experts: int = 1,
|
||||
scoring_func: str = 'softmax',
|
||||
aux_loss_alpha: float = 0.01,
|
||||
seq_aux: bool = True,
|
||||
norm_topk_prob: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, hidden_size=768, num_hidden_layers=8, use_moe=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.dropout = dropout
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.vocab_size = vocab_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.rope_theta = rope_theta
|
||||
self.inference_rope_scaling = inference_rope_scaling
|
||||
# 外推长度 = factor * original_max_position_embeddings = 32768
|
||||
self.use_moe = use_moe
|
||||
self.dropout = kwargs.get("dropout", 0.0)
|
||||
self.vocab_size = kwargs.get("vocab_size", 6400)
|
||||
self.bos_token_id = kwargs.get("bos_token_id", 1)
|
||||
self.eos_token_id = kwargs.get("eos_token_id", 2)
|
||||
self.flash_attn = kwargs.get("flash_attn", True)
|
||||
self.num_attention_heads = kwargs.get("num_attention_heads", 8)
|
||||
self.num_key_value_heads = kwargs.get("num_key_value_heads", 4)
|
||||
self.head_dim = kwargs.get("head_dim", self.hidden_size // self.num_attention_heads)
|
||||
self.hidden_act = kwargs.get("hidden_act", 'silu')
|
||||
self.intermediate_size = kwargs.get("intermediate_size", math.ceil(hidden_size * math.pi / 64) * 64)
|
||||
self.max_position_embeddings = kwargs.get("max_position_embeddings", 32768)
|
||||
self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
|
||||
self.rope_theta = kwargs.get("rope_theta", 1e6)
|
||||
self.inference_rope_scaling = kwargs.get("inference_rope_scaling", False)
|
||||
self.rope_scaling = {
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
@@ -63,301 +36,147 @@ class MiniMindConfig(PretrainedConfig):
|
||||
"attention_factor": 1.0,
|
||||
"type": "yarn"
|
||||
} if self.inference_rope_scaling else None
|
||||
self.flash_attn = flash_attn
|
||||
####################################################
|
||||
# Here are the specific configurations of MOE
|
||||
# When use_moe is false, the following is invalid
|
||||
####################################################
|
||||
self.use_moe = use_moe
|
||||
self.num_experts_per_tok = num_experts_per_tok # 每个token选择的专家数量
|
||||
self.n_routed_experts = n_routed_experts # 总的专家数量
|
||||
self.n_shared_experts = n_shared_experts # 共享专家
|
||||
self.scoring_func = scoring_func # 评分函数,默认为'softmax'
|
||||
self.aux_loss_alpha = aux_loss_alpha # 辅助损失的alpha参数
|
||||
self.seq_aux = seq_aux # 是否在序列级别上计算辅助损失
|
||||
self.norm_topk_prob = norm_topk_prob # 是否标准化top-k概率
|
||||
|
||||
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
# MiniMind Model
|
||||
# 📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘📘
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, Tuple, List, Union
|
||||
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
### MoE specific configs (ignored if use_moe = False)
|
||||
self.num_experts = kwargs.get("num_experts", 4)
|
||||
self.num_experts_per_tok = kwargs.get("num_experts_per_tok", 1)
|
||||
self.moe_intermediate_size = kwargs.get("moe_intermediate_size", self.intermediate_size)
|
||||
self.norm_topk_prob = kwargs.get("norm_topk_prob", True)
|
||||
self.router_aux_loss_coef = kwargs.get("router_aux_loss_coef", 5e-4)
|
||||
|
||||
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||
# MiniMind Model
|
||||
# 🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏🌎🌍🌏
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * self._norm(x.float()).type_as(x)
|
||||
return (self.weight * self.norm(x.float())).type_as(x)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
||||
rope_scaling: Optional[dict] = None):
|
||||
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
|
||||
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
|
||||
if rope_scaling is not None:
|
||||
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
||||
orig_max, factor, beta_fast, beta_slow, attn_factor = (
|
||||
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
|
||||
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
|
||||
)
|
||||
if end / orig_max > 1.0:
|
||||
# YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
|
||||
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
|
||||
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
|
||||
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
|
||||
freqs = freqs * (1 - ramp + ramp / factor)
|
||||
|
||||
t = torch.arange(end, device=freqs.device)
|
||||
freqs = torch.outer(t, freqs).float()
|
||||
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
|
||||
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
def rotate_half(x):
|
||||
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
|
||||
|
||||
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
|
||||
k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
|
||||
q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
|
||||
k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, num_key_value_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
if n_rep == 1: return x
|
||||
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: MiniMindConfig):
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
|
||||
assert args.num_attention_heads % self.num_key_value_heads == 0
|
||||
self.n_local_heads = args.num_attention_heads
|
||||
self.num_key_value_heads = config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
self.n_local_heads = config.num_attention_heads
|
||||
self.n_local_kv_heads = self.num_key_value_heads
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.hidden_size // args.num_attention_heads
|
||||
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
|
||||
self.attn_dropout = nn.Dropout(args.dropout)
|
||||
self.resid_dropout = nn.Dropout(args.dropout)
|
||||
self.dropout = args.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
|
||||
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
||||
self.head_dim = config.head_dim
|
||||
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.attn_dropout = nn.Dropout(config.dropout)
|
||||
self.resid_dropout = nn.Dropout(config.dropout)
|
||||
self.dropout = config.dropout
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and config.flash_attn
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 修改为接收cos和sin
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=False,
|
||||
attention_mask: Optional[torch.Tensor] = None):
|
||||
def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
|
||||
bsz, seq_len, _ = x.shape
|
||||
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = self.q_norm(xq), self.k_norm(xk)
|
||||
cos, sin = position_embeddings
|
||||
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
|
||||
|
||||
# kv_cache实现
|
||||
if past_key_value is not None:
|
||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||
xv = torch.cat([past_key_value[1], xv], dim=1)
|
||||
past_kv = (xk, xv) if use_cache else None
|
||||
|
||||
xq, xk, xv = (
|
||||
xq.transpose(1, 2),
|
||||
repeat_kv(xk, self.n_rep).transpose(1, 2),
|
||||
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
||||
)
|
||||
|
||||
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2))
|
||||
if self.flash and (seq_len > 1) and (past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
|
||||
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
||||
else:
|
||||
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
|
||||
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
|
||||
scores = scores + extended_attention_mask
|
||||
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
scores = self.attn_dropout(scores)
|
||||
output = scores @ xv
|
||||
|
||||
scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1)
|
||||
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9
|
||||
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv
|
||||
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
output = self.resid_dropout(self.o_proj(output))
|
||||
return output, past_kv
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
def __init__(self, config: MiniMindConfig, intermediate_size: int = None):
|
||||
super().__init__()
|
||||
if config.intermediate_size is None:
|
||||
intermediate_size = int(config.hidden_size * 8 / 3)
|
||||
config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
|
||||
self.scoring_func = config.scoring_func
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
if self.scoring_func == 'softmax':
|
||||
scores = logits.softmax(dim=-1)
|
||||
else:
|
||||
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
||||
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss,
|
||||
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
|
||||
seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
else:
|
||||
aux_loss = scores.new_zeros(1).squeeze()
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class MOEFeedForward(nn.Module):
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_routed_experts)
|
||||
])
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts > 0:
|
||||
self.shared_experts = nn.ModuleList([
|
||||
FeedForward(config)
|
||||
for _ in range(config.n_shared_experts)
|
||||
])
|
||||
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
||||
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
orig_shape = x.shape
|
||||
bsz, seq_len, _ = x.shape
|
||||
# 使用门控机制选择专家
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
|
||||
y = torch.empty_like(x, dtype=x.dtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
expert_out = expert(x[flat_topk_idx == i])
|
||||
if expert_out.shape[0] > 0: y[flat_topk_idx == i] = expert_out.to(y.dtype)
|
||||
else: y[flat_topk_idx == i] = expert_out.to(y.dtype) + 0 * sum(p.sum() for p in expert.parameters())
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape)
|
||||
batch_size, seq_len, hidden_dim = x.shape
|
||||
x_flat = x.view(-1, hidden_dim)
|
||||
scores = F.softmax(self.gate(x_flat), dim=-1)
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False)
|
||||
if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
|
||||
y = torch.zeros_like(x_flat)
|
||||
for i, expert in enumerate(self.experts):
|
||||
mask = (topk_idx == i)
|
||||
if mask.any():
|
||||
token_idx = mask.any(dim=-1).nonzero().flatten()
|
||||
weight = topk_weight[mask].view(-1, 1)
|
||||
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
|
||||
elif self.training:
|
||||
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())
|
||||
if self.training and self.config.router_aux_loss_coef > 0:
|
||||
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
|
||||
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
if self.config.n_shared_experts > 0:
|
||||
for expert in self.shared_experts:
|
||||
y = y + expert(identity)
|
||||
self.aux_loss = aux_loss
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
||||
token_idxs = idxs // self.config.num_experts_per_tok
|
||||
# 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4)
|
||||
# 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时
|
||||
# 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok)
|
||||
# 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推
|
||||
for i, end_idx in enumerate(tokens_per_expert):
|
||||
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
|
||||
if start_idx == end_idx:
|
||||
continue
|
||||
expert = self.experts[i]
|
||||
exp_token_idx = token_idxs[start_idx:end_idx]
|
||||
expert_tokens = x[exp_token_idx]
|
||||
expert_out = expert(expert_tokens).to(expert_cache.dtype)
|
||||
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
||||
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
|
||||
|
||||
return expert_cache
|
||||
|
||||
self.aux_loss = scores.new_zeros(1).squeeze()
|
||||
return y.view(batch_size, seq_len, hidden_dim)
|
||||
|
||||
class MiniMindBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.self_attn = Attention(config)
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||
@@ -372,7 +191,6 @@ class MiniMindBlock(nn.Module):
|
||||
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
|
||||
return hidden_states, present_key_value
|
||||
|
||||
|
||||
class MiniMindModel(nn.Module):
|
||||
def __init__(self, config: MiniMindConfig):
|
||||
super().__init__()
|
||||
@@ -382,33 +200,19 @@ class MiniMindModel(nn.Module):
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, config) for l in range(self.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
|
||||
end=config.max_position_embeddings, rope_base=config.rope_theta,
|
||||
rope_scaling=config.rope_scaling)
|
||||
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
|
||||
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
||||
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
**kwargs):
|
||||
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, **kwargs):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
if hasattr(past_key_values, 'layers'): past_key_values = None
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
||||
|
||||
hidden_states = self.dropout(self.embed_tokens(input_ids))
|
||||
|
||||
position_embeddings = (
|
||||
self.freqs_cos[start_pos:start_pos + seq_length],
|
||||
self.freqs_sin[start_pos:start_pos + seq_length]
|
||||
)
|
||||
|
||||
position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length])
|
||||
presents = []
|
||||
for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
|
||||
for layer, past_key_value in zip(self.layers, past_key_values):
|
||||
hidden_states, present = layer(
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
@@ -417,47 +221,59 @@ class MiniMindModel(nn.Module):
|
||||
attention_mask=attention_mask
|
||||
)
|
||||
presents.append(present)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
|
||||
return hidden_states, presents, aux_loss
|
||||
|
||||
|
||||
class MiniMindForCausalLM(PreTrainedModel, GenerationMixin):
|
||||
config_class = MiniMindConfig
|
||||
|
||||
def __init__(self, config: MiniMindConfig = None):
|
||||
self.config = config or MiniMindConfig()
|
||||
super().__init__(self.config)
|
||||
self.model = MiniMindModel(self.config)
|
||||
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
||||
self.model.embed_tokens.weight = self.lm_head.weight
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
hidden_states, past_key_values, aux_loss = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
**args
|
||||
)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0, labels=None, **kwargs):
|
||||
hidden_states, past_key_values, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache, **kwargs)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
|
||||
|
||||
output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||
output.aux_loss = aux_loss
|
||||
return output
|
||||
x, y = logits[..., :-1, :].contiguous(), labels[..., 1:].contiguous()
|
||||
loss = F.cross_entropy(x.view(-1, x.size(-1)), y.view(-1), ignore_index=-100)
|
||||
return MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
|
||||
|
||||
# https://github.com/jingyaogong/minimind/discussions/611
|
||||
@torch.inference_mode()
|
||||
def generate(self, inputs=None, attention_mask=None, max_new_tokens=8192, temperature=0.85, top_p=0.85, top_k=50, eos_token_id=2, streamer=None, use_cache=True, num_return_sequences=1, do_sample=True, repetition_penalty=1.0, **kwargs):
|
||||
input_ids = kwargs.pop("input_ids", inputs).repeat(num_return_sequences, 1)
|
||||
attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
|
||||
past_key_values = kwargs.pop("past_key_values", None)
|
||||
finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
|
||||
if streamer: streamer.put(input_ids.cpu())
|
||||
for _ in range(max_new_tokens):
|
||||
past_len = past_key_values[0][0].shape[1] if past_key_values else 0
|
||||
outputs = self.forward(input_ids[:, past_len:], attention_mask, past_key_values, use_cache=use_cache, **kwargs)
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[0], 1)], -1) if attention_mask is not None else None
|
||||
logits = outputs.logits[:, -1, :] / temperature
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(input_ids.shape[0]): logits[i, torch.unique(input_ids[i])] /= repetition_penalty
|
||||
if top_k > 0:
|
||||
logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float('inf')
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
|
||||
mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
|
||||
logits[mask.scatter(1, sorted_indices, mask)] = -float('inf')
|
||||
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
|
||||
if eos_token_id is not None: next_token = torch.where(finished.unsqueeze(-1), next_token.new_full((next_token.shape[0], 1), eos_token_id), next_token)
|
||||
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
||||
past_key_values = outputs.past_key_values if use_cache else None
|
||||
if streamer: streamer.put(next_token.cpu())
|
||||
if eos_token_id is not None:
|
||||
finished |= next_token.squeeze(-1).eq(eos_token_id)
|
||||
if finished.all(): break
|
||||
if streamer: streamer.end()
|
||||
if kwargs.get("return_kv"): return {'generated_ids': input_ids, 'past_kv': past_key_values}
|
||||
return input_ids
|
||||
+29931
-29766
File diff suppressed because it is too large
Load Diff
+332
-40
@@ -1,43 +1,335 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<|object_ref_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"4": {
|
||||
"content": "<|object_ref_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"5": {
|
||||
"content": "<|box_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"6": {
|
||||
"content": "<|box_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"7": {
|
||||
"content": "<|quad_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"8": {
|
||||
"content": "<|quad_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"9": {
|
||||
"content": "<|vision_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"10": {
|
||||
"content": "<|vision_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"11": {
|
||||
"content": "<|vision_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"12": {
|
||||
"content": "<|image_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"13": {
|
||||
"content": "<|video_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"14": {
|
||||
"content": "<|audio_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"15": {
|
||||
"content": "<|audio_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"16": {
|
||||
"content": "<|audio_pad|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"17": {
|
||||
"content": "<tts_pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"18": {
|
||||
"content": "<tts_text_bos>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"19": {
|
||||
"content": "<tts_text_eod>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"20": {
|
||||
"content": "<tts_text_bos_single>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"21": {
|
||||
"content": "<tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"22": {
|
||||
"content": "</tool_call>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"23": {
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"24": {
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"25": {
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"26": {
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"27": {
|
||||
"content": "<|buffer1|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"28": {
|
||||
"content": "<|buffer2|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"29": {
|
||||
"content": "<|buffer3|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"30": {
|
||||
"content": "<|buffer4|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"31": {
|
||||
"content": "<|buffer5|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"32": {
|
||||
"content": "<|buffer6|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"33": {
|
||||
"content": "<|buffer7|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"34": {
|
||||
"content": "<|buffer8|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
},
|
||||
"35": {
|
||||
"content": "<|buffer9|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"1": {
|
||||
"content": "<|im_start|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<|im_end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<|im_start|>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"legacy": true,
|
||||
"model_max_length": 32768,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "PreTrainedTokenizerFast",
|
||||
"unk_token": "<|endoftext|>",
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|object_ref_start|>",
|
||||
"<|object_ref_end|>",
|
||||
"<|box_start|>",
|
||||
"<|box_end|>",
|
||||
"<|quad_start|>",
|
||||
"<|quad_end|>",
|
||||
"<|vision_start|>",
|
||||
"<|vision_end|>",
|
||||
"<|vision_pad|>",
|
||||
"<|image_pad|>",
|
||||
"<|video_pad|>",
|
||||
"<|audio_start|>",
|
||||
"<|audio_end|>",
|
||||
"<|audio_pad|>",
|
||||
"<tts_pad>",
|
||||
"<tts_text_bos>",
|
||||
"<tts_text_eod>",
|
||||
"<tts_text_bos_single>"
|
||||
],
|
||||
"bos_token": "<|im_start|>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "<|im_end|>",
|
||||
"legacy": true,
|
||||
"model_max_length": 131072,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"unk_token": "<|endoftext|>",
|
||||
"image_token": "<|image_pad|>",
|
||||
"audio_token": "<|audio_pad|>",
|
||||
"video_token": "<|video_pad|>",
|
||||
"vision_bos_token": "<|vision_start|>",
|
||||
"vision_eos_token": "<|vision_end|>",
|
||||
"audio_bos_token": "<|audio_start|>",
|
||||
"audio_eos_token": "<|audio_end|>",
|
||||
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if true %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if open_thinking is defined and open_thinking is true %}\n {{- '<think>\\n' }}\n {%- else %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
||||
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||
}
|
||||
Reference in New Issue
Block a user