linting(python): Enable ruff on more files (wave 1/N) (#5140)

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
2ez4bz 2025-06-14 04:19:34 -07:00 committed by GitHub
parent 0b60da2c45
commit dc52b67492
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1114 additions and 1010 deletions

View File

@ -71,9 +71,9 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
files: ".*/auto_deploy/.*"
pass_filenames: false
- id: ruff-format
files: ".*/auto_deploy/.*"
pass_filenames: false
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
hooks:

View File

@ -10,14 +10,54 @@ build-backend = "setuptools.build_meta"
####################################################################################################
[tool.isort]
line_length = 80
extend_skip_glob = ["**/auto_deploy/**"]
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
# is necessary.
extend_skip_glob = [
"**/auto_deploy/**",
"tensorrt_llm/_common.py",
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
"tensorrt_llm/logger.py",
"tensorrt_llm/lora_manager.py",
"tensorrt_llm/module.py",
"tensorrt_llm/moe_config.py",
"tensorrt_llm/profiler.py",
"tensorrt_llm/prompt_adapter_manager.py",
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
]
[tool.yapf]
based_on_style = "pep8"
column_limit = 80
[tool.yapfignore]
ignore_patterns = ["**/auto_deploy/**"]
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
# is necessary.
ignore_patterns = [
"**/auto_deploy/**",
"tensorrt_llm/_common.py",
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
"tensorrt_llm/logger.py",
"tensorrt_llm/lora_manager.py",
"tensorrt_llm/module.py",
"tensorrt_llm/moe_config.py",
"tensorrt_llm/profiler.py",
"tensorrt_llm/prompt_adapter_manager.py",
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
]
[tool.codespell]
skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json"
@ -28,7 +68,27 @@ ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te
in-place = true
remove_all_unused_imports = true
remove_unused_variables = true
exclude = ["**/auto_deploy/**"]
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
# is necessary.
exclude = [
"**/auto_deploy/**",
"tensorrt_llm/_common.py",
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
"tensorrt_llm/logger.py",
"tensorrt_llm/lora_manager.py",
"tensorrt_llm/module.py",
"tensorrt_llm/moe_config.py",
"tensorrt_llm/profiler.py",
"tensorrt_llm/prompt_adapter_manager.py",
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
]
####################################################################################################
@ -44,6 +104,33 @@ include = [
"**/auto_deploy/**/*.py",
"**/auto_deploy/**/*.pyi",
"**/auto_deploy/**/*.ipynb",
# Progressively enable ruff on all the repo to keep individual changes reasonably-sized, and
# keep merge conflicts manageable.
# Since keeping both `yapf` and `ruff` makes no sense (given that their formatting philosophies
# are quite different), we should move towards removing one in favor of the other. ruff's
# formatting mirrors black's, and both are much more widely adopted than yapf. ruff is also
# orders of magnitude faster, so we should move to deprecate `yapf`.
# In the transition period, we should keep the `ignore_patterns` in `[tool.yapfignore]` in sync
# with the below, so that both pre-commit hooks can complete successfully.
"tensorrt_llm/_common.py",
"tensorrt_llm/_dlpack_utils.py",
"tensorrt_llm/_ipc_utils.py",
"tensorrt_llm/_mnnvl_utils.py",
"tensorrt_llm/disaggregated_params.py",
"tensorrt_llm/engine.py",
"tensorrt_llm/graph_rewriting.py",
"tensorrt_llm/logger.py",
"tensorrt_llm/lora_manager.py",
"tensorrt_llm/module.py",
"tensorrt_llm/moe_config.py",
"tensorrt_llm/profiler.py",
"tensorrt_llm/prompt_adapter_manager.py",
"tensorrt_llm/python_plugin.py",
"tensorrt_llm/sampling_params.py",
"tensorrt_llm/top_model_mixin.py",
]
exclude = [
"3rdparty/**",
]

View File

@ -54,10 +54,10 @@ def _init(log_level: object = None) -> None:
logger.set_level(log_level)
if os.getenv("TRT_LLM_NO_LIB_INIT", "0") == "1":
logger.info('Skipping TensorRT-LLM init.')
logger.info("Skipping TensorRT-LLM init.")
return
logger.info('Starting TensorRT-LLM init.')
logger.info("Starting TensorRT-LLM init.")
# load plugin lib
_load_plugin_lib()
@ -65,24 +65,30 @@ def _init(log_level: object = None) -> None:
# load FT decoder layer and torch custom ops
project_dir = str(Path(__file__).parent.absolute())
if platform.system() == "Windows":
ft_decoder_lib = project_dir + '/libs/th_common.dll'
ft_decoder_lib = project_dir + "/libs/th_common.dll"
else:
ft_decoder_lib = project_dir + '/libs/libth_common.so'
ft_decoder_lib = project_dir + "/libs/libth_common.so"
try:
torch.classes.load_library(ft_decoder_lib)
from ._torch.custom_ops import _register_fake
_register_fake()
except Exception as e:
msg = '\nFATAL: Decoding operators failed to load. This may be caused by the incompatibility between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM.'
msg = (
"\nFATAL: Decoding operators failed to load. This may be caused by an incompatibility "
"between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM."
)
raise ImportError(str(e) + msg)
MpiComm.local_init()
logger.info('TensorRT-LLM inited.')
logger.info("TensorRT-LLM inited.")
def default_net() -> Network:
assert net, "Use builder to create network first, and use `set_network` or `net_guard` to set it to default"
assert net, (
"Use builder to create network first, and use `set_network` or `net_guard` to set it to default"
)
return net
@ -111,29 +117,29 @@ def precision(dtype):
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
logger.info(f"Serializing engine to {path}...")
tik = time.time()
if isinstance(engine, trt.ICudaEngine):
engine = engine.serialize()
with open(path, 'wb') as f:
with open(path, "wb") as f:
f.write(engine)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
logger.info(f"Engine serialized. Total time: {t}")
def deserialize_engine(path):
runtime = trt.Runtime(logger.trt_logger)
with open(path, 'rb') as f:
logger.info(f'Loading engine from {path}...')
with open(path, "rb") as f:
logger.info(f"Loading engine from {path}...")
tik = time.time()
engine = runtime.deserialize_cuda_engine(f.read())
assert engine is not None
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine loaded. Total time: {t}')
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
logger.info(f"Engine loaded. Total time: {t}")
return engine
@ -149,33 +155,32 @@ _field_dtype_to_np_dtype_dict = {
def field_dtype_to_np_dtype(dtype):
ret = _field_dtype_to_np_dtype_dict.get(dtype)
assert ret is not None, f'Unsupported dtype: {dtype}'
assert ret is not None, f"Unsupported dtype: {dtype}"
return ret
def convert_capsule_to_void_p(capsule):
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [
ctypes.py_object, ctypes.c_char_p
]
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
return ctypes.pythonapi.PyCapsule_GetPointer(capsule, None)
def get_nparray_from_void_p(void_pointer, elem_size, field_dtype):
ctypes.pythonapi.PyMemoryView_FromMemory.restype = ctypes.py_object
ctypes.pythonapi.PyMemoryView_FromMemory.argtypes = [
ctypes.c_char_p, ctypes.c_ssize_t, ctypes.c_int
ctypes.c_char_p,
ctypes.c_ssize_t,
ctypes.c_int,
]
logger.info(
f'get_nparray: pointer = {void_pointer}, elem_size = {elem_size}')
logger.info(f"get_nparray: pointer = {void_pointer}, elem_size = {elem_size}")
char_pointer = ctypes.cast(void_pointer, ctypes.POINTER(ctypes.c_char))
np_dtype = field_dtype_to_np_dtype(field_dtype)
buf_bytes = elem_size * np.dtype(np_dtype).itemsize
logger.info(f'get_nparray: buf_bytes = {buf_bytes}')
logger.info(f"get_nparray: buf_bytes = {buf_bytes}")
mem_view = ctypes.pythonapi.PyMemoryView_FromMemory(
char_pointer, buf_bytes, 0) # number 0 represents PyBUF_READ
logger.info(
f'get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}')
char_pointer, buf_bytes, 0
) # number 0 represents PyBUF_READ
logger.info(f"get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}")
buf = np.frombuffer(mem_view, np_dtype)
return buf
@ -187,18 +192,17 @@ def get_scalar_from_field(field):
class _BuildingFlag:
def __enter__(self):
os.environ['IS_BUILDING'] = '1'
os.environ["IS_BUILDING"] = "1"
def __exit__(self, type, value, tb):
del os.environ['IS_BUILDING']
del os.environ["IS_BUILDING"]
def _is_building(f):
'''Use this to decorate functions which are called during engine building/refitting process,
"""Use this to decorate functions which are called during engine building/refitting process,
otherwise, the plugin registration will fail.
'''
"""
@wraps(f)
def decorated(*args, **kwargs):
@ -208,15 +212,25 @@ def _is_building(f):
return decorated
def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
max_input_len, max_seq_len, max_beam_width,
remove_input_padding, enable_context_fmha,
tokens_per_block, multiple_profiles):
def check_max_num_tokens(
max_num_tokens,
opt_num_tokens,
max_batch_size,
max_input_len,
max_seq_len,
max_beam_width,
remove_input_padding,
enable_context_fmha,
tokens_per_block,
multiple_profiles,
):
if not remove_input_padding:
if max_num_tokens is not None or opt_num_tokens is not None:
max_num_tokens = max_batch_size * max_seq_len
logger.warning("remove_input_padding is not enabled, the specified "
"max_num_tokens/opt_num_tokens will be ignored.")
logger.warning(
"remove_input_padding is not enabled, the specified "
"max_num_tokens/opt_num_tokens will be ignored."
)
return max_num_tokens, opt_num_tokens
else:
if max_num_tokens is None:
@ -228,20 +242,22 @@ def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
"when remove_input_padding is enabled, because the number "
"of packed input tokens are very likely to be smaller, "
"we strongly recommend to set max_num_tokens according "
"to your workloads.")
"to your workloads."
)
if opt_num_tokens is None and not multiple_profiles:
opt_num_tokens = min(max_batch_size * max_beam_width,
max_num_tokens)
opt_num_tokens = min(max_batch_size * max_beam_width, max_num_tokens)
logger.warning(
"remove_input_padding is enabled, while opt_num_tokens "
"is not set, setting to max_batch_size*max_beam_width. \n")
"is not set, setting to max_batch_size*max_beam_width. \n"
)
if max_num_tokens > 16384:
logger.warning(
"Specifying a `max_num_tokens` larger than 16384 is usually "
"not recommended, we do not expect perf gain with that and too "
"large `max_num_tokens` could possibly exceed the TensorRT "
"tensor volume, causing runtime errors. "
f"Got `max_num_tokens` = {max_num_tokens}")
f"Got `max_num_tokens` = {max_num_tokens}"
)
if max_num_tokens > max_seq_len * max_batch_size:
logger.warning(
f"max_num_tokens ({max_num_tokens}) shouldn't be greater than "
@ -253,21 +269,24 @@ def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
logger.warning(
f"When enable_context_fmha is not turned on, max_num_tokens ({max_num_tokens}) "
f"should be at least max_input_len ({max_input_len}), specifying to "
f"max_input_len ({max_input_len}).")
f"max_input_len ({max_input_len})."
)
max_num_tokens = max_input_len
elif max_num_tokens < tokens_per_block and enable_context_fmha:
logger.warning(
f"When enable_context_fmha is turned on, max_num_tokens ({max_num_tokens}) "
f"should be at least tokens_per_block ({tokens_per_block}), specifying to "
f"tokens_per_block ({tokens_per_block}). At this time, you also need to enable "
f"context chunking at runtime, otherwise you may encounter errors.")
f"context chunking at runtime, otherwise you may encounter errors."
)
max_num_tokens = tokens_per_block
if opt_num_tokens is not None and opt_num_tokens > max_num_tokens:
logger.warning(
f"opt_num_tokens ({opt_num_tokens}) shouldn't be greater than "
f"max_num_tokens ({max_num_tokens}), "
f"specifying to max_num_tokens ({max_num_tokens}).")
f"specifying to max_num_tokens ({max_num_tokens})."
)
opt_num_tokens = max_num_tokens
return max_num_tokens, opt_num_tokens

View File

@ -13,8 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
from ctypes import (CFUNCTYPE, POINTER, c_int, c_int64, c_size_t, c_uint8,
c_uint16, c_void_p, pointer)
from ctypes import (
CFUNCTYPE,
POINTER,
c_int,
c_int64,
c_size_t,
c_uint8,
c_uint16,
c_void_p,
pointer,
)
import torch
@ -24,14 +33,14 @@ class DLDataType(ctypes.Structure):
_fields_ = [
("code", c_uint8), # Data type code, e.g., 2 for float
("bits", c_uint8), # Number of bits per element, e.g., 32
("lanes", c_uint16) # Number of lanes, usually 1
("lanes", c_uint16), # Number of lanes, usually 1
]
class DLDevice(ctypes.Structure):
_fields_ = [
("device_type", c_int), # Device type, typically 2 for GPU
("device_id", c_int) # Device ID, usually 0 for default GPU
("device_id", c_int), # Device ID, usually 0 for default GPU
]
@ -43,15 +52,15 @@ class DLTensor(ctypes.Structure):
("dtype", DLDataType), # Data type
("shape", POINTER(c_int64)), # Pointer to array of dimension sizes
(
"strides", POINTER(c_int64)
"strides",
POINTER(c_int64),
), # Pointer to strides array (can be NULL for default contiguous layout)
("byte_offset", c_size_t) # Byte offset (usually 0)
("byte_offset", c_size_t), # Byte offset (usually 0)
]
# Deleter type for DLManagedTensor
DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(
ctypes.c_void_p)) # Not used directly here
DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(ctypes.c_void_p)) # Not used directly here
# Define DLManagedTensor structure, with deleter prototype void(*deleter)(DLManagedTensor*)
@ -59,9 +68,11 @@ class DLManagedTensor(ctypes.Structure):
pass
DLManagedTensor._fields_ = [("dl_tensor", DLTensor), ("manager_ctx", c_void_p),
("deleter", CFUNCTYPE(None,
POINTER(DLManagedTensor)))]
DLManagedTensor._fields_ = [
("dl_tensor", DLTensor),
("manager_ctx", c_void_p),
("deleter", CFUNCTYPE(None, POINTER(DLManagedTensor))),
]
# A no-op deleter that doesn't perform any operation
@ -95,8 +106,7 @@ class CapsuleWrapper:
self._managed_tensor = managed_tensor # Keep reference to prevent garbage collection
def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
torch_dtype, dev_id):
def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id):
"""
Parameters:
ptr: GPU memory address obtained from cudaMalloc (Python int)
@ -106,14 +116,19 @@ def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
torch_dtype: torch dtype
dev_id: device id.
Returns:
A PyCapsule object compliant with DLPack specification, which can be directly converted to a tensor using torch.utils.dlpack.from_dlpack
A PyCapsule object compliant with DLPack specification, which can be directly converted to a
tensor using torch.utils.dlpack.from_dlpack
"""
bits_per_elements = 0
dldata_type_code = 0
# refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160
if torch_dtype in [
torch.float8_e5m2, torch.float8_e4m3fn, torch.bfloat16,
torch.float16, torch.float32, torch.float64
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.bfloat16,
torch.float16,
torch.float32,
torch.float64,
]:
bits_per_elements = torch.finfo(torch_dtype).bits
dldata_type_code = 2
@ -128,8 +143,7 @@ def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
bytes_per_element = bits_per_elements // 8
# Allocate space for shape (constructing a one-dimensional tensor here)
ShapeArrayType = c_int64 * 2 # 1 dimension
shape_array = ShapeArrayType(num_segments,
segment_size // bytes_per_element)
shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element)
stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1)
# Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed)
device = DLDevice(device_type=2, device_id=dev_id)
@ -166,8 +180,9 @@ def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
return capsule_wrapper
def pack_strided_memory(ptr: int, segment_size: int, segment_stride: int,
num_segments: int, dtype: torch.dtype, dev_id):
def pack_strided_memory(
ptr: int, segment_size: int, segment_stride: int, num_segments: int, dtype: torch.dtype, dev_id
):
"""
Pack GPU memory into a PyTorch tensor with specified stride.
@ -187,8 +202,9 @@ def pack_strided_memory(ptr: int, segment_size: int, segment_stride: int,
even with the same pointer. Each capsule is consumed only once.
"""
# Create a new capsule each time
capsule_wrapper = create_dlpack_capsule(ptr, segment_size, segment_stride,
num_segments, dtype, dev_id)
capsule_wrapper = create_dlpack_capsule(
ptr, segment_size, segment_stride, num_segments, dtype, dev_id
)
torch_tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule)
torch_tensor._capsule_wrapper = capsule_wrapper
return torch_tensor

View File

@ -42,9 +42,7 @@ def can_access_peer(mapping: Mapping) -> bool:
# Early exit if devices are on different nodes
if mapping.get_node_rank(rank) != mapping.node_rank:
logger.info(
f"Detect inter-node TP between rank {mapping.rank} and rank {rank}"
)
logger.info(f"Detect inter-node TP between rank {mapping.rank} and rank {rank}")
return False
# Skip if same device
@ -63,8 +61,7 @@ def can_access_peer(mapping: Mapping) -> bool:
return True
class IpcMemory():
class IpcMemory:
# WARNING: Must in sync with FLAGS_SIZE in cpp/include/tensorrt_llm/runtime/ipcUtils.h
# (Max all reduce blocks + 1) * sizeof(int)
IPC_BARRIERS_SIZE_PER_GPU = (24 + 1) * 4
@ -73,8 +70,7 @@ class IpcMemory():
self.mapping = mapping
self.open_ipc = open_ipc and mapping.tp_size <= mapping.gpus_per_node
if self.open_ipc:
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(
self.mapping, size, True)
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(self.mapping, size, True)
else:
self.peer_ptrs = [0] * mapping.tp_size
self.local_ptr = 0
@ -91,10 +87,10 @@ class IpcMemory():
return array.array("Q", buffer).tolist()
@staticmethod
def open_ipc_memory(mapping: Mapping,
size: int,
set_to_zero: bool = False) -> Tuple[List[int], int]:
""" Allocates a buffer with the given *size* on each GPU. Then, enables IPC communication between TP groups.
def open_ipc_memory(
mapping: Mapping, size: int, set_to_zero: bool = False
) -> Tuple[List[int], int]:
"""Allocates a buffer with the given *size* on each GPU. Then, enables IPC communication between TP groups.
Returns a list of buffer pointers, buffers[i] is a handle to the corresponding buffer residing on GPU #i.
Call close_ipc_handle with the *buffer*.
"""
@ -105,8 +101,8 @@ class IpcMemory():
return size
comm = mpi_comm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.tp_rank)
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
# 1 << 21 is 2MB
@ -131,7 +127,8 @@ class IpcMemory():
peer_ptrs.append(local_ptr)
else:
error, ptr = cudart.cudaIpcOpenMemHandle(
handle, cudart.cudaIpcMemLazyEnablePeerAccess)
handle, cudart.cudaIpcMemLazyEnablePeerAccess
)
_raise_if_error(error)
peer_ptrs.append(ptr)

