mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
Merge 4fd7b5a65b into 05d0b216f6
This commit is contained in:
commit
fe3dde125d
8
.gitignore
vendored
8
.gitignore
vendored
@ -1,4 +1,10 @@
|
||||
model/__pycache__
|
||||
out
|
||||
website/
|
||||
docs-minimind/
|
||||
docs-minimind/
|
||||
cli.txt
|
||||
dataset/out/
|
||||
dataset/__pycache__
|
||||
dataset/corpus/__pycache__
|
||||
dataset/scrapers/__pycache__
|
||||
dataset/utils/__pycache__
|
||||
0
dataset/corpus/__init__.py
Normal file
0
dataset/corpus/__init__.py
Normal file
103
dataset/corpus/build_corpus.py
Normal file
103
dataset/corpus/build_corpus.py
Normal file
@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Iterable, List
|
||||
from tqdm import tqdm
|
||||
|
||||
if __package__ is None or __package__ == "":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
repo_root = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir))
|
||||
if repo_root not in sys.path:
|
||||
sys.path.insert(0, repo_root)
|
||||
|
||||
from dataset.corpus.sources import multiplex_sources
|
||||
from dataset.corpus.cleaning import Deduper, clean_record
|
||||
from dataset.utils.util import write_jsonl
|
||||
|
||||
|
||||
def _aggregate_text(rec: Dict) -> str:
|
||||
parts: List[str] = []
|
||||
for k, v in rec.items():
|
||||
if isinstance(v, str):
|
||||
if v.strip():
|
||||
parts.append(v)
|
||||
elif isinstance(v, list):
|
||||
# flatten list of strings (best-effort)
|
||||
parts.extend([str(x) for x in v if isinstance(x, (str, int, float)) and str(x).strip()])
|
||||
elif isinstance(v, (int, float)):
|
||||
parts.append(str(v))
|
||||
text = " \n".join(parts)
|
||||
# cap extremely long text to avoid excessive memory
|
||||
if len(text) > 20000:
|
||||
text = text[:20000]
|
||||
return text
|
||||
|
||||
|
||||
def shard_records(records: Iterable[Dict], shard_size: int, out_dir: str, prefix: str):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
buf: List[Dict] = []
|
||||
shard_id = 0
|
||||
for rec in records:
|
||||
if not rec:
|
||||
continue
|
||||
buf.append(rec)
|
||||
if len(buf) >= shard_size:
|
||||
out_path = os.path.join(out_dir, f"{prefix}-{shard_id:05d}.jsonl")
|
||||
write_jsonl(out_path, buf)
|
||||
buf.clear()
|
||||
shard_id += 1
|
||||
if buf:
|
||||
out_path = os.path.join(out_dir, f"{prefix}-{shard_id:05d}.jsonl")
|
||||
write_jsonl(out_path, buf)
|
||||
|
||||
|
||||
def build_pipeline(target_lang: str = "zh", max_items: int = 5000, shard_size: int = 1000, out_dir: str = "dataset/out/corpus_out"):
|
||||
deduper = Deduper(threshold=0.88)
|
||||
|
||||
def cleaned_stream():
|
||||
for idx, rec in enumerate(tqdm(multiplex_sources(max_items=max_items, language=target_lang), total=max_items, desc=f"ingest[{target_lang}]")):
|
||||
key = f"{rec.get('source','unknown')}:{rec.get('title','')}-{idx}"
|
||||
agg_text = _aggregate_text(rec)
|
||||
if not agg_text:
|
||||
continue
|
||||
# Use cleaning on the aggregated text only for filtering, not for reformatting
|
||||
tmp = {"text": agg_text, "source": rec.get("source"), "title": rec.get("title")}
|
||||
cleaned = clean_record(tmp, target_lang=target_lang)
|
||||
if not cleaned:
|
||||
continue
|
||||
if deduper.is_duplicate(key, agg_text):
|
||||
continue
|
||||
# Keep original record shape as requested
|
||||
yield rec
|
||||
|
||||
count_in = 0
|
||||
count_out = 0
|
||||
records_iter = cleaned_stream()
|
||||
# Materialize in small chunks to measure counts
|
||||
chunk: List[Dict] = []
|
||||
for rec in tqdm(records_iter, desc="clean+dedup"):
|
||||
count_in += 1
|
||||
chunk.append(rec)
|
||||
if len(chunk) >= shard_size:
|
||||
shard_records(chunk, shard_size=shard_size, out_dir=out_dir, prefix=f"corpus-{target_lang}")
|
||||
count_out += len(chunk)
|
||||
chunk.clear()
|
||||
if chunk:
|
||||
shard_records(chunk, shard_size=shard_size, out_dir=out_dir, prefix=f"corpus-{target_lang}")
|
||||
count_out += len(chunk)
|
||||
print(f"[build_corpus] input_records={count_in} written_records={count_out} out_dir={out_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Build cleaned corpus from open sources.")
|
||||
parser.add_argument("--lang", default="en", help="Target language code, e.g., zh, en")
|
||||
parser.add_argument("--max-items", type=int, default=5000, help="Max items to ingest across sources")
|
||||
parser.add_argument("--shard-size", type=int, default=1000, help="Number of records per shard")
|
||||
parser.add_argument("--out-dir", default="dataset/out/corpus_out", help="Output directory for JSONL shards")
|
||||
args = parser.parse_args()
|
||||
|
||||
build_pipeline(target_lang=args.lang, max_items=args.max_items, shard_size=args.shard_size, out_dir=args.out_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
90
dataset/corpus/cleaning.py
Normal file
90
dataset/corpus/cleaning.py
Normal file
@ -0,0 +1,90 @@
|
||||
import re
|
||||
from typing import Dict, Iterable, List, Set
|
||||
|
||||
import ftfy
|
||||
from langdetect import detect, DetectorFactory
|
||||
from datasketch import MinHash, MinHashLSH
|
||||
|
||||
|
||||
DetectorFactory.seed = 0
|
||||
|
||||
|
||||
_CONTROL_CHARS = re.compile(r"[\u0000-\u001F\u007F]")
|
||||
_MULTI_SPACE = re.compile(r"\s{2,}")
|
||||
_URL = re.compile(r"https?://\S+")
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Basic normalization: fix encoding, strip control chars, canonical whitespace."""
|
||||
text = ftfy.fix_text(text)
|
||||
text = _CONTROL_CHARS.sub(" ", text)
|
||||
text = text.replace("\t", " ").replace("\r", " ")
|
||||
text = _MULTI_SPACE.sub(" ", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def is_language(text: str, lang: str = "zh", min_chars: int = 64) -> bool:
|
||||
"""Heuristic language detection using langdetect.
|
||||
Require minimal length to avoid misclassification.
|
||||
"""
|
||||
if len(text) < min_chars:
|
||||
return False
|
||||
try:
|
||||
return detect(text) == lang
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def filter_text(text: str) -> bool:
|
||||
"""Simple quality filters: drop extremely short, URL-only, code-only content."""
|
||||
if len(text) < 64:
|
||||
return False
|
||||
if len(text) < 256 and _URL.fullmatch(text):
|
||||
return False
|
||||
# Drop text with excessive symbols/noise
|
||||
noise_ratio = sum(ch in "{}[]<>|~^`" for ch in text) / max(1, len(text))
|
||||
if noise_ratio > 0.2:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def build_minhash(text: str, num_perm: int = 128) -> MinHash:
|
||||
mh = MinHash(num_perm=num_perm)
|
||||
# Use 5-gram shingles
|
||||
for i in range(max(0, len(text) - 4)):
|
||||
shingle = text[i : i + 5]
|
||||
mh.update(shingle.encode("utf-8", errors="ignore"))
|
||||
return mh
|
||||
|
||||
|
||||
class Deduper:
|
||||
def __init__(self, threshold: float = 0.85, num_perm: int = 128):
|
||||
self.lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
||||
self.num_perm = num_perm
|
||||
self._keys: Set[str] = set()
|
||||
|
||||
def is_duplicate(self, key: str, text: str) -> bool:
|
||||
mh = build_minhash(text, num_perm=self.num_perm)
|
||||
if key in self._keys:
|
||||
return True
|
||||
if self.lsh.query(mh):
|
||||
return True
|
||||
self.lsh.insert(key, mh)
|
||||
self._keys.add(key)
|
||||
return False
|
||||
|
||||
|
||||
def clean_record(rec: Dict, target_lang: str = "zh") -> Dict:
|
||||
"""Clean a single record: normalize, filter, language check."""
|
||||
text = normalize_text(rec.get("text", ""))
|
||||
if not filter_text(text):
|
||||
return {}
|
||||
# Prefer records that are in target language, but allow mixed content in Wikipedia
|
||||
if not is_language(text, target_lang):
|
||||
return {}
|
||||
cleaned = {
|
||||
"source": rec.get("source"),
|
||||
"title": rec.get("title"),
|
||||
"text": text,
|
||||
}
|
||||
return cleaned
|
||||
43
dataset/corpus/sources.py
Normal file
43
dataset/corpus/sources.py
Normal file
@ -0,0 +1,43 @@
|
||||
import time
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def load_open_corpora(split: str = "train") -> Iterable[Dict]:
|
||||
"""Unified loader: iterate over a list of HF datasets and yield raw rows.
|
||||
|
||||
Keeps original keys. Adds a 'source' field if missing for provenance.
|
||||
"""
|
||||
datasets_to_try: List[str] = [
|
||||
# Wikipedia
|
||||
"wikimedia/wikipedia",
|
||||
# arXiv (abstract-related)
|
||||
"ccdv/arxiv-summarization",
|
||||
"MMInstruction/ArxivQA",
|
||||
# StackOverflow
|
||||
"DmitriyGA/DPO-StackOverflow",
|
||||
]
|
||||
for ds_name in datasets_to_try:
|
||||
try:
|
||||
ds = load_dataset(ds_name, split=split, trust_remote_code=True)
|
||||
except Exception:
|
||||
continue
|
||||
for row in ds:
|
||||
rec = dict(row)
|
||||
if "source" not in rec:
|
||||
rec["source"] = ds_name
|
||||
yield rec
|
||||
|
||||
|
||||
def multiplex_sources(max_items: Optional[int] = None, **kwargs) -> Iterable[Dict]:
|
||||
"""Multiplex multiple open sources in a single iterator.
|
||||
|
||||
kwargs can carry filters like language or categories.
|
||||
"""
|
||||
count = 0
|
||||
for item in load_open_corpora(split=kwargs.get("split", "train")):
|
||||
yield item
|
||||
count += 1
|
||||
if max_items and count >= max_items:
|
||||
return
|
||||
@ -1,5 +0,0 @@
|
||||
# MiniMind Datasets
|
||||
|
||||
将所有下载的数据集文件放置到当前目录.
|
||||
|
||||
Place the downloaded dataset file in the current directory.
|
||||
0
dataset/scrapers/__init__.py
Normal file
0
dataset/scrapers/__init__.py
Normal file
114
dataset/scrapers/core.py
Normal file
114
dataset/scrapers/core.py
Normal file
@ -0,0 +1,114 @@
|
||||
import time
|
||||
import re
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
USER_AGENT = "MiniMindScraper/0.1 (+https://github.com/DiracSeas/minimind)"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScrapeResult:
|
||||
url: str
|
||||
title: Optional[str]
|
||||
text: str
|
||||
source: str
|
||||
|
||||
|
||||
class RobotsCache:
|
||||
def __init__(self):
|
||||
self.cache: Dict[str, Optional[requests.Response]] = {}
|
||||
|
||||
def allowed(self, base: str, path: str) -> bool:
|
||||
# Simple robots.txt respect: fetch robots once and deny disallowed paths via regex
|
||||
# For robust compliance, integrate robotexclusionrulesparser; keep lightweight here.
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(base)
|
||||
robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt"
|
||||
if robots_url not in self.cache:
|
||||
resp = requests.get(robots_url, timeout=10, headers={"User-Agent": USER_AGENT})
|
||||
self.cache[robots_url] = resp if resp.status_code == 200 else None
|
||||
resp = self.cache.get(robots_url)
|
||||
if not resp or not resp.text:
|
||||
return True
|
||||
# naive block: lines like "Disallow: /path" for all agents
|
||||
disallows = []
|
||||
for line in resp.text.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if line.lower().startswith("user-agent:"):
|
||||
# ignore agent filters: act as generic agent
|
||||
continue
|
||||
if line.lower().startswith("disallow:"):
|
||||
rule = line.split(":", 1)[1].strip()
|
||||
disallows.append(rule)
|
||||
for rule in disallows:
|
||||
if rule and path.startswith(rule):
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
class Scraper:
|
||||
def __init__(self, base_url: str, rate_limit_sec: float = 1.0):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.rate_limit_sec = rate_limit_sec
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"User-Agent": USER_AGENT})
|
||||
self.robots = RobotsCache()
|
||||
self._last_fetch = 0.0
|
||||
|
||||
def _throttle(self):
|
||||
now = time.time()
|
||||
wait = self.rate_limit_sec - (now - self._last_fetch)
|
||||
if wait > 0:
|
||||
time.sleep(wait)
|
||||
self._last_fetch = time.time()
|
||||
|
||||
def fetch(self, url_path: str) -> Optional[requests.Response]:
|
||||
if not url_path.startswith("/"):
|
||||
url_path = "/" + url_path
|
||||
if not self.robots.allowed(self.base_url, url_path):
|
||||
return None
|
||||
self._throttle()
|
||||
full_url = f"{self.base_url}{url_path}"
|
||||
try:
|
||||
resp = self.session.get(full_url, timeout=20, allow_redirects=True)
|
||||
ctype = resp.headers.get("Content-Type", "")
|
||||
if resp.status_code == 200 and ("text/html" in ctype or "text" in ctype or ctype == ""):
|
||||
return resp
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def extract_text(html: str) -> Dict[str, Optional[str]]:
|
||||
soup = BeautifulSoup(html, "lxml")
|
||||
# Remove script/style
|
||||
for tag in soup(["script", "style", "noscript"]):
|
||||
tag.decompose()
|
||||
title = None
|
||||
if soup.title and soup.title.string:
|
||||
title = soup.title.string.strip()
|
||||
# Prefer article-like containers if present
|
||||
article = soup.find("article") or soup.find(id=re.compile("content|main", re.I))
|
||||
body = article.get_text(" ", strip=True) if article else soup.get_text(" ", strip=True)
|
||||
return {"title": title, "text": body}
|
||||
|
||||
def crawl_paths(self, paths: Iterable[str], source_name: str) -> Iterable[ScrapeResult]:
|
||||
for p in paths:
|
||||
resp = self.fetch(p)
|
||||
if not resp:
|
||||
continue
|
||||
extracted = self.extract_text(resp.text)
|
||||
text = extracted.get("text") or ""
|
||||
if len(text) < 64:
|
||||
continue
|
||||
yield ScrapeResult(url=f"{self.base_url}{p}", title=extracted.get("title"), text=text, source=source_name)
|
||||
|
||||
45
dataset/scrapers/run_scraper.py
Normal file
45
dataset/scrapers/run_scraper.py
Normal file
@ -0,0 +1,45 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
if __package__ is None or __package__ == "":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
repo_root = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir))
|
||||
if repo_root not in sys.path:
|
||||
sys.path.insert(0, repo_root)
|
||||
|
||||
from dataset.scrapers.core import Scraper
|
||||
from dataset.utils.util import write_jsonl
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run targeted web scraping respecting robots and rate limits.")
|
||||
parser.add_argument("--base-url", default="https://github.com/jingyaogong/minimind", help="Base site URL, e.g., https://example.com")
|
||||
parser.add_argument("--paths", default='/', help="Comma-separated relative paths to crawl, e.g., /,/posts,/about")
|
||||
parser.add_argument("--urls", default='', help="Comma-separated full URLs to crawl, e.g., https://example.com/page1,https://example.com/page2")
|
||||
parser.add_argument("--rate", type=float, default=1.0, help="Rate limit per request in seconds")
|
||||
parser.add_argument("--out", default="dataset/out/scraper_out/scraped.jsonl", help="Output JSONL path")
|
||||
parser.add_argument("--source", default="web", help="Source label for records")
|
||||
args = parser.parse_args()
|
||||
|
||||
scraper = Scraper(base_url=args.base_url, rate_limit_sec=args.rate)
|
||||
records_iter = []
|
||||
paths = [p.strip() for p in args.paths.split(",") if p.strip()]
|
||||
if paths:
|
||||
records_iter.append(scraper.crawl_paths(paths, source_name=args.source))
|
||||
|
||||
|
||||
def recs():
|
||||
for it in records_iter:
|
||||
for s in it:
|
||||
yield {"source": s.source, "title": s.title, "text": s.text, "url": s.url}
|
||||
|
||||
write_jsonl(args.out, recs())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# python dataset/scrapers/run_scraper.py \
|
||||
# --base-url https://en.wikipedia.org/wiki/Nvidia \
|
||||
# --paths / \
|
||||
# --out dataset/dataset_out/scraped.jsonl
|
||||
12
dataset/utils/util.py
Normal file
12
dataset/utils/util.py
Normal file
@ -0,0 +1,12 @@
|
||||
from typing import Dict, Iterable, List
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def write_jsonl(path: str, records: Iterable[Dict]):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
for rec in records:
|
||||
if not rec:
|
||||
continue
|
||||
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
||||
@ -28,4 +28,13 @@ streamlit==1.50.0
|
||||
einops==0.8.1
|
||||
swanlab==0.6.8
|
||||
torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
torchvision==0.21.0
|
||||
ftfy==6.3.0
|
||||
langdetect==1.0.9
|
||||
wikipedia-api==0.6.0
|
||||
requests==2.32.3
|
||||
beautifulsoup4==4.12.3
|
||||
lxml==5.3.0
|
||||
ftfy==6.3.0
|
||||
langdetect==1.0.9
|
||||
wikipedia-api==0.6.0
|
||||
243
scripts/inference.py
Normal file
243
scripts/inference.py
Normal file
@ -0,0 +1,243 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from peft import PeftModel
|
||||
from model.model_minimind import MiniMindConfig, MiniMindForCausalLM
|
||||
from model.model_lora import apply_lora, load_lora
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Universal inference: CPU, single-GPU, single-node multi-GPU"
|
||||
)
|
||||
parser.add_argument("--model-path", default="/mnt/share/djx/minimind/out", help="Path or HF repo id of the model")
|
||||
parser.add_argument("--dtype", default="auto", choices=["auto", "fp32", "fp16", "bf16"], help="Model dtype")
|
||||
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"], help="Select device kind")
|
||||
parser.add_argument("--num-gpus", type=int, default=0, help="Number of GPUs to use; 0 means CPU. If >0 and device=cuda, uses device_map=auto")
|
||||
parser.add_argument("--prompt", default=None, help="Prompt text; if omitted, read from stdin or --input-file")
|
||||
parser.add_argument("--input-file", default=None, help="Path to a text file; one prompt per line. Ignored if --prompt provided.")
|
||||
parser.add_argument("--output-file", default=None, help="Write generations to file. When --input-file is set, outputs JSONL with fields {prompt, generation}.")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=128, help="Max new tokens to generate")
|
||||
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p nucleus sampling")
|
||||
parser.add_argument("--do-sample", action="store_true", help="Enable sampling; by default greedy if not set")
|
||||
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for parallel generation. Effective on GPU and multi-GPU with device_map=auto.")
|
||||
parser.add_argument("--use-lora", action="store_true", help="Enable LoRA adapter loading (requires PEFT)")
|
||||
parser.add_argument("--lora-path", default=None, help="Path to LoRA adapter weights (PEFT). If not set, --use-lora is ignored.")
|
||||
# Align with eval_llm.py for native MiniMind loading
|
||||
parser.add_argument('--save-dir', default='out', type=str, help="Model weights directory (native MiniMind)")
|
||||
parser.add_argument('--weight', default='full_sft', type=str, help="Weight prefix (pretrain, full_sft, rlhf, reason, ppo_actor, grpo, spo)")
|
||||
parser.add_argument('--lora-weight', default='None', type=str, help="LoRA weight name for native MiniMind (None to disable)")
|
||||
parser.add_argument('--hidden-size', default=512, type=int, help="Hidden size (512=Small-26M, 640=MoE-145M, 768=Base-104M)")
|
||||
parser.add_argument('--num-hidden-layers', default=8, type=int, help="Number of hidden layers (Small/MoE=8, Base=16)")
|
||||
parser.add_argument('--use-moe', default=0, type=int, choices=[0, 1], help="Use MoE architecture (0/1)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def select_dtype(dtype_str: str):
|
||||
if dtype_str == "auto":
|
||||
return None
|
||||
if dtype_str == "fp32":
|
||||
return torch.float32
|
||||
if dtype_str == "fp16":
|
||||
return torch.float16
|
||||
if dtype_str == "bf16":
|
||||
return torch.bfloat16 if hasattr(torch, "bfloat16") else torch.float16
|
||||
return None
|
||||
|
||||
|
||||
def build_device_map(device: str, num_gpus: int):
|
||||
if device == "cpu" or (device == "auto" and not torch.cuda.is_available()):
|
||||
return None # no device_map; model stays on CPU
|
||||
if num_gpus <= 1:
|
||||
return {"": 0} # place all on cuda:0
|
||||
# multi-GPU: let transformers shard automatically across available GPUs
|
||||
return "auto"
|
||||
|
||||
|
||||
def load_model_and_tokenizer(args, dtype_opt, device_map_opt, accelerator: Optional[Accelerator] = None):
|
||||
# Align behavior with eval_llm.py
|
||||
tok = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)
|
||||
# Native MiniMind path if 'model' in path
|
||||
if 'out' in args.model_path:
|
||||
print("Loading native MiniMind model...")
|
||||
model = MiniMindForCausalLM(MiniMindConfig(
|
||||
hidden_size=args.hidden_size,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
use_moe=bool(args.use_moe),
|
||||
))
|
||||
moe_suffix = '_moe' if args.use_moe else ''
|
||||
ckp = f'{args.model_path}/{args.weight}_{args.hidden_size}{moe_suffix}.pth'
|
||||
state = torch.load(ckp, map_location=args.device)
|
||||
model.load_state_dict(state, strict=True)
|
||||
if args.lora_weight != 'None':
|
||||
apply_lora(model)
|
||||
load_lora(model, f'./{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth')
|
||||
else:
|
||||
kwargs = {}
|
||||
if dtype_opt is not None:
|
||||
kwargs["torch_dtype"] = dtype_opt
|
||||
if device_map_opt is not None:
|
||||
kwargs["device_map"] = device_map_opt
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True, **kwargs)
|
||||
# PEFT LoRA only if proper adapter folder provided
|
||||
if args.use_lora and args.lora_path:
|
||||
try:
|
||||
model = PeftModel.from_pretrained(model, args.lora_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load PEFT LoRA at '{args.lora_path}': {e}")
|
||||
|
||||
if accelerator is not None:
|
||||
model = accelerator.prepare(model)
|
||||
model.eval()
|
||||
return model, tok
|
||||
|
||||
|
||||
def get_prompts(args) -> List[str]:
|
||||
if args.prompt is not None:
|
||||
return [args.prompt]
|
||||
if args.input_file:
|
||||
with open(args.input_file, "r", encoding="utf-8") as f:
|
||||
lines = [ln.strip() for ln in f.readlines()]
|
||||
return [ln for ln in lines if ln]
|
||||
# read from stdin (single prompt)
|
||||
print("Enter prompt. Press Ctrl-D to finish:")
|
||||
return [sys.stdin.read().strip()]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.device == "cpu" or (args.device == "auto" and not torch.cuda.is_available()):
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
|
||||
# Optional Accelerate setup
|
||||
accelerator: Optional[Accelerator] = None
|
||||
if device == "cuda" and args.num_gpus > 1:
|
||||
accelerator = Accelerator()
|
||||
|
||||
dtype = select_dtype(args.dtype)
|
||||
device_map = build_device_map(device, args.num_gpus)
|
||||
model, tok = load_model_and_tokenizer(
|
||||
args,
|
||||
dtype,
|
||||
device_map,
|
||||
accelerator,
|
||||
)
|
||||
|
||||
if accelerator is None and device == "cuda" and device_map in (None, {"": 0}):
|
||||
# single-GPU: move to cuda:0 if not already via device_map
|
||||
model.to("cuda")
|
||||
|
||||
prompts = get_prompts(args)
|
||||
gen_kwargs = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"do_sample": args.do_sample,
|
||||
}
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
outputs: List[str] = []
|
||||
bs = max(1, args.batch_size)
|
||||
|
||||
if accelerator is None:
|
||||
# Single-process path with optional multi-GPU device_map
|
||||
for i in range(0, len(prompts), bs):
|
||||
batch_prompts = prompts[i:i+bs]
|
||||
inputs = tok(batch_prompts, return_tensors="pt", padding=True, truncation=False)
|
||||
# Remove unused keys (e.g., token_type_ids) to avoid generate() validation errors
|
||||
inputs.pop("token_type_ids", None)
|
||||
if device == "cuda" and device_map in (None, {"": 0}):
|
||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
with torch.inference_mode():
|
||||
out = model.generate(**inputs, **gen_kwargs)
|
||||
for j, seq in enumerate(out):
|
||||
text = tok.decode(seq, skip_special_tokens=True)
|
||||
prompt_j = batch_prompts[j]
|
||||
gen = text[len(prompt_j):] if text.startswith(prompt_j) else text
|
||||
outputs.append(gen)
|
||||
else:
|
||||
# Accelerate multi-process path: shard prompts, gather results, keep stable order
|
||||
all_items: List[Tuple[int, str]] = [(i, p) for i, p in enumerate(prompts)]
|
||||
local_items = [it for idx, it in enumerate(all_items) if idx % accelerator.num_processes == accelerator.process_index]
|
||||
|
||||
local_results: List[Tuple[int, str]] = []
|
||||
for i in range(0, len(local_items), bs):
|
||||
batch = local_items[i:i+bs]
|
||||
idxs = [it[0] for it in batch]
|
||||
batch_prompts = [it[1] for it in batch]
|
||||
inputs = tok(batch_prompts, return_tensors="pt", padding=True)
|
||||
inputs.pop("token_type_ids", None)
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
with torch.inference_mode():
|
||||
out = model.generate(**inputs, **gen_kwargs)
|
||||
for j, seq in enumerate(out):
|
||||
text = tok.decode(seq, skip_special_tokens=True)
|
||||
prompt_j = batch_prompts[j]
|
||||
gen = text[len(prompt_j):] if text.startswith(prompt_j) else text
|
||||
local_results.append((idxs[j], gen))
|
||||
|
||||
gathered = accelerator.gather(local_results)
|
||||
if accelerator.is_main_process:
|
||||
# restore order
|
||||
gathered_sorted = sorted(gathered, key=lambda x: x[0])
|
||||
outputs = [gen for _, gen in gathered_sorted]
|
||||
else:
|
||||
return
|
||||
|
||||
if args.output_file:
|
||||
# If multiple prompts, write JSONL; else write plain text
|
||||
import json
|
||||
dirpath = os.path.dirname(args.output_file)
|
||||
if dirpath:
|
||||
os.makedirs(dirpath, exist_ok=True)
|
||||
with open(args.output_file, "w", encoding="utf-8") as f:
|
||||
if len(prompts) > 1:
|
||||
for p, g in zip(prompts, outputs):
|
||||
f.write(json.dumps({"prompt": p, "generation": g}, ensure_ascii=False) + "\n")
|
||||
else:
|
||||
f.write(outputs[0])
|
||||
else:
|
||||
if len(outputs) == 1:
|
||||
print(outputs[0])
|
||||
else:
|
||||
for p, g in zip(prompts, outputs):
|
||||
print("===== PROMPT =====")
|
||||
print(p)
|
||||
print("=== GENERATION ===")
|
||||
print(g)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
# python scripts/inference.py \
|
||||
# --model-path /mnt/share/djx/minimind/out/full_sft_512.pth \
|
||||
# --save-dir out \
|
||||
# --weight full_sft \
|
||||
# --hidden-size 512 \
|
||||
# --num-hidden-layers 8 \
|
||||
# --use-moe 0 \
|
||||
# --device cuda --num-gpus 1 \
|
||||
# --prompt "你好,介绍一下LoRA是什么?"
|
||||
Loading…
Reference in New Issue
Block a user