mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
185 lines
8.3 KiB
Python
185 lines
8.3 KiB
Python
import io
|
|
# pickle is not secure, but but this whole file is a wrapper to make it
|
|
# possible to mitigate the primary risk of code injection via pickle.
|
|
import pickle # nosec B403
|
|
from functools import partial
|
|
|
|
# This is an example class (white list) to showcase how to guard serialization with approved classes.
|
|
# If a class is needed routinely it should be added into the whitelist. If it is only needed in a single instance
|
|
# the class can be added at runtime using register_approved_class.
|
|
BASE_EXAMPLE_CLASSES = {
|
|
"builtins": [
|
|
"Exception", "ValueError", "NotImplementedError", "AttributeError",
|
|
"AssertionError", "RuntimeError"
|
|
], # each Exception Error class needs to be added explicitly
|
|
"collections": ["OrderedDict"],
|
|
"datetime": ["timedelta"],
|
|
"pathlib": ["PosixPath"],
|
|
"llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess"
|
|
], # only used in tests
|
|
### starting import of torch models classes. They are used in test_llm_multi_gpu.py.
|
|
"tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"],
|
|
"tensorrt_llm._torch.models.modeling_bert":
|
|
["BertForSequenceClassification"],
|
|
"tensorrt_llm._torch.models.modeling_clip": ["CLIPVisionModel"],
|
|
"tensorrt_llm._torch.models.modeling_deepseekv3": ["DeepseekV3ForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_gemma3": ["Gemma3ForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_hyperclovax": ["HCXVisionForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_llama": [
|
|
"Eagle3LlamaForCausalLM",
|
|
"LlamaForCausalLM",
|
|
"Llama4ForCausalLM",
|
|
"Llama4ForConditionalGeneration",
|
|
],
|
|
"tensorrt_llm._torch.models.modeling_llava_next": ["LlavaNextModel"],
|
|
"tensorrt_llm._torch.models.modeling_mistral": ["MistralForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_mixtral": ["MixtralForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_mllama":
|
|
["MllamaForConditionalGeneration"],
|
|
"tensorrt_llm._torch.models.modeling_nemotron": ["NemotronForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_nemotron_h": ["NemotronHForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_nemotron_nas":
|
|
["NemotronNASForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_qwen":
|
|
["Qwen2ForCausalLM", "Qwen2ForProcessRewardModel", "Qwen2ForRewardModel"],
|
|
"tensorrt_llm._torch.models.modeling_qwen2vl":
|
|
["Qwen2VLModel", "Qwen2_5_VLModel"],
|
|
"tensorrt_llm._torch.models.modeling_qwen3": ["Qwen3ForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_qwen3_moe": ["Qwen3MoeForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_qwen_moe": ["Qwen2MoeForCausalLM"],
|
|
"tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"],
|
|
"tensorrt_llm._torch.models.modeling_vila": ["VilaModel"],
|
|
"tensorrt_llm._torch.models.modeling_gpt_oss": ["GptOssForCausalLM"],
|
|
"tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"],
|
|
"tensorrt_llm._torch.pyexecutor.llm_request":
|
|
["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"],
|
|
"tensorrt_llm._torch.speculative.mtp": ["MTPConfig"],
|
|
"tensorrt_llm._torch.speculative.interface": ["SpeculativeDecodingMode"],
|
|
### ending import of torch models classes
|
|
"tensorrt_llm.bindings.executor": [
|
|
"BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy",
|
|
"ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig",
|
|
"ExecutorConfig", "ExtendedRuntimePerfKnobConfig", "Response", "Result",
|
|
"FinishReason", "KvCacheConfig", "KvCacheTransferMode",
|
|
"KvCacheRetentionConfig",
|
|
"KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig",
|
|
"SchedulerConfig"
|
|
],
|
|
"tensorrt_llm.builder": ["BuildConfig"],
|
|
"tensorrt_llm.disaggregated_params": ["DisaggregatedParams"],
|
|
"tensorrt_llm.inputs.multimodal": ["MultimodalInput"],
|
|
"tensorrt_llm.executor.postproc_worker": [
|
|
"PostprocArgs", "PostprocParams", "PostprocWorkerConfig",
|
|
"PostprocWorker.Input", "PostprocWorker.Output"
|
|
],
|
|
"tensorrt_llm.executor.request": [
|
|
"CancellingRequest", "GenerationRequest", "LoRARequest",
|
|
"PromptAdapterRequest"
|
|
],
|
|
"tensorrt_llm.executor.result": [
|
|
"CompletionOutput", "DetokenizedGenerationResultBase",
|
|
"GenerationResult", "GenerationResultBase", "IterationResult",
|
|
"Logprob", "LogProbsResult", "ResponseWrapper"
|
|
],
|
|
"tensorrt_llm.executor.utils": ["ErrorResponse", "WorkerCommIpcAddrs"],
|
|
"tensorrt_llm.executor.worker": ["GenerationExecutorWorker", "worker_main"],
|
|
"tensorrt_llm.llmapi.llm_args": [
|
|
"_ModelFormatKind", "_ParallelConfig", "CalibConfig",
|
|
"CapacitySchedulerPolicy", "KvCacheConfig", "LookaheadDecodingConfig",
|
|
"TrtLlmArgs", "SchedulerConfig", "LoadFormat", "DynamicBatchConfig"
|
|
],
|
|
"tensorrt_llm.llmapi.mpi_session": ["RemoteTask"],
|
|
"tensorrt_llm.llmapi.llm_utils":
|
|
["CachedModelLoader._node_build_task", "LlmBuildStats"],
|
|
"tensorrt_llm.llmapi.tokenizer": ["TransformersTokenizer"],
|
|
"tensorrt_llm.lora_manager": ["LoraConfig"],
|
|
"tensorrt_llm.mapping": ["Mapping"],
|
|
"tensorrt_llm.models.modeling_utils":
|
|
["QuantConfig", "SpeculativeDecodingMode"],
|
|
"tensorrt_llm.plugin.plugin": ["PluginConfig"],
|
|
"tensorrt_llm.sampling_params":
|
|
["SamplingParams", "GuidedDecodingParams", "GreedyDecodingParams"],
|
|
"tensorrt_llm.serve.postprocess_handlers": [
|
|
"chat_response_post_processor", "chat_stream_post_processor",
|
|
"completion_stream_post_processor",
|
|
"completion_response_post_processor", "CompletionPostprocArgs",
|
|
"ChatPostprocArgs"
|
|
],
|
|
"torch._utils": ["_rebuild_tensor_v2"],
|
|
"torch.storage": ["_load_from_bytes"],
|
|
}
|
|
|
|
|
|
def _register_class(dict, obj):
|
|
name = getattr(obj, '__qualname__', None)
|
|
if name is None:
|
|
name = obj.__name__
|
|
module = pickle.whichmodule(obj, name)
|
|
if module not in dict.keys():
|
|
dict[module] = []
|
|
dict[module].append(name)
|
|
|
|
|
|
def register_approved_class(obj):
|
|
_register_class(BASE_EXAMPLE_CLASSES, obj)
|
|
|
|
|
|
class Unpickler(pickle.Unpickler):
|
|
|
|
def __init__(self, *args, approved_imports={}, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.approved_imports = approved_imports
|
|
|
|
# only import approved classes, this is the security boundary.
|
|
def find_class(self, module, name):
|
|
if name not in self.approved_imports.get(module, []):
|
|
# If this is triggered when it shouldn't be, then the module
|
|
# and class should be added to the approved_imports. If the class
|
|
# is being used as part of a routine scenario, then it should be added
|
|
# to the appropriate base classes above.
|
|
raise ValueError(f"Import {module} | {name} is not allowed")
|
|
return super().find_class(module, name)
|
|
|
|
|
|
# these are taken from the pickle module to allow for this to be a drop in replacement
|
|
# source: https://github.com/python/cpython/blob/3.13/Lib/pickle.py
|
|
# dump and dumps are just aliases because the serucity controls are on the deserialization
|
|
# side. However they are included here so that in the future if a more secure serialization
|
|
# soliton is identified, it can be added with less impact to the rest of the application.
|
|
dump = partial(pickle.dump, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301
|
|
dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301
|
|
|
|
|
|
def load(file,
|
|
*,
|
|
fix_imports=True,
|
|
encoding="ASCII",
|
|
errors="strict",
|
|
buffers=None,
|
|
approved_imports={}):
|
|
return Unpickler(file,
|
|
fix_imports=fix_imports,
|
|
buffers=buffers,
|
|
encoding=encoding,
|
|
errors=errors,
|
|
approved_imports=approved_imports).load()
|
|
|
|
|
|
def loads(s,
|
|
/,
|
|
*,
|
|
fix_imports=True,
|
|
encoding="ASCII",
|
|
errors="strict",
|
|
buffers=None,
|
|
approved_imports={}):
|
|
if isinstance(s, str):
|
|
raise TypeError("Can't load pickle from unicode string")
|
|
file = io.BytesIO(s)
|
|
return Unpickler(file,
|
|
fix_imports=fix_imports,
|
|
buffers=buffers,
|
|
encoding=encoding,
|
|
errors=errors,
|
|
approved_imports=approved_imports).load()
|