View File

@ -67,8 +67,7 @@ class MnnvlMemory:
def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(
self.mapping, size)
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(self.mapping, size)
def __del__(self):
if not sys.is_finalizing():
@ -76,15 +75,15 @@ class MnnvlMemory:
def as_torch_strided_tensor(self, dtype):
num_segments = MnnvlMemory.comm.Get_size()
return pack_strided_memory(self.ptr, self.segment_size,
self.rank_stride, num_segments, dtype,
MnnvlMemory.dev_id)
return pack_strided_memory(
self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id
)
@staticmethod
def initialize():
if not MnnvlMemory.initialized:
# use a dummy torch CUDA tensor to trigger CUDA context initialization
_ = torch.empty(1, device='cuda')
_ = torch.empty(1, device="cuda")
# ensure nvml is initialized.
try:
pynvml.nvmlDeviceGetCount()
@ -97,8 +96,8 @@ class MnnvlMemory:
if MnnvlMemory.comm is not None:
return MnnvlMemory.comm
comm = mpi_comm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
mapping.tp_rank)
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
MnnvlMemory.comm = comm
return comm
@ -109,7 +108,9 @@ class MnnvlMemory:
location.id = dev_id
allocation_prop = cuda.CUmemAllocationProp()
allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
allocation_prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
)
allocation_prop.location = location
return allocation_prop
@ -119,27 +120,25 @@ class MnnvlMemory:
return MnnvlMemory.allocation_granularity
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
option = cuda.CUmemAllocationGranularity_flags(
cuda.CUmemAllocationGranularity_flags.
CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED
)
granularity = _check_cu_result(
cuda.cuMemGetAllocationGranularity(prop=allocation_prop,
option=option))
cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)
)
MnnvlMemory.allocation_granularity = granularity
return MnnvlMemory.allocation_granularity
@staticmethod
def new_mnnvl_memory_address(mapping: Mapping, size: int):
page_count = (size + MnnvlMemory.fabric_page_size -
1) // MnnvlMemory.fabric_page_size
page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
logger.info(
f"[MnnvlMemory] creating address with stride={current_rank_stride}")
logger.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}")
comm = MnnvlMemory.get_comm(mapping)
comm_size = comm.Get_size()
address_size = current_rank_stride * comm_size
ptr = _check_cu_result(
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size,
0, 0))
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)
)
MnnvlMemory.current_start_address = int(ptr)
MnnvlMemory.current_rank_stride = current_rank_stride
MnnvlMemory.current_mem_offset = 0
@ -150,16 +149,15 @@ class MnnvlMemory:
dev_id = int(dev)
if MnnvlMemory.dev_id is None:
MnnvlMemory.dev_id = dev_id
assert dev_id == MnnvlMemory.dev_id,\
assert dev_id == MnnvlMemory.dev_id, (
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
)
comm = MnnvlMemory.get_comm(mapping)
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
all_rank_allocate_sizes = comm.allgather(size)
assert len(all_rank_allocate_sizes) == comm_size
assert all(
x == size for x in
all_rank_allocate_sizes), "Not all rank allocating same size."
assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size."
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
aligned_size = (size + granularity - 1) // granularity * granularity
@ -170,13 +168,15 @@ class MnnvlMemory:
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
allocated_mem_handle = _check_cu_result(
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0))
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)
)
exported_fabric_handle = _check_cu_result(
cuda.cuMemExportToShareableHandle(
allocated_mem_handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0))
allocated_mem_handle, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0
)
)
all_handles_data = comm.allgather(exported_fabric_handle.data)
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501
# can use buf = memoryview(data) to import if using plain buffer for data.
madesc = cuda.CUmemAccessDesc()
@ -186,44 +186,49 @@ class MnnvlMemory:
mem_handles = [None] * comm_size
for i, remote_handle_data in enumerate(all_handles_data):
rank_ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset
rank_ptr = (
MnnvlMemory.current_start_address
+ MnnvlMemory.current_rank_stride * i
+ MnnvlMemory.current_mem_offset
)
if i == comm_rank:
# Local memory mapping
mem_handles[i] = allocated_mem_handle
_check_cu_result(
cuda.cuMemMap(rank_ptr, aligned_size, 0,
allocated_mem_handle, 0))
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0))
else:
# Fabric memory mapping
imported_mem_handle = _check_cu_result(
cuda.cuMemImportFromShareableHandle(
remote_handle_data, cuda.CUmemAllocationHandleType.
CU_MEM_HANDLE_TYPE_FABRIC))
remote_handle_data, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
)
)
mem_handles[i] = imported_mem_handle
_check_cu_result(
cuda.cuMemMap(rank_ptr, aligned_size, 0,
imported_mem_handle, 0))
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0))
_check_cu_result(
cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
_check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset
stride = MnnvlMemory.current_rank_stride
MnnvlMemory.allocated_map[ptr] = (mapping, aligned_size, mem_handles,
MnnvlMemory.current_start_address,
MnnvlMemory.current_rank_stride,
MnnvlMemory.current_mem_offset)
MnnvlMemory.address_refcnt[
MnnvlMemory.current_start_address] = MnnvlMemory.address_refcnt.get(
MnnvlMemory.current_start_address, 0) + 1
MnnvlMemory.allocated_map[ptr] = (
mapping,
aligned_size,
mem_handles,
MnnvlMemory.current_start_address,
MnnvlMemory.current_rank_stride,
MnnvlMemory.current_mem_offset,
)
MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = (
MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1
)
MnnvlMemory.current_mem_offset += aligned_size
return ptr, stride
@staticmethod
def close_mnnvl_memory(ptr: int):
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = MnnvlMemory.allocated_map.pop(
ptr)
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = (
MnnvlMemory.allocated_map.pop(ptr)
)
comm = MnnvlMemory.get_comm(mapping)
comm_size = comm.Get_size()
for i in range(comm_size):
@ -235,8 +240,7 @@ class MnnvlMemory:
if MnnvlMemory.address_refcnt[start_address] == 0:
MnnvlMemory.address_refcnt.pop(start_address)
device_ptr = cuda.CUdeviceptr(start_address)
_check_cu_result(
cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
_check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
if start_address == MnnvlMemory.current_start_address:
MnnvlMemory.current_start_address = 0
MnnvlMemory.current_rank_stride = 0
@ -252,15 +256,19 @@ class MnnvlMemory:
for link_idx in range(link_count):
try:
if pynvml.nvmlDeviceGetNvLinkCapability(
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED):
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED
):
available_links += 1
is_active = pynvml.nvmlDeviceGetNvLinkState(
handle, link_idx)
is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx)
if is_active:
active_links += 1
except pynvml.NVMLError_NotSupported:
continue
return active_links == available_links and available_links > 0 if need_all_up else available_links > 0
return (
active_links == available_links and available_links > 0
if need_all_up
else available_links > 0
)
@staticmethod
def supports_mnnvl() -> bool:
@ -269,7 +277,7 @@ class MnnvlMemory:
# But it is not equivalent to MNNVL support.
# May need better support check.
arch = platform.machine().lower()
is_on_aarch64 = 'aarch64' in arch
is_on_aarch64 = "aarch64" in arch
support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True)
return is_on_aarch64 and support_nvlink_and_all_up
@ -298,91 +306,137 @@ class MnnvlMoe:
MnnvlMoe.moe_mapping = mapping
workspace_size_per_rank = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(
mapping.tp_size)
mapping.tp_size
)
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
torch.uint64)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64)
return MnnvlMoe.moe_workspace_tensor
@staticmethod
def compute_target_rank_id(token_selected_experts: torch.Tensor,
expert_count: int, ep_size: int):
def compute_target_rank_id(
token_selected_experts: torch.Tensor, expert_count: int, ep_size: int
):
assert expert_count % ep_size == 0, "expert_count should be divisible by ep_size"
expert_per_rank = expert_count // ep_size
token_target_rank_ids = token_selected_experts // expert_per_rank
return token_target_rank_ids
@staticmethod
def mnnvl_moe_alltoallv_prepare(gathered_target_rank_ids: torch.Tensor,
real_rank_token_count_cumsum: torch.Tensor,
gathered_expert_ids: torch.Tensor,
gathered_scales: torch.Tensor,
max_token_count_per_rank: int,
expert_count: int, top_k: int, ep_rank: int,
ep_size: int):
local_gather_indices, send_rank_count_cumsum, send_rank_local_indices, \
recv_rank_count_cumsum, recv_rank_local_indices, backward_recv_rank_local_indices = \
torch.ops.trtllm.moe_comm_prepare_indices(gathered_target_rank_ids, real_rank_token_count_cumsum,
max_token_count_per_rank, expert_count, top_k, ep_rank, ep_size)
def mnnvl_moe_alltoallv_prepare(
gathered_target_rank_ids: torch.Tensor,
real_rank_token_count_cumsum: torch.Tensor,
gathered_expert_ids: torch.Tensor,
gathered_scales: torch.Tensor,
max_token_count_per_rank: int,
expert_count: int,
top_k: int,
ep_rank: int,
ep_size: int,
):
(
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
) = torch.ops.trtllm.moe_comm_prepare_indices(
gathered_target_rank_ids,
real_rank_token_count_cumsum,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
local_token_allocation_count = max_token_count_per_rank * ep_size
local_expert_ids = torch.empty(local_token_allocation_count,
top_k,
dtype=torch.int32,
device=torch.device('cuda'))
local_scales = torch.empty(local_token_allocation_count,
top_k,
dtype=torch.float32,
device=torch.device('cuda'))
local_expert_ids = torch.empty(
local_token_allocation_count, top_k, dtype=torch.int32, device=torch.device("cuda")
)
local_scales = torch.empty(
local_token_allocation_count, top_k, dtype=torch.float32, device=torch.device("cuda")
)
torch.ops.trtllm.moe_local_gather(recv_rank_count_cumsum,
local_gather_indices,
gathered_expert_ids, gathered_scales,
local_expert_ids, local_scales,
max_token_count_per_rank,
expert_count, top_k, ep_rank, ep_size)
torch.ops.trtllm.moe_local_gather(
recv_rank_count_cumsum,
local_gather_indices,
gathered_expert_ids,
gathered_scales,
local_expert_ids,
local_scales,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
alltoall_info = MoEAlltoallInfo(
local_gather_indices, send_rank_count_cumsum,
send_rank_local_indices, recv_rank_count_cumsum,
recv_rank_local_indices, backward_recv_rank_local_indices,
local_token_allocation_count)
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
local_token_allocation_count,
)
return alltoall_info, local_expert_ids, local_scales
@staticmethod
def mnnvl_moe_alltoallv(x: torch.Tensor, alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor, ep_rank: int,
ep_size: int):
def mnnvl_moe_alltoallv(
x: torch.Tensor,
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
):
assert x.dim() == 2, "only 2D tensor supported, please reshape."
output_tensor = torch.empty(alltoall_info.local_token_allocation_count,
x.shape[1],
dtype=x.dtype,
device=torch.device('cuda'))
torch.ops.trtllm.moe_comm(x, alltoall_info.send_rank_count_cumsum,
alltoall_info.send_rank_local_indices,
output_tensor,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
workspace, ep_rank, ep_size)
output_tensor = torch.empty(
alltoall_info.local_token_allocation_count,
x.shape[1],
dtype=x.dtype,
device=torch.device("cuda"),
)
torch.ops.trtllm.moe_comm(
x,
alltoall_info.send_rank_count_cumsum,
alltoall_info.send_rank_local_indices,
output_tensor,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
workspace,
ep_rank,
ep_size,
)
return output_tensor
@staticmethod
def mnnvl_moe_alltoallv_combine(x: torch.Tensor,
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor, ep_rank: int,
ep_size: int, top_k: int, token_count: int):
def mnnvl_moe_alltoallv_combine(
x: torch.Tensor,
alltoall_info: MoEAlltoallInfo,
workspace: torch.Tensor,
ep_rank: int,
ep_size: int,
top_k: int,
token_count: int,
):
assert x.dim() == 2, "2D tensor supported, please reshape."
output_tensor = torch.zeros(token_count * top_k,
x.shape[1],
dtype=x.dtype,
device=torch.device('cuda'))
output_tensor = torch.zeros(
token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda")
)
torch.ops.trtllm.moe_comm(
x, alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices, output_tensor,
x,
alltoall_info.recv_rank_count_cumsum,
alltoall_info.recv_rank_local_indices,
output_tensor,
alltoall_info.send_rank_count_cumsum,
alltoall_info.backward_recv_rank_local_indices, workspace, ep_rank,
ep_size)
return torch.sum(output_tensor.reshape(token_count, top_k, x.shape[1]),
dim=1,
keepdim=False)
alltoall_info.backward_recv_rank_local_indices,
workspace,
ep_rank,
ep_size,
)
return torch.sum(
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
)

View File

@ -6,8 +6,7 @@ from tensorrt_llm.bindings import executor as tllme
@dataclass(slots=True, kw_only=True)
class DisaggregatedParams:
"""
Disaggregated seving parameters
"""Disaggregated seving parameters.
Args:
request_type (str): The type of request ("context_only" or "generation_only")
@ -23,10 +22,9 @@ class DisaggregatedParams:
draft_tokens: Optional[List[int]] = None
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
return tllme.ContextPhaseParams(self.first_gen_tokens,
self.ctx_request_id, self.opaque_state,
self.draft_tokens)
return tllme.ContextPhaseParams(
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
)
def get_request_type(self) -> tllme.RequestType:
if self.request_type == "context_only":
@ -37,5 +35,6 @@ class DisaggregatedParams:
return tllme.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
else:
raise ValueError(
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or context_and_generation"
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
"context_and_generation"
)

