diff --git a/scripts/train_tokenizer.py b/scripts/train_tokenizer.py
index 78211a8..55c4999 100644
--- a/scripts/train_tokenizer.py
+++ b/scripts/train_tokenizer.py
@@ -99,7 +99,7 @@ def train_tokenizer():
"spaces_between_special_tokens": False,
"tokenizer_class": "PreTrainedTokenizerFast",
"unk_token": "<|endoftext|>",
- "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}"
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}"
}
# 保存配置文件
diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py
index bbe3bec..ee1ceba 100644
--- a/trainer/train_distill_reason.py
+++ b/trainer/train_distill_reason.py
@@ -65,6 +65,7 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py
index 2fbc1c9..1105e8d 100644
--- a/trainer/train_distillation.py
+++ b/trainer/train_distillation.py
@@ -96,6 +96,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py
index 4c892ee..e6bfc40 100644
--- a/trainer/train_dpo.py
+++ b/trainer/train_dpo.py
@@ -90,6 +90,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py
index 9d1c915..316ed74 100644
--- a/trainer/train_full_sft.py
+++ b/trainer/train_full_sft.py
@@ -52,6 +52,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
scaler.update()
optimizer.zero_grad(set_to_none=True)
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py
index 1536965..52f6562 100755
--- a/trainer/train_grpo.py
+++ b/trainer/train_grpo.py
@@ -149,6 +149,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
optimizer.step()
scheduler.step()
optimizer.zero_grad()
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()
diff --git a/trainer/train_lora.py b/trainer/train_lora.py
index 1305c56..c98537d 100644
--- a/trainer/train_lora.py
+++ b/trainer/train_lora.py
@@ -53,6 +53,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
scaler.update()
optimizer.zero_grad(set_to_none=True)
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py
index b8d904c..1cdb074 100644
--- a/trainer/train_ppo.py
+++ b/trainer/train_ppo.py
@@ -179,6 +179,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
critic_scheduler.step()
actor_optimizer.zero_grad()
critic_optimizer.zero_grad()
+ torch.cuda.empty_cache()
if is_main_process():
response_ids = gen_out[:, enc.input_ids.shape[1]:]
diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py
index 8408bfb..c5acd3e 100644
--- a/trainer/train_pretrain.py
+++ b/trainer/train_pretrain.py
@@ -52,6 +52,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
scaler.update()
optimizer.zero_grad(set_to_none=True)
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters - 1:
spend_time = time.time() - start_time
diff --git a/trainer/train_spo.py b/trainer/train_spo.py
index ac0443c..c250b14 100755
--- a/trainer/train_spo.py
+++ b/trainer/train_spo.py
@@ -192,6 +192,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
optimizer.step()
scheduler.step()
optimizer.zero_grad()
+ torch.cuda.empty_cache()
if step % args.log_interval == 0 or step == iters:
policy_loss_val = loss.item()