Add rSVD-based adaptive LoRA rank estimation

This commit is contained in:
Ltimbe 2025-12-16 19:36:14 +08:00
parent be979ec9e7
commit bae81a2ce9
2 changed files with 115 additions and 9 deletions

View File

@ -18,10 +18,19 @@ class LoRA(nn.Module):
return self.B(self.A(x)) return self.B(self.A(x))
def apply_lora(model, rank=8): def apply_lora(model, rank=8, rank_map=None):
"""
Attach LoRA adapters to q/k/v/o projection matrices.
Args:
model: target model.
rank: default LoRA rank when ``rank_map`` is not provided.
rank_map: optional dict mapping module names to specific ranks.
"""
for name, module in model.named_modules(): for name, module in model.named_modules():
if 'lora' not in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name): if 'lora' not in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name):
lora = LoRA(module.weight.shape[1], module.weight.shape[0], rank=rank).to(model.device) target_rank = rank_map.get(name, rank) if rank_map is not None else rank
lora = LoRA(module.weight.shape[1], module.weight.shape[0], rank=target_rank).to(model.device)
setattr(module, "lora", lora) setattr(module, "lora", lora)
original_forward = module.forward original_forward = module.forward

View File

@ -21,6 +21,83 @@ from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
def _iter_target_modules(model):
for name, module in model.named_modules():
if 'lora' not in name and ('q_proj' in name or 'k_proj' in name or 'v_proj' in name or 'o_proj' in name):
if hasattr(module, 'weight'):
yield name, module
def _randomized_svd_rank(matrix: torch.Tensor, k: int, q: int, tau: float) -> int:
matrix_2d = matrix.detach().float()
effective_k = min(k, matrix_2d.shape[0], matrix_2d.shape[1])
if effective_k == 0:
return 0
omega = torch.randn(matrix_2d.shape[1], effective_k, device=matrix_2d.device, dtype=matrix_2d.dtype)
y = matrix_2d @ omega
for _ in range(max(1, q)):
y = matrix_2d @ (matrix_2d.transpose(0, 1) @ y)
q_matrix, _ = torch.linalg.qr(y, mode='reduced')
b_matrix = q_matrix.transpose(0, 1) @ matrix_2d
_, singular_values, _ = torch.linalg.svd(b_matrix, full_matrices=False)
singular_values = singular_values[:effective_k]
if singular_values.numel() == 0:
return 0
energy = singular_values.pow(2)
energy_sum = energy.sum()
if energy_sum == 0:
return 0
coverage = torch.cumsum(energy, dim=0) / energy_sum
rank_idx = torch.nonzero(coverage >= tau, as_tuple=False)
rank = rank_idx[0, 0].item() + 1 if rank_idx.numel() > 0 else singular_values.numel()
return min(rank, effective_k)
def estimate_ranks(model, loader, steps, autocast_ctx, tau, k, q, device):
grad_sums = {name: torch.zeros_like(module.weight, device=device) for name, module in _iter_target_modules(model)}
loss_fct = nn.CrossEntropyLoss(reduction='none')
model.train()
for idx, (X, Y, loss_mask) in enumerate(loader):
if idx >= steps:
break
X = X.to(device)
Y = Y.to(device)
loss_mask = loss_mask.to(device)
with autocast_ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss.backward()
for name, module in _iter_target_modules(model):
if module.weight.grad is not None:
grad_sums[name] += module.weight.grad.detach()
model.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
rank_map = {}
for name, grad in grad_sums.items():
rank = _randomized_svd_rank(grad, k=k, q=q, tau=tau)
rank_map[name] = rank if rank > 0 else 1
Logger(f"LoRA rank for {name}: {rank_map[name]}")
return rank_map
def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
loss_fct = nn.CrossEntropyLoss(reduction='none') loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time() start_time = time.time()
@ -99,6 +176,10 @@ if __name__ == "__main__":
parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)") parser.add_argument('--from_resume', default=0, type=int, choices=[0, 1], help="是否自动检测&续训0=否1=是)")
parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb") parser.add_argument("--use_wandb", action="store_true", help="是否使用wandb")
parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名") parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="wandb项目名")
parser.add_argument('--rank_eval_steps', default=10, type=int, help="用于rSVD估计秩的batch数量0表示跳过")
parser.add_argument('--svd_tau', default=0.95, type=float, help="rSVD能量覆盖率阈值τ")
parser.add_argument('--svd_k', default=128, type=int, help="rSVD截断奇异值数k")
parser.add_argument('--svd_q', default=2, type=int, help="rSVD幂迭代次数q (建议1-2)")
args = parser.parse_args() args = parser.parse_args()
# ========== 1. 初始化环境和随机种子 ========== # ========== 1. 初始化环境和随机种子 ==========
@ -125,9 +206,26 @@ if __name__ == "__main__":
wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}" wandb_run_name = f"MiniMind-LoRA-{args.lora_name}-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LR-{args.learning_rate}"
wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume) wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)
# ========== 5. 定义模型、应用LoRA、冻结非LoRA参数 ========== # ========== 5. 定义模型 ==========
model, tokenizer = init_model(lm_config, args.from_weight, device=args.device) model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
apply_lora(model)
# ========== 6. rSVD估计各投影层rank ==========
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
rank_sampler = DistributedSampler(train_ds, shuffle=False) if dist.is_initialized() else None
rank_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=False, sampler=rank_sampler, num_workers=args.num_workers, pin_memory=True)
svd_q = max(1, min(args.svd_q, 2))
rank_map = None
if args.rank_eval_steps > 0:
if dist.is_initialized():
if dist.get_rank() == 0:
rank_map = estimate_ranks(model, rank_loader, args.rank_eval_steps, autocast_ctx, args.svd_tau, args.svd_k, svd_q, args.device)
obj_list = [rank_map]
dist.broadcast_object_list(obj_list, src=0)
rank_map = obj_list[0]
else:
rank_map = estimate_ranks(model, rank_loader, args.rank_eval_steps, autocast_ctx, args.svd_tau, args.svd_k, svd_q, args.device)
apply_lora(model, rank_map=rank_map)
# 统计参数 # 统计参数
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
@ -145,13 +243,12 @@ if __name__ == "__main__":
else: else:
param.requires_grad = False param.requires_grad = False
# ========== 6. 定义数据和优化器 ========== # ========== 7. 定义数据和优化器 ==========
train_ds = SFTDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(lora_params, lr=args.learning_rate) optimizer = optim.AdamW(lora_params, lr=args.learning_rate)
# ========== 7. 从ckp恢复状态 ========== # ========== 8. 从ckp恢复状态 ==========
start_epoch, start_step = 0, 0 start_epoch, start_step = 0, 0
if ckp_data: if ckp_data:
model.load_state_dict(ckp_data['model'], strict=False) model.load_state_dict(ckp_data['model'], strict=False)
@ -160,12 +257,12 @@ if __name__ == "__main__":
start_epoch = ckp_data['epoch'] start_epoch = ckp_data['epoch']
start_step = ckp_data.get('step', 0) start_step = ckp_data.get('step', 0)
# ========== 8. DDP包模型 ========== # ========== 9. DDP包模型 ==========
if dist.is_initialized(): if dist.is_initialized():
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[local_rank]) model = DistributedDataParallel(model, device_ids=[local_rank])
# ========== 9. 开始训练 ========== # ========== 10. 开始训练 ==========
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
train_sampler and train_sampler.set_epoch(epoch) train_sampler and train_sampler.set_epoch(epoch)
if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点 if epoch == start_epoch and start_step > 0: # 第一个epoch且存在检查点