diff --git a/scripts/convert_model.py b/scripts/convert_model.py index 60cb953..01b1d51 100644 --- a/scripts/convert_model.py +++ b/scripts/convert_model.py @@ -64,9 +64,8 @@ def convert_torch2transformers_llama(torch_path, transformers_path, dtype=torch. def convert_transformers2torch(transformers_path, torch_path): model = AutoModelForCausalLM.from_pretrained(transformers_path, trust_remote_code=True) - torch.save(model.state_dict(), torch_path) - print(f"模型已保存为 PyTorch 格式: {torch_path}") - + torch.save({k: v.cpu().half() for k, v in model.state_dict().items()}, torch_path) + print(f"模型已保存为 PyTorch 格式 (half精度): {torch_path}") if __name__ == '__main__':