View File

@ -3,8 +3,7 @@ import weakref
from copy import copy
from dataclasses import dataclass, field
from functools import wraps
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple,
TypeVar)
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, TypeVar
import tensorrt as trt
@ -14,9 +13,7 @@ from .network import Network
class Layer:
'''
Layer is a wrapper for TensorRT's ILayer with several python-friendly helper functions.
'''
"""Layer is a wrapper for TensorRT's ILayer with several python-friendly helper functions."""
def __init__(self, network: Network, trt_layer: trt.ILayer):
self._network = weakref.ref(network)
@ -30,51 +27,50 @@ class Layer:
return self._network()
def get_inputs(self, *indices: int):
'''
Get the input tensors of the layer.
"""Get the input tensors of the layer.
Parameters:
idx: the indices of the input tensor, will return all inputs if left empty
Returns:
List[Tensor]
'''
"""
from .functional import Tensor
indices = indices if indices else range(self.trt_layer.num_inputs)
ret = []
for i in indices:
assert i < self.trt_layer.num_inputs, f"Invalid input index {i} for layer {self.trt_layer.name}"
assert i < self.trt_layer.num_inputs, (
f"Invalid input index {i} for layer {self.trt_layer.name}"
)
tensor = self.trt_layer.get_input(i)
tensor = Tensor(trt_tensor=tensor,
network=self.network,
is_network_input=False)
tensor = Tensor(trt_tensor=tensor, network=self.network, is_network_input=False)
ret.append(tensor)
return ret
def get_outputs(self, *indices: int):
'''
Get the output tensor of the layer.
"""Get the output tensor of the layer.
Parameters:
idx: the index of the output tensor
Returns:
List[Tensor]
'''
"""
from .functional import Tensor
indices = indices if indices else range(self.trt_layer.num_outputs)
ret = []
for i in indices:
assert i < self.trt_layer.num_outputs, f"Invalid output index {i} for layer {self.trt_layer.name}"
assert i < self.trt_layer.num_outputs, (
f"Invalid output index {i} for layer {self.trt_layer.name}"
)
tensor = self.trt_layer.get_output(i)
tensor = Tensor(trt_tensor=tensor,
network=self.network,
is_network_input=False)
tensor = Tensor(trt_tensor=tensor, network=self.network, is_network_input=False)
ret.append(tensor)
return ret
@ -82,10 +78,9 @@ class Layer:
return self.network.is_removed_layer(self)
def mark_as_removed(self):
'''
Mark the layer as removed, this will remove the layer from the network.
'''
# NOTE, since INetwork python API doesn't provide a way to remove a layer, we actually mark the layer as removed in the network.
"""Mark the layer as removed, this will remove the layer from the network."""
# NOTE, since INetwork python API doesn't provide a way to remove a layer, we actually mark
# the layer as removed in the network.
self.network.mark_removed_layer(self)
# remove the FLayerInfo if exists
@ -101,8 +96,8 @@ class Layer:
def __getattr__(self, name: str) -> Any:
return getattr(self.trt_layer, name)
# Refer to https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html?highlight=elementwise#layers for a complete
# list of TRT layers.
# Refer to https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html?highlight=elementwise#layers
# for a complete list of TRT layers.
TRT_LAYER_TYPE_TO_LAYER = {
trt.LayerType.CONVOLUTION: trt.IConvolutionLayer,
trt.LayerType.ACTIVATION: trt.IActivationLayer,
@ -153,14 +148,13 @@ class Layer:
}
if trt_gte(10, 8):
TRT_LAYER_TYPE_TO_LAYER[
trt.LayerType.DYNAMIC_QUANTIZE] = trt.IQuantizeLayer
TRT_LAYER_TYPE_TO_LAYER[trt.LayerType.DYNAMIC_QUANTIZE] = trt.IQuantizeLayer
def as_layer(self) -> Any:
'''
Convert to a actual TensorRT layer object, such as IPluginV2Layer or IConvolutionLayer so
that we can access the actual layer information.
'''
"""Convert to a actual TensorRT layer object.
This can be IPluginV2Layer or IConvolutionLayer, so that we can access the actual layer information.
"""
if self.type in self.TRT_LAYER_TYPE_TO_LAYER:
# bypass TRT's bug of retrieving a specific ILayer type in TensorRT
self.trt_layer.__class__ = self.TRT_LAYER_TYPE_TO_LAYER[self.type]
@ -188,22 +182,27 @@ class _Pattern:
class PatternRewriter(_Pattern):
'''
A pattern rewriter is a class that can match a pattern in the graph and rewrite the matched pattern.
"""A pattern rewriter is a class that can match a pattern in the graph and rewrite the matched pattern.
There are two ways to implement a pattern rewriter, either override match() and rewrite() separately, or override match_and_rewrite().
'''
There are two ways to implement a pattern rewriter, either override match() and rewrite() separately, or
override match_and_rewrite().
"""
def __init__(
self,
name: str,
root_layer: Optional[Set[trt.LayerType]] = None,
separate_match_rewrite=False,
):
"""Constructor.
def __init__(self,
name: str,
root_layer: Optional[Set[trt.LayerType]] = None,
separate_match_rewrite=False):
'''
Parameters:
name: the name of the rewrite pattern
root_layer: the root layer types to start the pattern matching, if not provided, the pattern will traverse all the layers in the graph.
separate_match_rewrite: if set to True, the pattern should override match() and rewrite() separately, otherwise, the pattern should override match_and_rewrite()
'''
root_layer: the root layer types to start the pattern matching, if not provided, the pattern
will traverse all the layers in the graph.
separate_match_rewrite: if set to True, the pattern should override match() and rewrite()
separately, otherwise, the pattern should override match_and_rewrite()
"""
super().__init__(name)
self.root_layer = root_layer
self._separate_match_rewrite = separate_match_rewrite
@ -219,9 +218,7 @@ class PatternRewriter(_Pattern):
class PatternAnalyzer(_Pattern):
def __init__(self, name: str,
root_layer: Optional[Set[trt.LayerType]]) -> None:
def __init__(self, name: str, root_layer: Optional[Set[trt.LayerType]]) -> None:
super().__init__(name)
self.root_layer = root_layer
@ -233,16 +230,13 @@ class PatternAnalyzer(_Pattern):
class _PatternManager:
PatternType = TypeVar('PatternType')
PatternType = TypeVar("PatternType")
def __init__(self):
# records of (benefit, pattern, id)
self.patterns: Dict[str, Tuple[int, _PatternManager.PatternType]] = {}
def add(self,
label: str,
pattern: "_PatternManager.PatternType",
benefit: int = 0):
def add(self, label: str, pattern: "_PatternManager.PatternType", benefit: int = 0):
assert label not in self.patterns, f"Pattern {label} already exists"
self.patterns[label] = (benefit, pattern)
@ -251,18 +245,18 @@ class _PatternManager:
class RewritePatternManager(_PatternManager):
def rewrite(self, net: Network, args=None):
modified = True
# TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after the graph is modified
# TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after
# the graph is modified.
while modified:
modified = False
# Since the graph iterator is hold by the underlying INetwork, we can only rebuild the graph cache and match the nodes again.
# Since the graph iterator is hold by the underlying INetwork, we can only rebuild the
# graph cache and match the nodes again.
for layer in net.get_layers():
if layer.is_removed():
continue
for (profit, pattern) in sorted(self.patterns.values(),
key=lambda x: x[0]):
for profit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]):
pattern.args = args
if pattern.root_layer is not None and layer.type not in pattern.root_layer:
@ -281,13 +275,11 @@ class RewritePatternManager(_PatternManager):
class AnalysisPatternManager(_PatternManager):
def analyze(self, graph: Network, args=None):
for layer in graph.get_layers():
if layer.name in graph.removed_layers:
continue
for (benefit, pattern) in sorted(self.patterns.values(),
key=lambda x: x[0]):
for benefit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]):
pattern.args = args
if pattern.root_layer is not None and layer.type not in pattern.root_layer:
@ -303,22 +295,25 @@ class AnalysisPatternManager(_PatternManager):
@dataclass
class FLayerInfo:
'''
The FLayerInfo is used to track the functional layers in the INetwork, and it is used to help the graph rewriting.
"""The FLayerInfo is used to track the functional layers in the INetwork and help graph rewriting.
The lifetime of a FLayer is the same as the corresponding plugin instance in the INetwork. Once the
plugin instance is removed by the graph rewriting, the FLayer will be removed as well.
WHY this is needed?
In the current implementation, for functional methods, once it is called in Python, it will lower to a plugin instance in the INetwork. However,
the plugin interface is black box with customized logic, we cannot retrieve necessary information from it, this is quite different from ILayer,
which provides a set of APIs to retrieve the information. Therefore, we need to record the high level information in the FLayerInfo, and keep
In the current implementation, for functional methods, once it is called in Python, it will lower
to a plugin instance in the INetwork.
However, the plugin interface is black box with customized logic, we cannot retrieve necessary
information from it. This is quite different from ILayer, which provides a set of APIs to retrieve
the information.
Therefore, we need to record the high level information in the FLayerInfo, and keep
it consistent during the graph rewriting.
'''
"""
layer_kind: str # the method name in the functional.py
# Record the raw inputs of the functional layer to be used in the graph rewrite
# NOTE: the raw inputs contains both the constants and Tensors, the Tensors will be also updated by graph rewriting
# APIs such as `replace_all_uses_with`
# NOTE: the raw inputs contains both the constants and Tensors, the Tensors will be also updated by
# graph rewriting APIs such as `replace_all_uses_with`
raw_inputs: Dict[str, Any]
raw_outputs: List[Any] = field(default_factory=list, init=False)
@ -331,6 +326,7 @@ class FLayerInfo:
def __post_init__(self):
from .functional import Tensor
assert self.layer_kind
def replace_with_symbols(arg) -> Any:
@ -357,10 +353,10 @@ class FLayerInfo:
return arg
self.signature = self.layer_kind, {
name: replace_with_symbols(value)
for name, value in self.raw_inputs.items()
}
self.signature = (
self.layer_kind,
{name: replace_with_symbols(value) for name, value in self.raw_inputs.items()},
)
amend_tensor(self.raw_inputs)
@ -371,18 +367,15 @@ class FLayerInfo:
return self.raw_inputs[name]
def clone_inputs(self):
'''
Get a shallow copy of the inputs.
'''
"""Get a shallow copy of the inputs."""
return copy(self.raw_inputs)
def replace_input_with(self, src, dst):
'''
Replace the input `src` with the input `dst` in the raw_inputs.
"""Replace the input `src` with the input `dst` in the raw_inputs.
src: Tensor
dst: Tensor
'''
"""
from .functional import Tensor
def replace(arg: Any):
@ -399,21 +392,22 @@ class FLayerInfo:
replace(self.raw_inputs)
def replace_outputs_uses_with(self, net: Network, new_outs: List[Any]):
'''
Replace the output users with the new outputs.
"""Replace the output users with the new outputs.
new_outs: List[Tensor], the new outputs to replace with
'''
"""
from .functional import Tensor
assert len(self.raw_outputs) == len(new_outs)
for old_out, new_out in zip(self.raw_outputs, new_outs):
assert type(old_out) == type(
new_out
), f"rewrite error, the output type {type(old_out)} is different from the new output type {type(new_out)} not match the original output type {type(old_out)}"
assert type(old_out) is type(new_out), (
f"rewrite error, the output type {type(old_out)} is different from the new output "
f"type {type(new_out)} not match the original output type {type(old_out)}"
)
def _swap_tensor_info(new, deprecated):
name = deprecated.trt_tensor.name
deprecated.trt_tensor.name = name + '_deprecated'
deprecated.trt_tensor.name = name + "_deprecated"
from .functional import cast
new = cast(new, deprecated.dtype)
@ -463,13 +457,11 @@ class FLayerInfo:
return hash(self.signature)
def __repr__(self) -> str:
return '<FLayer {}>'.format(self.signature)
return "<FLayer {}>".format(self.signature)
@staticmethod
def _get_spec(arg):
'''
Get the spec that could impact on the Module's topology in the `forward` method.
'''
"""Get the spec that could impact on the Module's topology in the `forward` method."""
from .functional import Tensor
# For scalars, we track their value since they are constant
@ -482,17 +474,17 @@ class FLayerInfo:
return Tensor
elif isinstance(arg, (list, tuple)):
return [FLayerInfo._get_spec(x) for x in arg]
# NOTE Free to add more types here is broken, carefully note that, from the engine building angle, all the constants
# should be captured while for the network variables, their types as placeholders are enough
# NOTE Free to add more types here is broken, carefully note that, from the engine building angle,
# all the constants should be captured while for the network variables, their types as placeholders
# are enough.
else:
raise TypeError(f"unsupported input type detected: {type(arg)}")
@dataclass
class FLayerInfoMemo:
'''
FLayerInfoMemo holds the FLayer of all the necessary functional layers.
'''
"""FLayerInfoMemo holds the FLayer of all the necessary functional layers."""
data: Dict[str, FLayerInfo] = field(default_factory=dict, init=False)
cur_flayer: ClassVar[Optional[FLayerInfo]] = None
@ -502,11 +494,8 @@ class FLayerInfoMemo:
self.data[layer_name] = layer
def create(self, fn: Callable, *args, **kwargs) -> FLayerInfo:
'''
Add a FLayer to the memo.
'''
return FLayerInfo(fn.__name__,
self.get_function_arg_dict(fn, *args, **kwargs))
"""Add a FLayer to the memo."""
return FLayerInfo(fn.__name__, self.get_function_arg_dict(fn, *args, **kwargs))
def get(self, layer_name: str) -> Optional[FLayerInfo]:
return self.data.get(layer_name, None)
@ -517,29 +506,24 @@ class FLayerInfoMemo:
@staticmethod
def instance() -> "FLayerInfoMemo":
'''
A singleton instance of FLayerMemo.
'''
"""A singleton instance of FLayerMemo."""
from ._common import default_net
return default_net().flayer_memo
@staticmethod
def get_function_arg_dict(f: Callable, *args, **kwargs):
'''
Get the input argument dict of a function.
'''
"""Get the input argument dict of a function."""
sig = inspect.signature(f)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
return {k: v for k, v in bound_args.arguments.items() if k != 'self'}
return {k: v for k, v in bound_args.arguments.items() if k != "self"}
class FLayerScope:
'''
FLayerScope is used to capture the plugin within a functional method.
'''
"""FLayerScope is used to capture the plugin within a functional method."""
def __init__(self, fn, *args, **kwargs):
self.layer = FLayerInfoMemo.instance().create(fn, *args, **kwargs)
@ -552,14 +536,14 @@ class FLayerScope:
def __exit__(self, exc_type, exc_val, exc_tb):
FLayerInfoMemo.cur_flayer = None
if exc_type is None:
assert self.layer.layer_name != "", f"FLayer {self.layer.layer_kind} without a plugin name detected"
assert self.layer.layer_name != "", (
f"FLayer {self.layer.layer_kind} without a plugin name detected"
)
FLayerInfoMemo.instance().add(self.layer.layer_name, self.layer)
def record_signature(f):
'''
Helps to decorate a functional method and record its metadata with a FLayerInfo.
'''
"""Helps to decorate a functional method and record its metadata with a FLayerInfo."""
@wraps(f)
def wrapper(*args, **kwargs):
@ -577,10 +561,8 @@ _global_analysis_pattern_manager = AnalysisPatternManager()
class FuseAttentionWithBiasPass(PatternRewriter):
def __init__(self):
super().__init__(name="fuse_attention_with_bias",
separate_match_rewrite=False)
super().__init__(name="fuse_attention_with_bias", separate_match_rewrite=False)
@staticmethod
def is_attention_plugin(layer: Layer) -> bool:
@ -588,14 +570,15 @@ class FuseAttentionWithBiasPass(PatternRewriter):
return False
p = layer.as_layer().plugin
conds = [
p.plugin_namespace == 'tensorrt_llm',
p.plugin_type == 'GPTAttention', p.num_outputs == 2
p.plugin_namespace == "tensorrt_llm",
p.plugin_type == "GPTAttention",
p.num_outputs == 2,
]
return all(conds)
@staticmethod
def is_elementwise_sum(layer: Layer) -> bool:
l = layer.as_layer()
l = layer.as_layer() # noqa: E741
if l.type != trt.LayerType.ELEMENTWISE:
return False
return l.op == trt.ElementWiseOperation.SUM
@ -612,11 +595,11 @@ class FuseAttentionWithBiasPass(PatternRewriter):
layer = tensor.get_parent()
if layer is None or depth > max_depth:
return False
if layer.type == trt.LayerType.CONSTANT and len(
layer.get_inputs()) == 0:
if layer.type == trt.LayerType.CONSTANT and len(layer.get_inputs()) == 0:
return True
for _ in layer.get_inputs():
if not const_foldable(_, depth + 1): return False
if not const_foldable(_, depth + 1):
return False
return True
for input in layer.get_inputs():
@ -628,27 +611,26 @@ class FuseAttentionWithBiasPass(PatternRewriter):
def match_and_rewrite(self, layer: Layer) -> bool:
from tensorrt_llm.network import net_guard
with net_guard(layer.network):
if not self.is_attention_plugin(layer):
return False
plugin_flayer = FLayerInfoMemo.instance().get(layer.name)
input = plugin_flayer.raw_inputs['qkv']
if input is None or isinstance(input, list) or len(
list(input.get_users())) != 1:
input = plugin_flayer.raw_inputs["qkv"]
if input is None or isinstance(input, list) or len(list(input.get_users())) != 1:
return False
parent_layer = input.get_parent()
if not self.is_elementwise_sum(parent_layer):
return False
eltwise_const_inputs, eltwise_mutable_inputs = self.get_eltwise_inputs(
parent_layer)
if len(eltwise_const_inputs) != 1 or len(
eltwise_mutable_inputs) != 1:
eltwise_const_inputs, eltwise_mutable_inputs = self.get_eltwise_inputs(parent_layer)
if len(eltwise_const_inputs) != 1 or len(eltwise_mutable_inputs) != 1:
return False
if plugin_flayer.raw_inputs['qkv_bias'] is not None:
if plugin_flayer.raw_inputs["qkv_bias"] is not None:
return False
plugin_flayer.raw_inputs['qkv'] = eltwise_mutable_inputs[0]
plugin_flayer.raw_inputs['qkv_bias'] = eltwise_const_inputs[0]
plugin_flayer.raw_inputs["qkv"] = eltwise_mutable_inputs[0]
plugin_flayer.raw_inputs["qkv_bias"] = eltwise_const_inputs[0]
from .functional import gpt_attention
new_outputs = gpt_attention(**plugin_flayer.raw_inputs)
plugin_flayer.replace_outputs_uses_with(layer.network, new_outputs)
return True
@ -657,7 +639,7 @@ class FuseAttentionWithBiasPass(PatternRewriter):
def optimize(net):
patterns = RewritePatternManager()
patterns.add(
label='fuse_attention_with_bias',
label="fuse_attention_with_bias",
pattern=FuseAttentionWithBiasPass(),
)
patterns.rewrite(net)

