[fix] Refactor get_lr function to include min_lr calculation

这里的退火算法会让参数里的lr的起始值变成原来lr的1.1倍,作出如下修改
This commit is contained in:
dyhuachi 2025-12-06 17:09:51 +08:00 committed by GitHub
parent 5e1447b913
commit bf3878ace8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,7 +23,8 @@ def Logger(content):
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
min_lr = lr / 10
return min_lr + 0.5 * (lr - min_lr) * (1 + math.cos(math.pi * current_step / total_steps))
def init_distributed_mode():