diff --git a/scripts/train_tokenizer.py b/scripts/train_tokenizer.py index 78211a8..55c4999 100644 --- a/scripts/train_tokenizer.py +++ b/scripts/train_tokenizer.py @@ -99,7 +99,7 @@ def train_tokenizer(): "spaces_between_special_tokens": False, "tokenizer_class": "PreTrainedTokenizerFast", "unk_token": "<|endoftext|>", - "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}" + "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}" } # 保存配置文件 diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index bbe3bec..ee1ceba 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -65,6 +65,7 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb= scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index 2fbc1c9..1105e8d 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -96,6 +96,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index 4c892ee..e6bfc40 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -90,6 +90,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb= scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 9d1c915..316ed74 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -52,6 +52,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): scaler.update() optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 1536965..52f6562 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -149,6 +149,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token optimizer.step() scheduler.step() optimizer.zero_grad() + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters: policy_loss_val = loss.item() diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 1305c56..c98537d 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -53,6 +53,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None): scaler.update() optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index b8d904c..1cdb074 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -179,6 +179,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche critic_scheduler.step() actor_optimizer.zero_grad() critic_optimizer.zero_grad() + torch.cuda.empty_cache() if is_main_process(): response_ids = gen_out[:, enc.input_ids.shape[1]:] diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 8408bfb..c5acd3e 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -52,6 +52,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None): scaler.update() optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters - 1: spend_time = time.time() - start_time diff --git a/trainer/train_spo.py b/trainer/train_spo.py index ac0443c..c250b14 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -192,6 +192,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni optimizer.step() scheduler.step() optimizer.zero_grad() + torch.cuda.empty_cache() if step % args.log_interval == 0 or step == iters: policy_loss_val = loss.item()