mirror of
https://github.com/lucasjinreal/Namo-R1.git
synced 2026-01-14 06:17:15 +08:00
205 lines
6.7 KiB
Python
205 lines
6.7 KiB
Python
import datetime
|
|
import logging
|
|
import logging.handlers
|
|
import mimetypes
|
|
import os
|
|
import sys
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import requests
|
|
import transformers
|
|
from loguru import logger
|
|
|
|
|
|
def disable_torch_init():
|
|
"""
|
|
Disable the redundant torch default initialization to accelerate model creation.
|
|
"""
|
|
import torch
|
|
|
|
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
|
|
|
|
|
def is_dist_avail_and_initialized():
|
|
if not dist.is_available():
|
|
return False
|
|
if not dist.is_initialized():
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_rank():
|
|
if not is_dist_avail_and_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0
|
|
|
|
|
|
def rank0_print(*args):
|
|
if is_main_process():
|
|
print(*args)
|
|
|
|
|
|
def is_image(url_or_path):
|
|
# If it's a local file path, convert it to an absolute path
|
|
url_or_path = url_or_path.split(" ")[0]
|
|
if os.path.exists(url_or_path):
|
|
url_or_path = os.path.abspath(url_or_path)
|
|
|
|
mimetype, encoding = mimetypes.guess_type(url_or_path)
|
|
return (mimetype and mimetype.startswith("image")) or url_or_path.endswith("webp")
|
|
|
|
|
|
def load_conn_weights(conn_model_path, model_conn, module_key="conn_ve_llm"):
|
|
mm_projector_weights = torch.load(conn_model_path, map_location="cpu")
|
|
|
|
def get_w(weights, keyword):
|
|
return {
|
|
k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k
|
|
}
|
|
|
|
mm_projector_weights = get_w(mm_projector_weights, module_key)
|
|
try:
|
|
model_conn.load_state_dict(mm_projector_weights, strict=False)
|
|
if is_main_process():
|
|
logger.info(f"conn weights loaded from: {conn_model_path}")
|
|
except Exception as e:
|
|
print(f"got error load state dict: {e}")
|
|
model_conn.load_state_dict(
|
|
{
|
|
k: v
|
|
for k, v in mm_projector_weights.items()
|
|
if "layers.1" not in k and "layers.0" not in k
|
|
},
|
|
strict=False,
|
|
)
|
|
print(f"{module_key} partially loaded!")
|
|
|
|
|
|
def maybe_zero_3(param, ignore_status=False, name=None):
|
|
from deepspeed import zero
|
|
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
|
|
|
if hasattr(param, "ds_id"):
|
|
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
|
if not ignore_status:
|
|
logging.warning(
|
|
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
|
|
)
|
|
with zero.GatheredParameters([param]):
|
|
param = param.data.detach().cpu().clone()
|
|
else:
|
|
param = param.detach().cpu().clone()
|
|
return param
|
|
|
|
|
|
# Borrowed from peft.utils.get_peft_model_state_dict
|
|
def get_peft_state_maybe_zero_3(named_params, bias):
|
|
if bias == "none":
|
|
to_return = {k: t for k, t in named_params if "lora_" in k}
|
|
elif bias == "all":
|
|
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
|
elif bias == "lora_only":
|
|
to_return = {}
|
|
maybe_lora_bias = {}
|
|
lora_bias_names = set()
|
|
for k, t in named_params:
|
|
if "lora_" in k:
|
|
to_return[k] = t
|
|
bias_name = k.split("lora_")[0] + "bias"
|
|
lora_bias_names.add(bias_name)
|
|
elif "bias" in k:
|
|
maybe_lora_bias[k] = t
|
|
for k, t in maybe_lora_bias:
|
|
if bias_name in lora_bias_names:
|
|
to_return[bias_name] = t
|
|
else:
|
|
raise NotImplementedError
|
|
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
|
return to_return
|
|
|
|
|
|
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
|
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
|
if require_grad_only:
|
|
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
|
to_return = {
|
|
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
|
|
}
|
|
return to_return
|
|
|
|
|
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
|
to_return = {
|
|
k: t
|
|
for k, t in named_params
|
|
if any(key_match in k for key_match in keys_to_match)
|
|
}
|
|
to_return = {
|
|
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
|
|
}
|
|
return to_return
|
|
|
|
|
|
def find_all_linear_names(model):
|
|
cls = torch.nn.Linear
|
|
lora_module_names = set()
|
|
# {'up_proj', 'v_proj', 'gate_proj', 'k_proj', 'down_proj', 'q_proj', 'o_proj', 'lm_head'}
|
|
multimodal_keywords = ["conn_ve_llm", "ve", "vision_resampler"]
|
|
# multimodal_keywords = ['mm_projector', 'vision_resampler']
|
|
for name, module in model.named_modules():
|
|
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
|
continue
|
|
if isinstance(module, cls):
|
|
names = name.split(".")
|
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
|
rank0_print(f"==> lora modules: {lora_module_names}")
|
|
if "lm_head" in lora_module_names: # needed for 16-bit
|
|
lora_module_names.remove("lm_head")
|
|
return list(lora_module_names)
|
|
|
|
|
|
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
|
"""Collects the state dict and dump to disk."""
|
|
|
|
if getattr(trainer.args, "tune_conn_ve_llm", False):
|
|
# Only save Adapter
|
|
keys_to_match = ["conn_ve_llm"]
|
|
if getattr(trainer.args, "use_im_start_end", False):
|
|
keys_to_match.extend(["embed_tokens", "embed_in"])
|
|
|
|
weight_to_save = get_mm_adapter_state_maybe_zero_3(
|
|
trainer.model.named_parameters(), keys_to_match
|
|
)
|
|
trainer.model.config.save_pretrained(output_dir)
|
|
|
|
current_folder = output_dir.split("/")[-1]
|
|
parent_folder = os.path.dirname(output_dir)
|
|
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
|
|
if current_folder.startswith("checkpoint-"):
|
|
mm_projector_folder = os.path.join(parent_folder, "conn_ve_llm")
|
|
os.makedirs(mm_projector_folder, exist_ok=True)
|
|
torch.save(
|
|
weight_to_save,
|
|
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
|
|
)
|
|
else:
|
|
torch.save(weight_to_save, os.path.join(output_dir, f"conn_ve_llm.bin"))
|
|
return
|
|
|
|
if trainer.deepspeed:
|
|
torch.cuda.synchronize()
|
|
trainer.save_model(output_dir)
|
|
return
|
|
|
|
state_dict = trainer.model.state_dict()
|
|
if trainer.args.should_save:
|
|
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
|
del state_dict
|
|
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|