mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
29 lines
785 B
Python
29 lines
785 B
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
|
|
|
|
|
|
def make_weak_ref(x):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
return convert_to_torch_tensor(
|
|
TensorWrapper(x.data_ptr(), x.dtype, x.shape)) if x.is_cuda else x
|
|
elif isinstance(x, tuple):
|
|
return tuple(make_weak_ref(i) for i in x)
|
|
elif isinstance(x, list):
|
|
return [make_weak_ref(i) for i in x]
|
|
elif isinstance(x, dict):
|
|
return {k: make_weak_ref(v) for k, v in x.items()}
|
|
elif isinstance(x, (int, float, bool)):
|
|
return x
|
|
else:
|
|
raise TypeError(f"Invalid type {type(x)} to make weak ref")
|
|
|
|
|
|
@dataclass
|
|
class Fp4QuantizedTensor:
|
|
fp4_tensor: torch.Tensor
|
|
scaling_factor: torch.Tensor
|