View File

@ -30,23 +30,21 @@ class Singleton(type):
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton,
cls).__call__(*args, **kwargs)
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Logger(metaclass=Singleton):
ENV_VARIABLE = "TLLM_LOG_LEVEL"
PREFIX = "TRT-LLM"
DEFAULT_LEVEL = "error"
ENV_VARIABLE = 'TLLM_LOG_LEVEL'
PREFIX = 'TRT-LLM'
DEFAULT_LEVEL = 'error'
INTERNAL_ERROR = '[F]'
ERROR = '[E]'
WARNING = '[W]'
INFO = '[I]'
VERBOSE = '[V]'
DEBUG = '[D]'
INTERNAL_ERROR = "[F]"
ERROR = "[E]"
WARNING = "[W]"
INFO = "[I]"
VERBOSE = "[V]"
DEBUG = "[D]"
def __init__(self):
environ_severity = os.environ.get(self.ENV_VARIABLE)
@ -54,8 +52,7 @@ class Logger(metaclass=Singleton):
self.rank: Optional[int] = None
min_severity = environ_severity.lower(
) if self._set_from_env else self.DEFAULT_LEVEL
min_severity = environ_severity.lower() if self._set_from_env else self.DEFAULT_LEVEL
invalid_severity = min_severity not in severity_map
if invalid_severity:
min_severity = self.DEFAULT_LEVEL
@ -66,14 +63,13 @@ class Logger(metaclass=Singleton):
self._logger.propagate = False
handler = logging.StreamHandler(stream=sys.stdout)
handler.setFormatter(
logging.Formatter(fmt='[%(asctime)s] %(message)s',
datefmt='%m/%d/%Y-%H:%M:%S'))
logging.Formatter(fmt="[%(asctime)s] %(message)s", datefmt="%m/%d/%Y-%H:%M:%S")
)
self._logger.addHandler(handler)
self._logger.setLevel(severity_map[min_severity][1])
self._polygraphy_logger = G_LOGGER
if self._polygraphy_logger is not None:
self._polygraphy_logger.module_severity = severity_map[
min_severity][2]
self._polygraphy_logger.module_severity = severity_map[min_severity][2]
# For log_once
self._appeared_keys = set()
@ -98,16 +94,16 @@ class Logger(metaclass=Singleton):
elif severity == self.VERBOSE or severity == self.DEBUG:
return self._logger.debug
else:
raise AttributeError(f'No such severity: {severity}')
raise AttributeError(f"No such severity: {severity}")
@property
def trt_logger(self) -> trt.ILogger:
return self._trt_logger
def log(self, severity, *msg):
parts = [f'[{self.PREFIX}]']
parts = [f"[{self.PREFIX}]"]
if self.rank is not None:
parts.append(f'[RANK {self.rank}]')
parts.append(f"[RANK {self.rank}]")
parts.append(severity)
parts.extend(map(str, msg))
self._func_wrapper(severity)(" ".join(parts))
@ -164,29 +160,28 @@ class Logger(metaclass=Singleton):
self._trt_logger.min_severity = severity_map[min_severity][0]
self._logger.setLevel(severity_map[min_severity][1])
if self._polygraphy_logger is not None:
self._polygraphy_logger.module_severity = severity_map[
min_severity][2]
self._polygraphy_logger.module_severity = severity_map[min_severity][2]
severity_map = {
'internal_error': [trt.Logger.INTERNAL_ERROR, logging.CRITICAL],
'error': [trt.Logger.ERROR, logging.ERROR],
'warning': [trt.Logger.WARNING, logging.WARNING],
'info': [trt.Logger.INFO, logging.INFO],
'verbose': [trt.Logger.VERBOSE, logging.DEBUG],
'debug': [trt.Logger.VERBOSE, logging.DEBUG],
'trace': [trt.Logger.VERBOSE, logging.DEBUG],
"internal_error": [trt.Logger.INTERNAL_ERROR, logging.CRITICAL],
"error": [trt.Logger.ERROR, logging.ERROR],
"warning": [trt.Logger.WARNING, logging.WARNING],
"info": [trt.Logger.INFO, logging.INFO],
"verbose": [trt.Logger.VERBOSE, logging.DEBUG],
"debug": [trt.Logger.VERBOSE, logging.DEBUG],
"trace": [trt.Logger.VERBOSE, logging.DEBUG],
}
if G_LOGGER is not None:
g_logger_severity_map = {
'internal_error': G_LOGGER.CRITICAL,
'error': G_LOGGER.ERROR,
'warning': G_LOGGER.WARNING,
'info': G_LOGGER.INFO,
'verbose': G_LOGGER.SUPER_VERBOSE,
'debug': G_LOGGER.SUPER_VERBOSE,
'trace': G_LOGGER.SUPER_VERBOSE,
"internal_error": G_LOGGER.CRITICAL,
"error": G_LOGGER.ERROR,
"warning": G_LOGGER.WARNING,
"info": G_LOGGER.INFO,
"verbose": G_LOGGER.SUPER_VERBOSE,
"debug": G_LOGGER.SUPER_VERBOSE,
"trace": G_LOGGER.SUPER_VERBOSE,
}
for key, value in g_logger_severity_map.items():
severity_map[key].append(value)

View File

