mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix pre-commit
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
This commit is contained in:
parent
14ad0cb4a7
commit
2f17268409
@ -13,21 +13,32 @@ from tensorrt_llm.logger import logger
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
"""Restricted unpickler that only allows safe types.
|
||||
|
||||
|
||||
This prevents arbitrary code execution by restricting what can be deserialized.
|
||||
Only allows: list, tuple, str, int, float, bool, bytes, and torch-related types
|
||||
needed for tensor reconstruction.
|
||||
"""
|
||||
|
||||
|
||||
# Allowed modules and their classes
|
||||
ALLOWED_TYPES = {
|
||||
'builtins': {'list', 'tuple', 'str', 'int', 'float', 'bool', 'bytes', 'dict', 'NoneType', 'type'},
|
||||
"builtins": {
|
||||
"list",
|
||||
"tuple",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"bytes",
|
||||
"dict",
|
||||
"NoneType",
|
||||
"type",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def find_class(self, module, name):
|
||||
"""Override to restrict which classes can be unpickled."""
|
||||
# Check if the module is in our allowed list
|
||||
is_torch_module = module.startswith('torch')
|
||||
is_torch_module = module.startswith("torch")
|
||||
if is_torch_module or (module in self.ALLOWED_TYPES and name in self.ALLOWED_TYPES[module]):
|
||||
return super().find_class(module, name)
|
||||
|
||||
@ -93,7 +104,7 @@ class WorkerExtension:
|
||||
logger.info("Deserializing base64-encoded weight handles")
|
||||
decoded_data = base64.b64decode(serialized_handles)
|
||||
all_handles = RestrictedUnpickler(io.BytesIO(decoded_data)).load()
|
||||
|
||||
|
||||
# Verify the result is a list as expected
|
||||
if not isinstance(all_handles, list):
|
||||
raise ValueError(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Callable, List, Optional
|
||||
import base64
|
||||
import io
|
||||
import pickle
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -102,9 +102,9 @@ class HFModel:
|
||||
continue
|
||||
handle = reduce_tensor(p)
|
||||
all_handles.append((name, handle))
|
||||
|
||||
|
||||
# Serialize with base64-encoded pickle
|
||||
serialized = base64.b64encode(pickle.dumps(all_handles)).decode('utf-8')
|
||||
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
|
||||
ret[self.device_uuid[device]] = serialized
|
||||
|
||||
return ret
|
||||
@ -220,7 +220,7 @@ def test_llm_update_weights_with_serialized_handles(model_dir):
|
||||
|
||||
# Use the serialized format (base64-encoded pickle)
|
||||
ipc_handles_serialized = hf_model.get_weight_ipc_handles_serialized([0])
|
||||
|
||||
|
||||
# Verify the format is correct (should be base64-encoded strings)
|
||||
for device_uuid, serialized_data in ipc_handles_serialized.items():
|
||||
assert isinstance(serialized_data, str), "Should be base64-encoded string"
|
||||
@ -238,7 +238,7 @@ def test_llm_update_weights_with_serialized_handles(model_dir):
|
||||
# Verify generation works with updated weights
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
|
||||
print("✓ LLM update_weights with serialized handles (RestrictedUnpickler) works!")
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user