From bf3878ace83b001199bfed5d9b165591138432ec Mon Sep 17 00:00:00 2001 From: dyhuachi <95952350+dyhuachi@users.noreply.github.com> Date: Sat, 6 Dec 2025 17:09:51 +0800 Subject: [PATCH] [fix] Refactor get_lr function to include min_lr calculation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 这里的退火算法会让参数里的lr的起始值变成原来lr的1.1倍,作出如下修改 --- trainer/trainer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trainer/trainer_utils.py b/trainer/trainer_utils.py index 93c9bfc..0ad1b00 100644 --- a/trainer/trainer_utils.py +++ b/trainer/trainer_utils.py @@ -23,7 +23,8 @@ def Logger(content): def get_lr(current_step, total_steps, lr): - return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) + min_lr = lr / 10 + return min_lr + 0.5 * (lr - min_lr) * (1 + math.cos(math.pi * current_step / total_steps)) def init_distributed_mode():