@ -11,12 +11,10 @@ import numpy as np
import torch
import yaml
from ._utils import (DictConversion, pad_vocab_size, release_gc,
str_dtype_to_torch, torch_to_numpy)
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from .layers.linear import ColumnLinear
from .mapping import Mapping
from .models.convert_utils import (get_model_path, load_state_dict,
split_matrix_tp)
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
if TYPE_CHECKING:
from .runtime import ModelConfig
@ -25,13 +23,13 @@ if TYPE_CHECKING:
def get_all_nemo_lora_weights(lora_weights):
layer_weights = defaultdict(dict)
adapter_key = "self_attention.adapter_layer.lora_kqv_adapter"
layer_pattern = re.compile(r'.*\.layers\.(\d+)\..*')
layer_pattern = re.compile(r".*\.layers\.(\d+)\..*")
for key, weights in lora_weights.items():
if adapter_key in key:
if key.endswith('linear_in.weight'):
inout = 'in'
elif key.endswith('linear_out.weight'):
inout = 'out'
if key.endswith("linear_in.weight"):
inout = "in"
elif key.endswith("linear_out.weight"):
inout = "out"
else:
continue
m = layer_pattern.match(key)
@ -46,9 +44,9 @@ def get_all_nemo_lora_weights(lora_weights):
return layer_weights
# The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight
# The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight # noqa: E501
HF_LORA_PATTERN = re.compile(
r'(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight)'
r"(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight)"
)
@ -77,7 +75,9 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
hf_module = m.group(3) + "." + module_name
if hf_module not in hf_modules:
hf_module = module_name
assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported list {hf_modules}"
assert hf_module in hf_modules, (
f"hf_module {hf_module} is not in supported list {hf_modules}"
)
is_lora_a_or_b = m.group(8) is not None
if is_lora_a_or_b:
@ -90,13 +90,11 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
all_weights[layer_idx][hf_module][inout_or_mag] = weights
else:
all_weights[layer_idx][hf_module].setdefault(expert_idx, {})
all_weights[layer_idx][hf_module][expert_idx][
inout_or_mag] = weights
all_weights[layer_idx][hf_module][expert_idx][inout_or_mag] = weights
return all_weights
def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
if expert_idx is None:
all_weights[layer_idx][hf_module][inout] = weights
@ -110,7 +108,6 @@ def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
def get_hf_target_modules(lora_weights, hf_modules):
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
hf_target_modules.add(hf_module)
@ -130,11 +127,9 @@ def invert_module_mapping(trtllm_modules_to_hf_modules):
return hf_modules_to_trtllm_modules
def norm_dora_magnitude(W0: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
m: torch.Tensor,
scaling: float = 1.0):
def norm_dora_magnitude(
W0: torch.Tensor, A: torch.Tensor, B: torch.Tensor, m: torch.Tensor, scaling: float = 1.0
):
new_weight_v = W0 + (B @ A) * scaling
norm_m = m.view(-1) / (torch.linalg.norm(new_weight_v, dim=1)).detach()
return norm_m
@ -143,7 +138,7 @@ def norm_dora_magnitude(W0: torch.Tensor,
@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = 'hf'
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
@ -151,9 +146,9 @@ class LoraConfig(DictConversion):
max_cpu_loras: int = 4
def __post_init__(self):
assert self.lora_ckpt_source in [
'hf', 'nemo'
], f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
assert self.lora_ckpt_source in ["hf", "nemo"], (
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)
@property
def missing_qkv_modules(self) -> List[str]:
@ -169,7 +164,6 @@ class LoraModelConfig:
class HfLoraLoader:
def __init__(self, lora_dirs: List[str]):
self.lora_target_modules = []
self.is_valid = False
@ -183,8 +177,7 @@ class HfLoraLoader:
for lora_dir in lora_dirs:
model_path = get_model_path(lora_dir, "adapter_model")
if model_path is None:
raise ValueError(
f"adapter_model file does not exist in {lora_dir}")
raise ValueError(f"adapter_model file does not exist in {lora_dir}")
config_file = Path(f"{lora_dir}/adapter_config.json")
if not config_file.exists():
raise ValueError(f"{config_file} does not exist")
@ -207,12 +200,10 @@ class HfLoraLoader:
self.vocab_size = self.lm_head.shape[0]
if "embed_tokens" in adapter_config["modules_to_save"]:
self.embed_tokens = lora_weight[
"base_model.model.model.embed_tokens.weight"]
self.embed_tokens = lora_weight["base_model.model.model.embed_tokens.weight"]
def get_target_modules(self, trtllm_modules_to_hf_modules):
hf_modules_to_trtllm_modules = invert_module_mapping(
trtllm_modules_to_hf_modules)
hf_modules_to_trtllm_modules = invert_module_mapping(trtllm_modules_to_hf_modules)
lora_target_modules = set()
if self.is_valid:
hf_target_modules = get_hf_target_modules(
@ -226,7 +217,6 @@ class HfLoraLoader:
class NemoLoraLoader:
def __init__(self, lora_dirs: List[str]):
self.lora_target_modules = []
self.is_valid = False
@ -269,21 +259,21 @@ def get_default_trtllm_modules_to_hf_modules():
def load_torch_hf_lora(lora_config: LoraConfig):
"""
This is a shortned version of load_hf_lora that is used for torch models.
"""This is a shortned version of load_hf_lora that is used for torch models.
Main problem is model.config in legacy code is custom (defined in the legacy code) whereas
pivot model config is the transformer's one.
"""
# TODO smor- need to comibe with load_hf_lora
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules(
)
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
lora_loader = HfLoraLoader(lora_config.lora_dir)
if len(lora_config.lora_target_modules) == 0:
lora_config.lora_target_modules = lora_loader.get_target_modules(
lora_config.trtllm_modules_to_hf_modules)
lora_config.trtllm_modules_to_hf_modules
)
if len(lora_config.lora_target_modules) == 0:
raise ValueError(
@ -291,8 +281,7 @@ def load_torch_hf_lora(lora_config: LoraConfig):
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
)
missing_qkv_modules = LoraManager.get_missing_qkv_modules(
lora_config.lora_target_modules)
missing_qkv_modules = LoraManager.get_missing_qkv_modules(lora_config.lora_target_modules)
lora_config.lora_target_modules.extend(missing_qkv_modules)
@ -301,7 +290,8 @@ def load_hf_lora(
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules(
trtllm_modules_to_hf_modules = (
trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules()
)
lora_config.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules
@ -309,15 +299,15 @@ def load_hf_lora(
if len(lora_config.lora_target_modules) == 0:
lora_config.lora_target_modules = lora_loader.get_target_modules(
trtllm_modules_to_hf_modules)
trtllm_modules_to_hf_modules
)
if len(lora_config.lora_target_modules) == 0:
raise ValueError(
"lora_target_modules is empty. "
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
)
missing_qkv_modules = LoraManager.get_missing_qkv_modules(
lora_config.lora_target_modules)
missing_qkv_modules = LoraManager.get_missing_qkv_modules(lora_config.lora_target_modules)
lora_config.lora_target_modules.extend(missing_qkv_modules)
if lora_loader.is_valid:
@ -341,15 +331,12 @@ def load_hf_lora(
num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
dtype=config.dtype,
tp_size=mapping.tp_size
if config.use_parallel_embedding else 1,
tp_group=mapping.tp_group
if config.use_parallel_embedding else None,
tp_size=mapping.tp_size if config.use_parallel_embedding else 1,
tp_group=mapping.tp_group if config.use_parallel_embedding else None,
sharding_dim=config.embedding_sharding_dim,
tp_rank=mapping.tp_rank,
)
model.transformer.vocab_embedding.weight.value = weight.to(
torch_dtype)
model.transformer.vocab_embedding.weight.value = weight.to(torch_dtype)
if mapping.is_last_pp_rank() and lora_loader.lm_head is not None:
weight = lora_loader.lm_head
vocab_size = lora_loader.vocab_size
@ -359,9 +346,13 @@ def load_hf_lora(
pad_width = vocab_size_padded - vocab_size
weight = torch.from_numpy(
np.pad(torch_to_numpy(weight), ((0, pad_width), (0, 0)),
'constant',
constant_values=0))
np.pad(
torch_to_numpy(weight),
((0, pad_width), (0, 0)),
"constant",
constant_values=0,
)
)
else:
vocab_size_padded = vocab_size
if model.lm_head.weight.raw_value.shape != weight.shape:
@ -392,8 +383,7 @@ def use_lora(
elif lora_config.lora_ckpt_source == "hf":
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
def unpack_nemo_weights(nemo_archive_path):
@ -416,8 +406,9 @@ def unpack_nemo_weights(nemo_archive_path):
model_config_dict = yaml.safe_load(model_config_content)
model_weights_bytes = model_weights_file.read()
model_weights_dict = torch.load(io.BytesIO(model_weights_bytes),
map_location=torch.device("cpu"))
model_weights_dict = torch.load(
io.BytesIO(model_weights_bytes), map_location=torch.device("cpu")
)
return model_config_dict, model_weights_dict
@ -446,58 +437,55 @@ class LoraManager(object):
}
def __init__(self):
'''
_lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
{
uid: {
0: {
lora_module: int
}, # layer_0_rank,
1: {
lora_module: int
}, # layer_1_rank,
...
}
}
"""Constructor."""
# _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
# {
# uid: {
# 0: {
# lora_module: int
# }, # layer_0_rank,
# 1: {
# lora_module: int
# }, # layer_1_rank,
# ...
# }
# }
_lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]]
{
uid: {
0: {
lora_module: [t_in, t_out]
}, # layer_0,
1: {
lora_module: [t_in, t_out]
}, # layer_1,
...
}
}
# _lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]]
# {
# uid: {
# 0: {
# lora_module: [t_in, t_out]
# }, # layer_0,
# 1: {
# lora_module: [t_in, t_out]
# }, # layer_1,
# ...
# }
# }
'''
self._lora_uid_counter = 0
self._lora_uid_to_low_ranks: Dict[str, Dict[int, Dict[str, int]]] = {}
# hold the torch tensors and prevent them from being freed
# TODO(enweiz): free device tensors if it's used for c++ runtime only
self._lora_weights: List[torch.Tensor] = []
self._lora_weights_pointers_list: Dict[str, Dict[int,
Dict[str,
List[int]]]] = {}
self._lora_weights_pointers_list: Dict[str, Dict[int, Dict[str, List[int]]]] = {}
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
self.lora_target_modules: List[str] = []
@staticmethod
def get_missing_qkv_modules(lora_target_modules):
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time.
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor to fill the missing ones.
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
# all disabled at the same time.
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
# to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules
for x in ["attn_q", "attn_k", "attn_v"]):
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
@ -507,30 +495,38 @@ class LoraManager(object):
def missing_qkv_modules(self) -> List[str]:
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
def load_from_ckpt(self,
model_dirs_or_files: List[str],
model_config: Union['ModelConfig', LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
ckpt_source: str = 'hf'):
if ckpt_source == 'hf':
self.load_from_hf(model_dirs=model_dirs_or_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids)
elif ckpt_source == 'nemo':
self.load_from_nemo(model_files=model_dirs_or_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids)
def load_from_ckpt(
self,
model_dirs_or_files: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
ckpt_source: str = "hf",
):
if ckpt_source == "hf":
self.load_from_hf(
model_dirs=model_dirs_or_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids,
)
elif ckpt_source == "nemo":
self.load_from_nemo(
model_files=model_dirs_or_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids,
)
else:
assert False, f"{self.__class__.__name__} does not support source {ckpt_source}"
def load_from_nemo(self,
model_files: List[str],
model_config: Union['ModelConfig', LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None):
def load_from_nemo(
self,
model_files: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
):
if runtime_mapping is None:
runtime_mapping = Mapping()
tp_size = runtime_mapping.tp_size
@ -554,11 +550,9 @@ class LoraManager(object):
def load_from_model_file(uid, model_file):
if uid not in self._cpp_lora_weights:
self._cpp_lora_weights[uid] = [
] # Will be converted to tensor later
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
if uid not in self._cpp_lora_config:
self._cpp_lora_config[uid] = [
] # Will be converted to tensor later
self._cpp_lora_config[uid] = [] # Will be converted to tensor later
_, nemo_weights = unpack_nemo_weights(model_file)
all_lora_weights = get_all_nemo_lora_weights(nemo_weights)
@ -571,72 +565,67 @@ class LoraManager(object):
for lora_module in self.lora_target_modules:
if lora_module != "attn_qkv":
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = 0
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
continue
if lora_module == "attn_qkv":
t_in = all_lora_weights[layer_idx]["in"]
t_out = all_lora_weights[layer_idx]["out"]
assert t_out.shape[0] % tp_size == 0
t_out = torch.split(t_out,
t_out.shape[0] // tp_size,
dim=0)[tp_rank].contiguous()
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
tp_rank
].contiguous()
else:
t_in = None
t_out = None
if t_in is not None and t_out is not None:
t_in = t_in.cuda().to(
str_dtype_to_torch(
model_config.dtype)).contiguous()
t_out = t_out.cuda().to(
str_dtype_to_torch(
model_config.dtype)).contiguous()
t_in = t_in.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous()
t_out = t_out.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous()
rank = t_in.shape[0]
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = int(rank)
self._lora_weights_pointers_list[uid][layer_idx][
lora_module] = [
t_in.data_ptr(),
t_out.data_ptr(), 0
]
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = int(rank)
self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [
t_in.data_ptr(),
t_out.data_ptr(),
0,
]
# prevent torch free this buffer
self._lora_weights.append(t_in)
self._lora_weights.append(t_out)
self._cpp_lora_weights[uid].append(
torch.concatenate(
[t_in.flatten().cpu(),
t_out.flatten().cpu()]))
torch.concatenate([t_in.flatten().cpu(), t_out.flatten().cpu()])
)
self._cpp_lora_config[uid].append(
torch.tensor([
self.LORA_MODULE_IDS[lora_module], layer_idx,
int(rank)
],
dtype=torch.int32))
torch.tensor(
[self.LORA_MODULE_IDS[lora_module], layer_idx, int(rank)],
dtype=torch.int32,
)
)
max_weight_size = max(
w.size(0) for w in self._cpp_lora_weights[uid])
self._cpp_lora_weights[uid] = torch.stack([
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
for w in self._cpp_lora_weights[uid]
])
self._cpp_lora_config[uid] = torch.stack(
[c for c in self._cpp_lora_config[uid]])
max_weight_size = max(w.size(0) for w in self._cpp_lora_weights[uid])
self._cpp_lora_weights[uid] = torch.stack(
[
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
for w in self._cpp_lora_weights[uid]
]
)
self._cpp_lora_config[uid] = torch.stack([c for c in self._cpp_lora_config[uid]])
for uid, model_file in zip(new_uids, new_model_files):
load_from_model_file(uid, model_file)
release_gc()
def load_from_hf(self,
model_dirs: List[str],
model_config: Union['ModelConfig', LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
component: Optional[str] = None):
'''
lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b
def load_from_hf(
self,
model_dirs: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
component: Optional[str] = None,
):
"""Lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b.
{
"base_model_name_or_path": "/Llama-2-7b-hf",
"bias": "none",
@ -682,7 +671,7 @@ class LoraManager(object):
base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64])
...
'''
"""
if runtime_mapping is None:
runtime_mapping = Mapping()
tp_size = runtime_mapping.tp_size
@ -704,13 +693,14 @@ class LoraManager(object):
lora_hf_configs = []
for model_dir in new_model_dirs:
with open(f"{model_dir}/adapter_config.json", 'r') as f:
with open(f"{model_dir}/adapter_config.json", "r") as f:
config = json.load(f)
lora_hf_configs.append(config)
self.lora_target_modules = model_config.lora_target_modules
hf_modules_to_trtllm_modules = invert_module_mapping(
model_config.trtllm_modules_to_hf_modules)
model_config.trtllm_modules_to_hf_modules
)
hf_modules = set(hf_modules_to_trtllm_modules.keys())
def preprocess_lora_weights(lora_model):
@ -727,20 +717,15 @@ class LoraManager(object):
def load_from_model_dir(uid, model_dir, hf_config):
if uid not in self._cpp_lora_weights:
self._cpp_lora_weights[uid] = [
] # Will be converted to tensor later
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
if uid not in self._cpp_lora_config:
self._cpp_lora_config[uid] = [
] # Will be converted to tensor later
self._cpp_lora_config[uid] = [] # Will be converted to tensor later
lora_model = load_state_dict(
get_model_path(model_dir, "adapter_model"))
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
if lora_model is None:
raise ValueError(
f"Failed to load adapter_model from {model_dir}")
raise ValueError(f"Failed to load adapter_model from {model_dir}")
lora_model = preprocess_lora_weights(lora_model)
all_weights = get_all_hf_lora_weights(lora_model, hf_modules,
component)
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
rank = int(hf_config["r"])
rs_lora = bool(hf_config.get("use_rslora", False))
@ -752,8 +737,7 @@ class LoraManager(object):
self._lora_weights_pointers_list[uid][layer_idx] = {}
for lora_module in self.missing_qkv_modules:
hf_module = model_config.trtllm_modules_to_hf_modules[
lora_module]
hf_module = model_config.trtllm_modules_to_hf_modules[lora_module]
if isinstance(hf_module, list):
hf_module = hf_module[0]
layer_weights[hf_module] = {
@ -764,24 +748,26 @@ class LoraManager(object):
for hf_module, module_weights in layer_weights.items():
lora_module = hf_modules_to_trtllm_modules[hf_module]
if lora_module not in self.lora_target_modules:
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = 0
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
continue
if "in" not in module_weights:
is_moe = True
t_in = torch.stack([
module_weights[expert_idx]["in"]
for expert_idx in sorted(module_weights.keys())
])
t_out = torch.stack([
module_weights[expert_idx]["out"]
for expert_idx in sorted(module_weights.keys())
])
t_in = torch.stack(
[
module_weights[expert_idx]["in"]
for expert_idx in sorted(module_weights.keys())
]
)
t_out = torch.stack(
[
module_weights[expert_idx]["out"]
for expert_idx in sorted(module_weights.keys())
]
)
for weights in module_weights.values():
if "mag" in weights:
# TODO(oargov): this might work, but I had no MoE DoRA models to test
raise ValueError(
"DoRA with MoE is not supported")
raise ValueError("DoRA with MoE is not supported")
t_mag = None
else:
is_moe = False
@ -796,28 +782,28 @@ class LoraManager(object):
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
pass
elif lora_module in [
"attn_dense",
"cross_attn_dense",
"mlp_4h_to_h",
"moe_4h_to_h",
"attn_dense",
"cross_attn_dense",
"mlp_4h_to_h",
"moe_4h_to_h",
]:
# split by row
dim = 2 if is_moe else 1
assert t_in.shape[dim] % tp_size == 0
t_in = torch.split(t_in,
t_in.shape[dim] // tp_size,
dim=dim)[tp_rank].contiguous()
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
tp_rank
].contiguous()
else:
# split by column
dim = 1 if is_moe else 0
assert t_out.shape[dim] % tp_size == 0
t_out = torch.split(t_out,
t_out.shape[dim] // tp_size,
dim=dim)[tp_rank].contiguous()
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
tp_rank
].contiguous()
if dim == 0 and is_dora and t_mag is not None:
t_mag = torch.split(t_mag,
t_mag.shape[0] // tp_size,
dim=0)[tp_rank].contiguous()
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
tp_rank
].contiguous()
rank_dim = 1 if is_moe else 0
effective_rank = t_in.shape[rank_dim]
@ -828,8 +814,7 @@ class LoraManager(object):
t_mag = t_mag.cuda().contiguous()
if rs_lora:
scale = float(
hf_config["lora_alpha"]) / np.sqrt(effective_rank)
scale = float(hf_config["lora_alpha"]) / np.sqrt(effective_rank)
else:
scale = float(hf_config["lora_alpha"]) / effective_rank
t_out = t_out * scale
@ -838,15 +823,12 @@ class LoraManager(object):
if is_dora and t_mag is not None:
t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype))
self._lora_uid_to_low_ranks[uid][layer_idx][
lora_module] = effective_rank
self._lora_weights_pointers_list[uid][layer_idx][
lora_module] = [
t_in.data_ptr(),
t_out.data_ptr(),
t_mag.data_ptr() if
(is_dora and t_mag is not None) else 0
]
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = effective_rank
self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [
t_in.data_ptr(),
t_out.data_ptr(),
t_mag.data_ptr() if (is_dora and t_mag is not None) else 0,
]
# prevent torch free this buffer
self._lora_weights.append(t_in)
@ -862,26 +844,24 @@ class LoraManager(object):
t_mag_cpu = t_mag.flatten().cpu()
weights_to_concat.append(t_mag_cpu)
self._cpp_lora_weights[uid].append(
torch.cat(weights_to_concat))
self._cpp_lora_weights[uid].append(torch.cat(weights_to_concat))
self._cpp_lora_config[uid].append(
torch.tensor([
self.LORA_MODULE_IDS[lora_module], layer_idx,
effective_rank, is_dora
],
dtype=torch.int32))
torch.tensor(
[self.LORA_MODULE_IDS[lora_module], layer_idx, effective_rank, is_dora],
dtype=torch.int32,
)
)
max_weight_size = max(
w.size(0) for w in self._cpp_lora_weights[uid])
self._cpp_lora_weights[uid] = torch.stack([
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
for w in self._cpp_lora_weights[uid]
])
self._cpp_lora_config[uid] = torch.stack(
[c for c in self._cpp_lora_config[uid]])
max_weight_size = max(w.size(0) for w in self._cpp_lora_weights[uid])
self._cpp_lora_weights[uid] = torch.stack(
[
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
for w in self._cpp_lora_weights[uid]
]
)
self._cpp_lora_config[uid] = torch.stack([c for c in self._cpp_lora_config[uid]])
for uid, model_dir, hf_config in zip(new_uids, new_model_dirs,
lora_hf_configs):
for uid, model_dir, hf_config in zip(new_uids, new_model_dirs, lora_hf_configs):
load_from_model_dir(uid, model_dir, hf_config)
release_gc()
@ -914,10 +894,9 @@ class LoraManager(object):
@property
def num_lora_adapters(self):
return len([uid for uid in self._lora_uid_to_low_ranks if uid != '-1'])
return len([uid for uid in self._lora_uid_to_low_ranks if uid != "-1"])
def save_lora_weights_to_bin(self, out_dir):
def save_val(val, dir, key, tp_num=None, write_npy=False):
ext = "npy" if write_npy else "bin"
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
@ -933,32 +912,21 @@ class LoraManager(object):
else:
assert False
for uid in self.cpp_lora_weights:
if uid == '-1':
if uid == "-1":
continue
all_weights = np.expand_dims(
torch_to_numpy(self.cpp_lora_weights[uid]), 0)
all_configs = np.expand_dims(
torch_to_numpy(self.cpp_lora_config[uid]), 0)
all_weights = np.expand_dims(torch_to_numpy(self.cpp_lora_weights[uid]), 0)
all_configs = np.expand_dims(torch_to_numpy(self.cpp_lora_config[uid]), 0)
uid_path = out_dir_path / f"{uid}"
uid_path.mkdir(parents=True, exist_ok=True)
save_val(all_weights,
uid_path,
"lora_weights",
tp_num=None,
write_npy=True)
save_val(all_configs,
uid_path,
"lora_config",
tp_num=None,
write_npy=True)
save_val(all_weights, uid_path, "lora_weights", tp_num=None, write_npy=True)
save_val(all_configs, uid_path, "lora_config", tp_num=None, write_npy=True)
def input_buffers(self, lora_uids, mapping: Mapping, num_layers: int):
inputs = {}
for layer_idx in mapping.pp_layers(num_layers):
for lora_module in (self.lora_target_modules +
self.missing_qkv_modules):
for lora_module in self.lora_target_modules + self.missing_qkv_modules:
lora_ranks_ = []
lora_ptrs_ = []
for lora_uid in lora_uids:
@ -968,21 +936,21 @@ class LoraManager(object):
if lora_uid != "-1":
low_ranks = self.uid_to_low_ranks(lora_uid)
if (layer_idx in low_ranks
and lora_module in low_ranks[layer_idx].keys()
and low_ranks[layer_idx][lora_module] != 0):
if (
layer_idx in low_ranks
and lora_module in low_ranks[layer_idx].keys()
and low_ranks[layer_idx][lora_module] != 0
):
lora_rank = low_ranks[layer_idx][lora_module]
lora_ptrs = self.lora_weights_pointers_list[
lora_uid][layer_idx][lora_module]
lora_ptrs = self.lora_weights_pointers_list[lora_uid][layer_idx][
lora_module
]
lora_ranks_.append(lora_rank)
lora_ptrs_.append(lora_ptrs)
inputs[
f'{lora_module}_lora_ranks_{layer_idx}'] = torch.IntTensor(
lora_ranks_)
inputs[
f'{lora_module}_lora_weights_pointers_{layer_idx}'] = torch.LongTensor(
lora_ptrs_)
inputs[f"{lora_module}_lora_ranks_{layer_idx}"] = torch.IntTensor(lora_ranks_)
inputs[f"{lora_module}_lora_weights_pointers_{layer_idx}"] = torch.LongTensor(
lora_ptrs_
)
return inputs

View File

@ -19,19 +19,18 @@ from .logger import logger
def _addindent(s_, numSpaces):
s = s_.split('\n')
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
class Module(object):
def __init__(self) -> None:
self._modules = {}
self._parameters = {}
@ -52,20 +51,20 @@ class Module(object):
output = self.forward(*args, **kwargs)
end_layer_idx = current_net.trt_network.num_layers
current_net._module_call_stack.set_layer_range(
self, range(start_layer_idx, end_layer_idx))
self, range(start_layer_idx, end_layer_idx)
)
return output
def __getattr__(self, name):
parameters = self.__dict__.get('_parameters')
parameters = self.__dict__.get("_parameters")
if parameters is not None and name in parameters:
return parameters[name]
modules = self.__dict__.get('_modules')
modules = self.__dict__.get("_modules")
if modules is not None and name in modules:
return modules[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
def __setattr__(self, name, value) -> None:
from .parameter import Parameter
@ -81,9 +80,9 @@ class Module(object):
# - keep Parameter and Module attrs in this Module class
# - leave all other attrs in base class
if isinstance(value, Parameter):
self.__dict__.get('_parameters')[name] = value
self.__dict__.get("_parameters")[name] = value
elif isinstance(value, Module):
self.__dict__.get('_modules')[name] = value
self.__dict__.get("_modules")[name] = value
else:
super().__setattr__(name, value)
@ -93,14 +92,14 @@ class Module(object):
# - other types reset and remain in base class
if isinstance(value, Parameter):
super().__delattr__(name)
self.__dict__.get('_parameters')[name] = value
self.__dict__.get("_parameters")[name] = value
elif isinstance(value, Module):
super().__delattr__(name)
self.__dict__.get('_modules')[name] = value
self.__dict__.get("_modules")[name] = value
else:
super().__setattr__(name, value)
def named_modules(self, memo=None, prefix='', remove_duplicate=True):
def named_modules(self, memo=None, prefix="", remove_duplicate=True):
if memo is None:
memo = set()
if self not in memo:
@ -110,16 +109,11 @@ class Module(object):
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix,
remove_duplicate):
submodule_prefix = prefix + ("." if prefix else "") + name
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m
def named_modules_with_parent(self,
memo=None,
prefix='',
parent=None,
remove_duplicate=True):
def named_modules_with_parent(self, memo=None, prefix="", parent=None, remove_duplicate=True):
if memo is None:
memo = set()
if self not in memo:
@ -130,7 +124,7 @@ class Module(object):
if parent:
# Use the up-to-date module from the parent, to allow replacing
# layers while iterating this generator.
module_name = prefix.rsplit('.', 1)[-1]
module_name = prefix.rsplit(".", 1)[-1]
module = getattr(parent, module_name)
if module is None:
return
@ -140,9 +134,10 @@ class Module(object):
for child_name, child_module in module._modules.items():
if child_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + child_name
submodule_prefix = prefix + ("." if prefix else "") + child_name
for m in child_module.named_modules_with_parent(
memo, submodule_prefix, module, remove_duplicate):
memo, submodule_prefix, module, remove_duplicate
):
yield m
def named_children(self):
@ -152,27 +147,26 @@ class Module(object):
memo.add(module)
yield name, module
def _named_members(self, get_members_fn, prefix='', recurse=True):
def _named_members(self, get_members_fn, prefix="", recurse=True):
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix,
self)]
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
name = module_prefix + ("." if module_prefix else "") + k
yield name, v
def parameters(self, recurse=True):
for name, param in self.named_parameters():
yield param
def named_parameters(self, prefix='', recurse=True):
gen = self._named_members(lambda module: module._parameters.items(),
prefix=prefix,
recurse=recurse)
def named_parameters(self, prefix="", recurse=True):
gen = self._named_members(
lambda module: module._parameters.items(), prefix=prefix, recurse=recurse
)
for elem in gen:
yield elem
@ -201,15 +195,15 @@ class Module(object):
def named_network_outputs(self):
for name, module in self.named_modules():
for n, output in module._network_outputs.items():
yield name + ('.' if name else '') + n, output
yield name + ("." if name else "") + n, output
def update_parameters(self, torch_module):
m = {k: v for k, v in self.named_parameters()}
tm = {k: v for k, v in torch_module.named_parameters()}
assert sorted(m.keys()) == sorted(
tm.keys()
), 'The parameter names of the tensorrt-llm module must be the same with the torch module'
assert sorted(m.keys()) == sorted(tm.keys()), (
"The parameter names of the tensorrt-llm module must be the same with the torch module"
)
for k, v in self.named_parameters():
v.value = tm[k].detach().cpu().numpy()
@ -223,17 +217,16 @@ class Module(object):
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
main_str = self._get_name() + '('
child_lines.append("(" + key + "): " + mod_str)
main_str = self._get_name() + "("
if child_lines:
# simple one-liner info, which most builtin Modules will use
main_str += '\n ' + '\n '.join(child_lines) + '\n'
main_str += ')'
main_str += "\n " + "\n ".join(child_lines) + "\n"
main_str += ")"
return main_str
class ModuleList(Module):
def __init__(self, modules) -> None:
super(ModuleList, self).__init__()
offset = len(self)
@ -241,10 +234,10 @@ class ModuleList(Module):
self._modules[str(offset + i)] = module
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules"""
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
raise IndexError("index {} is out of range".format(idx))
if idx < 0:
idx += len(self)
return str(idx)

View File

@ -36,19 +36,21 @@ from tensorrt_llm.logger import logger
from ._common import _is_building
if psutil is None:
logger.warning("A required package 'psutil' is not installed. Will not "
"monitor the host memory usages. Please install the package "
"first, e.g, 'pip install psutil'.")
logger.warning(
"A required package 'psutil' is not installed. Will not "
"monitor the host memory usages. Please install the package "
"first, e.g, 'pip install psutil'."
)
if pynvml is None:
logger.warning(
"A required package 'pynvml' is not installed. Will not "
"monitor the device memory usages. Please install the package "
"first, e.g, 'pip install nvidia-ml-py>=12'.")
"first, e.g, 'pip install nvidia-ml-py>=12'."
)
class Timer:
def __init__(self):
self._start_times = {}
self._total_elapsed_times = {}
@ -77,9 +79,9 @@ class Timer:
self._total_elapsed_times.pop(tag, None)
def summary(self):
logger.info('Profile Results')
logger.info("Profile Results")
for tag, elapsed_time in self._total_elapsed_times.items():
logger.info(f' - {tag.ljust(30, ".")}: {elapsed_time:.6f} (sec)')
logger.info(f" - {tag.ljust(30, '.')}: {elapsed_time:.6f} (sec)")
_default_timer = Timer()
@ -105,11 +107,10 @@ def summary():
_default_timer.summary()
MemUnitType = Literal['GiB', 'MiB', 'KiB']
MemUnitType = Literal["GiB", "MiB", "KiB"]
class PyNVMLContext:
def __enter__(self):
if pynvml is not None:
pynvml.nvmlInit()
@ -141,9 +142,7 @@ def host_memory_info(pid: Optional[int] = None) -> Tuple[int, int, int]:
return 0, 0, 0 # used, free, total
def device_memory_info(
device: Optional[Union[torch.device,
int]] = None) -> Tuple[int, int, int]:
def device_memory_info(device: Optional[Union[torch.device, int]] = None) -> Tuple[int, int, int]:
if pynvml is not None:
if device is None:
device = torch.cuda.current_device()
@ -156,8 +155,8 @@ def device_memory_info(
def bytes_to_target_unit(mem_bytes: int, unit: MemUnitType) -> float:
units = {'GiB': 1 << 30, 'MiB': 1 << 20, 'KiB': 1 << 10}
_rename_map = {'GB': 'GiB', 'MB': 'MiB', 'KB': 'KiB'}
units = {"GiB": 1 << 30, "MiB": 1 << 20, "KiB": 1 << 10}
_rename_map = {"GB": "GiB", "MB": "MiB", "KB": "KiB"}
if unit not in units:
unit = _rename_map[unit]
return float(mem_bytes) / units[unit]
@ -165,51 +164,60 @@ def bytes_to_target_unit(mem_bytes: int, unit: MemUnitType) -> float:
def _format(mem_bytes: int, unit: MemUnitType) -> str:
mem_usage = bytes_to_target_unit(mem_bytes, unit)
return f'{mem_usage:.4f} ({unit})'
return f"{mem_usage:.4f} ({unit})"
def _print_mem_message(msg: str, tag: Optional[str] = None):
if tag:
msg = f'{tag} - {msg}'
logger.info(f'[MemUsage] {msg}')
msg = f"{tag} - {msg}"
logger.info(f"[MemUsage] {msg}")
def print_host_memory_usage(tag: Optional[str] = None,
unit: MemUnitType = 'GiB'):
def print_host_memory_usage(tag: Optional[str] = None, unit: MemUnitType = "GiB"):
if psutil is None:
return
alloc_mem, _, _ = host_memory_info()
msg = f'Allocated Host Memory {_format(alloc_mem, unit)}'
msg = f"Allocated Host Memory {_format(alloc_mem, unit)}"
_print_mem_message(msg, tag)
def print_device_memory_usage(
tag: Optional[str] = None,
unit: MemUnitType = 'GiB',
unit: MemUnitType = "GiB",
device: Optional[Union[torch.device, int]] = None,
):
alloc_mem, _, _ = device_memory_info(device)
msg = f'Allocated Device Memory {_format(alloc_mem, unit)}'
msg = f"Allocated Device Memory {_format(alloc_mem, unit)}"
_print_mem_message(msg, tag)
def print_memory_usage(
tag: Optional[str] = None,
unit: MemUnitType = 'GiB',
unit: MemUnitType = "GiB",
device: Optional[Union[torch.device, int]] = None,
):
alloc_host_mem, _, _ = host_memory_info()
alloc_device_mem, _, _ = device_memory_info(device=device)
msg = f'Allocated Memory: Host {_format(alloc_host_mem, unit)} '\
f'Device {_format(alloc_device_mem, unit)}'
msg = (
f"Allocated Memory: Host {_format(alloc_host_mem, unit)} "
f"Device {_format(alloc_device_mem, unit)}"
)
_print_mem_message(msg, tag)
@_is_building
def check_gpt_mem_usage(engine, kv_dtype, use_gpt_attention_plugin,
paged_kv_cache, max_batch_size, max_beam_width,
max_seq_len, local_num_kv_heads, head_size,
num_layers) -> int:
def check_gpt_mem_usage(
engine,
kv_dtype,
use_gpt_attention_plugin,
paged_kv_cache,
max_batch_size,
max_beam_width,
max_seq_len,
local_num_kv_heads,
head_size,
num_layers,
) -> int:
# Get the amount of memory
runtime = trt.Runtime(logger.trt_logger)
# 1. TensorRT engine activation memory
@ -220,40 +228,51 @@ def check_gpt_mem_usage(engine, kv_dtype, use_gpt_attention_plugin,
activation_size = cuda_engine.device_memory_size_v2 / 1024 / 1024
del cuda_engine
except Exception:
logger.warning(
f'Exception when deserializing engine: {traceback.format_exc()}')
logger.warning(f'Activation memory size will be regarded as 0.')
logger.info(f'Activation memory size: {activation_size:.2f} MiB')
logger.warning(f"Exception when deserializing engine: {traceback.format_exc()}")
logger.warning("Activation memory size will be regarded as 0.")
logger.info(f"Activation memory size: {activation_size:.2f} MiB")
# 2. Weights
weights_size = bytes_to_target_unit(engine.nbytes, 'MiB')
logger.info(f'Weights memory size: {weights_size:.2f} MiB')
weights_size = bytes_to_target_unit(engine.nbytes, "MiB")
logger.info(f"Weights memory size: {weights_size:.2f} MiB")
# 3. Estimated max KV Cache size
kv_cache_size = max_batch_size * max_beam_width * 2 * local_num_kv_heads * max_seq_len * head_size * num_layers * kv_dtype.itemsize
kv_cache_size = (
max_batch_size
* max_beam_width
* 2
* local_num_kv_heads
* max_seq_len
* head_size
* num_layers
* kv_dtype.itemsize
)
# without plugin, we need two set of kv cache buffers,
# one for inputs, and the other for outputs.
if not use_gpt_attention_plugin:
kv_cache_size *= 2
kv_cache_size = bytes_to_target_unit(kv_cache_size, 'MiB')
logger.info(f'Max KV Cache memory size: {kv_cache_size:.2f} MiB')
kv_cache_size = bytes_to_target_unit(kv_cache_size, "MiB")
logger.info(f"Max KV Cache memory size: {kv_cache_size:.2f} MiB")
# Estimated total amount of memory
est_memory_size = activation_size + weights_size + kv_cache_size
logger.info(
f'Estimated max memory usage on runtime: {est_memory_size:.2f} MiB')
logger.info(f"Estimated max memory usage on runtime: {est_memory_size:.2f} MiB")
_, _, total_mem = device_memory_info(torch.cuda.current_device())
total_mem = bytes_to_target_unit(total_mem, 'MiB')
total_mem = bytes_to_target_unit(total_mem, "MiB")
if est_memory_size > total_mem:
logger.warning(
f'Engine is successfully built, but GPU Memory ({total_mem:.2f} MB)'
' may not be enough when running inference on max shape.')
f"Engine is successfully built, but GPU Memory ({total_mem:.2f} MB)"
" may not be enough when running inference on max shape."
)
if paged_kv_cache:
logger.warning(f'Since paged_kv_cache is enabled, the max KV Cache '
'memory size is a estimate for very extreme cases, '
'it\'s possible that most cases won\'t meet OOM.')
logger.warning(
"Since paged_kv_cache is enabled, the max KV Cache "
"memory size is a estimate for very extreme cases, "
"it's possible that most cases won't meet OOM."
)
else:
logger.warning(f'Enabling `--paged_kv_cache` could help reduce the '
'GPU memory usage on runtime.')
logger.warning(
"Enabling `--paged_kv_cache` could help reduce the GPU memory usage on runtime."
)
return est_memory_size

View File

@ -10,15 +10,13 @@ if TYPE_CHECKING:
class PromptAdapterManager:
def __init__(self):
self._uid_counter = 0
self._uid_to_weights: Dict[str, torch.Tensor] = {}
def load_from_ckpt(self,
model_dirs: List[str],
model_config: 'ModelConfig',
uids: Optional[List[str]] = None):
def load_from_ckpt(
self, model_dirs: List[str], model_config: "ModelConfig", uids: Optional[List[str]] = None
):
if uids is None:
uids = [self._generate_uid() for _ in range(len(model_dirs))]
assert len(uids) == len(model_dirs)
@ -34,10 +32,10 @@ class PromptAdapterManager:
return
for uid, model_dir in zip(new_uids, new_model_dirs):
state_dict = load_state_dict(
get_model_path(model_dir, 'adapter_model'))
self._uid_to_weights[uid] = state_dict['prompt_embeddings'].to(
str_dtype_to_torch(model_config.dtype))
state_dict = load_state_dict(get_model_path(model_dir, "adapter_model"))
self._uid_to_weights[uid] = state_dict["prompt_embeddings"].to(
str_dtype_to_torch(model_config.dtype)
)
@property
def uid_to_weights(self):

View File

@ -25,8 +25,13 @@ import tensorrt as trt
import torch
from ._common import default_trtnet
from ._utils import (TensorWrapper, np_dtype_to_trt, str_dtype_to_trt,
torch_dtype_to_trt, trt_dtype_to_torch)
from ._utils import (
TensorWrapper,
np_dtype_to_trt,
str_dtype_to_trt,
torch_dtype_to_trt,
trt_dtype_to_torch,
)
from .functional import Tensor, _create_tensor
from .plugin.plugin import TRT_LLM_PLUGIN_NAMESPACE
@ -42,31 +47,31 @@ class PluginInfo:
plugin_num_outputs: int
def __hash__(self):
return hash(
(self.plugin_name, self.plugin_namespace, self.plugin_version))
return hash((self.plugin_name, self.plugin_namespace, self.plugin_version))
def __eq__(self, obj):
if not isinstance(obj, PluginInfo):
return False
return (self.plugin_name == obj.plugin_name
and self.plugin_namespace == obj.plugin_namespace
and self.plugin_version == obj.plugin_version)
return (
self.plugin_name == obj.plugin_name
and self.plugin_namespace == obj.plugin_namespace
and self.plugin_version == obj.plugin_version
)
def make_expr(
exprBuilder: Union[trt.IExprBuilder, Type[None]],
dim: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]],
) -> Union[trt.IDimensionExpr, Type[None]]:
"""
Parameters:
exprBuilder: Union[trt.IExprBuilder, Type[None]]
The trt.exprBuilder object. Using it to check whether dim has the same exprBuilder or to create trt.IDimensionExpr if necessary.
"""Make a dimension expression.
dim : Union["DimensionExpr", int, Type[None]]
The input dim
Parameters:
exprBuilder: The trt.exprBuilder object. Using it to check whether dim has the same exprBuilder
or to create trt.IDimensionExpr if necessary.
dim: The input dim.
Returns:
A trt.IDimensionExpr object
A trt.IDimensionExpr object.
"""
if isinstance(dim, DimensionExpr):
assert exprBuilder == dim.exprBuilder
@ -87,9 +92,7 @@ def expr_operation(
operation: trt.DimensionOperation,
exprBuilder: trt.IExprBuilder,
):
"""
The function to do expr operation with None support
"""
"""The function to do expr operation with None support."""
if exprBuilder is None or a is None or b is None:
expr = None
else:
@ -98,9 +101,7 @@ def expr_operation(
class DimensionExpr:
'''
The class to wrap `trt.IDimensionExpr` to support more pythonic methods.
'''
"""The class to wrap `trt.IDimensionExpr` to support more pythonic methods."""
def __init__(
self,
@ -115,100 +116,70 @@ class DimensionExpr:
return self._expr
@expr.setter
def expr(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def expr(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
self._expr = make_expr(self.exprBuilder, expr)
def __add__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __add__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.SUM,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.SUM, self.exprBuilder)
def __radd__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __radd__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
return self.__add__(expr)
def __mul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __mul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.PROD,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.PROD, self.exprBuilder)
def __rmul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __rmul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
return self.__mul__(expr)
def __sub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __sub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.SUB,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.SUB, self.exprBuilder)
def __rsub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __rsub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(expr, self.expr, trt.DimensionOperation.SUB,
self.exprBuilder)
return expr_operation(expr, self.expr, trt.DimensionOperation.SUB, self.exprBuilder)
def __eq__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __eq__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.EQUAL,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.EQUAL, self.exprBuilder)
def __lt__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __lt__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.LESS,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.LESS, self.exprBuilder)
def __floordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __floordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.FLOOR_DIV,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.FLOOR_DIV, self.exprBuilder)
def __rfloordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr,
int, Type[None]]):
def __rfloordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(expr, self.expr, trt.DimensionOperation.FLOOR_DIV,
self.exprBuilder)
return expr_operation(expr, self.expr, trt.DimensionOperation.FLOOR_DIV, self.exprBuilder)
def __truediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __truediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.CEIL_DIV,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.CEIL_DIV, self.exprBuilder)
def __rtruediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def __rtruediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(expr, self.expr, trt.DimensionOperation.CEIL_DIV,
self.exprBuilder)
return expr_operation(expr, self.expr, trt.DimensionOperation.CEIL_DIV, self.exprBuilder)
def max(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def max(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.MAX,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.MAX, self.exprBuilder)
def min(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]):
def min(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
expr = make_expr(self.exprBuilder, expr)
return expr_operation(self.expr, expr, trt.DimensionOperation.MIN,
self.exprBuilder)
return expr_operation(self.expr, expr, trt.DimensionOperation.MIN, self.exprBuilder)
class ShapeExpr:
'''
The class to Wrap `trt.DimsExprs` to support more pythonic methods.
'''
"""The class to Wrap `trt.DimsExprs` to support more pythonic methods."""
def __init__(
self,
dims: Union[Sequence[trt.IDimensionExpr], Sequence[int],
Sequence[type[None]]],
dims: Union[Sequence[trt.IDimensionExpr], Sequence[int], Sequence[type[None]]],
exprBuilder: Union[trt.IExprBuilder, type[None]],
):
self.exprBuilder = exprBuilder
@ -221,13 +192,11 @@ class ShapeExpr:
@dims.setter
def dims(
self,
dims: Sequence[Union["DimensionExpr", trt.IDimensionExpr, int,
Type[None]]],
dims: Sequence[Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]],
):
if dims is not None:
self._dims = [
DimensionExpr(make_expr(self.exprBuilder, i), self.exprBuilder)
for i in dims
DimensionExpr(make_expr(self.exprBuilder, i), self.exprBuilder) for i in dims
]
else:
self._dims = None
@ -246,8 +215,7 @@ class ShapeExpr:
if self._dims is None:
return
assert index < len(self._dims)
value = DimensionExpr(make_expr(self.exprBuilder, value),
self.exprBuilder)
value = DimensionExpr(make_expr(self.exprBuilder, value), self.exprBuilder)
self._dims[index] = value
def __len__(self):
@ -261,9 +229,10 @@ class ShapeExpr:
class SymTensor:
'''
The class to represent symbolic tensors. Only contains dtype and shape information for users to write their own shape/dtype inference function.
'''
"""The class to represent symbolic tensors.
Only contains dtype and shape information for users to write their own shape/dtype inference function.
"""
def __init__(
self,
@ -290,8 +259,7 @@ class SymTensor:
return self._dtype
@dtype.setter
def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType,
Type[None]]):
def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType, Type[None]]):
if isinstance(dtype, torch.dtype):
self._dtype = torch_dtype_to_trt(dtype)
elif isinstance(dtype, str):
@ -313,13 +281,15 @@ def _convert_return_value_to_list(ret):
return ret
class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
trt.IPluginV3OneRuntime):
'''
The base class of TRT-LLM plugin.
class PluginBase(
trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime
):
"""The base class of TRT-LLM plugin.
All TRT-LLM plugin should inherit this class and at least rewrite `forward` and `shape_dtype_inference` function. `forward` defines the plugin's compute flow while `shape_dtype_inference` defines how would the output tensor's shape and dtype be inferenced from the input tensor.
'''
All TRT-LLM plugin should inherit this class and at least rewrite `forward` and `shape_dtype_inference`
function. `forward` defines the plugin's compute flow while `shape_dtype_inference` defines how would
the output tensor's shape and dtype be inferenced from the input tensor.
"""
_plugin_creator = None
_no_serialize_attr = {"_current_stream", "_workspace"}
@ -327,9 +297,9 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
def __init__(self):
cls = type(self)
# Runtime check for plugin decorator
assert (
cls._plugin_creator is not None
), "Please make sure the plugin is registered through `@trtllm_plugin`"
assert cls._plugin_creator is not None, (
"Please make sure the plugin is registered through `@trtllm_plugin`"
)
assert cls != PluginBase
trt.IPluginV3.__init__(self)
@ -381,8 +351,7 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
pass
def get_output_data_types(self, input_types):
ret = self.shape_dtype_inference(
[SymTensor(i, ShapeExpr(None, None)) for i in input_types])
ret = self.shape_dtype_inference([SymTensor(i, ShapeExpr(None, None)) for i in input_types])
ret = _convert_return_value_to_list(ret)
assert len(ret) == self.num_outputs
@ -392,11 +361,11 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
return [i.dtype for i in ret]
def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
assert len(
shape_inputs) == 0, "Currently we do not support shape inputs"
assert len(shape_inputs) == 0, "Currently we do not support shape inputs"
ret = self.shape_dtype_inference(
[SymTensor(None, ShapeExpr(i, exprBuilder)) for i in inputs])
[SymTensor(None, ShapeExpr(i, exprBuilder)) for i in inputs]
)
ret = _convert_return_value_to_list(ret)
assert len(ret) == self.num_outputs
@ -406,8 +375,9 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
return [i.shape.to_trt() for i in ret]
def supports_format_combination(self, pos, in_out, num_inputs):
"""
By default, TRT-LLM plugin supports all dtype and linear format. It is the users responsibility to check the dtype the plugin supported in `forward` function.
"""By default, TRT-LLM plugin supports all dtype and linear format.
It is the users responsibility to check the dtype the plugin supported in `forward` function.
"""
assert pos < len(in_out)
@ -422,13 +392,11 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
def get_fields_to_serialize(self):
buffer = pickle.dumps(self._get_dict_to_serialize())
return trt.PluginFieldCollection([
trt.PluginField("__plugin_pickle_obj__", buffer,
trt.PluginFieldType.UNKNOWN)
])
return trt.PluginFieldCollection(
[trt.PluginField("__plugin_pickle_obj__", buffer, trt.PluginFieldType.UNKNOWN)]
)
def enqueue(self, input_desc, output_desc, inputs, outputs, workspace,
stream):
def enqueue(self, input_desc, output_desc, inputs, outputs, workspace, stream):
torch_stream = torch.cuda.ExternalStream(stream_ptr=stream)
self.workspace = workspace
self.current_stream = stream
@ -437,31 +405,32 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
self.forward(
tuple(
TensorWrapper.from_trt_desc(input_desc[i], inputs[i])
for i in range(len(input_desc))),
for i in range(len(input_desc))
),
tuple(
TensorWrapper.from_trt_desc(output_desc[i], outputs[i])
for i in range(len(output_desc))),
for i in range(len(output_desc))
),
)
self.current_stream = -1
def __call__(self, *args: Union[Sequence[TensorWrapper],
Sequence[torch.Tensor]]):
def __call__(self, *args: Union[Sequence[TensorWrapper], Sequence[torch.Tensor]]):
is_trtllm = True
for i in args:
is_trtllm &= isinstance(i, Tensor)
if not is_trtllm:
for i in args:
assert isinstance(
i, torch.Tensor
), "Plugin inputs must be `tensorrt_llm.Tensor`s or `torch.Tensor`s"
assert isinstance(i, torch.Tensor), (
"Plugin inputs must be `tensorrt_llm.Tensor`s or `torch.Tensor`s"
)
sym_tensors = self.shape_dtype_inference(
[SymTensor(i.dtype, [j for j in i.shape]) for i in args])
[SymTensor(i.dtype, [j for j in i.shape]) for i in args]
)
sym_tensors = _convert_return_value_to_list(sym_tensors)
ret = [
torch.empty(sym_tensor.shape,
dtype=trt_dtype_to_torch(sym_tensor.dtype))
torch.empty(sym_tensor.shape, dtype=trt_dtype_to_torch(sym_tensor.dtype))
for sym_tensor in sym_tensors
]
self.current_stream = torch.cuda.current_stream().cuda_stream
@ -492,20 +461,20 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
"By default TRT should not set tactics since PluginBase do not provide custom tactic."
)
def forward(self, inputs: Sequence[TensorWrapper],
outputs: Sequence[TensorWrapper]):
'''
Expect users to rewrite this function to define the compute flow. There are a few special attributes for users to get access to some resources.
def forward(self, inputs: Sequence[TensorWrapper], outputs: Sequence[TensorWrapper]):
"""Expect users to rewrite this function to define the compute flow.
There are a few special attributes for users to get access to some resources.
`self.workspace`: The workspace address of TRT managed workspace.
`self.current_stream`: The CUDA stream this plugin is expected to execute on. By default `PluginBase` set the torch.cuda.current_stream() to this stream. This attribute is for the toolkit that doesn't work with torch's stream.
'''
`self.current_stream`: The CUDA stream this plugin is expected to execute on. By default
`PluginBase` set the torch.cuda.current_stream() to this stream. This attribute is for the
toolkit that doesn't work with torch's stream.
"""
raise NotImplementedError
def shape_dtype_inference(self, inputs: Sequence[SymTensor]):
'''
Expect users to rewrite this function to define the shape dtype inference for output tensors.
'''
"""Expect users to rewrite this function to define the shape dtype inference for output tensors."""
raise NotImplementedError
def _get_dict_to_serialize(self):
@ -517,7 +486,6 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
class PluginCreatorBase(trt.IPluginCreatorV3One):
def __init__(self):
super().__init__()
@ -535,50 +503,53 @@ class PluginCreatorBase(trt.IPluginCreatorV3One):
def trtllm_plugin(
plugin_name: str,
*,
plugin_version: str = "1",
plugin_namespace: str = TRT_LLM_PLUGIN_NAMESPACE,
plugin_num_outputs: Union[int, Type[None]] = None,
deepcopy_clone: bool = True,
no_serialize_attr: Sequence[str] = set(),
plugin_name: str,
*,
plugin_version: str = "1",
plugin_namespace: str = TRT_LLM_PLUGIN_NAMESPACE,
plugin_num_outputs: Union[int, Type[None]] = None,
deepcopy_clone: bool = True,
no_serialize_attr: Sequence[str] = set(),
):
def plugin_registration(plugin_cls):
assert issubclass(plugin_cls, PluginBase)
assert hasattr(
plugin_cls,
"__dict__"), "Plugin wrapper uses `__dict__` to track plugin states"
assert hasattr(plugin_cls, "__dict__"), (
"Plugin wrapper uses `__dict__` to track plugin states"
)
nonlocal plugin_num_outputs
annotation = inspect.signature(
plugin_cls.shape_dtype_inference).return_annotation
annotation = inspect.signature(plugin_cls.shape_dtype_inference).return_annotation
origin_annotation = typing.get_origin(annotation)
if origin_annotation is tuple or annotation is SymTensor:
if origin_annotation is tuple:
element_types = typing.get_args(annotation)
for ty in element_types:
assert (
ty == SymTensor
), f"Plugin {plugin_name}'s `shape_dtype_inference` return annotation must be SymTensor or a tuple of SymTensor"
assert ty == SymTensor, (
f"Plugin {plugin_name}'s `shape_dtype_inference` return annotation must be SymTensor "
"or a tuple of SymTensor"
)
infered_num_outputs = len(element_types)
else:
infered_num_outputs = 1
if plugin_num_outputs is not None:
assert (
plugin_num_outputs == infered_num_outputs
), f"Plugin {plugin_name}'s `_num_outputs` and return annotation mismatch, {plugin_cls._num_outputs} != {infered_num_outputs}"
assert plugin_num_outputs == infered_num_outputs, (
f"Plugin {plugin_name}'s `_num_outputs` and return annotation mismatch, "
f"{plugin_cls._num_outputs} != {infered_num_outputs}"
)
plugin_num_outputs = infered_num_outputs
else:
assert (
plugin_num_outputs is not None
), f"Must specify `num_outputs` or valid `shape_dtype_inference` return annotation for {plugin_name}. The valid types are SymTensor or a tuple of SymTensor, got {annotation}."
assert plugin_num_outputs is not None, (
"Must specify `num_outputs` or valid `shape_dtype_inference` return annotation for "
f"{plugin_name}. The valid types are SymTensor or a tuple of SymTensor, got {annotation}."
)
plugin_info = PluginInfo(3, plugin_namespace, plugin_name,
plugin_version, plugin_num_outputs)
assert (
plugin_info not in _plugin_registered
), f"Redefine plugin with info: {plugin_info} which is previously defined as {_plugin_registered[plugin_info]}"
plugin_info = PluginInfo(
3, plugin_namespace, plugin_name, plugin_version, plugin_num_outputs
)
assert plugin_info not in _plugin_registered, (
f"Redefine plugin with info: {plugin_info} which is previously defined as "
f"{_plugin_registered[plugin_info]}"
)
_plugin_registered[plugin_info] = plugin_info
plugin_cls._plugin_name = plugin_name
@ -598,8 +569,7 @@ def trtllm_plugin(
plugin_creator.plugin_cls = plugin_cls
plugin_cls._plugin_creator = plugin_creator
ret = plugin_registry.register_creator(plugin_creator,
plugin_cls._plugin_namespace)
ret = plugin_registry.register_creator(plugin_creator, plugin_cls._plugin_namespace)
assert ret, f"Plugin: {plugin_cls} register failed, please check the error log."

View File

@ -12,8 +12,7 @@ from tensorrt_llm.bindings import executor as tllme
@dataclass(slots=True, kw_only=True)
class GuidedDecodingParams:
"""
Guided decoding parameters for text generation. Only one of the fields could be effective.
"""Guided decoding parameters for text generation. Only one of the fields could be effective.
Args:
json (str, pydantic.main.BaseModel, dict, optional): The generated text is amenable to json format with additional user-specified restrictions, namely schema. Defaults to None.
@ -21,7 +20,8 @@ class GuidedDecodingParams:
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None.
"""
""" # noqa: E501
json: Optional[Union[str, BaseModel, dict]] = None
regex: Optional[str] = None
grammar: Optional[str] = None
@ -30,12 +30,10 @@ class GuidedDecodingParams:
def _validate(self):
num_guides = 0
for field in fields(self):
num_guides += bool(getattr(self, field.name))
for _field in fields(self):
num_guides += bool(getattr(self, _field.name))
if num_guides > 1:
raise ValueError(
f"Only one guide can be used for a request, but got {num_guides}."
)
raise ValueError(f"Only one guide can be used for a request, but got {num_guides}.")
class LogprobParams(NamedTuple):
@ -57,16 +55,23 @@ class LogitsProcessor(ABC):
"""
@abstractmethod
def __call__(self, req_id: int, logits: torch.Tensor,
token_ids: List[List[int]], stream_ptr: Optional[int],
client_id: Optional[int]) -> None:
def __call__(
self,
req_id: int,
logits: torch.Tensor,
token_ids: List[List[int]],
stream_ptr: Optional[int],
client_id: Optional[int],
) -> None:
"""Logits processing callback. The callback is expected to inplace modify the logits.
Args:
req_id (int): Request id.
logits (torch.Tensor): Logits tensor to be modified.
token_ids (List[List[int]]): Token ids produced by the request so far. The shape is beam_width * sequence_length.
stream_ptr (int, optional): The operation stream used by the logits tensor. Not required for PyTorch backend.
token_ids (List[List[int]]): Token ids produced by the request so far.
The shape is beam_width * sequence_length.
stream_ptr (int, optional): The operation stream used by the logits tensor.
Not required for PyTorch backend.
client_id (int, optional): An optional client id.
"""
pass # noqa
@ -82,15 +87,21 @@ class BatchedLogitsProcessor(ABC):
"""
@abstractmethod
def __call__(self, req_ids: List[int], logits: List[torch.Tensor],
token_ids: List[List[List[int]]], stream_ptr: int,
client_ids: List[Optional[int]]) -> None:
def __call__(
self,
req_ids: List[int],
logits: List[torch.Tensor],
token_ids: List[List[List[int]]],
stream_ptr: int,
client_ids: List[Optional[int]],
) -> None:
"""Batched logits processing callback. The callback is expected to inplace modify the logits.
Args:
req_ids (List[int]): A batch of request ids.
logits (List[torch.Tensor]): A batch of the logits tensors.
token_ids (List[List[List[int]]]): A batch of the token ids produced by the requests so far. The shape is batch * beam_width * sequence_length.
token_ids (List[List[List[int]]]): A batch of the token ids produced by the requests so far.
The shape is batch * beam_width * sequence_length.
stream_ptr (int): The operation stream used by the logits tensors.
client_ids (List[Optional[int]]): A batch of optional client ids.
"""
@ -99,21 +110,20 @@ class BatchedLogitsProcessor(ABC):
@dataclass(slots=True, kw_only=True)
class AdditionalModelOutput:
"""
An additional output to gather from the model.
"""An additional output to gather from the model.
Args:
name (str): The name of the additional output to gather from the model.
gather_context (bool): A value indicating whether or not to gather the additional output from the context too. Defaults to False.
"""
""" # noqa: E501
name: str
gather_context: bool
@dataclass(slots=True, kw_only=True)
class SamplingParams:
"""
Sampling parameters for text generation.
"""Sampling parameters for text generation.
Usage Examples:
@ -179,7 +189,8 @@ class SamplingParams:
truncate_prompt_tokens (int, optional): If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None.
skip_special_tokens (bool): Whether to skip special tokens in the output. Defaults to True.
spaces_between_special_tokens (bool): Whether to add spaces between special tokens in the output. Defaults to True.
"""
""" # noqa: E501
# [TO DEVELOPER] This class provides an interface to LLMAPI users.
# Internally, it manages and dispatches fields to Python bindings of C++ objects, currently including:
# (1) all fields of tllme.SamplingConfig;
@ -194,19 +205,14 @@ class SamplingParams:
max_tokens: int = 32
bad: Optional[Union[str, List[str]]] = None
bad_token_ids: Optional[List[int]] = None
_bad_word_ids: Optional[List[List[int]]] = field(default=None,
init=False,
repr=False)
_bad_word_ids: Optional[List[List[int]]] = field(default=None, init=False, repr=False)
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
include_stop_str_in_output: bool = False
_stop_word_ids: Optional[List[List[int]]] = field(default=None,
init=False,
repr=False)
_stop_word_ids: Optional[List[List[int]]] = field(default=None, init=False, repr=False)
embedding_bias: Optional[torch.Tensor] = None
logits_processor: Optional[Union[LogitsProcessor,
List[LogitsProcessor]]] = None
logits_processor: Optional[Union[LogitsProcessor, List[LogitsProcessor]]] = None
apply_batched_logits_processor: bool = False
n: int = 1
@ -273,25 +279,28 @@ class SamplingParams:
self._validate()
def _validate(self):
''' Verify the sampling parameters.
"""Verify the sampling parameters.
This function verifies the sampling parameters in the LLM API, which
may have stricter requirements than the Executor class of C++ runtime.
For instance, while the greedy decoding with n > 1 is capable in the
Executor class of C++ runtime, the LLM API disallows such combination.
'''
"""
if self.best_of < self.n:
raise ValueError(
f"best_of ({self.best_of}) cannot be less than n ({self.n})")
raise ValueError(f"best_of ({self.best_of}) cannot be less than n ({self.n})")
if (self.best_of > 1 and self._greedy_decoding
and not os.environ.get('TLLM_ALLOW_N_GREEDY_DECODING', None)):
if (
self.best_of > 1
and self._greedy_decoding
and not os.environ.get("TLLM_ALLOW_N_GREEDY_DECODING", None)
):
raise ValueError(
f'Greedy decoding in the LLM API does not allow multiple '
f'returns. Please set to best_of=1, got best_of={self.best_of}. '
f'Please set to best_of=1 or set an environment variable '
f'TLLM_ALLOW_N_GREEDY_DECODING=1 to allow best_of > 1 '
f'under the greedy decoding.')
f"Greedy decoding in the LLM API does not allow multiple "
f"returns. Please set to best_of=1, got best_of={self.best_of}. "
f"Please set to best_of=1 or set an environment variable "
f"TLLM_ALLOW_N_GREEDY_DECODING=1 to allow best_of > 1 "
f"under the greedy decoding."
)
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
raise ValueError(
@ -303,14 +312,15 @@ class SamplingParams:
# correct types as users might pass in logprob=True for Top-1 logprobs
self.logprobs = self.logprobs and int(self.logprobs)
self.prompt_logprobs = self.prompt_logprobs and int(
self.prompt_logprobs)
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
@property
def _greedy_decoding(self) -> bool:
return (not self.use_beam_search
and (self.top_k is None or self.top_k == 1)
and (self.top_p is None or self.top_p == 0.0))
return (
not self.use_beam_search
and (self.top_k is None or self.top_k == 1)
and (self.top_p is None or self.top_p == 0.0)
)
@property
def _need_return_context_logits(self) -> bool:
@ -320,9 +330,7 @@ class SamplingParams:
def _need_return_generation_logits(self) -> bool:
return self.return_generation_logits and not self._generation_logits_auto_enabled
def _setup(self,
tokenizer,
add_special_tokens: bool = False) -> 'SamplingParams':
def _setup(self, tokenizer, add_special_tokens: bool = False) -> "SamplingParams":
if self.end_id is None:
self.end_id = tokenizer.eos_token_id
self.pad_id = tokenizer.pad_token_id
@ -332,15 +340,13 @@ class SamplingParams:
if self.bad is not None:
strs = [self.bad] if isinstance(self.bad, str) else self.bad
self._bad_word_ids = [
tokenizer.encode(s, add_special_tokens=add_special_tokens)
for s in strs
tokenizer.encode(s, add_special_tokens=add_special_tokens) for s in strs
]
if self.stop is not None:
strs = [self.stop] if isinstance(self.stop, str) else self.stop
self._stop_word_ids = [
tokenizer.encode(s, add_special_tokens=add_special_tokens)
for s in strs
tokenizer.encode(s, add_special_tokens=add_special_tokens) for s in strs
]
return self
@ -356,7 +362,8 @@ class SamplingParams:
if self._bad_word_ids is None:
raise RuntimeError(
f"{self.__class__.__name__}.bad ({self.bad}) is not processed by tokenizer, "
"please call the setup method.")
"please call the setup method."
)
return words + self._bad_word_ids
def _get_stop_words(self) -> List[List[int]]:
@ -370,11 +377,11 @@ class SamplingParams:
if self._stop_word_ids is None:
raise RuntimeError(
f"{self.__class__.__name__}.stop ({self.stop}) is not processed by tokenizer, "
"please call the setup method.")
"please call the setup method."
)
return words + self._stop_word_ids
def _get_stop_reasons_and_words(
self) -> List[Tuple[Union[str, int], List[List[int]]]]:
def _get_stop_reasons_and_words(self) -> List[Tuple[Union[str, int], List[List[int]]]]:
stop_reasons = []
if self.stop_token_ids is not None:
stop_reasons.extend(self.stop_token_ids)
@ -400,37 +407,30 @@ class SamplingParams:
# | Sampling | use_beam_search | beam_width == 1 |
# | Sampling | n | num_return_sequences |
# | Sampling | best_of | no corresponding param |
fields = {
f
for f in dir(tllme.SamplingConfig) if not f.startswith('__')
}
fields = {f for f in dir(tllme.SamplingConfig) if not f.startswith("__")}
unmatched_params = [
'num_return_sequences',
'beam_width',
'n',
'best_of',
'use_beam_search',
"num_return_sequences",
"beam_width",
"n",
"best_of",
"use_beam_search",
]
llmapi_to_rt_param_map = {
f: getattr(self, f)
for f in fields if f not in unmatched_params
}
llmapi_to_rt_param_map = {f: getattr(self, f) for f in fields if f not in unmatched_params}
if self.use_beam_search:
llmapi_to_rt_param_map['num_return_sequences'] = self.n
llmapi_to_rt_param_map['beam_width'] = self.best_of
llmapi_to_rt_param_map["num_return_sequences"] = self.n
llmapi_to_rt_param_map["beam_width"] = self.best_of
else:
llmapi_to_rt_param_map['num_return_sequences'] = self.best_of
llmapi_to_rt_param_map['beam_width'] = 1
llmapi_to_rt_param_map["num_return_sequences"] = self.best_of
llmapi_to_rt_param_map["beam_width"] = 1
return tllme.SamplingConfig(**llmapi_to_rt_param_map)
def _get_output_config(self,
is_pytorch_backend: bool = False
) -> tllme.OutputConfig:
def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputConfig:
sampling_param_fields = set(dir(SamplingParams))
fields = [
f for f in dir(tllme.OutputConfig)
if not f.startswith('__') and f in sampling_param_fields
f
for f in dir(tllme.OutputConfig)
if not f.startswith("__") and f in sampling_param_fields
]
config_kwargs = {f: getattr(self, f) for f in fields}
@ -447,8 +447,7 @@ class SamplingParams:
return None
if self.guided_decoding.json_object:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.JSON)
return tllme.GuidedDecodingParams(tllme.GuidedDecodingParams.GuideType.JSON)
elif self.guided_decoding.json is not None:
json_schema = self.guided_decoding.json
if isinstance(json_schema, BaseModel):
@ -456,18 +455,20 @@ class SamplingParams:
if isinstance(json_schema, dict):
json_schema = json.dumps(json_schema)
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema)
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema
)
elif self.guided_decoding.regex is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.REGEX,
self.guided_decoding.regex)
tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex
)
elif self.guided_decoding.grammar is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
self.guided_decoding.grammar)
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar
)
elif self.guided_decoding.structural_tag is not None:
return tllme.GuidedDecodingParams(
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
self.guided_decoding.structural_tag)
self.guided_decoding.structural_tag,
)
else:
return None

