Namo-R1/namo/utils/utils.py
2025-02-22 12:25:49 +08:00

205 lines
6.7 KiB
Python

import datetime
import logging
import logging.handlers
import mimetypes
import os
import sys
import torch
import torch.distributed as dist
import requests
import transformers
from loguru import logger
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def rank0_print(*args):
if is_main_process():
print(*args)
def is_image(url_or_path):
# If it's a local file path, convert it to an absolute path
url_or_path = url_or_path.split(" ")[0]
if os.path.exists(url_or_path):
url_or_path = os.path.abspath(url_or_path)
mimetype, encoding = mimetypes.guess_type(url_or_path)
return (mimetype and mimetype.startswith("image")) or url_or_path.endswith("webp")
def load_conn_weights(conn_model_path, model_conn, module_key="conn_ve_llm"):
mm_projector_weights = torch.load(conn_model_path, map_location="cpu")
def get_w(weights, keyword):
return {
k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k
}
mm_projector_weights = get_w(mm_projector_weights, module_key)
try:
model_conn.load_state_dict(mm_projector_weights, strict=False)
if is_main_process():
logger.info(f"conn weights loaded from: {conn_model_path}")
except Exception as e:
print(f"got error load state dict: {e}")
model_conn.load_state_dict(
{
k: v
for k, v in mm_projector_weights.items()
if "layers.1" not in k and "layers.0" not in k
},
strict=False,
)
print(f"{module_key} partially loaded!")
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(
f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}"
)
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {
k: t
for k, t in named_params
if any(key_match in k for key_match in keys_to_match)
}
to_return = {
k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()
}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
# {'up_proj', 'v_proj', 'gate_proj', 'k_proj', 'down_proj', 'q_proj', 'o_proj', 'lm_head'}
multimodal_keywords = ["conn_ve_llm", "ve", "vision_resampler"]
# multimodal_keywords = ['mm_projector', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
rank0_print(f"==> lora modules: {lora_module_names}")
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if getattr(trainer.args, "tune_conn_ve_llm", False):
# Only save Adapter
keys_to_match = ["conn_ve_llm"]
if getattr(trainer.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
trainer.model.named_parameters(), keys_to_match
)
trainer.model.config.save_pretrained(output_dir)
current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "conn_ve_llm")
os.makedirs(mm_projector_folder, exist_ok=True)
torch.save(
weight_to_save,
os.path.join(mm_projector_folder, f"{current_folder}.bin"),
)
else:
torch.save(weight_to_save, os.path.join(output_dir, f"conn_ve_llm.bin"))
return
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa