TensorRT-LLMs/tensorrt_llm/_dlpack_utils.py
2ez4bz dc52b67492
linting(python): Enable ruff on more files (wave 1/N) (#5140)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
2025-06-14 19:19:34 +08:00

211 lines
8.0 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from ctypes import (
CFUNCTYPE,
POINTER,
c_int,
c_int64,
c_size_t,
c_uint8,
c_uint16,
c_void_p,
pointer,
)
import torch
# Define data structures required for DLPack
class DLDataType(ctypes.Structure):
_fields_ = [
("code", c_uint8), # Data type code, e.g., 2 for float
("bits", c_uint8), # Number of bits per element, e.g., 32
("lanes", c_uint16), # Number of lanes, usually 1
]
class DLDevice(ctypes.Structure):
_fields_ = [
("device_type", c_int), # Device type, typically 2 for GPU
("device_id", c_int), # Device ID, usually 0 for default GPU
]
class DLTensor(ctypes.Structure):
_fields_ = [
("data", c_void_p), # Data pointer
("device", DLDevice), # Device information
("ndim", c_int), # Number of dimensions
("dtype", DLDataType), # Data type
("shape", POINTER(c_int64)), # Pointer to array of dimension sizes
(
"strides",
POINTER(c_int64),
), # Pointer to strides array (can be NULL for default contiguous layout)
("byte_offset", c_size_t), # Byte offset (usually 0)
]
# Deleter type for DLManagedTensor
DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(ctypes.c_void_p)) # Not used directly here
# Define DLManagedTensor structure, with deleter prototype void(*deleter)(DLManagedTensor*)
class DLManagedTensor(ctypes.Structure):
pass
DLManagedTensor._fields_ = [
("dl_tensor", DLTensor),
("manager_ctx", c_void_p),
("deleter", CFUNCTYPE(None, POINTER(DLManagedTensor))),
]
# A no-op deleter that doesn't perform any operation
@CFUNCTYPE(None, POINTER(DLManagedTensor))
def no_op_deleter(dmt_ptr):
# You can also call cudaFree here if you want to free memory when the tensor's lifecycle ends
pass
# Wrapper class to prevent Python garbage collection of DLPack-related objects
class CapsuleWrapper:
"""
A wrapper class that holds references to the PyCapsule and its associated data.
This class prevents Python's garbage collector from collecting the shape_array and
managed_tensor objects while the capsule is still in use. It serves as a container
to maintain the lifecycle of all DLPack-related objects.
"""
def __init__(self, capsule, shape_array, managed_tensor):
"""
Initialize the CapsuleWrapper with the necessary objects.
Parameters:
capsule: The PyCapsule object that follows the DLPack protocol
shape_array: The array containing tensor shape information
managed_tensor: The DLManagedTensor instance that the capsule points to
"""
self.capsule = capsule # The main PyCapsule object that can be passed to other libraries
self._shape_array = shape_array # Keep reference to prevent garbage collection
self._managed_tensor = managed_tensor # Keep reference to prevent garbage collection
def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id):
"""
Parameters:
ptr: GPU memory address obtained from cudaMalloc (Python int)
segment_size: Memory size of each segments in bytes
segment_stride: Memory stride size between segments in bytes
num_segments: Number of segments
torch_dtype: torch dtype
dev_id: device id.
Returns:
A PyCapsule object compliant with DLPack specification, which can be directly converted to a
tensor using torch.utils.dlpack.from_dlpack
"""
bits_per_elements = 0
dldata_type_code = 0
# refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160
if torch_dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.bfloat16,
torch.float16,
torch.float32,
torch.float64,
]:
bits_per_elements = torch.finfo(torch_dtype).bits
dldata_type_code = 2
elif torch_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
bits_per_elements = torch.iinfo(torch_dtype).bits
dldata_type_code = 0
elif torch_dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]:
bits_per_elements = torch.iinfo(torch_dtype).bits
dldata_type_code = 1
else:
raise NotImplementedError(torch_dtype)
bytes_per_element = bits_per_elements // 8
# Allocate space for shape (constructing a one-dimensional tensor here)
ShapeArrayType = c_int64 * 2 # 1 dimension
shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element)
stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1)
# Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed)
device = DLDevice(device_type=2, device_id=dev_id)
# Set data type
dtype = DLDataType(code=dldata_type_code, bits=bits_per_elements, lanes=1)
# Construct DLTensor
dltensor = DLTensor()
dltensor.data = c_void_p(ptr)
dltensor.device = device
dltensor.ndim = 2
dltensor.dtype = dtype
dltensor.shape = ctypes.cast(shape_array, POINTER(c_int64))
dltensor.strides = ctypes.cast(stride_array, POINTER(c_int64))
dltensor.byte_offset = 0
# Construct DLManagedTensor and set deleter to no-op (you can also call cudaFree here)
managed_tensor = DLManagedTensor()
managed_tensor.dl_tensor = dltensor
managed_tensor.manager_ctx = None
managed_tensor.deleter = no_op_deleter
# Note: Must ensure that shape_array and managed_tensor are not garbage collected by Python,
# A simple way is to attach them to the capsule object.
# Call PyCapsule_New to create capsule
PyCapsule_New = ctypes.pythonapi.PyCapsule_New
PyCapsule_New.restype = c_void_p
PyCapsule_New.argtypes = [c_void_p, ctypes.c_char_p, c_void_p]
# Allocate managed_tensor on the heap (note that pointer returns a pointer)
managed_tensor_ptr = pointer(managed_tensor)
# The capsule name must be "dltensor", as required by the DLPack specification
capsule_ptr = PyCapsule_New(managed_tensor_ptr, b"dltensor", None)
# Convert capsule_ptr to Python object
capsule = ctypes.cast(capsule_ptr, ctypes.py_object).value
# To prevent shape_array and managed_tensor from being collected, we attach them as attributes to the capsule
capsule_wrapper = CapsuleWrapper(capsule, shape_array, managed_tensor)
return capsule_wrapper
def pack_strided_memory(
ptr: int, segment_size: int, segment_stride: int, num_segments: int, dtype: torch.dtype, dev_id
):
"""
Pack GPU memory into a PyTorch tensor with specified stride.
Parameters:
ptr: GPU memory address obtained from cudaMalloc
segment_size: Memory size of each segment in bytes
segment_stride: Memory stride size between segments in bytes
num_segments: Number of segments
dtype: PyTorch data type for the resulting tensor
dev_id: CUDA device ID
Returns:
PyTorch tensor that references the provided memory
Note:
This function creates a new DLPack capsule each time it's called,
even with the same pointer. Each capsule is consumed only once.
"""
# Create a new capsule each time
capsule_wrapper = create_dlpack_capsule(
ptr, segment_size, segment_stride, num_segments, dtype, dev_id
)
torch_tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule)
torch_tensor._capsule_wrapper = capsule_wrapper
return torch_tensor