View File

@ -21,46 +21,52 @@ from .plugin.plugin import PluginConfig
class TopModelMixin:
'''
The Module class are reused between building blocks (like Attention, MLP) and the top level model (like LLaMAForCausalLM)
While there are some functions, like the loading hf/ft weights, or build/load trt engines are only supported by the top level model, not the building blocks.
So top level model class like: LLaMAForCausalLM shall inherit this class.
'''
"""Top model mixin.
The Module classes are reused between building blocks (like Attention, MLP) and the top level models
(like LLaMAForCausalLM).
While there are some functions, like the loading hf/ft weights, or build/load trt engines, that are
only supported by the top level model, not the building blocks.
So top level model class like: LLaMAForCausalLM shall inherit this class.
"""
def __init__(self) -> None:
pass
@classmethod
def from_hugging_face(cls,
hf_model_dir: str,
dtype: Optional[str] = 'float16',
mapping: Optional[Mapping] = None,
**kwargs):
'''
Create LLM object and load weights from hugging face
def from_hugging_face(
cls,
hf_model_dir: str,
dtype: Optional[str] = "float16",
mapping: Optional[Mapping] = None,
**kwargs,
):
"""Create LLM object and load weights from hugging face.
Parameters:
hf_model_dir: the hugging face model directory
dtype: str, the default weights data type when loading from the hugging face model
mapping: Mapping, specify the multi-gpu parallel strategy, when it's None, single GPU is used
'''
"""
raise NotImplementedError("Subclass shall override this")
def use_lora(self, lora_config: LoraConfig):
'''
Load lora weights from the give config to the module
"""Load lora weights from the give config to the module.
Parameters:
lora_config: the lora config
'''
"""
raise NotImplementedError("Subclass shall override this")
def use_prompt_tuning(self, max_prompt_embedding_table_size: str,
prompt_table_path: str):
'''Enable p tuning when build the TRT engine, call this before to_trt
'''
def use_prompt_tuning(self, max_prompt_embedding_table_size: str, prompt_table_path: str):
"""Enable p tuning when build the TRT engine, call this before to_trt."""
raise NotImplementedError
def default_plugin_config(self, **kwargs) -> PluginConfig:
'''Return the default plugin config for this model, when the plugin_config value is not given in to_trt() call.
If users need to set different plugin configs, they can start from the return object and change it.
'''
"""Return the default plugin config for this model.
This is used when the plugin_config value is not given in to_trt() call.
If users need to set different plugin configs, they can start from the return object and change it.
"""
return PluginConfig.from_dict(kwargs)