This commit is contained in:
Yibin Li 2026-01-13 05:17:55 -08:00 committed by GitHub
commit 53019ad351
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 137 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import base64
import io
import pickle # nosec B403
from typing import Optional
@ -10,6 +11,44 @@ from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.logger import logger
class RestrictedUnpickler(pickle.Unpickler):
"""Restricted unpickler that only allows safe types.
This prevents arbitrary code execution by restricting what can be deserialized.
Only allows: list, tuple, str, int, float, bool, bytes, and torch-related types
needed for tensor reconstruction.
"""
# Allowed modules and their classes
ALLOWED_TYPES = {
"builtins": {
"list",
"tuple",
"str",
"int",
"float",
"bool",
"bytes",
"dict",
"NoneType",
"type",
},
}
def find_class(self, module, name):
"""Override to restrict which classes can be unpickled."""
# Check if the module is in our allowed list
is_torch_module = module.startswith("torch")
if is_torch_module or (module in self.ALLOWED_TYPES and name in self.ALLOWED_TYPES[module]):
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Global '{module}.{name}' is forbidden for security reasons. "
f"Only basic types (list, tuple, str, int, etc.) and torch tensor types are allowed."
f"Module: {module}, Name: {name}"
)
class WorkerExtension:
"""Worker extension class for extending TensorRT-LLM Ray workers with custom functionality.
@ -61,9 +100,16 @@ class WorkerExtension:
serialized_handles = ipc_handles[device_uuid]
if isinstance(serialized_handles, str):
# Data is base64-encoded pickled bytes - deserialize it
# Data is base64-encoded pickled bytes - deserialize it using restricted unpickler
logger.info("Deserializing base64-encoded weight handles")
all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301
decoded_data = base64.b64decode(serialized_handles)
all_handles = RestrictedUnpickler(io.BytesIO(decoded_data)).load()
# Verify the result is a list as expected
if not isinstance(all_handles, list):
raise ValueError(
f"Deserialized data must be a list, got {type(all_handles).__name__} instead"
)
else:
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles

View File

@ -1,3 +1,6 @@
import base64
import io
import pickle
from typing import Callable, List, Optional
import pytest
@ -8,6 +11,7 @@ from utils.llm_data import llm_models_root
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.rlhf_utils import RestrictedUnpickler
class HFModel:
@ -71,6 +75,40 @@ class HFModel:
return ret
def get_weight_ipc_handles_serialized(
self,
cuda_device: Optional[List[int]] = None,
weight_filter: Optional[Callable[[str], bool]] = None,
):
"""
Get base64-encoded serialized IPC handles for model weights.
Args:
cuda_device: List of CUDA device indices to get weights from
weight_filter: Optional function that takes weight name and returns True if weight should be included
Returns:
ret: Dictionary mapping device UUIDs to base64-encoded pickled handles
"""
ret = {}
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device
for device in device_list:
all_handles = []
for item in self.all_weights[device]:
name, p = item
# Apply filter if provided
if weight_filter is not None and not weight_filter(name):
continue
handle = reduce_tensor(p)
all_handles.append((name, handle))
# Serialize with base64-encoded pickle
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
ret[self.device_uuid[device]] = serialized
return ret
def generate_batch_incremental(
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
):
@ -153,6 +191,57 @@ def run_generate(llm, hf_model, prompts, sampling_params):
return llm_logits, ref_logits
@pytest.mark.parametrize(
"model_dir",
["Qwen2.5-0.5B-Instruct"],
)
def test_llm_update_weights_with_serialized_handles(model_dir):
"""Test LLM update_weights with base64-encoded serialized handles (RestrictedUnpickler)."""
model_dir = str(llm_models_root() / model_dir)
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
hf_model = HFModel(model_dir)
llm = LLM(
model=model_dir,
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
tensor_parallel_size=1,
load_format="dummy",
pipeline_parallel_size=1,
kv_cache_config=kv_cache_config,
)
prompts = [
"Hello, my name is",
"The president of the United States is",
]
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
# Use the serialized format (base64-encoded pickle)
ipc_handles_serialized = hf_model.get_weight_ipc_handles_serialized([0])
# Verify the format is correct (should be base64-encoded strings)
for device_uuid, serialized_data in ipc_handles_serialized.items():
assert isinstance(serialized_data, str), "Should be base64-encoded string"
# Verify it can be decoded
decoded = base64.b64decode(serialized_data)
# Verify it can be deserialized with RestrictedUnpickler
deserialized = RestrictedUnpickler(io.BytesIO(decoded)).load()
assert isinstance(deserialized, list), "Should deserialize to list"
# Update weights using the serialized format
llm._collective_rpc("update_weights", (ipc_handles_serialized,))
# Finalize the update weights
llm._collective_rpc("update_weights", (None,))
# Verify generation works with updated weights
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
compare_logits(llm_logits, ref_logits)
print("✓ LLM update_weights with serialized handles (RestrictedUnpickler) works!")
@pytest.mark.parametrize(
"model_dir",
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],