mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
211 lines
8.0 KiB
Python
211 lines
8.0 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import ctypes
|
|
from ctypes import (
|
|
CFUNCTYPE,
|
|
POINTER,
|
|
c_int,
|
|
c_int64,
|
|
c_size_t,
|
|
c_uint8,
|
|
c_uint16,
|
|
c_void_p,
|
|
pointer,
|
|
)
|
|
|
|
import torch
|
|
|
|
|
|
# Define data structures required for DLPack
|
|
class DLDataType(ctypes.Structure):
|
|
_fields_ = [
|
|
("code", c_uint8), # Data type code, e.g., 2 for float
|
|
("bits", c_uint8), # Number of bits per element, e.g., 32
|
|
("lanes", c_uint16), # Number of lanes, usually 1
|
|
]
|
|
|
|
|
|
class DLDevice(ctypes.Structure):
|
|
_fields_ = [
|
|
("device_type", c_int), # Device type, typically 2 for GPU
|
|
("device_id", c_int), # Device ID, usually 0 for default GPU
|
|
]
|
|
|
|
|
|
class DLTensor(ctypes.Structure):
|
|
_fields_ = [
|
|
("data", c_void_p), # Data pointer
|
|
("device", DLDevice), # Device information
|
|
("ndim", c_int), # Number of dimensions
|
|
("dtype", DLDataType), # Data type
|
|
("shape", POINTER(c_int64)), # Pointer to array of dimension sizes
|
|
(
|
|
"strides",
|
|
POINTER(c_int64),
|
|
), # Pointer to strides array (can be NULL for default contiguous layout)
|
|
("byte_offset", c_size_t), # Byte offset (usually 0)
|
|
]
|
|
|
|
|
|
# Deleter type for DLManagedTensor
|
|
DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(ctypes.c_void_p)) # Not used directly here
|
|
|
|
|
|
# Define DLManagedTensor structure, with deleter prototype void(*deleter)(DLManagedTensor*)
|
|
class DLManagedTensor(ctypes.Structure):
|
|
pass
|
|
|
|
|
|
DLManagedTensor._fields_ = [
|
|
("dl_tensor", DLTensor),
|
|
("manager_ctx", c_void_p),
|
|
("deleter", CFUNCTYPE(None, POINTER(DLManagedTensor))),
|
|
]
|
|
|
|
|
|
# A no-op deleter that doesn't perform any operation
|
|
@CFUNCTYPE(None, POINTER(DLManagedTensor))
|
|
def no_op_deleter(dmt_ptr):
|
|
# You can also call cudaFree here if you want to free memory when the tensor's lifecycle ends
|
|
pass
|
|
|
|
|
|
# Wrapper class to prevent Python garbage collection of DLPack-related objects
|
|
class CapsuleWrapper:
|
|
"""
|
|
A wrapper class that holds references to the PyCapsule and its associated data.
|
|
|
|
This class prevents Python's garbage collector from collecting the shape_array and
|
|
managed_tensor objects while the capsule is still in use. It serves as a container
|
|
to maintain the lifecycle of all DLPack-related objects.
|
|
"""
|
|
|
|
def __init__(self, capsule, shape_array, managed_tensor):
|
|
"""
|
|
Initialize the CapsuleWrapper with the necessary objects.
|
|
|
|
Parameters:
|
|
capsule: The PyCapsule object that follows the DLPack protocol
|
|
shape_array: The array containing tensor shape information
|
|
managed_tensor: The DLManagedTensor instance that the capsule points to
|
|
"""
|
|
self.capsule = capsule # The main PyCapsule object that can be passed to other libraries
|
|
self._shape_array = shape_array # Keep reference to prevent garbage collection
|
|
self._managed_tensor = managed_tensor # Keep reference to prevent garbage collection
|
|
|
|
|
|
def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id):
|
|
"""
|
|
Parameters:
|
|
ptr: GPU memory address obtained from cudaMalloc (Python int)
|
|
segment_size: Memory size of each segments in bytes
|
|
segment_stride: Memory stride size between segments in bytes
|
|
num_segments: Number of segments
|
|
torch_dtype: torch dtype
|
|
dev_id: device id.
|
|
Returns:
|
|
A PyCapsule object compliant with DLPack specification, which can be directly converted to a
|
|
tensor using torch.utils.dlpack.from_dlpack
|
|
"""
|
|
bits_per_elements = 0
|
|
dldata_type_code = 0
|
|
# refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160
|
|
if torch_dtype in [
|
|
torch.float8_e5m2,
|
|
torch.float8_e4m3fn,
|
|
torch.bfloat16,
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.float64,
|
|
]:
|
|
bits_per_elements = torch.finfo(torch_dtype).bits
|
|
dldata_type_code = 2
|
|
elif torch_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
|
bits_per_elements = torch.iinfo(torch_dtype).bits
|
|
dldata_type_code = 0
|
|
elif torch_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
|
|
bits_per_elements = torch.iinfo(torch_dtype).bits
|
|
dldata_type_code = 1
|
|
else:
|
|
raise NotImplementedError(torch_dtype)
|
|
bytes_per_element = bits_per_elements // 8
|
|
# Allocate space for shape (constructing a one-dimensional tensor here)
|
|
ShapeArrayType = c_int64 * 2 # 1 dimension
|
|
shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element)
|
|
stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1)
|
|
# Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed)
|
|
device = DLDevice(device_type=2, device_id=dev_id)
|
|
# Set data type
|
|
dtype = DLDataType(code=dldata_type_code, bits=bits_per_elements, lanes=1)
|
|
# Construct DLTensor
|
|
dltensor = DLTensor()
|
|
dltensor.data = c_void_p(ptr)
|
|
dltensor.device = device
|
|
dltensor.ndim = 2
|
|
dltensor.dtype = dtype
|
|
dltensor.shape = ctypes.cast(shape_array, POINTER(c_int64))
|
|
dltensor.strides = ctypes.cast(stride_array, POINTER(c_int64))
|
|
dltensor.byte_offset = 0
|
|
# Construct DLManagedTensor and set deleter to no-op (you can also call cudaFree here)
|
|
managed_tensor = DLManagedTensor()
|
|
managed_tensor.dl_tensor = dltensor
|
|
managed_tensor.manager_ctx = None
|
|
managed_tensor.deleter = no_op_deleter
|
|
# Note: Must ensure that shape_array and managed_tensor are not garbage collected by Python,
|
|
# A simple way is to attach them to the capsule object.
|
|
# Call PyCapsule_New to create capsule
|
|
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
|
|
PyCapsule_New.restype = c_void_p
|
|
PyCapsule_New.argtypes = [c_void_p, ctypes.c_char_p, c_void_p]
|
|
# Allocate managed_tensor on the heap (note that pointer returns a pointer)
|
|
managed_tensor_ptr = pointer(managed_tensor)
|
|
# The capsule name must be "dltensor", as required by the DLPack specification
|
|
capsule_ptr = PyCapsule_New(managed_tensor_ptr, b"dltensor", None)
|
|
# Convert capsule_ptr to Python object
|
|
capsule = ctypes.cast(capsule_ptr, ctypes.py_object).value
|
|
# To prevent shape_array and managed_tensor from being collected, we attach them as attributes to the capsule
|
|
capsule_wrapper = CapsuleWrapper(capsule, shape_array, managed_tensor)
|
|
return capsule_wrapper
|
|
|
|
|
|
def pack_strided_memory(
|
|
ptr: int, segment_size: int, segment_stride: int, num_segments: int, dtype: torch.dtype, dev_id
|
|
):
|
|
"""
|
|
Pack GPU memory into a PyTorch tensor with specified stride.
|
|
|
|
Parameters:
|
|
ptr: GPU memory address obtained from cudaMalloc
|
|
segment_size: Memory size of each segment in bytes
|
|
segment_stride: Memory stride size between segments in bytes
|
|
num_segments: Number of segments
|
|
dtype: PyTorch data type for the resulting tensor
|
|
dev_id: CUDA device ID
|
|
|
|
Returns:
|
|
PyTorch tensor that references the provided memory
|
|
|
|
Note:
|
|
This function creates a new DLPack capsule each time it's called,
|
|
even with the same pointer. Each capsule is consumed only once.
|
|
"""
|
|
# Create a new capsule each time
|
|
capsule_wrapper = create_dlpack_capsule(
|
|
ptr, segment_size, segment_stride, num_segments, dtype, dev_id
|
|
)
|
|
torch_tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule)
|
|
torch_tensor._capsule_wrapper = capsule_wrapper
|
|
return torch_tensor
|