mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
fix: resolve high vulnerability trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
Automatically generated security fix
This commit is contained in:
parent
a044578d73
commit
63d50b8778
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
from torch import optim, nn
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# 定义Lora网络结构
|
||||
@ -33,17 +35,75 @@ def apply_lora(model, rank=8):
|
||||
|
||||
|
||||
def load_lora(model, path):
|
||||
state_dict = torch.load(path, map_location=model.device)
|
||||
"""Load LoRA weights safely using directory-based format only"""
|
||||
state_dict = {}
|
||||
|
||||
if os.path.isdir(path):
|
||||
# Load from safe directory structure format
|
||||
metadata_path = os.path.join(path, 'metadata.json')
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
for key, info in metadata.items():
|
||||
tensor_path = os.path.join(path, f"{key.replace('.', '_')}.pt")
|
||||
if os.path.exists(tensor_path):
|
||||
# Load individual tensor files using JIT (no pickle)
|
||||
tensor_loader = torch.jit.load(tensor_path, map_location=model.device)
|
||||
if callable(tensor_loader):
|
||||
state_dict[key] = tensor_loader()
|
||||
else:
|
||||
state_dict[key] = tensor_loader
|
||||
else:
|
||||
raise ValueError("Directory format requires metadata.json file")
|
||||
else:
|
||||
# For single files, expect .pt format with individual tensors
|
||||
if path.endswith('.pt'):
|
||||
# Load single tensor file using JIT
|
||||
tensor_loader = torch.jit.load(path, map_location=model.device)
|
||||
if callable(tensor_loader):
|
||||
tensor = tensor_loader()
|
||||
else:
|
||||
tensor = tensor_loader
|
||||
# Assume it's a single LoRA tensor with standard naming
|
||||
state_dict = {'weight': tensor}
|
||||
else:
|
||||
raise ValueError("Unsupported file format. Please use directory structure or .pt files.")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'lora'):
|
||||
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
|
||||
module.lora.load_state_dict(lora_state)
|
||||
if lora_state:
|
||||
module.lora.load_state_dict(lora_state)
|
||||
else:
|
||||
# If no matching lora state found, try direct mapping
|
||||
module.lora.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
def save_lora(model, path):
|
||||
"""Save LoRA weights in a secure directory format avoiding pickle vulnerabilities"""
|
||||
state_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'lora'):
|
||||
lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
|
||||
state_dict.update(lora_state)
|
||||
torch.save(state_dict, path)
|
||||
|
||||
# Save in safe directory structure format only
|
||||
os.makedirs(path, exist_ok=True)
|
||||
metadata = {}
|
||||
for key, tensor in state_dict.items():
|
||||
# Save each tensor as individual JIT file (no pickle)
|
||||
tensor_path = os.path.join(path, f"{key.replace('.', '_')}.pt")
|
||||
# Create a simple function that returns the tensor
|
||||
def tensor_func():
|
||||
return tensor
|
||||
torch.jit.save(torch.jit.script(tensor_func), tensor_path)
|
||||
metadata[key] = {
|
||||
'shape': list(tensor.shape),
|
||||
'dtype': str(tensor.dtype),
|
||||
'device': str(tensor.device)
|
||||
}
|
||||
|
||||
# Save metadata as JSON
|
||||
metadata_path = os.path.join(path, 'metadata.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user