update inference

This commit is contained in:
jingyaogong
2025-04-05 12:03:04 +08:00
parent 7fcc46b39a
commit ed01c5d84a
2 changed files with 20 additions and 11 deletions
+2 -2
View File
@@ -120,7 +120,7 @@ def main():
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
parser.add_argument('--history_cnt', default=0, type=int)
parser.add_argument('--stream', default=True, type=bool)
parser.add_argument('--stream', default=False, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
@@ -154,7 +154,7 @@ def main():
max_new_tokens=args.max_seq_len,
temperature=args.temperature,
top_p=args.top_p,
stream=True,
stream=args.stream,
pad_token_id=tokenizer.pad_token_id
)
+18 -9
View File
@@ -4,7 +4,7 @@ import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
@@ -307,6 +307,7 @@ class MiniMindLM(PreTrainedModel):
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
@@ -320,7 +321,9 @@ class MiniMindLM(PreTrainedModel):
use_cache=use_cache
)
past_kvs.append(past_kv)
logits = self.output(self.norm(h))
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
@@ -329,7 +332,7 @@ class MiniMindLM(PreTrainedModel):
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
@@ -338,11 +341,13 @@ class MiniMindLM(PreTrainedModel):
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
@@ -350,7 +355,11 @@ class MiniMindLM(PreTrainedModel):
dim=-1)
for seq in generated
]
return torch.cat(generated, dim=0)
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0), num_return_sequences, -1)
res = res.squeeze(0) if input_ids.size(0) == 1 else res
res = res.squeeze(1) if num_return_sequences == 1 else res
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None