mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-14 04:07:17 +08:00
Add rSVD-based adaptive LoRA rank estimation
This commit is contained in:
parent
be979ec9e7
commit
bae81a2ce9
@ -18,10 +18,19 @@ class LoRA(nn.Module):
|
|||||||
return self.B(self.A(x))
|
return self.B(self.A(x))
|
||||||
|
|
||||||
|
|
||||||
def apply_lora(model, rank=8):
|
def apply_lora(model, rank=8, rank_map=None):
|
||||||
|
"""
|
||||||
|
Attach LoRA adapters to q/k/v/o projection matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: target model.
|
||||||
|
rank: default LoRA rank when ``rank_map`` is not provided.
|
||||||
|
rank_map: optional dict mapping module names to specific ranks.
|
||||||
|
"""
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if 'lora' not in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name):
|
if 'lora' not in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name):
|
||||||
lora = LoRA(module.weight.shape[1], module.weight.shape[0], rank=rank).to(model.device)
|
target_rank = rank_map.get(name, rank) if rank_map is not None else rank
|
||||||
|
lora = LoRA(module.weight.shape[1], module.weight.shape[0], rank=target_rank).to(model.device)
|
||||||
setattr(module, "lora", lora)
|
setattr(module, "lora", lora)
|
||||||
original_forward = module.forward
|
original_forward = module.forward
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,83 @@ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_target_modules(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if 'lora' not in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name):
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
yield name, module
|
||||||
|
|
||||||
|
|
||||||
|
def _randomized_svd_rank(matrix: torch.Tensor, k: int, q: int, tau: float) -> int:
|
||||||
|
matrix_2d = matrix.detach().float()
|
||||||
|
effective_k = min(k, matrix_2d.shape[0], matrix_2d.shape[1])
|
||||||
|
if effective_k == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
omega = torch.randn(matrix_2d.shape[1], effective_k, device=matrix_2d.device, dtype=matrix_2d.dtype)
|
||||||
|
y = matrix_2d @ omega
|
||||||
|
for _ in range(max(1, q)):
|
||||||
|
y = matrix_2d @ (matrix_2d.transpose(0, 1) @ y)
|
||||||
|
|
||||||
|
q_matrix, _ = torch.linalg.qr(y, mode='reduced')
|
||||||
|
b_matrix = q_matrix.transpose(0, 1) @ matrix_2d
|
||||||
|
_, singular_values, _ = torch.linalg.svd(b_matrix, full_matrices=False)
|
||||||
|
singular_values = singular_values[:effective_k]
|
||||||
|
|
||||||
|
if singular_values.numel() == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
energy = singular_values.pow(2)
|
||||||
|
energy_sum = energy.sum()
|
||||||
|
if energy_sum == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
coverage = torch.cumsum(energy, dim=0) / energy_sum
|
||||||
|
rank_idx = torch.nonzero(coverage >= tau, as_tuple=False)
|
||||||
|
rank = rank_idx[0, 0].item() + 1 if rank_idx.numel() > 0 else singular_values.numel()
|
||||||
|
return min(rank, effective_k)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_ranks(model, loader, steps, autocast_ctx, tau, k, q, device):
|
||||||
|
grad_sums = {name: torch.zeros_like(module.weight, device=device) for name, module in _iter_target_modules(model)}
|
||||||
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for idx, (X, Y, loss_mask) in enumerate(loader):
|
||||||
|
if idx >= steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
X = X.to(device)
|
||||||
|
Y = Y.to(device)
|
||||||
|
loss_mask = loss_mask.to(device)
|
||||||
|
|
||||||
|
with autocast_ctx:
|
||||||
|
res = model(X)
|
||||||
|
loss = loss_fct(
|
||||||
|
res.logits.view(-1, res.logits.size(-1)),
|
||||||
|
Y.view(-1)
|
||||||
|
).view(Y.size())
|
||||||
|
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
|
loss += res.aux_loss
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
for name, module in _iter_target_modules(model):
|
||||||
|
if module.weight.grad is not None:
|
||||||
|
grad_sums[name] += module.weight.grad.detach()
|
||||||
|
|
||||||
|
model.zero_grad(set_to_none=True)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
rank_map = {}
|
||||||
|
for name, grad in grad_sums.items():
|
||||||
|
rank = _randomized_svd_rank(grad, k=k, q=q, tau=tau)
|
||||||
|
rank_map[name] = rank if rank > 0 else 1
|
||||||
|
Logger(f"LoRA rank for {name}: {rank_map[name]}")
|
||||||
|
|
||||||
|
return rank_map
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -99,6 +176,10 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训(0=否,1=是)")
|
||||||
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
|
||||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
|
||||||
|
parser.add_argument('--rank_eval_steps', default=10, type=int, help="用于rSVD估计秩的batch数量(0表示跳过)")
|
||||||
|
parser.add_argument('--svd_tau', default=0.95, type=float, help="rSVD能量覆盖率阈值τ")
|
||||||
|
parser.add_argument('--svd_k', default=128, type=int, help="rSVD截断奇异值数k")
|
||||||
|
parser.add_argument('--svd_q', default=2, type=int, help="rSVD幂迭代次数q (建议1-2)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# ========== 1. 初始化环境和随机种子 ==========
|
# ========== 1. 初始化环境和随机种子 ==========
|
||||||
@ -125,9 +206,26 @@ if __name__ == "__main__":
|
|||||||
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
|
||||||
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
|
||||||
|
|
||||||
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ==========
|
# ========== 5. 定义模型 ==========
|
||||||
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
|
||||||
apply_lora(model)
|
|
||||||
|
# ========== 6. rSVD估计各投影层rank ==========
|
||||||
|
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
||||||
|
rank_sampler = DistributedSampler(train_ds, shuffle=False) if dist.is_initialized() else None
|
||||||
|
rank_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=False, sampler=rank_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||||
|
svd_q = max(1, min(args.svd_q, 2))
|
||||||
|
rank_map = None
|
||||||
|
if args.rank_eval_steps > 0:
|
||||||
|
if dist.is_initialized():
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
rank_map = estimate_ranks(model, rank_loader, args.rank_eval_steps, autocast_ctx, args.svd_tau, args.svd_k, svd_q, args.device)
|
||||||
|
obj_list = [rank_map]
|
||||||
|
dist.broadcast_object_list(obj_list, src=0)
|
||||||
|
rank_map = obj_list[0]
|
||||||
|
else:
|
||||||
|
rank_map = estimate_ranks(model, rank_loader, args.rank_eval_steps, autocast_ctx, args.svd_tau, args.svd_k, svd_q, args.device)
|
||||||
|
|
||||||
|
apply_lora(model, rank_map=rank_map)
|
||||||
|
|
||||||
# 统计参数
|
# 统计参数
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
@ -145,13 +243,12 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
# ========== 6. 定义数据和优化器 ==========
|
# ========== 7. 定义数据和优化器 ==========
|
||||||
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
|
|
||||||
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||||
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
|
||||||
|
|
||||||
# ========== 7. 从ckp恢复状态 ==========
|
# ========== 8. 从ckp恢复状态 ==========
|
||||||
start_epoch, start_step = 0, 0
|
start_epoch, start_step = 0, 0
|
||||||
if ckp_data:
|
if ckp_data:
|
||||||
model.load_state_dict(ckp_data['model'], strict=False)
|
model.load_state_dict(ckp_data['model'], strict=False)
|
||||||
@ -160,12 +257,12 @@ if __name__ == "__main__":
|
|||||||
start_epoch = ckp_data['epoch']
|
start_epoch = ckp_data['epoch']
|
||||||
start_step = ckp_data.get('step', 0)
|
start_step = ckp_data.get('step', 0)
|
||||||
|
|
||||||
# ========== 8. DDP包模型 ==========
|
# ========== 9. DDP包模型 ==========
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
|
||||||
model = DistributedDataParallel(model, device_ids=[local_rank])
|
model = DistributedDataParallel(model, device_ids=[local_rank])
|
||||||
|
|
||||||
# ========== 9. 开始训练 ==========
|
# ========== 10. 开始训练 ==========
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
train_sampler and train_sampler.set_epoch(epoch)
|
train_sampler and train_sampler.set_epoch(epoch)
|
||||||
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user