diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index f1e3526..3c882cd 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -15,7 +15,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index f1de5d2..848bf52 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index b4e7b37..97716a5 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import DPODataset -from trainer.trainer_utils import * +from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 09fa941..6702159 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -15,7 +15,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index 22a4e26..13727fe 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -19,7 +19,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_lora.py b/trainer/train_lora.py index b6fc2b0..1473d60 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import SFTDataset from model.model_lora import save_lora, apply_lora -from trainer.trainer_utils import * +from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index 27a82b7..51b0813 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 18cc445..87c79a7 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -15,7 +15,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from model.model_minimind import MiniMindConfig from dataset.lm_dataset import PretrainDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import get_lr, Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, init_model, SkipBatchSampler warnings.filterwarnings('ignore') diff --git a/trainer/train_spo.py b/trainer/train_spo.py index 74dc72c..64c4e9f 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -19,7 +19,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from transformers import AutoModel from model.model_minimind import MiniMindConfig, MiniMindForCausalLM from dataset.lm_dataset import RLAIFDataset -from trainer.trainer_utils import * +from trainer.trainer_utils import Logger, is_main_process, lm_checkpoint, init_distributed_mode, setup_seed, SkipBatchSampler warnings.filterwarnings('ignore')