TensorRT-LLMs/tensorrt_llm/_torch/utils.py
Kaiyu Xie 2ea17cdad2
Update TensorRT-LLM (#2792)
* Update TensorRT-LLM

---------

Co-authored-by: jlee <jungmoolee@clika.io>
2025-02-18 21:27:39 +08:00

21 lines
649 B
Python

import torch
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
def make_weak_ref(x):
if isinstance(x, torch.Tensor):
return convert_to_torch_tensor(
TensorWrapper(x.data_ptr(), x.dtype, x.shape)) if x.is_cuda else x
elif isinstance(x, tuple):
return tuple(make_weak_ref(i) for i in x)
elif isinstance(x, list):
return [make_weak_ref(i) for i in x]
elif isinstance(x, dict):
return {k: make_weak_ref(v) for k, v in x.items()}
elif isinstance(x, (int, float, bool)):
return x
else:
raise TypeError(f"Invalid type {type(x)} to make weak ref")