TensorRT-LLMs/tensorrt_llm/serialization.py
Anish Shanbhag 15de45d782
[TRTLLM-8682][chore] Remove auto_parallel module (#8329)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2025-10-22 20:53:08 -04:00

185 lines
8.3 KiB
Python

import io
# pickle is not secure, but but this whole file is a wrapper to make it
# possible to mitigate the primary risk of code injection via pickle.
import pickle # nosec B403
from functools import partial
# This is an example class (white list) to showcase how to guard serialization with approved classes.
# If a class is needed routinely it should be added into the whitelist. If it is only needed in a single instance
# the class can be added at runtime using register_approved_class.
BASE_EXAMPLE_CLASSES = {
"builtins": [
"Exception", "ValueError", "NotImplementedError", "AttributeError",
"AssertionError", "RuntimeError"
], # each Exception Error class needs to be added explicitly
"collections": ["OrderedDict"],
"datetime": ["timedelta"],
"pathlib": ["PosixPath"],
"llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess"
], # only used in tests
### starting import of torch models classes. They are used in test_llm_multi_gpu.py.
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
"tensorrt_llm._torch.models.modeling_bert":
["BertForSequenceClassification"],
"tensorrt_llm._torch.models.modeling_clip": ["CLIPVisionModel"],
"tensorrt_llm._torch.models.modeling_deepseekv3": ["DeepseekV3ForCausalLM"],
"tensorrt_llm._torch.models.modeling_gemma3": ["Gemma3ForCausalLM"],
"tensorrt_llm._torch.models.modeling_hyperclovax": ["HCXVisionForCausalLM"],
"tensorrt_llm._torch.models.modeling_llama": [
"Eagle3LlamaForCausalLM",
"LlamaForCausalLM",
"Llama4ForCausalLM",
"Llama4ForConditionalGeneration",
],
"tensorrt_llm._torch.models.modeling_llava_next": ["LlavaNextModel"],
"tensorrt_llm._torch.models.modeling_mistral": ["MistralForCausalLM"],
"tensorrt_llm._torch.models.modeling_mixtral": ["MixtralForCausalLM"],
"tensorrt_llm._torch.models.modeling_mllama":
["MllamaForConditionalGeneration"],
"tensorrt_llm._torch.models.modeling_nemotron": ["NemotronForCausalLM"],
"tensorrt_llm._torch.models.modeling_nemotron_h": ["NemotronHForCausalLM"],
"tensorrt_llm._torch.models.modeling_nemotron_nas":
["NemotronNASForCausalLM"],
"tensorrt_llm._torch.models.modeling_qwen":
["Qwen2ForCausalLM", "Qwen2ForProcessRewardModel", "Qwen2ForRewardModel"],
"tensorrt_llm._torch.models.modeling_qwen2vl":
["Qwen2VLModel", "Qwen2_5_VLModel"],
"tensorrt_llm._torch.models.modeling_qwen3": ["Qwen3ForCausalLM"],
"tensorrt_llm._torch.models.modeling_qwen3_moe": ["Qwen3MoeForCausalLM"],
"tensorrt_llm._torch.models.modeling_qwen_moe": ["Qwen2MoeForCausalLM"],
"tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"],
"tensorrt_llm._torch.models.modeling_vila": ["VilaModel"],
"tensorrt_llm._torch.models.modeling_gpt_oss": ["GptOssForCausalLM"],
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
"tensorrt_llm._torch.pyexecutor.llm_request":
["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"],
"tensorrt_llm._torch.speculative.mtp": ["MTPConfig"],
"tensorrt_llm._torch.speculative.interface": ["SpeculativeDecodingMode"],
### ending import of torch models classes
"tensorrt_llm.bindings.executor": [
"BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy",
"ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig",
"ExecutorConfig", "ExtendedRuntimePerfKnobConfig", "Response", "Result",
"FinishReason", "KvCacheConfig", "KvCacheTransferMode",
"KvCacheRetentionConfig",
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
"SchedulerConfig"
],
"tensorrt_llm.builder": ["BuildConfig"],
"tensorrt_llm.disaggregated_params": ["DisaggregatedParams"],
"tensorrt_llm.inputs.multimodal": ["MultimodalInput"],
"tensorrt_llm.executor.postproc_worker": [
"PostprocArgs", "PostprocParams", "PostprocWorkerConfig",
"PostprocWorker.Input", "PostprocWorker.Output"
],
"tensorrt_llm.executor.request": [
"CancellingRequest", "GenerationRequest", "LoRARequest",
"PromptAdapterRequest"
],
"tensorrt_llm.executor.result": [
"CompletionOutput", "DetokenizedGenerationResultBase",
"GenerationResult", "GenerationResultBase", "IterationResult",
"Logprob", "LogProbsResult", "ResponseWrapper"
],
"tensorrt_llm.executor.utils": ["ErrorResponse", "WorkerCommIpcAddrs"],
"tensorrt_llm.executor.worker": ["GenerationExecutorWorker", "worker_main"],
"tensorrt_llm.llmapi.llm_args": [
"_ModelFormatKind", "_ParallelConfig", "CalibConfig",
"CapacitySchedulerPolicy", "KvCacheConfig", "LookaheadDecodingConfig",
"TrtLlmArgs", "SchedulerConfig", "LoadFormat", "DynamicBatchConfig"
],
"tensorrt_llm.llmapi.mpi_session": ["RemoteTask"],
"tensorrt_llm.llmapi.llm_utils":
["CachedModelLoader._node_build_task", "LlmBuildStats"],
"tensorrt_llm.llmapi.tokenizer": ["TransformersTokenizer"],
"tensorrt_llm.lora_manager": ["LoraConfig"],
"tensorrt_llm.mapping": ["Mapping"],
"tensorrt_llm.models.modeling_utils":
["QuantConfig", "SpeculativeDecodingMode"],
"tensorrt_llm.plugin.plugin": ["PluginConfig"],
"tensorrt_llm.sampling_params":
["SamplingParams", "GuidedDecodingParams", "GreedyDecodingParams"],
"tensorrt_llm.serve.postprocess_handlers": [
"chat_response_post_processor", "chat_stream_post_processor",
"completion_stream_post_processor",
"completion_response_post_processor", "CompletionPostprocArgs",
"ChatPostprocArgs"
],
"torch._utils": ["_rebuild_tensor_v2"],
"torch.storage": ["_load_from_bytes"],
}
def _register_class(dict, obj):
name = getattr(obj, '__qualname__', None)
if name is None:
name = obj.__name__
module = pickle.whichmodule(obj, name)
if module not in dict.keys():
dict[module] = []
dict[module].append(name)
def register_approved_class(obj):
_register_class(BASE_EXAMPLE_CLASSES, obj)
class Unpickler(pickle.Unpickler):
def __init__(self, *args, approved_imports={}, **kwargs):
super().__init__(*args, **kwargs)
self.approved_imports = approved_imports
# only import approved classes, this is the security boundary.
def find_class(self, module, name):
if name not in self.approved_imports.get(module, []):
# If this is triggered when it shouldn't be, then the module
# and class should be added to the approved_imports. If the class
# is being used as part of a routine scenario, then it should be added
# to the appropriate base classes above.
raise ValueError(f"Import {module} | {name} is not allowed")
return super().find_class(module, name)
# these are taken from the pickle module to allow for this to be a drop in replacement
# source: https://github.com/python/cpython/blob/3.13/Lib/pickle.py
# dump and dumps are just aliases because the serucity controls are on the deserialization
# side. However they are included here so that in the future if a more secure serialization
# soliton is identified, it can be added with less impact to the rest of the application.
dump = partial(pickle.dump, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301
dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301
def load(file,
*,
fix_imports=True,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
return Unpickler(file,
fix_imports=fix_imports,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()
def loads(s,
/,
*,
fix_imports=True,
encoding="ASCII",
errors="strict",
buffers=None,
approved_imports={}):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file,
fix_imports=fix_imports,
buffers=buffers,
encoding=encoding,
errors=errors,
approved_imports=approved_imports).load()