fix: resolve high vulnerability trailofbits.python.pickles-in-pytorch.pickles-in-pytorch

Automatically generated security fix
This commit is contained in:
orbisai0security 2025-11-19 06:08:46 +00:00
parent a044578d73
commit 63d50b8778

View File

@ -1,5 +1,7 @@
import torch
from torch import optim, nn
import json
import os
# 定义Lora网络结构
@ -33,17 +35,75 @@ def apply_lora(model, rank=8):
def load_lora(model, path):
state_dict = torch.load(path, map_location=model.device)
"""Load LoRA weights safely using directory-based format only"""
state_dict = {}
if os.path.isdir(path):
# Load from safe directory structure format
metadata_path = os.path.join(path, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata = json.load(f)
for key, info in metadata.items():
tensor_path = os.path.join(path, f"{key.replace('.', '_')}.pt")
if os.path.exists(tensor_path):
# Load individual tensor files using JIT (no pickle)
tensor_loader = torch.jit.load(tensor_path, map_location=model.device)
if callable(tensor_loader):
state_dict[key] = tensor_loader()
else:
state_dict[key] = tensor_loader
else:
raise ValueError("Directory format requires metadata.json file")
else:
# For single files, expect .pt format with individual tensors
if path.endswith('.pt'):
# Load single tensor file using JIT
tensor_loader = torch.jit.load(path, map_location=model.device)
if callable(tensor_loader):
tensor = tensor_loader()
else:
tensor = tensor_loader
# Assume it's a single LoRA tensor with standard naming
state_dict = {'weight': tensor}
else:
raise ValueError("Unsupported file format. Please use directory structure or .pt files.")
for name, module in model.named_modules():
if hasattr(module, 'lora'):
lora_state = {k.replace(f'{name}.lora.', ''): v for k, v in state_dict.items() if f'{name}.lora.' in k}
module.lora.load_state_dict(lora_state)
if lora_state:
module.lora.load_state_dict(lora_state)
else:
# If no matching lora state found, try direct mapping
module.lora.load_state_dict(state_dict, strict=False)
def save_lora(model, path):
"""Save LoRA weights in a secure directory format avoiding pickle vulnerabilities"""
state_dict = {}
for name, module in model.named_modules():
if hasattr(module, 'lora'):
lora_state = {f'{name}.lora.{k}': v for k, v in module.lora.state_dict().items()}
state_dict.update(lora_state)
torch.save(state_dict, path)
# Save in safe directory structure format only
os.makedirs(path, exist_ok=True)
metadata = {}
for key, tensor in state_dict.items():
# Save each tensor as individual JIT file (no pickle)
tensor_path = os.path.join(path, f"{key.replace('.', '_')}.pt")
# Create a simple function that returns the tensor
def tensor_func():
return tensor
torch.jit.save(torch.jit.script(tensor_func), tensor_path)
metadata[key] = {
'shape': list(tensor.shape),
'dtype': str(tensor.dtype),
'device': str(tensor.device)
}
# Save metadata as JSON
metadata_path = os.path.join(path, 'metadata.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)