mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 2f17268409 into 6df2c8a074
This commit is contained in:
commit
53019ad351
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import io
|
||||
import pickle # nosec B403
|
||||
from typing import Optional
|
||||
|
||||
@ -10,6 +11,44 @@ from tensorrt_llm._torch.utils import get_device_uuid
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
"""Restricted unpickler that only allows safe types.
|
||||
|
||||
This prevents arbitrary code execution by restricting what can be deserialized.
|
||||
Only allows: list, tuple, str, int, float, bool, bytes, and torch-related types
|
||||
needed for tensor reconstruction.
|
||||
"""
|
||||
|
||||
# Allowed modules and their classes
|
||||
ALLOWED_TYPES = {
|
||||
"builtins": {
|
||||
"list",
|
||||
"tuple",
|
||||
"str",
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"bytes",
|
||||
"dict",
|
||||
"NoneType",
|
||||
"type",
|
||||
},
|
||||
}
|
||||
|
||||
def find_class(self, module, name):
|
||||
"""Override to restrict which classes can be unpickled."""
|
||||
# Check if the module is in our allowed list
|
||||
is_torch_module = module.startswith("torch")
|
||||
if is_torch_module or (module in self.ALLOWED_TYPES and name in self.ALLOWED_TYPES[module]):
|
||||
return super().find_class(module, name)
|
||||
|
||||
raise pickle.UnpicklingError(
|
||||
f"Global '{module}.{name}' is forbidden for security reasons. "
|
||||
f"Only basic types (list, tuple, str, int, etc.) and torch tensor types are allowed."
|
||||
f"Module: {module}, Name: {name}"
|
||||
)
|
||||
|
||||
|
||||
class WorkerExtension:
|
||||
"""Worker extension class for extending TensorRT-LLM Ray workers with custom functionality.
|
||||
|
||||
@ -61,9 +100,16 @@ class WorkerExtension:
|
||||
|
||||
serialized_handles = ipc_handles[device_uuid]
|
||||
if isinstance(serialized_handles, str):
|
||||
# Data is base64-encoded pickled bytes - deserialize it
|
||||
# Data is base64-encoded pickled bytes - deserialize it using restricted unpickler
|
||||
logger.info("Deserializing base64-encoded weight handles")
|
||||
all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301
|
||||
decoded_data = base64.b64decode(serialized_handles)
|
||||
all_handles = RestrictedUnpickler(io.BytesIO(decoded_data)).load()
|
||||
|
||||
# Verify the result is a list as expected
|
||||
if not isinstance(all_handles, list):
|
||||
raise ValueError(
|
||||
f"Deserialized data must be a list, got {type(all_handles).__name__} instead"
|
||||
)
|
||||
else:
|
||||
# Data is already in the correct format (backward compatibility)
|
||||
all_handles = serialized_handles
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import pickle
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import pytest
|
||||
@ -8,6 +11,7 @@ from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
|
||||
from tensorrt_llm.llmapi.rlhf_utils import RestrictedUnpickler
|
||||
|
||||
|
||||
class HFModel:
|
||||
@ -71,6 +75,40 @@ class HFModel:
|
||||
|
||||
return ret
|
||||
|
||||
def get_weight_ipc_handles_serialized(
|
||||
self,
|
||||
cuda_device: Optional[List[int]] = None,
|
||||
weight_filter: Optional[Callable[[str], bool]] = None,
|
||||
):
|
||||
"""
|
||||
Get base64-encoded serialized IPC handles for model weights.
|
||||
|
||||
Args:
|
||||
cuda_device: List of CUDA device indices to get weights from
|
||||
weight_filter: Optional function that takes weight name and returns True if weight should be included
|
||||
|
||||
Returns:
|
||||
ret: Dictionary mapping device UUIDs to base64-encoded pickled handles
|
||||
"""
|
||||
ret = {}
|
||||
device_list = list(range(torch.cuda.device_count())) if cuda_device is None else cuda_device
|
||||
|
||||
for device in device_list:
|
||||
all_handles = []
|
||||
for item in self.all_weights[device]:
|
||||
name, p = item
|
||||
# Apply filter if provided
|
||||
if weight_filter is not None and not weight_filter(name):
|
||||
continue
|
||||
handle = reduce_tensor(p)
|
||||
all_handles.append((name, handle))
|
||||
|
||||
# Serialize with base64-encoded pickle
|
||||
serialized = base64.b64encode(pickle.dumps(all_handles)).decode("utf-8")
|
||||
ret[self.device_uuid[device]] = serialized
|
||||
|
||||
return ret
|
||||
|
||||
def generate_batch_incremental(
|
||||
self, original_prompts: List[str], generated_token_ids_list: List[List[int]]
|
||||
):
|
||||
@ -153,6 +191,57 @@ def run_generate(llm, hf_model, prompts, sampling_params):
|
||||
return llm_logits, ref_logits
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
["Qwen2.5-0.5B-Instruct"],
|
||||
)
|
||||
def test_llm_update_weights_with_serialized_handles(model_dir):
|
||||
"""Test LLM update_weights with base64-encoded serialized handles (RestrictedUnpickler)."""
|
||||
model_dir = str(llm_models_root() / model_dir)
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=True, free_gpu_memory_fraction=0.1)
|
||||
|
||||
hf_model = HFModel(model_dir)
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
ray_worker_extension_cls="tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=1,
|
||||
load_format="dummy",
|
||||
pipeline_parallel_size=1,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, return_generation_logits=True)
|
||||
|
||||
# Use the serialized format (base64-encoded pickle)
|
||||
ipc_handles_serialized = hf_model.get_weight_ipc_handles_serialized([0])
|
||||
|
||||
# Verify the format is correct (should be base64-encoded strings)
|
||||
for device_uuid, serialized_data in ipc_handles_serialized.items():
|
||||
assert isinstance(serialized_data, str), "Should be base64-encoded string"
|
||||
# Verify it can be decoded
|
||||
decoded = base64.b64decode(serialized_data)
|
||||
# Verify it can be deserialized with RestrictedUnpickler
|
||||
deserialized = RestrictedUnpickler(io.BytesIO(decoded)).load()
|
||||
assert isinstance(deserialized, list), "Should deserialize to list"
|
||||
|
||||
# Update weights using the serialized format
|
||||
llm._collective_rpc("update_weights", (ipc_handles_serialized,))
|
||||
# Finalize the update weights
|
||||
llm._collective_rpc("update_weights", (None,))
|
||||
|
||||
# Verify generation works with updated weights
|
||||
llm_logits, ref_logits = run_generate(llm, hf_model, prompts, sampling_params)
|
||||
compare_logits(llm_logits, ref_logits)
|
||||
|
||||
print("✓ LLM update_weights with serialized handles (RestrictedUnpickler) works!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir",
|
||||
["Qwen2.5-0.5B-Instruct", "Qwen3/Qwen3-8B", "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user