fix pre-commit

Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
Yibin Li 2026-01-13 06:58:34 +00:00
parent 14ad0cb4a7
commit 2f17268409
2 changed files with 22 additions and 11 deletions

View File

@ -13,21 +13,32 @@ from tensorrt_llm.logger import logger
class RestrictedUnpickler(pickle.Unpickler):
"""Restricted unpickler that only allows safe types.
This prevents arbitrary code execution by restricting what can be deserialized.
Only allows: list, tuple, str, int, float, bool, bytes, and torch-related types
needed for tensor reconstruction.
"""
# Allowed modules and their classes
ALLOWED_TYPES = {
'builtins': {'list', 'tuple', 'str', 'int', 'float', 'bool', 'bytes', 'dict', 'NoneType', 'type'},
"builtins": {
"list",
"tuple",
"str",
"int",
"float",
"bool",
"bytes",
"dict",
"NoneType",
"type",
},
}
def find_class(self, module, name):
"""Override to restrict which classes can be unpickled."""
# Check if the module is in our allowed list
is_torch_module = module.startswith('torch')
is_torch_module = module.startswith("torch")
if is_torch_module or (module in self.ALLOWED_TYPES and name in self.ALLOWED_TYPES[module]):
return super().find_class(module, name)
@ -93,7 +104,7 @@ class WorkerExtension:
logger.info("Deserializing base64-encoded weight handles")
decoded_data = base64.b64decode(serialized_handles)
all_handles = RestrictedUnpickler(io.BytesIO(decoded_data)).load()
# Verify the result is a list as expected
if not isinstance(all_handles, list):
raise ValueError(

View File

@ -1,7 +1,7 @@
from typing import Callable, List, Optional
import base64
import io
import pickle
from typing import Callable, List, Optional
import pytest
import torch
@ -102,9 +102,9 @@ class HFModel:
continue
handle = reduce_tensor(p)
all_handles.append((name, handle))
# Serialize with base64-encoded pickle
serialized = base64.b64encode(pickle.dumps(all_handles)).decode('utf-8')
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
ret[self.device_uuid[device]] = serialized
return ret
@ -220,7 +220,7 @@ def test_llm_update_weights_with_serialized_handles(model_dir):
# Use the serialized format (base64-encoded pickle)
ipc_handles_serialized = hf_model.get_weight_ipc_handles_serialized([0])
# Verify the format is correct (should be base64-encoded strings)
for device_uuid, serialized_data in ipc_handles_serialized.items():
assert isinstance(serialized_data, str), "Should be base64-encoded string"
@ -238,7 +238,7 @@ def test_llm_update_weights_with_serialized_handles(model_dir):
# Verify generation works with updated weights
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
print("✓ LLM update_weights with serialized handles (RestrictedUnpickler) works!")