mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
linting(python): Enable ruff on more files (wave 1/N) (#5140)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
0b60da2c45
commit
dc52b67492
@ -71,9 +71,9 @@ repos:
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
files: ".*/auto_deploy/.*"
|
||||
pass_filenames: false
|
||||
- id: ruff-format
|
||||
files: ".*/auto_deploy/.*"
|
||||
pass_filenames: false
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.17
|
||||
hooks:
|
||||
|
||||
@ -10,14 +10,54 @@ build-backend = "setuptools.build_meta"
|
||||
####################################################################################################
|
||||
[tool.isort]
|
||||
line_length = 80
|
||||
extend_skip_glob = ["**/auto_deploy/**"]
|
||||
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
|
||||
# is necessary.
|
||||
extend_skip_glob = [
|
||||
"**/auto_deploy/**",
|
||||
"tensorrt_llm/_common.py",
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
"tensorrt_llm/logger.py",
|
||||
"tensorrt_llm/lora_manager.py",
|
||||
"tensorrt_llm/module.py",
|
||||
"tensorrt_llm/moe_config.py",
|
||||
"tensorrt_llm/profiler.py",
|
||||
"tensorrt_llm/prompt_adapter_manager.py",
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
]
|
||||
|
||||
[tool.yapf]
|
||||
based_on_style = "pep8"
|
||||
column_limit = 80
|
||||
|
||||
[tool.yapfignore]
|
||||
ignore_patterns = ["**/auto_deploy/**"]
|
||||
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
|
||||
# is necessary.
|
||||
ignore_patterns = [
|
||||
"**/auto_deploy/**",
|
||||
"tensorrt_llm/_common.py",
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
"tensorrt_llm/logger.py",
|
||||
"tensorrt_llm/lora_manager.py",
|
||||
"tensorrt_llm/module.py",
|
||||
"tensorrt_llm/moe_config.py",
|
||||
"tensorrt_llm/profiler.py",
|
||||
"tensorrt_llm/prompt_adapter_manager.py",
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
]
|
||||
|
||||
[tool.codespell]
|
||||
skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json"
|
||||
@ -28,7 +68,27 @@ ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te
|
||||
in-place = true
|
||||
remove_all_unused_imports = true
|
||||
remove_unused_variables = true
|
||||
exclude = ["**/auto_deploy/**"]
|
||||
# This should match the `include` in `[tool.ruff]`. See the comments in that section for why this
|
||||
# is necessary.
|
||||
exclude = [
|
||||
"**/auto_deploy/**",
|
||||
"tensorrt_llm/_common.py",
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
"tensorrt_llm/logger.py",
|
||||
"tensorrt_llm/lora_manager.py",
|
||||
"tensorrt_llm/module.py",
|
||||
"tensorrt_llm/moe_config.py",
|
||||
"tensorrt_llm/profiler.py",
|
||||
"tensorrt_llm/prompt_adapter_manager.py",
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
]
|
||||
|
||||
|
||||
####################################################################################################
|
||||
@ -44,6 +104,33 @@ include = [
|
||||
"**/auto_deploy/**/*.py",
|
||||
"**/auto_deploy/**/*.pyi",
|
||||
"**/auto_deploy/**/*.ipynb",
|
||||
# Progressively enable ruff on all the repo to keep individual changes reasonably-sized, and
|
||||
# keep merge conflicts manageable.
|
||||
# Since keeping both `yapf` and `ruff` makes no sense (given that their formatting philosophies
|
||||
# are quite different), we should move towards removing one in favor of the other. ruff's
|
||||
# formatting mirrors black's, and both are much more widely adopted than yapf. ruff is also
|
||||
# orders of magnitude faster, so we should move to deprecate `yapf`.
|
||||
# In the transition period, we should keep the `ignore_patterns` in `[tool.yapfignore]` in sync
|
||||
# with the below, so that both pre-commit hooks can complete successfully.
|
||||
"tensorrt_llm/_common.py",
|
||||
"tensorrt_llm/_dlpack_utils.py",
|
||||
"tensorrt_llm/_ipc_utils.py",
|
||||
"tensorrt_llm/_mnnvl_utils.py",
|
||||
"tensorrt_llm/disaggregated_params.py",
|
||||
"tensorrt_llm/engine.py",
|
||||
"tensorrt_llm/graph_rewriting.py",
|
||||
"tensorrt_llm/logger.py",
|
||||
"tensorrt_llm/lora_manager.py",
|
||||
"tensorrt_llm/module.py",
|
||||
"tensorrt_llm/moe_config.py",
|
||||
"tensorrt_llm/profiler.py",
|
||||
"tensorrt_llm/prompt_adapter_manager.py",
|
||||
"tensorrt_llm/python_plugin.py",
|
||||
"tensorrt_llm/sampling_params.py",
|
||||
"tensorrt_llm/top_model_mixin.py",
|
||||
]
|
||||
exclude = [
|
||||
"3rdparty/**",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -54,10 +54,10 @@ def _init(log_level: object = None) -> None:
|
||||
logger.set_level(log_level)
|
||||
|
||||
if os.getenv("TRT_LLM_NO_LIB_INIT", "0") == "1":
|
||||
logger.info('Skipping TensorRT-LLM init.')
|
||||
logger.info("Skipping TensorRT-LLM init.")
|
||||
return
|
||||
|
||||
logger.info('Starting TensorRT-LLM init.')
|
||||
logger.info("Starting TensorRT-LLM init.")
|
||||
|
||||
# load plugin lib
|
||||
_load_plugin_lib()
|
||||
@ -65,24 +65,30 @@ def _init(log_level: object = None) -> None:
|
||||
# load FT decoder layer and torch custom ops
|
||||
project_dir = str(Path(__file__).parent.absolute())
|
||||
if platform.system() == "Windows":
|
||||
ft_decoder_lib = project_dir + '/libs/th_common.dll'
|
||||
ft_decoder_lib = project_dir + "/libs/th_common.dll"
|
||||
else:
|
||||
ft_decoder_lib = project_dir + '/libs/libth_common.so'
|
||||
ft_decoder_lib = project_dir + "/libs/libth_common.so"
|
||||
try:
|
||||
torch.classes.load_library(ft_decoder_lib)
|
||||
from ._torch.custom_ops import _register_fake
|
||||
|
||||
_register_fake()
|
||||
except Exception as e:
|
||||
msg = '\nFATAL: Decoding operators failed to load. This may be caused by the incompatibility between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM.'
|
||||
msg = (
|
||||
"\nFATAL: Decoding operators failed to load. This may be caused by an incompatibility "
|
||||
"between PyTorch and TensorRT-LLM. Please rebuild and install TensorRT-LLM."
|
||||
)
|
||||
raise ImportError(str(e) + msg)
|
||||
|
||||
MpiComm.local_init()
|
||||
|
||||
logger.info('TensorRT-LLM inited.')
|
||||
logger.info("TensorRT-LLM inited.")
|
||||
|
||||
|
||||
def default_net() -> Network:
|
||||
assert net, "Use builder to create network first, and use `set_network` or `net_guard` to set it to default"
|
||||
assert net, (
|
||||
"Use builder to create network first, and use `set_network` or `net_guard` to set it to default"
|
||||
)
|
||||
return net
|
||||
|
||||
|
||||
@ -111,29 +117,29 @@ def precision(dtype):
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
logger.info(f"Serializing engine to {path}...")
|
||||
tik = time.time()
|
||||
if isinstance(engine, trt.ICudaEngine):
|
||||
engine = engine.serialize()
|
||||
with open(path, 'wb') as f:
|
||||
with open(path, "wb") as f:
|
||||
f.write(engine)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
|
||||
logger.info(f"Engine serialized. Total time: {t}")
|
||||
|
||||
|
||||
def deserialize_engine(path):
|
||||
runtime = trt.Runtime(logger.trt_logger)
|
||||
with open(path, 'rb') as f:
|
||||
logger.info(f'Loading engine from {path}...')
|
||||
with open(path, "rb") as f:
|
||||
logger.info(f"Loading engine from {path}...")
|
||||
tik = time.time()
|
||||
|
||||
engine = runtime.deserialize_cuda_engine(f.read())
|
||||
assert engine is not None
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine loaded. Total time: {t}')
|
||||
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
|
||||
logger.info(f"Engine loaded. Total time: {t}")
|
||||
return engine
|
||||
|
||||
|
||||
@ -149,33 +155,32 @@ _field_dtype_to_np_dtype_dict = {
|
||||
|
||||
def field_dtype_to_np_dtype(dtype):
|
||||
ret = _field_dtype_to_np_dtype_dict.get(dtype)
|
||||
assert ret is not None, f'Unsupported dtype: {dtype}'
|
||||
assert ret is not None, f"Unsupported dtype: {dtype}"
|
||||
return ret
|
||||
|
||||
|
||||
def convert_capsule_to_void_p(capsule):
|
||||
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
|
||||
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [
|
||||
ctypes.py_object, ctypes.c_char_p
|
||||
]
|
||||
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
|
||||
return ctypes.pythonapi.PyCapsule_GetPointer(capsule, None)
|
||||
|
||||
|
||||
def get_nparray_from_void_p(void_pointer, elem_size, field_dtype):
|
||||
ctypes.pythonapi.PyMemoryView_FromMemory.restype = ctypes.py_object
|
||||
ctypes.pythonapi.PyMemoryView_FromMemory.argtypes = [
|
||||
ctypes.c_char_p, ctypes.c_ssize_t, ctypes.c_int
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_ssize_t,
|
||||
ctypes.c_int,
|
||||
]
|
||||
logger.info(
|
||||
f'get_nparray: pointer = {void_pointer}, elem_size = {elem_size}')
|
||||
logger.info(f"get_nparray: pointer = {void_pointer}, elem_size = {elem_size}")
|
||||
char_pointer = ctypes.cast(void_pointer, ctypes.POINTER(ctypes.c_char))
|
||||
np_dtype = field_dtype_to_np_dtype(field_dtype)
|
||||
buf_bytes = elem_size * np.dtype(np_dtype).itemsize
|
||||
logger.info(f'get_nparray: buf_bytes = {buf_bytes}')
|
||||
logger.info(f"get_nparray: buf_bytes = {buf_bytes}")
|
||||
mem_view = ctypes.pythonapi.PyMemoryView_FromMemory(
|
||||
char_pointer, buf_bytes, 0) # number 0 represents PyBUF_READ
|
||||
logger.info(
|
||||
f'get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}')
|
||||
char_pointer, buf_bytes, 0
|
||||
) # number 0 represents PyBUF_READ
|
||||
logger.info(f"get_nparray: mem_view = {mem_view}, field_dtype = {field_dtype}")
|
||||
buf = np.frombuffer(mem_view, np_dtype)
|
||||
return buf
|
||||
|
||||
@ -187,18 +192,17 @@ def get_scalar_from_field(field):
|
||||
|
||||
|
||||
class _BuildingFlag:
|
||||
|
||||
def __enter__(self):
|
||||
os.environ['IS_BUILDING'] = '1'
|
||||
os.environ["IS_BUILDING"] = "1"
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
del os.environ['IS_BUILDING']
|
||||
del os.environ["IS_BUILDING"]
|
||||
|
||||
|
||||
def _is_building(f):
|
||||
'''Use this to decorate functions which are called during engine building/refitting process,
|
||||
"""Use this to decorate functions which are called during engine building/refitting process,
|
||||
otherwise, the plugin registration will fail.
|
||||
'''
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -208,15 +212,25 @@ def _is_building(f):
|
||||
return decorated
|
||||
|
||||
|
||||
def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
|
||||
max_input_len, max_seq_len, max_beam_width,
|
||||
remove_input_padding, enable_context_fmha,
|
||||
tokens_per_block, multiple_profiles):
|
||||
def check_max_num_tokens(
|
||||
max_num_tokens,
|
||||
opt_num_tokens,
|
||||
max_batch_size,
|
||||
max_input_len,
|
||||
max_seq_len,
|
||||
max_beam_width,
|
||||
remove_input_padding,
|
||||
enable_context_fmha,
|
||||
tokens_per_block,
|
||||
multiple_profiles,
|
||||
):
|
||||
if not remove_input_padding:
|
||||
if max_num_tokens is not None or opt_num_tokens is not None:
|
||||
max_num_tokens = max_batch_size * max_seq_len
|
||||
logger.warning("remove_input_padding is not enabled, the specified "
|
||||
"max_num_tokens/opt_num_tokens will be ignored.")
|
||||
logger.warning(
|
||||
"remove_input_padding is not enabled, the specified "
|
||||
"max_num_tokens/opt_num_tokens will be ignored."
|
||||
)
|
||||
return max_num_tokens, opt_num_tokens
|
||||
else:
|
||||
if max_num_tokens is None:
|
||||
@ -228,20 +242,22 @@ def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
|
||||
"when remove_input_padding is enabled, because the number "
|
||||
"of packed input tokens are very likely to be smaller, "
|
||||
"we strongly recommend to set max_num_tokens according "
|
||||
"to your workloads.")
|
||||
"to your workloads."
|
||||
)
|
||||
if opt_num_tokens is None and not multiple_profiles:
|
||||
opt_num_tokens = min(max_batch_size * max_beam_width,
|
||||
max_num_tokens)
|
||||
opt_num_tokens = min(max_batch_size * max_beam_width, max_num_tokens)
|
||||
logger.warning(
|
||||
"remove_input_padding is enabled, while opt_num_tokens "
|
||||
"is not set, setting to max_batch_size*max_beam_width. \n")
|
||||
"is not set, setting to max_batch_size*max_beam_width. \n"
|
||||
)
|
||||
if max_num_tokens > 16384:
|
||||
logger.warning(
|
||||
"Specifying a `max_num_tokens` larger than 16384 is usually "
|
||||
"not recommended, we do not expect perf gain with that and too "
|
||||
"large `max_num_tokens` could possibly exceed the TensorRT "
|
||||
"tensor volume, causing runtime errors. "
|
||||
f"Got `max_num_tokens` = {max_num_tokens}")
|
||||
f"Got `max_num_tokens` = {max_num_tokens}"
|
||||
)
|
||||
if max_num_tokens > max_seq_len * max_batch_size:
|
||||
logger.warning(
|
||||
f"max_num_tokens ({max_num_tokens}) shouldn't be greater than "
|
||||
@ -253,21 +269,24 @@ def check_max_num_tokens(max_num_tokens, opt_num_tokens, max_batch_size,
|
||||
logger.warning(
|
||||
f"When enable_context_fmha is not turned on, max_num_tokens ({max_num_tokens}) "
|
||||
f"should be at least max_input_len ({max_input_len}), specifying to "
|
||||
f"max_input_len ({max_input_len}).")
|
||||
f"max_input_len ({max_input_len})."
|
||||
)
|
||||
max_num_tokens = max_input_len
|
||||
elif max_num_tokens < tokens_per_block and enable_context_fmha:
|
||||
logger.warning(
|
||||
f"When enable_context_fmha is turned on, max_num_tokens ({max_num_tokens}) "
|
||||
f"should be at least tokens_per_block ({tokens_per_block}), specifying to "
|
||||
f"tokens_per_block ({tokens_per_block}). At this time, you also need to enable "
|
||||
f"context chunking at runtime, otherwise you may encounter errors.")
|
||||
f"context chunking at runtime, otherwise you may encounter errors."
|
||||
)
|
||||
max_num_tokens = tokens_per_block
|
||||
|
||||
if opt_num_tokens is not None and opt_num_tokens > max_num_tokens:
|
||||
logger.warning(
|
||||
f"opt_num_tokens ({opt_num_tokens}) shouldn't be greater than "
|
||||
f"max_num_tokens ({max_num_tokens}), "
|
||||
f"specifying to max_num_tokens ({max_num_tokens}).")
|
||||
f"specifying to max_num_tokens ({max_num_tokens})."
|
||||
)
|
||||
opt_num_tokens = max_num_tokens
|
||||
|
||||
return max_num_tokens, opt_num_tokens
|
||||
|
||||
@ -13,8 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ctypes
|
||||
from ctypes import (CFUNCTYPE, POINTER, c_int, c_int64, c_size_t, c_uint8,
|
||||
c_uint16, c_void_p, pointer)
|
||||
from ctypes import (
|
||||
CFUNCTYPE,
|
||||
POINTER,
|
||||
c_int,
|
||||
c_int64,
|
||||
c_size_t,
|
||||
c_uint8,
|
||||
c_uint16,
|
||||
c_void_p,
|
||||
pointer,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
@ -24,14 +33,14 @@ class DLDataType(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("code", c_uint8), # Data type code, e.g., 2 for float
|
||||
("bits", c_uint8), # Number of bits per element, e.g., 32
|
||||
("lanes", c_uint16) # Number of lanes, usually 1
|
||||
("lanes", c_uint16), # Number of lanes, usually 1
|
||||
]
|
||||
|
||||
|
||||
class DLDevice(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("device_type", c_int), # Device type, typically 2 for GPU
|
||||
("device_id", c_int) # Device ID, usually 0 for default GPU
|
||||
("device_id", c_int), # Device ID, usually 0 for default GPU
|
||||
]
|
||||
|
||||
|
||||
@ -43,15 +52,15 @@ class DLTensor(ctypes.Structure):
|
||||
("dtype", DLDataType), # Data type
|
||||
("shape", POINTER(c_int64)), # Pointer to array of dimension sizes
|
||||
(
|
||||
"strides", POINTER(c_int64)
|
||||
"strides",
|
||||
POINTER(c_int64),
|
||||
), # Pointer to strides array (can be NULL for default contiguous layout)
|
||||
("byte_offset", c_size_t) # Byte offset (usually 0)
|
||||
("byte_offset", c_size_t), # Byte offset (usually 0)
|
||||
]
|
||||
|
||||
|
||||
# Deleter type for DLManagedTensor
|
||||
DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(
|
||||
ctypes.c_void_p)) # Not used directly here
|
||||
DLManagedTensorDeleter = CFUNCTYPE(None, POINTER(ctypes.c_void_p)) # Not used directly here
|
||||
|
||||
|
||||
# Define DLManagedTensor structure, with deleter prototype void(*deleter)(DLManagedTensor*)
|
||||
@ -59,9 +68,11 @@ class DLManagedTensor(ctypes.Structure):
|
||||
pass
|
||||
|
||||
|
||||
DLManagedTensor._fields_ = [("dl_tensor", DLTensor), ("manager_ctx", c_void_p),
|
||||
("deleter", CFUNCTYPE(None,
|
||||
POINTER(DLManagedTensor)))]
|
||||
DLManagedTensor._fields_ = [
|
||||
("dl_tensor", DLTensor),
|
||||
("manager_ctx", c_void_p),
|
||||
("deleter", CFUNCTYPE(None, POINTER(DLManagedTensor))),
|
||||
]
|
||||
|
||||
|
||||
# A no-op deleter that doesn't perform any operation
|
||||
@ -95,8 +106,7 @@ class CapsuleWrapper:
|
||||
self._managed_tensor = managed_tensor # Keep reference to prevent garbage collection
|
||||
|
||||
|
||||
def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
|
||||
torch_dtype, dev_id):
|
||||
def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments, torch_dtype, dev_id):
|
||||
"""
|
||||
Parameters:
|
||||
ptr: GPU memory address obtained from cudaMalloc (Python int)
|
||||
@ -106,14 +116,19 @@ def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
|
||||
torch_dtype: torch dtype
|
||||
dev_id: device id.
|
||||
Returns:
|
||||
A PyCapsule object compliant with DLPack specification, which can be directly converted to a tensor using torch.utils.dlpack.from_dlpack
|
||||
A PyCapsule object compliant with DLPack specification, which can be directly converted to a
|
||||
tensor using torch.utils.dlpack.from_dlpack
|
||||
"""
|
||||
bits_per_elements = 0
|
||||
dldata_type_code = 0
|
||||
# refer to https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h#L160
|
||||
if torch_dtype in [
|
||||
torch.float8_e5m2, torch.float8_e4m3fn, torch.bfloat16,
|
||||
torch.float16, torch.float32, torch.float64
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
]:
|
||||
bits_per_elements = torch.finfo(torch_dtype).bits
|
||||
dldata_type_code = 2
|
||||
@ -128,8 +143,7 @@ def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
|
||||
bytes_per_element = bits_per_elements // 8
|
||||
# Allocate space for shape (constructing a one-dimensional tensor here)
|
||||
ShapeArrayType = c_int64 * 2 # 1 dimension
|
||||
shape_array = ShapeArrayType(num_segments,
|
||||
segment_size // bytes_per_element)
|
||||
shape_array = ShapeArrayType(num_segments, segment_size // bytes_per_element)
|
||||
stride_array = ShapeArrayType(segment_stride // bytes_per_element, 1)
|
||||
# Set device information: GPU (device_type=2) and device_id=dev_id (modify as needed)
|
||||
device = DLDevice(device_type=2, device_id=dev_id)
|
||||
@ -166,8 +180,9 @@ def create_dlpack_capsule(ptr, segment_size, segment_stride, num_segments,
|
||||
return capsule_wrapper
|
||||
|
||||
|
||||
def pack_strided_memory(ptr: int, segment_size: int, segment_stride: int,
|
||||
num_segments: int, dtype: torch.dtype, dev_id):
|
||||
def pack_strided_memory(
|
||||
ptr: int, segment_size: int, segment_stride: int, num_segments: int, dtype: torch.dtype, dev_id
|
||||
):
|
||||
"""
|
||||
Pack GPU memory into a PyTorch tensor with specified stride.
|
||||
|
||||
@ -187,8 +202,9 @@ def pack_strided_memory(ptr: int, segment_size: int, segment_stride: int,
|
||||
even with the same pointer. Each capsule is consumed only once.
|
||||
"""
|
||||
# Create a new capsule each time
|
||||
capsule_wrapper = create_dlpack_capsule(ptr, segment_size, segment_stride,
|
||||
num_segments, dtype, dev_id)
|
||||
capsule_wrapper = create_dlpack_capsule(
|
||||
ptr, segment_size, segment_stride, num_segments, dtype, dev_id
|
||||
)
|
||||
torch_tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule)
|
||||
torch_tensor._capsule_wrapper = capsule_wrapper
|
||||
return torch_tensor
|
||||
|
||||
@ -42,9 +42,7 @@ def can_access_peer(mapping: Mapping) -> bool:
|
||||
|
||||
# Early exit if devices are on different nodes
|
||||
if mapping.get_node_rank(rank) != mapping.node_rank:
|
||||
logger.info(
|
||||
f"Detect inter-node TP between rank {mapping.rank} and rank {rank}"
|
||||
)
|
||||
logger.info(f"Detect inter-node TP between rank {mapping.rank} and rank {rank}")
|
||||
return False
|
||||
|
||||
# Skip if same device
|
||||
@ -63,8 +61,7 @@ def can_access_peer(mapping: Mapping) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class IpcMemory():
|
||||
|
||||
class IpcMemory:
|
||||
# WARNING: Must in sync with FLAGS_SIZE in cpp/include/tensorrt_llm/runtime/ipcUtils.h
|
||||
# (Max all reduce blocks + 1) * sizeof(int)
|
||||
IPC_BARRIERS_SIZE_PER_GPU = (24 + 1) * 4
|
||||
@ -73,8 +70,7 @@ class IpcMemory():
|
||||
self.mapping = mapping
|
||||
self.open_ipc = open_ipc and mapping.tp_size <= mapping.gpus_per_node
|
||||
if self.open_ipc:
|
||||
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(
|
||||
self.mapping, size, True)
|
||||
self.peer_ptrs, self.local_ptr = IpcMemory.open_ipc_memory(self.mapping, size, True)
|
||||
else:
|
||||
self.peer_ptrs = [0] * mapping.tp_size
|
||||
self.local_ptr = 0
|
||||
@ -91,10 +87,10 @@ class IpcMemory():
|
||||
return array.array("Q", buffer).tolist()
|
||||
|
||||
@staticmethod
|
||||
def open_ipc_memory(mapping: Mapping,
|
||||
size: int,
|
||||
set_to_zero: bool = False) -> Tuple[List[int], int]:
|
||||
""" Allocates a buffer with the given *size* on each GPU. Then, enables IPC communication between TP groups.
|
||||
def open_ipc_memory(
|
||||
mapping: Mapping, size: int, set_to_zero: bool = False
|
||||
) -> Tuple[List[int], int]:
|
||||
"""Allocates a buffer with the given *size* on each GPU. Then, enables IPC communication between TP groups.
|
||||
Returns a list of buffer pointers, buffers[i] is a handle to the corresponding buffer residing on GPU #i.
|
||||
Call close_ipc_handle with the *buffer*.
|
||||
"""
|
||||
@ -105,8 +101,8 @@ class IpcMemory():
|
||||
return size
|
||||
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
|
||||
mapping.tp_rank)
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
|
||||
)
|
||||
|
||||
# see allocateIpcMemory in cpp/tensorrt_llm/runtime/ipcUtils.cpp for alignment reason
|
||||
# 1 << 21 is 2MB
|
||||
@ -131,7 +127,8 @@ class IpcMemory():
|
||||
peer_ptrs.append(local_ptr)
|
||||
else:
|
||||
error, ptr = cudart.cudaIpcOpenMemHandle(
|
||||
handle, cudart.cudaIpcMemLazyEnablePeerAccess)
|
||||
handle, cudart.cudaIpcMemLazyEnablePeerAccess
|
||||
)
|
||||
_raise_if_error(error)
|
||||
peer_ptrs.append(ptr)
|
||||
|
||||
|
||||
@ -67,8 +67,7 @@ class MnnvlMemory:
|
||||
def __init__(self, mapping: Mapping, size: int):
|
||||
self.mapping = mapping
|
||||
self.segment_size = size
|
||||
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(
|
||||
self.mapping, size)
|
||||
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(self.mapping, size)
|
||||
|
||||
def __del__(self):
|
||||
if not sys.is_finalizing():
|
||||
@ -76,15 +75,15 @@ class MnnvlMemory:
|
||||
|
||||
def as_torch_strided_tensor(self, dtype):
|
||||
num_segments = MnnvlMemory.comm.Get_size()
|
||||
return pack_strided_memory(self.ptr, self.segment_size,
|
||||
self.rank_stride, num_segments, dtype,
|
||||
MnnvlMemory.dev_id)
|
||||
return pack_strided_memory(
|
||||
self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def initialize():
|
||||
if not MnnvlMemory.initialized:
|
||||
# use a dummy torch CUDA tensor to trigger CUDA context initialization
|
||||
_ = torch.empty(1, device='cuda')
|
||||
_ = torch.empty(1, device="cuda")
|
||||
# ensure nvml is initialized.
|
||||
try:
|
||||
pynvml.nvmlDeviceGetCount()
|
||||
@ -97,8 +96,8 @@ class MnnvlMemory:
|
||||
if MnnvlMemory.comm is not None:
|
||||
return MnnvlMemory.comm
|
||||
comm = mpi_comm().Split(
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank,
|
||||
mapping.tp_rank)
|
||||
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
|
||||
)
|
||||
MnnvlMemory.comm = comm
|
||||
return comm
|
||||
|
||||
@ -109,7 +108,9 @@ class MnnvlMemory:
|
||||
location.id = dev_id
|
||||
allocation_prop = cuda.CUmemAllocationProp()
|
||||
allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
|
||||
allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
|
||||
allocation_prop.requestedHandleTypes = (
|
||||
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
|
||||
)
|
||||
allocation_prop.location = location
|
||||
return allocation_prop
|
||||
|
||||
@ -119,27 +120,25 @@ class MnnvlMemory:
|
||||
return MnnvlMemory.allocation_granularity
|
||||
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
|
||||
option = cuda.CUmemAllocationGranularity_flags(
|
||||
cuda.CUmemAllocationGranularity_flags.
|
||||
CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)
|
||||
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
)
|
||||
granularity = _check_cu_result(
|
||||
cuda.cuMemGetAllocationGranularity(prop=allocation_prop,
|
||||
option=option))
|
||||
cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)
|
||||
)
|
||||
MnnvlMemory.allocation_granularity = granularity
|
||||
return MnnvlMemory.allocation_granularity
|
||||
|
||||
@staticmethod
|
||||
def new_mnnvl_memory_address(mapping: Mapping, size: int):
|
||||
page_count = (size + MnnvlMemory.fabric_page_size -
|
||||
1) // MnnvlMemory.fabric_page_size
|
||||
page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size
|
||||
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
|
||||
logger.info(
|
||||
f"[MnnvlMemory] creating address with stride={current_rank_stride}")
|
||||
logger.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}")
|
||||
comm = MnnvlMemory.get_comm(mapping)
|
||||
comm_size = comm.Get_size()
|
||||
address_size = current_rank_stride * comm_size
|
||||
ptr = _check_cu_result(
|
||||
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size,
|
||||
0, 0))
|
||||
cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)
|
||||
)
|
||||
MnnvlMemory.current_start_address = int(ptr)
|
||||
MnnvlMemory.current_rank_stride = current_rank_stride
|
||||
MnnvlMemory.current_mem_offset = 0
|
||||
@ -150,16 +149,15 @@ class MnnvlMemory:
|
||||
dev_id = int(dev)
|
||||
if MnnvlMemory.dev_id is None:
|
||||
MnnvlMemory.dev_id = dev_id
|
||||
assert dev_id == MnnvlMemory.dev_id,\
|
||||
assert dev_id == MnnvlMemory.dev_id, (
|
||||
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
|
||||
)
|
||||
comm = MnnvlMemory.get_comm(mapping)
|
||||
comm_rank = comm.Get_rank()
|
||||
comm_size = comm.Get_size()
|
||||
all_rank_allocate_sizes = comm.allgather(size)
|
||||
assert len(all_rank_allocate_sizes) == comm_size
|
||||
assert all(
|
||||
x == size for x in
|
||||
all_rank_allocate_sizes), "Not all rank allocating same size."
|
||||
assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size."
|
||||
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
|
||||
aligned_size = (size + granularity - 1) // granularity * granularity
|
||||
|
||||
@ -170,13 +168,15 @@ class MnnvlMemory:
|
||||
|
||||
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
|
||||
allocated_mem_handle = _check_cu_result(
|
||||
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0))
|
||||
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)
|
||||
)
|
||||
exported_fabric_handle = _check_cu_result(
|
||||
cuda.cuMemExportToShareableHandle(
|
||||
allocated_mem_handle,
|
||||
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0))
|
||||
allocated_mem_handle, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0
|
||||
)
|
||||
)
|
||||
all_handles_data = comm.allgather(exported_fabric_handle.data)
|
||||
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501
|
||||
# can use buf = memoryview(data) to import if using plain buffer for data.
|
||||
|
||||
madesc = cuda.CUmemAccessDesc()
|
||||
@ -186,44 +186,49 @@ class MnnvlMemory:
|
||||
mem_handles = [None] * comm_size
|
||||
|
||||
for i, remote_handle_data in enumerate(all_handles_data):
|
||||
rank_ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset
|
||||
rank_ptr = (
|
||||
MnnvlMemory.current_start_address
|
||||
+ MnnvlMemory.current_rank_stride * i
|
||||
+ MnnvlMemory.current_mem_offset
|
||||
)
|
||||
if i == comm_rank:
|
||||
# Local memory mapping
|
||||
mem_handles[i] = allocated_mem_handle
|
||||
_check_cu_result(
|
||||
cuda.cuMemMap(rank_ptr, aligned_size, 0,
|
||||
allocated_mem_handle, 0))
|
||||
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0))
|
||||
else:
|
||||
# Fabric memory mapping
|
||||
imported_mem_handle = _check_cu_result(
|
||||
cuda.cuMemImportFromShareableHandle(
|
||||
remote_handle_data, cuda.CUmemAllocationHandleType.
|
||||
CU_MEM_HANDLE_TYPE_FABRIC))
|
||||
remote_handle_data, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
|
||||
)
|
||||
)
|
||||
mem_handles[i] = imported_mem_handle
|
||||
_check_cu_result(
|
||||
cuda.cuMemMap(rank_ptr, aligned_size, 0,
|
||||
imported_mem_handle, 0))
|
||||
_check_cu_result(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0))
|
||||
|
||||
_check_cu_result(
|
||||
cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
|
||||
_check_cu_result(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1))
|
||||
|
||||
ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset
|
||||
stride = MnnvlMemory.current_rank_stride
|
||||
MnnvlMemory.allocated_map[ptr] = (mapping, aligned_size, mem_handles,
|
||||
MnnvlMemory.current_start_address,
|
||||
MnnvlMemory.current_rank_stride,
|
||||
MnnvlMemory.current_mem_offset)
|
||||
MnnvlMemory.address_refcnt[
|
||||
MnnvlMemory.current_start_address] = MnnvlMemory.address_refcnt.get(
|
||||
MnnvlMemory.current_start_address, 0) + 1
|
||||
MnnvlMemory.allocated_map[ptr] = (
|
||||
mapping,
|
||||
aligned_size,
|
||||
mem_handles,
|
||||
MnnvlMemory.current_start_address,
|
||||
MnnvlMemory.current_rank_stride,
|
||||
MnnvlMemory.current_mem_offset,
|
||||
)
|
||||
MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = (
|
||||
MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1
|
||||
)
|
||||
|
||||
MnnvlMemory.current_mem_offset += aligned_size
|
||||
return ptr, stride
|
||||
|
||||
@staticmethod
|
||||
def close_mnnvl_memory(ptr: int):
|
||||
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = MnnvlMemory.allocated_map.pop(
|
||||
ptr)
|
||||
mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset = (
|
||||
MnnvlMemory.allocated_map.pop(ptr)
|
||||
)
|
||||
comm = MnnvlMemory.get_comm(mapping)
|
||||
comm_size = comm.Get_size()
|
||||
for i in range(comm_size):
|
||||
@ -235,8 +240,7 @@ class MnnvlMemory:
|
||||
if MnnvlMemory.address_refcnt[start_address] == 0:
|
||||
MnnvlMemory.address_refcnt.pop(start_address)
|
||||
device_ptr = cuda.CUdeviceptr(start_address)
|
||||
_check_cu_result(
|
||||
cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
|
||||
_check_cu_result(cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride))
|
||||
if start_address == MnnvlMemory.current_start_address:
|
||||
MnnvlMemory.current_start_address = 0
|
||||
MnnvlMemory.current_rank_stride = 0
|
||||
@ -252,15 +256,19 @@ class MnnvlMemory:
|
||||
for link_idx in range(link_count):
|
||||
try:
|
||||
if pynvml.nvmlDeviceGetNvLinkCapability(
|
||||
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED):
|
||||
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED
|
||||
):
|
||||
available_links += 1
|
||||
is_active = pynvml.nvmlDeviceGetNvLinkState(
|
||||
handle, link_idx)
|
||||
is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx)
|
||||
if is_active:
|
||||
active_links += 1
|
||||
except pynvml.NVMLError_NotSupported:
|
||||
continue
|
||||
return active_links == available_links and available_links > 0 if need_all_up else available_links > 0
|
||||
return (
|
||||
active_links == available_links and available_links > 0
|
||||
if need_all_up
|
||||
else available_links > 0
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def supports_mnnvl() -> bool:
|
||||
@ -269,7 +277,7 @@ class MnnvlMemory:
|
||||
# But it is not equivalent to MNNVL support.
|
||||
# May need better support check.
|
||||
arch = platform.machine().lower()
|
||||
is_on_aarch64 = 'aarch64' in arch
|
||||
is_on_aarch64 = "aarch64" in arch
|
||||
support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True)
|
||||
return is_on_aarch64 and support_nvlink_and_all_up
|
||||
|
||||
@ -298,91 +306,137 @@ class MnnvlMoe:
|
||||
|
||||
MnnvlMoe.moe_mapping = mapping
|
||||
workspace_size_per_rank = torch.ops.trtllm.get_moe_commworkspace_size_per_rank(
|
||||
mapping.tp_size)
|
||||
mapping.tp_size
|
||||
)
|
||||
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
|
||||
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
|
||||
torch.uint64)
|
||||
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(torch.uint64)
|
||||
return MnnvlMoe.moe_workspace_tensor
|
||||
|
||||
@staticmethod
|
||||
def compute_target_rank_id(token_selected_experts: torch.Tensor,
|
||||
expert_count: int, ep_size: int):
|
||||
def compute_target_rank_id(
|
||||
token_selected_experts: torch.Tensor, expert_count: int, ep_size: int
|
||||
):
|
||||
assert expert_count % ep_size == 0, "expert_count should be divisible by ep_size"
|
||||
expert_per_rank = expert_count // ep_size
|
||||
token_target_rank_ids = token_selected_experts // expert_per_rank
|
||||
return token_target_rank_ids
|
||||
|
||||
@staticmethod
|
||||
def mnnvl_moe_alltoallv_prepare(gathered_target_rank_ids: torch.Tensor,
|
||||
real_rank_token_count_cumsum: torch.Tensor,
|
||||
gathered_expert_ids: torch.Tensor,
|
||||
gathered_scales: torch.Tensor,
|
||||
max_token_count_per_rank: int,
|
||||
expert_count: int, top_k: int, ep_rank: int,
|
||||
ep_size: int):
|
||||
local_gather_indices, send_rank_count_cumsum, send_rank_local_indices, \
|
||||
recv_rank_count_cumsum, recv_rank_local_indices, backward_recv_rank_local_indices = \
|
||||
torch.ops.trtllm.moe_comm_prepare_indices(gathered_target_rank_ids, real_rank_token_count_cumsum,
|
||||
max_token_count_per_rank, expert_count, top_k, ep_rank, ep_size)
|
||||
def mnnvl_moe_alltoallv_prepare(
|
||||
gathered_target_rank_ids: torch.Tensor,
|
||||
real_rank_token_count_cumsum: torch.Tensor,
|
||||
gathered_expert_ids: torch.Tensor,
|
||||
gathered_scales: torch.Tensor,
|
||||
max_token_count_per_rank: int,
|
||||
expert_count: int,
|
||||
top_k: int,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
(
|
||||
local_gather_indices,
|
||||
send_rank_count_cumsum,
|
||||
send_rank_local_indices,
|
||||
recv_rank_count_cumsum,
|
||||
recv_rank_local_indices,
|
||||
backward_recv_rank_local_indices,
|
||||
) = torch.ops.trtllm.moe_comm_prepare_indices(
|
||||
gathered_target_rank_ids,
|
||||
real_rank_token_count_cumsum,
|
||||
max_token_count_per_rank,
|
||||
expert_count,
|
||||
top_k,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
local_token_allocation_count = max_token_count_per_rank * ep_size
|
||||
|
||||
local_expert_ids = torch.empty(local_token_allocation_count,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=torch.device('cuda'))
|
||||
local_scales = torch.empty(local_token_allocation_count,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=torch.device('cuda'))
|
||||
local_expert_ids = torch.empty(
|
||||
local_token_allocation_count, top_k, dtype=torch.int32, device=torch.device("cuda")
|
||||
)
|
||||
local_scales = torch.empty(
|
||||
local_token_allocation_count, top_k, dtype=torch.float32, device=torch.device("cuda")
|
||||
)
|
||||
|
||||
torch.ops.trtllm.moe_local_gather(recv_rank_count_cumsum,
|
||||
local_gather_indices,
|
||||
gathered_expert_ids, gathered_scales,
|
||||
local_expert_ids, local_scales,
|
||||
max_token_count_per_rank,
|
||||
expert_count, top_k, ep_rank, ep_size)
|
||||
torch.ops.trtllm.moe_local_gather(
|
||||
recv_rank_count_cumsum,
|
||||
local_gather_indices,
|
||||
gathered_expert_ids,
|
||||
gathered_scales,
|
||||
local_expert_ids,
|
||||
local_scales,
|
||||
max_token_count_per_rank,
|
||||
expert_count,
|
||||
top_k,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
alltoall_info = MoEAlltoallInfo(
|
||||
local_gather_indices, send_rank_count_cumsum,
|
||||
send_rank_local_indices, recv_rank_count_cumsum,
|
||||
recv_rank_local_indices, backward_recv_rank_local_indices,
|
||||
local_token_allocation_count)
|
||||
local_gather_indices,
|
||||
send_rank_count_cumsum,
|
||||
send_rank_local_indices,
|
||||
recv_rank_count_cumsum,
|
||||
recv_rank_local_indices,
|
||||
backward_recv_rank_local_indices,
|
||||
local_token_allocation_count,
|
||||
)
|
||||
return alltoall_info, local_expert_ids, local_scales
|
||||
|
||||
@staticmethod
|
||||
def mnnvl_moe_alltoallv(x: torch.Tensor, alltoall_info: MoEAlltoallInfo,
|
||||
workspace: torch.Tensor, ep_rank: int,
|
||||
ep_size: int):
|
||||
def mnnvl_moe_alltoallv(
|
||||
x: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo,
|
||||
workspace: torch.Tensor,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
):
|
||||
assert x.dim() == 2, "only 2D tensor supported, please reshape."
|
||||
output_tensor = torch.empty(alltoall_info.local_token_allocation_count,
|
||||
x.shape[1],
|
||||
dtype=x.dtype,
|
||||
device=torch.device('cuda'))
|
||||
torch.ops.trtllm.moe_comm(x, alltoall_info.send_rank_count_cumsum,
|
||||
alltoall_info.send_rank_local_indices,
|
||||
output_tensor,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices,
|
||||
workspace, ep_rank, ep_size)
|
||||
output_tensor = torch.empty(
|
||||
alltoall_info.local_token_allocation_count,
|
||||
x.shape[1],
|
||||
dtype=x.dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
torch.ops.trtllm.moe_comm(
|
||||
x,
|
||||
alltoall_info.send_rank_count_cumsum,
|
||||
alltoall_info.send_rank_local_indices,
|
||||
output_tensor,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices,
|
||||
workspace,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
@staticmethod
|
||||
def mnnvl_moe_alltoallv_combine(x: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo,
|
||||
workspace: torch.Tensor, ep_rank: int,
|
||||
ep_size: int, top_k: int, token_count: int):
|
||||
def mnnvl_moe_alltoallv_combine(
|
||||
x: torch.Tensor,
|
||||
alltoall_info: MoEAlltoallInfo,
|
||||
workspace: torch.Tensor,
|
||||
ep_rank: int,
|
||||
ep_size: int,
|
||||
top_k: int,
|
||||
token_count: int,
|
||||
):
|
||||
assert x.dim() == 2, "2D tensor supported, please reshape."
|
||||
output_tensor = torch.zeros(token_count * top_k,
|
||||
x.shape[1],
|
||||
dtype=x.dtype,
|
||||
device=torch.device('cuda'))
|
||||
output_tensor = torch.zeros(
|
||||
token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda")
|
||||
)
|
||||
torch.ops.trtllm.moe_comm(
|
||||
x, alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices, output_tensor,
|
||||
x,
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
alltoall_info.recv_rank_local_indices,
|
||||
output_tensor,
|
||||
alltoall_info.send_rank_count_cumsum,
|
||||
alltoall_info.backward_recv_rank_local_indices, workspace, ep_rank,
|
||||
ep_size)
|
||||
return torch.sum(output_tensor.reshape(token_count, top_k, x.shape[1]),
|
||||
dim=1,
|
||||
keepdim=False)
|
||||
alltoall_info.backward_recv_rank_local_indices,
|
||||
workspace,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
return torch.sum(
|
||||
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
|
||||
)
|
||||
|
||||
@ -6,8 +6,7 @@ from tensorrt_llm.bindings import executor as tllme
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class DisaggregatedParams:
|
||||
"""
|
||||
Disaggregated seving parameters
|
||||
"""Disaggregated seving parameters.
|
||||
|
||||
Args:
|
||||
request_type (str): The type of request ("context_only" or "generation_only")
|
||||
@ -23,10 +22,9 @@ class DisaggregatedParams:
|
||||
draft_tokens: Optional[List[int]] = None
|
||||
|
||||
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
|
||||
|
||||
return tllme.ContextPhaseParams(self.first_gen_tokens,
|
||||
self.ctx_request_id, self.opaque_state,
|
||||
self.draft_tokens)
|
||||
return tllme.ContextPhaseParams(
|
||||
self.first_gen_tokens, self.ctx_request_id, self.opaque_state, self.draft_tokens
|
||||
)
|
||||
|
||||
def get_request_type(self) -> tllme.RequestType:
|
||||
if self.request_type == "context_only":
|
||||
@ -37,5 +35,6 @@ class DisaggregatedParams:
|
||||
return tllme.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or context_and_generation"
|
||||
f"Unknown request type: {self.request_type}. Must be context_only, generation_only or "
|
||||
"context_and_generation"
|
||||
)
|
||||
|
||||
@ -3,8 +3,7 @@ import weakref
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple,
|
||||
TypeVar)
|
||||
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
import tensorrt as trt
|
||||
|
||||
@ -14,9 +13,7 @@ from .network import Network
|
||||
|
||||
|
||||
class Layer:
|
||||
'''
|
||||
Layer is a wrapper for TensorRT's ILayer with several python-friendly helper functions.
|
||||
'''
|
||||
"""Layer is a wrapper for TensorRT's ILayer with several python-friendly helper functions."""
|
||||
|
||||
def __init__(self, network: Network, trt_layer: trt.ILayer):
|
||||
self._network = weakref.ref(network)
|
||||
@ -30,51 +27,50 @@ class Layer:
|
||||
return self._network()
|
||||
|
||||
def get_inputs(self, *indices: int):
|
||||
'''
|
||||
Get the input tensors of the layer.
|
||||
"""Get the input tensors of the layer.
|
||||
|
||||
Parameters:
|
||||
idx: the indices of the input tensor, will return all inputs if left empty
|
||||
|
||||
Returns:
|
||||
List[Tensor]
|
||||
'''
|
||||
"""
|
||||
from .functional import Tensor
|
||||
|
||||
indices = indices if indices else range(self.trt_layer.num_inputs)
|
||||
|
||||
ret = []
|
||||
for i in indices:
|
||||
assert i < self.trt_layer.num_inputs, f"Invalid input index {i} for layer {self.trt_layer.name}"
|
||||
assert i < self.trt_layer.num_inputs, (
|
||||
f"Invalid input index {i} for layer {self.trt_layer.name}"
|
||||
)
|
||||
|
||||
tensor = self.trt_layer.get_input(i)
|
||||
tensor = Tensor(trt_tensor=tensor,
|
||||
network=self.network,
|
||||
is_network_input=False)
|
||||
tensor = Tensor(trt_tensor=tensor, network=self.network, is_network_input=False)
|
||||
ret.append(tensor)
|
||||
return ret
|
||||
|
||||
def get_outputs(self, *indices: int):
|
||||
'''
|
||||
Get the output tensor of the layer.
|
||||
"""Get the output tensor of the layer.
|
||||
|
||||
Parameters:
|
||||
idx: the index of the output tensor
|
||||
|
||||
Returns:
|
||||
List[Tensor]
|
||||
'''
|
||||
"""
|
||||
from .functional import Tensor
|
||||
|
||||
indices = indices if indices else range(self.trt_layer.num_outputs)
|
||||
|
||||
ret = []
|
||||
for i in indices:
|
||||
assert i < self.trt_layer.num_outputs, f"Invalid output index {i} for layer {self.trt_layer.name}"
|
||||
assert i < self.trt_layer.num_outputs, (
|
||||
f"Invalid output index {i} for layer {self.trt_layer.name}"
|
||||
)
|
||||
|
||||
tensor = self.trt_layer.get_output(i)
|
||||
tensor = Tensor(trt_tensor=tensor,
|
||||
network=self.network,
|
||||
is_network_input=False)
|
||||
tensor = Tensor(trt_tensor=tensor, network=self.network, is_network_input=False)
|
||||
ret.append(tensor)
|
||||
return ret
|
||||
|
||||
@ -82,10 +78,9 @@ class Layer:
|
||||
return self.network.is_removed_layer(self)
|
||||
|
||||
def mark_as_removed(self):
|
||||
'''
|
||||
Mark the layer as removed, this will remove the layer from the network.
|
||||
'''
|
||||
# NOTE, since INetwork python API doesn't provide a way to remove a layer, we actually mark the layer as removed in the network.
|
||||
"""Mark the layer as removed, this will remove the layer from the network."""
|
||||
# NOTE, since INetwork python API doesn't provide a way to remove a layer, we actually mark
|
||||
# the layer as removed in the network.
|
||||
self.network.mark_removed_layer(self)
|
||||
|
||||
# remove the FLayerInfo if exists
|
||||
@ -101,8 +96,8 @@ class Layer:
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.trt_layer, name)
|
||||
|
||||
# Refer to https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html?highlight=elementwise#layers for a complete
|
||||
# list of TRT layers.
|
||||
# Refer to https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html?highlight=elementwise#layers
|
||||
# for a complete list of TRT layers.
|
||||
TRT_LAYER_TYPE_TO_LAYER = {
|
||||
trt.LayerType.CONVOLUTION: trt.IConvolutionLayer,
|
||||
trt.LayerType.ACTIVATION: trt.IActivationLayer,
|
||||
@ -153,14 +148,13 @@ class Layer:
|
||||
}
|
||||
|
||||
if trt_gte(10, 8):
|
||||
TRT_LAYER_TYPE_TO_LAYER[
|
||||
trt.LayerType.DYNAMIC_QUANTIZE] = trt.IQuantizeLayer
|
||||
TRT_LAYER_TYPE_TO_LAYER[trt.LayerType.DYNAMIC_QUANTIZE] = trt.IQuantizeLayer
|
||||
|
||||
def as_layer(self) -> Any:
|
||||
'''
|
||||
Convert to a actual TensorRT layer object, such as IPluginV2Layer or IConvolutionLayer so
|
||||
that we can access the actual layer information.
|
||||
'''
|
||||
"""Convert to a actual TensorRT layer object.
|
||||
|
||||
This can be IPluginV2Layer or IConvolutionLayer, so that we can access the actual layer information.
|
||||
"""
|
||||
if self.type in self.TRT_LAYER_TYPE_TO_LAYER:
|
||||
# bypass TRT's bug of retrieving a specific ILayer type in TensorRT
|
||||
self.trt_layer.__class__ = self.TRT_LAYER_TYPE_TO_LAYER[self.type]
|
||||
@ -188,22 +182,27 @@ class _Pattern:
|
||||
|
||||
|
||||
class PatternRewriter(_Pattern):
|
||||
'''
|
||||
A pattern rewriter is a class that can match a pattern in the graph and rewrite the matched pattern.
|
||||
"""A pattern rewriter is a class that can match a pattern in the graph and rewrite the matched pattern.
|
||||
|
||||
There are two ways to implement a pattern rewriter, either override match() and rewrite() separately, or override match_and_rewrite().
|
||||
'''
|
||||
There are two ways to implement a pattern rewriter, either override match() and rewrite() separately, or
|
||||
override match_and_rewrite().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
root_layer: Optional[Set[trt.LayerType]] = None,
|
||||
separate_match_rewrite=False,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
root_layer: Optional[Set[trt.LayerType]] = None,
|
||||
separate_match_rewrite=False):
|
||||
'''
|
||||
Parameters:
|
||||
name: the name of the rewrite pattern
|
||||
root_layer: the root layer types to start the pattern matching, if not provided, the pattern will traverse all the layers in the graph.
|
||||
separate_match_rewrite: if set to True, the pattern should override match() and rewrite() separately, otherwise, the pattern should override match_and_rewrite()
|
||||
'''
|
||||
root_layer: the root layer types to start the pattern matching, if not provided, the pattern
|
||||
will traverse all the layers in the graph.
|
||||
separate_match_rewrite: if set to True, the pattern should override match() and rewrite()
|
||||
separately, otherwise, the pattern should override match_and_rewrite()
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.root_layer = root_layer
|
||||
self._separate_match_rewrite = separate_match_rewrite
|
||||
@ -219,9 +218,7 @@ class PatternRewriter(_Pattern):
|
||||
|
||||
|
||||
class PatternAnalyzer(_Pattern):
|
||||
|
||||
def __init__(self, name: str,
|
||||
root_layer: Optional[Set[trt.LayerType]]) -> None:
|
||||
def __init__(self, name: str, root_layer: Optional[Set[trt.LayerType]]) -> None:
|
||||
super().__init__(name)
|
||||
self.root_layer = root_layer
|
||||
|
||||
@ -233,16 +230,13 @@ class PatternAnalyzer(_Pattern):
|
||||
|
||||
|
||||
class _PatternManager:
|
||||
PatternType = TypeVar('PatternType')
|
||||
PatternType = TypeVar("PatternType")
|
||||
|
||||
def __init__(self):
|
||||
# records of (benefit, pattern, id)
|
||||
self.patterns: Dict[str, Tuple[int, _PatternManager.PatternType]] = {}
|
||||
|
||||
def add(self,
|
||||
label: str,
|
||||
pattern: "_PatternManager.PatternType",
|
||||
benefit: int = 0):
|
||||
def add(self, label: str, pattern: "_PatternManager.PatternType", benefit: int = 0):
|
||||
assert label not in self.patterns, f"Pattern {label} already exists"
|
||||
self.patterns[label] = (benefit, pattern)
|
||||
|
||||
@ -251,18 +245,18 @@ class _PatternManager:
|
||||
|
||||
|
||||
class RewritePatternManager(_PatternManager):
|
||||
|
||||
def rewrite(self, net: Network, args=None):
|
||||
modified = True
|
||||
# TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after the graph is modified
|
||||
# TODO: we can optimize this by asking TRT to expose a graph iterator consistent even after
|
||||
# the graph is modified.
|
||||
while modified:
|
||||
modified = False
|
||||
# Since the graph iterator is hold by the underlying INetwork, we can only rebuild the graph cache and match the nodes again.
|
||||
# Since the graph iterator is hold by the underlying INetwork, we can only rebuild the
|
||||
# graph cache and match the nodes again.
|
||||
for layer in net.get_layers():
|
||||
if layer.is_removed():
|
||||
continue
|
||||
for (profit, pattern) in sorted(self.patterns.values(),
|
||||
key=lambda x: x[0]):
|
||||
for profit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]):
|
||||
pattern.args = args
|
||||
|
||||
if pattern.root_layer is not None and layer.type not in pattern.root_layer:
|
||||
@ -281,13 +275,11 @@ class RewritePatternManager(_PatternManager):
|
||||
|
||||
|
||||
class AnalysisPatternManager(_PatternManager):
|
||||
|
||||
def analyze(self, graph: Network, args=None):
|
||||
for layer in graph.get_layers():
|
||||
if layer.name in graph.removed_layers:
|
||||
continue
|
||||
for (benefit, pattern) in sorted(self.patterns.values(),
|
||||
key=lambda x: x[0]):
|
||||
for benefit, pattern in sorted(self.patterns.values(), key=lambda x: x[0]):
|
||||
pattern.args = args
|
||||
|
||||
if pattern.root_layer is not None and layer.type not in pattern.root_layer:
|
||||
@ -303,22 +295,25 @@ class AnalysisPatternManager(_PatternManager):
|
||||
|
||||
@dataclass
|
||||
class FLayerInfo:
|
||||
'''
|
||||
The FLayerInfo is used to track the functional layers in the INetwork, and it is used to help the graph rewriting.
|
||||
"""The FLayerInfo is used to track the functional layers in the INetwork and help graph rewriting.
|
||||
|
||||
The lifetime of a FLayer is the same as the corresponding plugin instance in the INetwork. Once the
|
||||
plugin instance is removed by the graph rewriting, the FLayer will be removed as well.
|
||||
|
||||
WHY this is needed?
|
||||
In the current implementation, for functional methods, once it is called in Python, it will lower to a plugin instance in the INetwork. However,
|
||||
the plugin interface is black box with customized logic, we cannot retrieve necessary information from it, this is quite different from ILayer,
|
||||
which provides a set of APIs to retrieve the information. Therefore, we need to record the high level information in the FLayerInfo, and keep
|
||||
In the current implementation, for functional methods, once it is called in Python, it will lower
|
||||
to a plugin instance in the INetwork.
|
||||
However, the plugin interface is black box with customized logic, we cannot retrieve necessary
|
||||
information from it. This is quite different from ILayer, which provides a set of APIs to retrieve
|
||||
the information.
|
||||
Therefore, we need to record the high level information in the FLayerInfo, and keep
|
||||
it consistent during the graph rewriting.
|
||||
'''
|
||||
"""
|
||||
|
||||
layer_kind: str # the method name in the functional.py
|
||||
# Record the raw inputs of the functional layer to be used in the graph rewrite
|
||||
# NOTE: the raw inputs contains both the constants and Tensors, the Tensors will be also updated by graph rewriting
|
||||
# APIs such as `replace_all_uses_with`
|
||||
# NOTE: the raw inputs contains both the constants and Tensors, the Tensors will be also updated by
|
||||
# graph rewriting APIs such as `replace_all_uses_with`
|
||||
raw_inputs: Dict[str, Any]
|
||||
|
||||
raw_outputs: List[Any] = field(default_factory=list, init=False)
|
||||
@ -331,6 +326,7 @@ class FLayerInfo:
|
||||
|
||||
def __post_init__(self):
|
||||
from .functional import Tensor
|
||||
|
||||
assert self.layer_kind
|
||||
|
||||
def replace_with_symbols(arg) -> Any:
|
||||
@ -357,10 +353,10 @@ class FLayerInfo:
|
||||
|
||||
return arg
|
||||
|
||||
self.signature = self.layer_kind, {
|
||||
name: replace_with_symbols(value)
|
||||
for name, value in self.raw_inputs.items()
|
||||
}
|
||||
self.signature = (
|
||||
self.layer_kind,
|
||||
{name: replace_with_symbols(value) for name, value in self.raw_inputs.items()},
|
||||
)
|
||||
|
||||
amend_tensor(self.raw_inputs)
|
||||
|
||||
@ -371,18 +367,15 @@ class FLayerInfo:
|
||||
return self.raw_inputs[name]
|
||||
|
||||
def clone_inputs(self):
|
||||
'''
|
||||
Get a shallow copy of the inputs.
|
||||
'''
|
||||
"""Get a shallow copy of the inputs."""
|
||||
return copy(self.raw_inputs)
|
||||
|
||||
def replace_input_with(self, src, dst):
|
||||
'''
|
||||
Replace the input `src` with the input `dst` in the raw_inputs.
|
||||
"""Replace the input `src` with the input `dst` in the raw_inputs.
|
||||
|
||||
src: Tensor
|
||||
dst: Tensor
|
||||
'''
|
||||
"""
|
||||
from .functional import Tensor
|
||||
|
||||
def replace(arg: Any):
|
||||
@ -399,21 +392,22 @@ class FLayerInfo:
|
||||
replace(self.raw_inputs)
|
||||
|
||||
def replace_outputs_uses_with(self, net: Network, new_outs: List[Any]):
|
||||
'''
|
||||
Replace the output users with the new outputs.
|
||||
"""Replace the output users with the new outputs.
|
||||
|
||||
new_outs: List[Tensor], the new outputs to replace with
|
||||
'''
|
||||
"""
|
||||
from .functional import Tensor
|
||||
|
||||
assert len(self.raw_outputs) == len(new_outs)
|
||||
for old_out, new_out in zip(self.raw_outputs, new_outs):
|
||||
assert type(old_out) == type(
|
||||
new_out
|
||||
), f"rewrite error, the output type {type(old_out)} is different from the new output type {type(new_out)} not match the original output type {type(old_out)}"
|
||||
assert type(old_out) is type(new_out), (
|
||||
f"rewrite error, the output type {type(old_out)} is different from the new output "
|
||||
f"type {type(new_out)} not match the original output type {type(old_out)}"
|
||||
)
|
||||
|
||||
def _swap_tensor_info(new, deprecated):
|
||||
name = deprecated.trt_tensor.name
|
||||
deprecated.trt_tensor.name = name + '_deprecated'
|
||||
deprecated.trt_tensor.name = name + "_deprecated"
|
||||
from .functional import cast
|
||||
|
||||
new = cast(new, deprecated.dtype)
|
||||
@ -463,13 +457,11 @@ class FLayerInfo:
|
||||
return hash(self.signature)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return '<FLayer {}>'.format(self.signature)
|
||||
return "<FLayer {}>".format(self.signature)
|
||||
|
||||
@staticmethod
|
||||
def _get_spec(arg):
|
||||
'''
|
||||
Get the spec that could impact on the Module's topology in the `forward` method.
|
||||
'''
|
||||
"""Get the spec that could impact on the Module's topology in the `forward` method."""
|
||||
from .functional import Tensor
|
||||
|
||||
# For scalars, we track their value since they are constant
|
||||
@ -482,17 +474,17 @@ class FLayerInfo:
|
||||
return Tensor
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return [FLayerInfo._get_spec(x) for x in arg]
|
||||
# NOTE Free to add more types here is broken, carefully note that, from the engine building angle, all the constants
|
||||
# should be captured while for the network variables, their types as placeholders are enough
|
||||
# NOTE Free to add more types here is broken, carefully note that, from the engine building angle,
|
||||
# all the constants should be captured while for the network variables, their types as placeholders
|
||||
# are enough.
|
||||
else:
|
||||
raise TypeError(f"unsupported input type detected: {type(arg)}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FLayerInfoMemo:
|
||||
'''
|
||||
FLayerInfoMemo holds the FLayer of all the necessary functional layers.
|
||||
'''
|
||||
"""FLayerInfoMemo holds the FLayer of all the necessary functional layers."""
|
||||
|
||||
data: Dict[str, FLayerInfo] = field(default_factory=dict, init=False)
|
||||
|
||||
cur_flayer: ClassVar[Optional[FLayerInfo]] = None
|
||||
@ -502,11 +494,8 @@ class FLayerInfoMemo:
|
||||
self.data[layer_name] = layer
|
||||
|
||||
def create(self, fn: Callable, *args, **kwargs) -> FLayerInfo:
|
||||
'''
|
||||
Add a FLayer to the memo.
|
||||
'''
|
||||
return FLayerInfo(fn.__name__,
|
||||
self.get_function_arg_dict(fn, *args, **kwargs))
|
||||
"""Add a FLayer to the memo."""
|
||||
return FLayerInfo(fn.__name__, self.get_function_arg_dict(fn, *args, **kwargs))
|
||||
|
||||
def get(self, layer_name: str) -> Optional[FLayerInfo]:
|
||||
return self.data.get(layer_name, None)
|
||||
@ -517,29 +506,24 @@ class FLayerInfoMemo:
|
||||
|
||||
@staticmethod
|
||||
def instance() -> "FLayerInfoMemo":
|
||||
'''
|
||||
A singleton instance of FLayerMemo.
|
||||
'''
|
||||
"""A singleton instance of FLayerMemo."""
|
||||
from ._common import default_net
|
||||
|
||||
return default_net().flayer_memo
|
||||
|
||||
@staticmethod
|
||||
def get_function_arg_dict(f: Callable, *args, **kwargs):
|
||||
'''
|
||||
Get the input argument dict of a function.
|
||||
'''
|
||||
"""Get the input argument dict of a function."""
|
||||
sig = inspect.signature(f)
|
||||
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
return {k: v for k, v in bound_args.arguments.items() if k != 'self'}
|
||||
return {k: v for k, v in bound_args.arguments.items() if k != "self"}
|
||||
|
||||
|
||||
class FLayerScope:
|
||||
'''
|
||||
FLayerScope is used to capture the plugin within a functional method.
|
||||
'''
|
||||
"""FLayerScope is used to capture the plugin within a functional method."""
|
||||
|
||||
def __init__(self, fn, *args, **kwargs):
|
||||
self.layer = FLayerInfoMemo.instance().create(fn, *args, **kwargs)
|
||||
@ -552,14 +536,14 @@ class FLayerScope:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
FLayerInfoMemo.cur_flayer = None
|
||||
if exc_type is None:
|
||||
assert self.layer.layer_name != "", f"FLayer {self.layer.layer_kind} without a plugin name detected"
|
||||
assert self.layer.layer_name != "", (
|
||||
f"FLayer {self.layer.layer_kind} without a plugin name detected"
|
||||
)
|
||||
FLayerInfoMemo.instance().add(self.layer.layer_name, self.layer)
|
||||
|
||||
|
||||
def record_signature(f):
|
||||
'''
|
||||
Helps to decorate a functional method and record its metadata with a FLayerInfo.
|
||||
'''
|
||||
"""Helps to decorate a functional method and record its metadata with a FLayerInfo."""
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
@ -577,10 +561,8 @@ _global_analysis_pattern_manager = AnalysisPatternManager()
|
||||
|
||||
|
||||
class FuseAttentionWithBiasPass(PatternRewriter):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="fuse_attention_with_bias",
|
||||
separate_match_rewrite=False)
|
||||
super().__init__(name="fuse_attention_with_bias", separate_match_rewrite=False)
|
||||
|
||||
@staticmethod
|
||||
def is_attention_plugin(layer: Layer) -> bool:
|
||||
@ -588,14 +570,15 @@ class FuseAttentionWithBiasPass(PatternRewriter):
|
||||
return False
|
||||
p = layer.as_layer().plugin
|
||||
conds = [
|
||||
p.plugin_namespace == 'tensorrt_llm',
|
||||
p.plugin_type == 'GPTAttention', p.num_outputs == 2
|
||||
p.plugin_namespace == "tensorrt_llm",
|
||||
p.plugin_type == "GPTAttention",
|
||||
p.num_outputs == 2,
|
||||
]
|
||||
return all(conds)
|
||||
|
||||
@staticmethod
|
||||
def is_elementwise_sum(layer: Layer) -> bool:
|
||||
l = layer.as_layer()
|
||||
l = layer.as_layer() # noqa: E741
|
||||
if l.type != trt.LayerType.ELEMENTWISE:
|
||||
return False
|
||||
return l.op == trt.ElementWiseOperation.SUM
|
||||
@ -612,11 +595,11 @@ class FuseAttentionWithBiasPass(PatternRewriter):
|
||||
layer = tensor.get_parent()
|
||||
if layer is None or depth > max_depth:
|
||||
return False
|
||||
if layer.type == trt.LayerType.CONSTANT and len(
|
||||
layer.get_inputs()) == 0:
|
||||
if layer.type == trt.LayerType.CONSTANT and len(layer.get_inputs()) == 0:
|
||||
return True
|
||||
for _ in layer.get_inputs():
|
||||
if not const_foldable(_, depth + 1): return False
|
||||
if not const_foldable(_, depth + 1):
|
||||
return False
|
||||
return True
|
||||
|
||||
for input in layer.get_inputs():
|
||||
@ -628,27 +611,26 @@ class FuseAttentionWithBiasPass(PatternRewriter):
|
||||
|
||||
def match_and_rewrite(self, layer: Layer) -> bool:
|
||||
from tensorrt_llm.network import net_guard
|
||||
|
||||
with net_guard(layer.network):
|
||||
if not self.is_attention_plugin(layer):
|
||||
return False
|
||||
plugin_flayer = FLayerInfoMemo.instance().get(layer.name)
|
||||
input = plugin_flayer.raw_inputs['qkv']
|
||||
if input is None or isinstance(input, list) or len(
|
||||
list(input.get_users())) != 1:
|
||||
input = plugin_flayer.raw_inputs["qkv"]
|
||||
if input is None or isinstance(input, list) or len(list(input.get_users())) != 1:
|
||||
return False
|
||||
parent_layer = input.get_parent()
|
||||
if not self.is_elementwise_sum(parent_layer):
|
||||
return False
|
||||
eltwise_const_inputs, eltwise_mutable_inputs = self.get_eltwise_inputs(
|
||||
parent_layer)
|
||||
if len(eltwise_const_inputs) != 1 or len(
|
||||
eltwise_mutable_inputs) != 1:
|
||||
eltwise_const_inputs, eltwise_mutable_inputs = self.get_eltwise_inputs(parent_layer)
|
||||
if len(eltwise_const_inputs) != 1 or len(eltwise_mutable_inputs) != 1:
|
||||
return False
|
||||
if plugin_flayer.raw_inputs['qkv_bias'] is not None:
|
||||
if plugin_flayer.raw_inputs["qkv_bias"] is not None:
|
||||
return False
|
||||
plugin_flayer.raw_inputs['qkv'] = eltwise_mutable_inputs[0]
|
||||
plugin_flayer.raw_inputs['qkv_bias'] = eltwise_const_inputs[0]
|
||||
plugin_flayer.raw_inputs["qkv"] = eltwise_mutable_inputs[0]
|
||||
plugin_flayer.raw_inputs["qkv_bias"] = eltwise_const_inputs[0]
|
||||
from .functional import gpt_attention
|
||||
|
||||
new_outputs = gpt_attention(**plugin_flayer.raw_inputs)
|
||||
plugin_flayer.replace_outputs_uses_with(layer.network, new_outputs)
|
||||
return True
|
||||
@ -657,7 +639,7 @@ class FuseAttentionWithBiasPass(PatternRewriter):
|
||||
def optimize(net):
|
||||
patterns = RewritePatternManager()
|
||||
patterns.add(
|
||||
label='fuse_attention_with_bias',
|
||||
label="fuse_attention_with_bias",
|
||||
pattern=FuseAttentionWithBiasPass(),
|
||||
)
|
||||
patterns.rewrite(net)
|
||||
|
||||
@ -30,23 +30,21 @@ class Singleton(type):
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(Singleton,
|
||||
cls).__call__(*args, **kwargs)
|
||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class Logger(metaclass=Singleton):
|
||||
ENV_VARIABLE = "TLLM_LOG_LEVEL"
|
||||
PREFIX = "TRT-LLM"
|
||||
DEFAULT_LEVEL = "error"
|
||||
|
||||
ENV_VARIABLE = 'TLLM_LOG_LEVEL'
|
||||
PREFIX = 'TRT-LLM'
|
||||
DEFAULT_LEVEL = 'error'
|
||||
|
||||
INTERNAL_ERROR = '[F]'
|
||||
ERROR = '[E]'
|
||||
WARNING = '[W]'
|
||||
INFO = '[I]'
|
||||
VERBOSE = '[V]'
|
||||
DEBUG = '[D]'
|
||||
INTERNAL_ERROR = "[F]"
|
||||
ERROR = "[E]"
|
||||
WARNING = "[W]"
|
||||
INFO = "[I]"
|
||||
VERBOSE = "[V]"
|
||||
DEBUG = "[D]"
|
||||
|
||||
def __init__(self):
|
||||
environ_severity = os.environ.get(self.ENV_VARIABLE)
|
||||
@ -54,8 +52,7 @@ class Logger(metaclass=Singleton):
|
||||
|
||||
self.rank: Optional[int] = None
|
||||
|
||||
min_severity = environ_severity.lower(
|
||||
) if self._set_from_env else self.DEFAULT_LEVEL
|
||||
min_severity = environ_severity.lower() if self._set_from_env else self.DEFAULT_LEVEL
|
||||
invalid_severity = min_severity not in severity_map
|
||||
if invalid_severity:
|
||||
min_severity = self.DEFAULT_LEVEL
|
||||
@ -66,14 +63,13 @@ class Logger(metaclass=Singleton):
|
||||
self._logger.propagate = False
|
||||
handler = logging.StreamHandler(stream=sys.stdout)
|
||||
handler.setFormatter(
|
||||
logging.Formatter(fmt='[%(asctime)s] %(message)s',
|
||||
datefmt='%m/%d/%Y-%H:%M:%S'))
|
||||
logging.Formatter(fmt="[%(asctime)s] %(message)s", datefmt="%m/%d/%Y-%H:%M:%S")
|
||||
)
|
||||
self._logger.addHandler(handler)
|
||||
self._logger.setLevel(severity_map[min_severity][1])
|
||||
self._polygraphy_logger = G_LOGGER
|
||||
if self._polygraphy_logger is not None:
|
||||
self._polygraphy_logger.module_severity = severity_map[
|
||||
min_severity][2]
|
||||
self._polygraphy_logger.module_severity = severity_map[min_severity][2]
|
||||
|
||||
# For log_once
|
||||
self._appeared_keys = set()
|
||||
@ -98,16 +94,16 @@ class Logger(metaclass=Singleton):
|
||||
elif severity == self.VERBOSE or severity == self.DEBUG:
|
||||
return self._logger.debug
|
||||
else:
|
||||
raise AttributeError(f'No such severity: {severity}')
|
||||
raise AttributeError(f"No such severity: {severity}")
|
||||
|
||||
@property
|
||||
def trt_logger(self) -> trt.ILogger:
|
||||
return self._trt_logger
|
||||
|
||||
def log(self, severity, *msg):
|
||||
parts = [f'[{self.PREFIX}]']
|
||||
parts = [f"[{self.PREFIX}]"]
|
||||
if self.rank is not None:
|
||||
parts.append(f'[RANK {self.rank}]')
|
||||
parts.append(f"[RANK {self.rank}]")
|
||||
parts.append(severity)
|
||||
parts.extend(map(str, msg))
|
||||
self._func_wrapper(severity)(" ".join(parts))
|
||||
@ -164,29 +160,28 @@ class Logger(metaclass=Singleton):
|
||||
self._trt_logger.min_severity = severity_map[min_severity][0]
|
||||
self._logger.setLevel(severity_map[min_severity][1])
|
||||
if self._polygraphy_logger is not None:
|
||||
self._polygraphy_logger.module_severity = severity_map[
|
||||
min_severity][2]
|
||||
self._polygraphy_logger.module_severity = severity_map[min_severity][2]
|
||||
|
||||
|
||||
severity_map = {
|
||||
'internal_error': [trt.Logger.INTERNAL_ERROR, logging.CRITICAL],
|
||||
'error': [trt.Logger.ERROR, logging.ERROR],
|
||||
'warning': [trt.Logger.WARNING, logging.WARNING],
|
||||
'info': [trt.Logger.INFO, logging.INFO],
|
||||
'verbose': [trt.Logger.VERBOSE, logging.DEBUG],
|
||||
'debug': [trt.Logger.VERBOSE, logging.DEBUG],
|
||||
'trace': [trt.Logger.VERBOSE, logging.DEBUG],
|
||||
"internal_error": [trt.Logger.INTERNAL_ERROR, logging.CRITICAL],
|
||||
"error": [trt.Logger.ERROR, logging.ERROR],
|
||||
"warning": [trt.Logger.WARNING, logging.WARNING],
|
||||
"info": [trt.Logger.INFO, logging.INFO],
|
||||
"verbose": [trt.Logger.VERBOSE, logging.DEBUG],
|
||||
"debug": [trt.Logger.VERBOSE, logging.DEBUG],
|
||||
"trace": [trt.Logger.VERBOSE, logging.DEBUG],
|
||||
}
|
||||
|
||||
if G_LOGGER is not None:
|
||||
g_logger_severity_map = {
|
||||
'internal_error': G_LOGGER.CRITICAL,
|
||||
'error': G_LOGGER.ERROR,
|
||||
'warning': G_LOGGER.WARNING,
|
||||
'info': G_LOGGER.INFO,
|
||||
'verbose': G_LOGGER.SUPER_VERBOSE,
|
||||
'debug': G_LOGGER.SUPER_VERBOSE,
|
||||
'trace': G_LOGGER.SUPER_VERBOSE,
|
||||
"internal_error": G_LOGGER.CRITICAL,
|
||||
"error": G_LOGGER.ERROR,
|
||||
"warning": G_LOGGER.WARNING,
|
||||
"info": G_LOGGER.INFO,
|
||||
"verbose": G_LOGGER.SUPER_VERBOSE,
|
||||
"debug": G_LOGGER.SUPER_VERBOSE,
|
||||
"trace": G_LOGGER.SUPER_VERBOSE,
|
||||
}
|
||||
for key, value in g_logger_severity_map.items():
|
||||
severity_map[key].append(value)
|
||||
|
||||
@ -11,12 +11,10 @@ import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from ._utils import (DictConversion, pad_vocab_size, release_gc,
|
||||
str_dtype_to_torch, torch_to_numpy)
|
||||
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
|
||||
from .layers.linear import ColumnLinear
|
||||
from .mapping import Mapping
|
||||
from .models.convert_utils import (get_model_path, load_state_dict,
|
||||
split_matrix_tp)
|
||||
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .runtime import ModelConfig
|
||||
@ -25,13 +23,13 @@ if TYPE_CHECKING:
|
||||
def get_all_nemo_lora_weights(lora_weights):
|
||||
layer_weights = defaultdict(dict)
|
||||
adapter_key = "self_attention.adapter_layer.lora_kqv_adapter"
|
||||
layer_pattern = re.compile(r'.*\.layers\.(\d+)\..*')
|
||||
layer_pattern = re.compile(r".*\.layers\.(\d+)\..*")
|
||||
for key, weights in lora_weights.items():
|
||||
if adapter_key in key:
|
||||
if key.endswith('linear_in.weight'):
|
||||
inout = 'in'
|
||||
elif key.endswith('linear_out.weight'):
|
||||
inout = 'out'
|
||||
if key.endswith("linear_in.weight"):
|
||||
inout = "in"
|
||||
elif key.endswith("linear_out.weight"):
|
||||
inout = "out"
|
||||
else:
|
||||
continue
|
||||
m = layer_pattern.match(key)
|
||||
@ -46,9 +44,9 @@ def get_all_nemo_lora_weights(lora_weights):
|
||||
return layer_weights
|
||||
|
||||
|
||||
# The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight
|
||||
# The pattern is {layer_prefix:1}.{layer_idx:2}.{module_prefix:3}.{module_name or {expert_name:5}.{expert_idx:6}.{module_name:7} :4}.lora_{A|B:8}.weight # noqa: E501
|
||||
HF_LORA_PATTERN = re.compile(
|
||||
r'(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight)'
|
||||
r"(.*)\.(\d+)\.(\w+)\.(\w+|\w+\.\w+|(\w+)\.(\d+)\.(\w+))\.(?:lora_(?:(A|B)\.weight|(magnitude)_vector)|weight_(m_wdecomp).weight)"
|
||||
)
|
||||
|
||||
|
||||
@ -77,7 +75,9 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
|
||||
hf_module = m.group(3) + "." + module_name
|
||||
if hf_module not in hf_modules:
|
||||
hf_module = module_name
|
||||
assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported list {hf_modules}"
|
||||
assert hf_module in hf_modules, (
|
||||
f"hf_module {hf_module} is not in supported list {hf_modules}"
|
||||
)
|
||||
|
||||
is_lora_a_or_b = m.group(8) is not None
|
||||
if is_lora_a_or_b:
|
||||
@ -90,13 +90,11 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
|
||||
all_weights[layer_idx][hf_module][inout_or_mag] = weights
|
||||
else:
|
||||
all_weights[layer_idx][hf_module].setdefault(expert_idx, {})
|
||||
all_weights[layer_idx][hf_module][expert_idx][
|
||||
inout_or_mag] = weights
|
||||
all_weights[layer_idx][hf_module][expert_idx][inout_or_mag] = weights
|
||||
return all_weights
|
||||
|
||||
|
||||
def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
|
||||
|
||||
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
|
||||
if expert_idx is None:
|
||||
all_weights[layer_idx][hf_module][inout] = weights
|
||||
@ -110,7 +108,6 @@ def get_all_hf_lora_weights(lora_weights, hf_modules, component=None):
|
||||
|
||||
|
||||
def get_hf_target_modules(lora_weights, hf_modules):
|
||||
|
||||
def iter_fn(layer_idx, hf_module, expert_idx, inout, weights):
|
||||
hf_target_modules.add(hf_module)
|
||||
|
||||
@ -130,11 +127,9 @@ def invert_module_mapping(trtllm_modules_to_hf_modules):
|
||||
return hf_modules_to_trtllm_modules
|
||||
|
||||
|
||||
def norm_dora_magnitude(W0: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
m: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
def norm_dora_magnitude(
|
||||
W0: torch.Tensor, A: torch.Tensor, B: torch.Tensor, m: torch.Tensor, scaling: float = 1.0
|
||||
):
|
||||
new_weight_v = W0 + (B @ A) * scaling
|
||||
norm_m = m.view(-1) / (torch.linalg.norm(new_weight_v, dim=1)).detach()
|
||||
return norm_m
|
||||
@ -143,7 +138,7 @@ def norm_dora_magnitude(W0: torch.Tensor,
|
||||
@dataclass
|
||||
class LoraConfig(DictConversion):
|
||||
lora_dir: List[str] = field(default_factory=list)
|
||||
lora_ckpt_source: str = 'hf'
|
||||
lora_ckpt_source: str = "hf"
|
||||
max_lora_rank: int = 64
|
||||
lora_target_modules: List[str] = field(default_factory=list)
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
|
||||
@ -151,9 +146,9 @@ class LoraConfig(DictConversion):
|
||||
max_cpu_loras: int = 4
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.lora_ckpt_source in [
|
||||
'hf', 'nemo'
|
||||
], f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
||||
assert self.lora_ckpt_source in ["hf", "nemo"], (
|
||||
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
|
||||
)
|
||||
|
||||
@property
|
||||
def missing_qkv_modules(self) -> List[str]:
|
||||
@ -169,7 +164,6 @@ class LoraModelConfig:
|
||||
|
||||
|
||||
class HfLoraLoader:
|
||||
|
||||
def __init__(self, lora_dirs: List[str]):
|
||||
self.lora_target_modules = []
|
||||
self.is_valid = False
|
||||
@ -183,8 +177,7 @@ class HfLoraLoader:
|
||||
for lora_dir in lora_dirs:
|
||||
model_path = get_model_path(lora_dir, "adapter_model")
|
||||
if model_path is None:
|
||||
raise ValueError(
|
||||
f"adapter_model file does not exist in {lora_dir}")
|
||||
raise ValueError(f"adapter_model file does not exist in {lora_dir}")
|
||||
config_file = Path(f"{lora_dir}/adapter_config.json")
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"{config_file} does not exist")
|
||||
@ -207,12 +200,10 @@ class HfLoraLoader:
|
||||
self.vocab_size = self.lm_head.shape[0]
|
||||
|
||||
if "embed_tokens" in adapter_config["modules_to_save"]:
|
||||
self.embed_tokens = lora_weight[
|
||||
"base_model.model.model.embed_tokens.weight"]
|
||||
self.embed_tokens = lora_weight["base_model.model.model.embed_tokens.weight"]
|
||||
|
||||
def get_target_modules(self, trtllm_modules_to_hf_modules):
|
||||
hf_modules_to_trtllm_modules = invert_module_mapping(
|
||||
trtllm_modules_to_hf_modules)
|
||||
hf_modules_to_trtllm_modules = invert_module_mapping(trtllm_modules_to_hf_modules)
|
||||
lora_target_modules = set()
|
||||
if self.is_valid:
|
||||
hf_target_modules = get_hf_target_modules(
|
||||
@ -226,7 +217,6 @@ class HfLoraLoader:
|
||||
|
||||
|
||||
class NemoLoraLoader:
|
||||
|
||||
def __init__(self, lora_dirs: List[str]):
|
||||
self.lora_target_modules = []
|
||||
self.is_valid = False
|
||||
@ -269,21 +259,21 @@ def get_default_trtllm_modules_to_hf_modules():
|
||||
|
||||
|
||||
def load_torch_hf_lora(lora_config: LoraConfig):
|
||||
"""
|
||||
This is a shortned version of load_hf_lora that is used for torch models.
|
||||
"""This is a shortned version of load_hf_lora that is used for torch models.
|
||||
|
||||
Main problem is model.config in legacy code is custom (defined in the legacy code) whereas
|
||||
pivot model config is the transformer's one.
|
||||
"""
|
||||
# TODO smor- need to comibe with load_hf_lora
|
||||
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules(
|
||||
)
|
||||
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules()
|
||||
|
||||
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
|
||||
lora_loader = HfLoraLoader(lora_config.lora_dir)
|
||||
|
||||
if len(lora_config.lora_target_modules) == 0:
|
||||
lora_config.lora_target_modules = lora_loader.get_target_modules(
|
||||
lora_config.trtllm_modules_to_hf_modules)
|
||||
lora_config.trtllm_modules_to_hf_modules
|
||||
)
|
||||
|
||||
if len(lora_config.lora_target_modules) == 0:
|
||||
raise ValueError(
|
||||
@ -291,8 +281,7 @@ def load_torch_hf_lora(lora_config: LoraConfig):
|
||||
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
|
||||
)
|
||||
|
||||
missing_qkv_modules = LoraManager.get_missing_qkv_modules(
|
||||
lora_config.lora_target_modules)
|
||||
missing_qkv_modules = LoraManager.get_missing_qkv_modules(lora_config.lora_target_modules)
|
||||
lora_config.lora_target_modules.extend(missing_qkv_modules)
|
||||
|
||||
|
||||
@ -301,7 +290,8 @@ def load_hf_lora(
|
||||
lora_config: LoraConfig,
|
||||
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules(
|
||||
trtllm_modules_to_hf_modules = (
|
||||
trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules()
|
||||
)
|
||||
lora_config.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules
|
||||
|
||||
@ -309,15 +299,15 @@ def load_hf_lora(
|
||||
|
||||
if len(lora_config.lora_target_modules) == 0:
|
||||
lora_config.lora_target_modules = lora_loader.get_target_modules(
|
||||
trtllm_modules_to_hf_modules)
|
||||
trtllm_modules_to_hf_modules
|
||||
)
|
||||
if len(lora_config.lora_target_modules) == 0:
|
||||
raise ValueError(
|
||||
"lora_target_modules is empty. "
|
||||
"Please specify lora_target_modules or provide lora_dir to infer lora_target_modules."
|
||||
)
|
||||
|
||||
missing_qkv_modules = LoraManager.get_missing_qkv_modules(
|
||||
lora_config.lora_target_modules)
|
||||
missing_qkv_modules = LoraManager.get_missing_qkv_modules(lora_config.lora_target_modules)
|
||||
lora_config.lora_target_modules.extend(missing_qkv_modules)
|
||||
|
||||
if lora_loader.is_valid:
|
||||
@ -341,15 +331,12 @@ def load_hf_lora(
|
||||
num_embeddings=config.vocab_size,
|
||||
embedding_dim=config.hidden_size,
|
||||
dtype=config.dtype,
|
||||
tp_size=mapping.tp_size
|
||||
if config.use_parallel_embedding else 1,
|
||||
tp_group=mapping.tp_group
|
||||
if config.use_parallel_embedding else None,
|
||||
tp_size=mapping.tp_size if config.use_parallel_embedding else 1,
|
||||
tp_group=mapping.tp_group if config.use_parallel_embedding else None,
|
||||
sharding_dim=config.embedding_sharding_dim,
|
||||
tp_rank=mapping.tp_rank,
|
||||
)
|
||||
model.transformer.vocab_embedding.weight.value = weight.to(
|
||||
torch_dtype)
|
||||
model.transformer.vocab_embedding.weight.value = weight.to(torch_dtype)
|
||||
if mapping.is_last_pp_rank() and lora_loader.lm_head is not None:
|
||||
weight = lora_loader.lm_head
|
||||
vocab_size = lora_loader.vocab_size
|
||||
@ -359,9 +346,13 @@ def load_hf_lora(
|
||||
pad_width = vocab_size_padded - vocab_size
|
||||
|
||||
weight = torch.from_numpy(
|
||||
np.pad(torch_to_numpy(weight), ((0, pad_width), (0, 0)),
|
||||
'constant',
|
||||
constant_values=0))
|
||||
np.pad(
|
||||
torch_to_numpy(weight),
|
||||
((0, pad_width), (0, 0)),
|
||||
"constant",
|
||||
constant_values=0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
vocab_size_padded = vocab_size
|
||||
if model.lm_head.weight.raw_value.shape != weight.shape:
|
||||
@ -392,8 +383,7 @@ def use_lora(
|
||||
elif lora_config.lora_ckpt_source == "hf":
|
||||
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
||||
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
|
||||
|
||||
|
||||
def unpack_nemo_weights(nemo_archive_path):
|
||||
@ -416,8 +406,9 @@ def unpack_nemo_weights(nemo_archive_path):
|
||||
model_config_dict = yaml.safe_load(model_config_content)
|
||||
|
||||
model_weights_bytes = model_weights_file.read()
|
||||
model_weights_dict = torch.load(io.BytesIO(model_weights_bytes),
|
||||
map_location=torch.device("cpu"))
|
||||
model_weights_dict = torch.load(
|
||||
io.BytesIO(model_weights_bytes), map_location=torch.device("cpu")
|
||||
)
|
||||
|
||||
return model_config_dict, model_weights_dict
|
||||
|
||||
@ -446,58 +437,55 @@ class LoraManager(object):
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
'''
|
||||
_lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
|
||||
{
|
||||
uid: {
|
||||
0: {
|
||||
lora_module: int
|
||||
}, # layer_0_rank,
|
||||
1: {
|
||||
lora_module: int
|
||||
}, # layer_1_rank,
|
||||
...
|
||||
}
|
||||
}
|
||||
"""Constructor."""
|
||||
# _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
|
||||
# {
|
||||
# uid: {
|
||||
# 0: {
|
||||
# lora_module: int
|
||||
# }, # layer_0_rank,
|
||||
# 1: {
|
||||
# lora_module: int
|
||||
# }, # layer_1_rank,
|
||||
# ...
|
||||
# }
|
||||
# }
|
||||
|
||||
_lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]]
|
||||
{
|
||||
uid: {
|
||||
0: {
|
||||
lora_module: [t_in, t_out]
|
||||
}, # layer_0,
|
||||
1: {
|
||||
lora_module: [t_in, t_out]
|
||||
}, # layer_1,
|
||||
...
|
||||
}
|
||||
}
|
||||
# _lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]]
|
||||
# {
|
||||
# uid: {
|
||||
# 0: {
|
||||
# lora_module: [t_in, t_out]
|
||||
# }, # layer_0,
|
||||
# 1: {
|
||||
# lora_module: [t_in, t_out]
|
||||
# }, # layer_1,
|
||||
# ...
|
||||
# }
|
||||
# }
|
||||
|
||||
'''
|
||||
self._lora_uid_counter = 0
|
||||
self._lora_uid_to_low_ranks: Dict[str, Dict[int, Dict[str, int]]] = {}
|
||||
# hold the torch tensors and prevent them from being freed
|
||||
# TODO(enweiz): free device tensors if it's used for c++ runtime only
|
||||
self._lora_weights: List[torch.Tensor] = []
|
||||
self._lora_weights_pointers_list: Dict[str, Dict[int,
|
||||
Dict[str,
|
||||
List[int]]]] = {}
|
||||
self._lora_weights_pointers_list: Dict[str, Dict[int, Dict[str, List[int]]]] = {}
|
||||
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
|
||||
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
|
||||
self.lora_target_modules: List[str] = []
|
||||
|
||||
@staticmethod
|
||||
def get_missing_qkv_modules(lora_target_modules):
|
||||
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time.
|
||||
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor to fill the missing ones.
|
||||
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
|
||||
# all disabled at the same time.
|
||||
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
|
||||
# to fill the missing ones.
|
||||
missing_qkv_modules = []
|
||||
if any(x in lora_target_modules
|
||||
for x in ["attn_q", "attn_k", "attn_v"]):
|
||||
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
|
||||
for lora_module in ["attn_q", "attn_k", "attn_v"]:
|
||||
if lora_module not in lora_target_modules:
|
||||
missing_qkv_modules.append(lora_module)
|
||||
if any(x in lora_target_modules
|
||||
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
|
||||
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
|
||||
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
|
||||
if lora_module not in lora_target_modules:
|
||||
missing_qkv_modules.append(lora_module)
|
||||
@ -507,30 +495,38 @@ class LoraManager(object):
|
||||
def missing_qkv_modules(self) -> List[str]:
|
||||
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
|
||||
|
||||
def load_from_ckpt(self,
|
||||
model_dirs_or_files: List[str],
|
||||
model_config: Union['ModelConfig', LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
ckpt_source: str = 'hf'):
|
||||
if ckpt_source == 'hf':
|
||||
self.load_from_hf(model_dirs=model_dirs_or_files,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
uids=uids)
|
||||
elif ckpt_source == 'nemo':
|
||||
self.load_from_nemo(model_files=model_dirs_or_files,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
uids=uids)
|
||||
def load_from_ckpt(
|
||||
self,
|
||||
model_dirs_or_files: List[str],
|
||||
model_config: Union["ModelConfig", LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
ckpt_source: str = "hf",
|
||||
):
|
||||
if ckpt_source == "hf":
|
||||
self.load_from_hf(
|
||||
model_dirs=model_dirs_or_files,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
uids=uids,
|
||||
)
|
||||
elif ckpt_source == "nemo":
|
||||
self.load_from_nemo(
|
||||
model_files=model_dirs_or_files,
|
||||
model_config=model_config,
|
||||
runtime_mapping=runtime_mapping,
|
||||
uids=uids,
|
||||
)
|
||||
else:
|
||||
assert False, f"{self.__class__.__name__} does not support source {ckpt_source}"
|
||||
|
||||
def load_from_nemo(self,
|
||||
model_files: List[str],
|
||||
model_config: Union['ModelConfig', LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None):
|
||||
def load_from_nemo(
|
||||
self,
|
||||
model_files: List[str],
|
||||
model_config: Union["ModelConfig", LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
):
|
||||
if runtime_mapping is None:
|
||||
runtime_mapping = Mapping()
|
||||
tp_size = runtime_mapping.tp_size
|
||||
@ -554,11 +550,9 @@ class LoraManager(object):
|
||||
|
||||
def load_from_model_file(uid, model_file):
|
||||
if uid not in self._cpp_lora_weights:
|
||||
self._cpp_lora_weights[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
|
||||
if uid not in self._cpp_lora_config:
|
||||
self._cpp_lora_config[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
self._cpp_lora_config[uid] = [] # Will be converted to tensor later
|
||||
|
||||
_, nemo_weights = unpack_nemo_weights(model_file)
|
||||
all_lora_weights = get_all_nemo_lora_weights(nemo_weights)
|
||||
@ -571,72 +565,67 @@ class LoraManager(object):
|
||||
|
||||
for lora_module in self.lora_target_modules:
|
||||
if lora_module != "attn_qkv":
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][
|
||||
lora_module] = 0
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
|
||||
continue
|
||||
|
||||
if lora_module == "attn_qkv":
|
||||
t_in = all_lora_weights[layer_idx]["in"]
|
||||
t_out = all_lora_weights[layer_idx]["out"]
|
||||
assert t_out.shape[0] % tp_size == 0
|
||||
t_out = torch.split(t_out,
|
||||
t_out.shape[0] // tp_size,
|
||||
dim=0)[tp_rank].contiguous()
|
||||
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
else:
|
||||
t_in = None
|
||||
t_out = None
|
||||
|
||||
if t_in is not None and t_out is not None:
|
||||
t_in = t_in.cuda().to(
|
||||
str_dtype_to_torch(
|
||||
model_config.dtype)).contiguous()
|
||||
t_out = t_out.cuda().to(
|
||||
str_dtype_to_torch(
|
||||
model_config.dtype)).contiguous()
|
||||
t_in = t_in.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous()
|
||||
t_out = t_out.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous()
|
||||
rank = t_in.shape[0]
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][
|
||||
lora_module] = int(rank)
|
||||
self._lora_weights_pointers_list[uid][layer_idx][
|
||||
lora_module] = [
|
||||
t_in.data_ptr(),
|
||||
t_out.data_ptr(), 0
|
||||
]
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = int(rank)
|
||||
self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [
|
||||
t_in.data_ptr(),
|
||||
t_out.data_ptr(),
|
||||
0,
|
||||
]
|
||||
|
||||
# prevent torch free this buffer
|
||||
self._lora_weights.append(t_in)
|
||||
self._lora_weights.append(t_out)
|
||||
self._cpp_lora_weights[uid].append(
|
||||
torch.concatenate(
|
||||
[t_in.flatten().cpu(),
|
||||
t_out.flatten().cpu()]))
|
||||
torch.concatenate([t_in.flatten().cpu(), t_out.flatten().cpu()])
|
||||
)
|
||||
self._cpp_lora_config[uid].append(
|
||||
torch.tensor([
|
||||
self.LORA_MODULE_IDS[lora_module], layer_idx,
|
||||
int(rank)
|
||||
],
|
||||
dtype=torch.int32))
|
||||
torch.tensor(
|
||||
[self.LORA_MODULE_IDS[lora_module], layer_idx, int(rank)],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
)
|
||||
|
||||
max_weight_size = max(
|
||||
w.size(0) for w in self._cpp_lora_weights[uid])
|
||||
self._cpp_lora_weights[uid] = torch.stack([
|
||||
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
|
||||
for w in self._cpp_lora_weights[uid]
|
||||
])
|
||||
self._cpp_lora_config[uid] = torch.stack(
|
||||
[c for c in self._cpp_lora_config[uid]])
|
||||
max_weight_size = max(w.size(0) for w in self._cpp_lora_weights[uid])
|
||||
self._cpp_lora_weights[uid] = torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
|
||||
for w in self._cpp_lora_weights[uid]
|
||||
]
|
||||
)
|
||||
self._cpp_lora_config[uid] = torch.stack([c for c in self._cpp_lora_config[uid]])
|
||||
|
||||
for uid, model_file in zip(new_uids, new_model_files):
|
||||
load_from_model_file(uid, model_file)
|
||||
release_gc()
|
||||
|
||||
def load_from_hf(self,
|
||||
model_dirs: List[str],
|
||||
model_config: Union['ModelConfig', LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
component: Optional[str] = None):
|
||||
'''
|
||||
lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b
|
||||
def load_from_hf(
|
||||
self,
|
||||
model_dirs: List[str],
|
||||
model_config: Union["ModelConfig", LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None,
|
||||
component: Optional[str] = None,
|
||||
):
|
||||
"""Lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b.
|
||||
|
||||
{
|
||||
"base_model_name_or_path": "/Llama-2-7b-hf",
|
||||
"bias": "none",
|
||||
@ -682,7 +671,7 @@ class LoraManager(object):
|
||||
base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64])
|
||||
...
|
||||
|
||||
'''
|
||||
"""
|
||||
if runtime_mapping is None:
|
||||
runtime_mapping = Mapping()
|
||||
tp_size = runtime_mapping.tp_size
|
||||
@ -704,13 +693,14 @@ class LoraManager(object):
|
||||
|
||||
lora_hf_configs = []
|
||||
for model_dir in new_model_dirs:
|
||||
with open(f"{model_dir}/adapter_config.json", 'r') as f:
|
||||
with open(f"{model_dir}/adapter_config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
lora_hf_configs.append(config)
|
||||
|
||||
self.lora_target_modules = model_config.lora_target_modules
|
||||
hf_modules_to_trtllm_modules = invert_module_mapping(
|
||||
model_config.trtllm_modules_to_hf_modules)
|
||||
model_config.trtllm_modules_to_hf_modules
|
||||
)
|
||||
hf_modules = set(hf_modules_to_trtllm_modules.keys())
|
||||
|
||||
def preprocess_lora_weights(lora_model):
|
||||
@ -727,20 +717,15 @@ class LoraManager(object):
|
||||
|
||||
def load_from_model_dir(uid, model_dir, hf_config):
|
||||
if uid not in self._cpp_lora_weights:
|
||||
self._cpp_lora_weights[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
|
||||
if uid not in self._cpp_lora_config:
|
||||
self._cpp_lora_config[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
self._cpp_lora_config[uid] = [] # Will be converted to tensor later
|
||||
|
||||
lora_model = load_state_dict(
|
||||
get_model_path(model_dir, "adapter_model"))
|
||||
lora_model = load_state_dict(get_model_path(model_dir, "adapter_model"))
|
||||
if lora_model is None:
|
||||
raise ValueError(
|
||||
f"Failed to load adapter_model from {model_dir}")
|
||||
raise ValueError(f"Failed to load adapter_model from {model_dir}")
|
||||
lora_model = preprocess_lora_weights(lora_model)
|
||||
all_weights = get_all_hf_lora_weights(lora_model, hf_modules,
|
||||
component)
|
||||
all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component)
|
||||
rank = int(hf_config["r"])
|
||||
rs_lora = bool(hf_config.get("use_rslora", False))
|
||||
|
||||
@ -752,8 +737,7 @@ class LoraManager(object):
|
||||
self._lora_weights_pointers_list[uid][layer_idx] = {}
|
||||
|
||||
for lora_module in self.missing_qkv_modules:
|
||||
hf_module = model_config.trtllm_modules_to_hf_modules[
|
||||
lora_module]
|
||||
hf_module = model_config.trtllm_modules_to_hf_modules[lora_module]
|
||||
if isinstance(hf_module, list):
|
||||
hf_module = hf_module[0]
|
||||
layer_weights[hf_module] = {
|
||||
@ -764,24 +748,26 @@ class LoraManager(object):
|
||||
for hf_module, module_weights in layer_weights.items():
|
||||
lora_module = hf_modules_to_trtllm_modules[hf_module]
|
||||
if lora_module not in self.lora_target_modules:
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][
|
||||
lora_module] = 0
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = 0
|
||||
continue
|
||||
if "in" not in module_weights:
|
||||
is_moe = True
|
||||
t_in = torch.stack([
|
||||
module_weights[expert_idx]["in"]
|
||||
for expert_idx in sorted(module_weights.keys())
|
||||
])
|
||||
t_out = torch.stack([
|
||||
module_weights[expert_idx]["out"]
|
||||
for expert_idx in sorted(module_weights.keys())
|
||||
])
|
||||
t_in = torch.stack(
|
||||
[
|
||||
module_weights[expert_idx]["in"]
|
||||
for expert_idx in sorted(module_weights.keys())
|
||||
]
|
||||
)
|
||||
t_out = torch.stack(
|
||||
[
|
||||
module_weights[expert_idx]["out"]
|
||||
for expert_idx in sorted(module_weights.keys())
|
||||
]
|
||||
)
|
||||
for weights in module_weights.values():
|
||||
if "mag" in weights:
|
||||
# TODO(oargov): this might work, but I had no MoE DoRA models to test
|
||||
raise ValueError(
|
||||
"DoRA with MoE is not supported")
|
||||
raise ValueError("DoRA with MoE is not supported")
|
||||
t_mag = None
|
||||
else:
|
||||
is_moe = False
|
||||
@ -796,28 +782,28 @@ class LoraManager(object):
|
||||
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
|
||||
pass
|
||||
elif lora_module in [
|
||||
"attn_dense",
|
||||
"cross_attn_dense",
|
||||
"mlp_4h_to_h",
|
||||
"moe_4h_to_h",
|
||||
"attn_dense",
|
||||
"cross_attn_dense",
|
||||
"mlp_4h_to_h",
|
||||
"moe_4h_to_h",
|
||||
]:
|
||||
# split by row
|
||||
dim = 2 if is_moe else 1
|
||||
assert t_in.shape[dim] % tp_size == 0
|
||||
t_in = torch.split(t_in,
|
||||
t_in.shape[dim] // tp_size,
|
||||
dim=dim)[tp_rank].contiguous()
|
||||
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
else:
|
||||
# split by column
|
||||
dim = 1 if is_moe else 0
|
||||
assert t_out.shape[dim] % tp_size == 0
|
||||
t_out = torch.split(t_out,
|
||||
t_out.shape[dim] // tp_size,
|
||||
dim=dim)[tp_rank].contiguous()
|
||||
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
if dim == 0 and is_dora and t_mag is not None:
|
||||
t_mag = torch.split(t_mag,
|
||||
t_mag.shape[0] // tp_size,
|
||||
dim=0)[tp_rank].contiguous()
|
||||
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
|
||||
tp_rank
|
||||
].contiguous()
|
||||
|
||||
rank_dim = 1 if is_moe else 0
|
||||
effective_rank = t_in.shape[rank_dim]
|
||||
@ -828,8 +814,7 @@ class LoraManager(object):
|
||||
t_mag = t_mag.cuda().contiguous()
|
||||
|
||||
if rs_lora:
|
||||
scale = float(
|
||||
hf_config["lora_alpha"]) / np.sqrt(effective_rank)
|
||||
scale = float(hf_config["lora_alpha"]) / np.sqrt(effective_rank)
|
||||
else:
|
||||
scale = float(hf_config["lora_alpha"]) / effective_rank
|
||||
t_out = t_out * scale
|
||||
@ -838,15 +823,12 @@ class LoraManager(object):
|
||||
if is_dora and t_mag is not None:
|
||||
t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype))
|
||||
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][
|
||||
lora_module] = effective_rank
|
||||
self._lora_weights_pointers_list[uid][layer_idx][
|
||||
lora_module] = [
|
||||
t_in.data_ptr(),
|
||||
t_out.data_ptr(),
|
||||
t_mag.data_ptr() if
|
||||
(is_dora and t_mag is not None) else 0
|
||||
]
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = effective_rank
|
||||
self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [
|
||||
t_in.data_ptr(),
|
||||
t_out.data_ptr(),
|
||||
t_mag.data_ptr() if (is_dora and t_mag is not None) else 0,
|
||||
]
|
||||
|
||||
# prevent torch free this buffer
|
||||
self._lora_weights.append(t_in)
|
||||
@ -862,26 +844,24 @@ class LoraManager(object):
|
||||
t_mag_cpu = t_mag.flatten().cpu()
|
||||
weights_to_concat.append(t_mag_cpu)
|
||||
|
||||
self._cpp_lora_weights[uid].append(
|
||||
torch.cat(weights_to_concat))
|
||||
self._cpp_lora_weights[uid].append(torch.cat(weights_to_concat))
|
||||
self._cpp_lora_config[uid].append(
|
||||
torch.tensor([
|
||||
self.LORA_MODULE_IDS[lora_module], layer_idx,
|
||||
effective_rank, is_dora
|
||||
],
|
||||
dtype=torch.int32))
|
||||
torch.tensor(
|
||||
[self.LORA_MODULE_IDS[lora_module], layer_idx, effective_rank, is_dora],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
)
|
||||
|
||||
max_weight_size = max(
|
||||
w.size(0) for w in self._cpp_lora_weights[uid])
|
||||
self._cpp_lora_weights[uid] = torch.stack([
|
||||
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
|
||||
for w in self._cpp_lora_weights[uid]
|
||||
])
|
||||
self._cpp_lora_config[uid] = torch.stack(
|
||||
[c for c in self._cpp_lora_config[uid]])
|
||||
max_weight_size = max(w.size(0) for w in self._cpp_lora_weights[uid])
|
||||
self._cpp_lora_weights[uid] = torch.stack(
|
||||
[
|
||||
torch.nn.functional.pad(w, (0, max_weight_size - w.size(0)))
|
||||
for w in self._cpp_lora_weights[uid]
|
||||
]
|
||||
)
|
||||
self._cpp_lora_config[uid] = torch.stack([c for c in self._cpp_lora_config[uid]])
|
||||
|
||||
for uid, model_dir, hf_config in zip(new_uids, new_model_dirs,
|
||||
lora_hf_configs):
|
||||
for uid, model_dir, hf_config in zip(new_uids, new_model_dirs, lora_hf_configs):
|
||||
load_from_model_dir(uid, model_dir, hf_config)
|
||||
release_gc()
|
||||
|
||||
@ -914,10 +894,9 @@ class LoraManager(object):
|
||||
|
||||
@property
|
||||
def num_lora_adapters(self):
|
||||
return len([uid for uid in self._lora_uid_to_low_ranks if uid != '-1'])
|
||||
return len([uid for uid in self._lora_uid_to_low_ranks if uid != "-1"])
|
||||
|
||||
def save_lora_weights_to_bin(self, out_dir):
|
||||
|
||||
def save_val(val, dir, key, tp_num=None, write_npy=False):
|
||||
ext = "npy" if write_npy else "bin"
|
||||
suffix = ext if tp_num is None else f"{tp_num}.{ext}"
|
||||
@ -933,32 +912,21 @@ class LoraManager(object):
|
||||
else:
|
||||
assert False
|
||||
for uid in self.cpp_lora_weights:
|
||||
if uid == '-1':
|
||||
if uid == "-1":
|
||||
continue
|
||||
|
||||
all_weights = np.expand_dims(
|
||||
torch_to_numpy(self.cpp_lora_weights[uid]), 0)
|
||||
all_configs = np.expand_dims(
|
||||
torch_to_numpy(self.cpp_lora_config[uid]), 0)
|
||||
all_weights = np.expand_dims(torch_to_numpy(self.cpp_lora_weights[uid]), 0)
|
||||
all_configs = np.expand_dims(torch_to_numpy(self.cpp_lora_config[uid]), 0)
|
||||
|
||||
uid_path = out_dir_path / f"{uid}"
|
||||
uid_path.mkdir(parents=True, exist_ok=True)
|
||||
save_val(all_weights,
|
||||
uid_path,
|
||||
"lora_weights",
|
||||
tp_num=None,
|
||||
write_npy=True)
|
||||
save_val(all_configs,
|
||||
uid_path,
|
||||
"lora_config",
|
||||
tp_num=None,
|
||||
write_npy=True)
|
||||
save_val(all_weights, uid_path, "lora_weights", tp_num=None, write_npy=True)
|
||||
save_val(all_configs, uid_path, "lora_config", tp_num=None, write_npy=True)
|
||||
|
||||
def input_buffers(self, lora_uids, mapping: Mapping, num_layers: int):
|
||||
inputs = {}
|
||||
for layer_idx in mapping.pp_layers(num_layers):
|
||||
for lora_module in (self.lora_target_modules +
|
||||
self.missing_qkv_modules):
|
||||
for lora_module in self.lora_target_modules + self.missing_qkv_modules:
|
||||
lora_ranks_ = []
|
||||
lora_ptrs_ = []
|
||||
for lora_uid in lora_uids:
|
||||
@ -968,21 +936,21 @@ class LoraManager(object):
|
||||
if lora_uid != "-1":
|
||||
low_ranks = self.uid_to_low_ranks(lora_uid)
|
||||
|
||||
if (layer_idx in low_ranks
|
||||
and lora_module in low_ranks[layer_idx].keys()
|
||||
and low_ranks[layer_idx][lora_module] != 0):
|
||||
|
||||
if (
|
||||
layer_idx in low_ranks
|
||||
and lora_module in low_ranks[layer_idx].keys()
|
||||
and low_ranks[layer_idx][lora_module] != 0
|
||||
):
|
||||
lora_rank = low_ranks[layer_idx][lora_module]
|
||||
lora_ptrs = self.lora_weights_pointers_list[
|
||||
lora_uid][layer_idx][lora_module]
|
||||
lora_ptrs = self.lora_weights_pointers_list[lora_uid][layer_idx][
|
||||
lora_module
|
||||
]
|
||||
|
||||
lora_ranks_.append(lora_rank)
|
||||
lora_ptrs_.append(lora_ptrs)
|
||||
|
||||
inputs[
|
||||
f'{lora_module}_lora_ranks_{layer_idx}'] = torch.IntTensor(
|
||||
lora_ranks_)
|
||||
inputs[
|
||||
f'{lora_module}_lora_weights_pointers_{layer_idx}'] = torch.LongTensor(
|
||||
lora_ptrs_)
|
||||
inputs[f"{lora_module}_lora_ranks_{layer_idx}"] = torch.IntTensor(lora_ranks_)
|
||||
inputs[f"{lora_module}_lora_weights_pointers_{layer_idx}"] = torch.LongTensor(
|
||||
lora_ptrs_
|
||||
)
|
||||
return inputs
|
||||
|
||||
@ -19,19 +19,18 @@ from .logger import logger
|
||||
|
||||
|
||||
def _addindent(s_, numSpaces):
|
||||
s = s_.split('\n')
|
||||
s = s_.split("\n")
|
||||
# don't do anything for single-line stuff
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(numSpaces * ' ') + line for line in s]
|
||||
s = '\n'.join(s)
|
||||
s = first + '\n' + s
|
||||
s = [(numSpaces * " ") + line for line in s]
|
||||
s = "\n".join(s)
|
||||
s = first + "\n" + s
|
||||
return s
|
||||
|
||||
|
||||
class Module(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._modules = {}
|
||||
self._parameters = {}
|
||||
@ -52,20 +51,20 @@ class Module(object):
|
||||
output = self.forward(*args, **kwargs)
|
||||
end_layer_idx = current_net.trt_network.num_layers
|
||||
current_net._module_call_stack.set_layer_range(
|
||||
self, range(start_layer_idx, end_layer_idx))
|
||||
self, range(start_layer_idx, end_layer_idx)
|
||||
)
|
||||
return output
|
||||
|
||||
def __getattr__(self, name):
|
||||
parameters = self.__dict__.get('_parameters')
|
||||
parameters = self.__dict__.get("_parameters")
|
||||
if parameters is not None and name in parameters:
|
||||
return parameters[name]
|
||||
|
||||
modules = self.__dict__.get('_modules')
|
||||
modules = self.__dict__.get("_modules")
|
||||
if modules is not None and name in modules:
|
||||
return modules[name]
|
||||
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, name))
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
|
||||
|
||||
def __setattr__(self, name, value) -> None:
|
||||
from .parameter import Parameter
|
||||
@ -81,9 +80,9 @@ class Module(object):
|
||||
# - keep Parameter and Module attrs in this Module class
|
||||
# - leave all other attrs in base class
|
||||
if isinstance(value, Parameter):
|
||||
self.__dict__.get('_parameters')[name] = value
|
||||
self.__dict__.get("_parameters")[name] = value
|
||||
elif isinstance(value, Module):
|
||||
self.__dict__.get('_modules')[name] = value
|
||||
self.__dict__.get("_modules")[name] = value
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@ -93,14 +92,14 @@ class Module(object):
|
||||
# - other types reset and remain in base class
|
||||
if isinstance(value, Parameter):
|
||||
super().__delattr__(name)
|
||||
self.__dict__.get('_parameters')[name] = value
|
||||
self.__dict__.get("_parameters")[name] = value
|
||||
elif isinstance(value, Module):
|
||||
super().__delattr__(name)
|
||||
self.__dict__.get('_modules')[name] = value
|
||||
self.__dict__.get("_modules")[name] = value
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def named_modules(self, memo=None, prefix='', remove_duplicate=True):
|
||||
def named_modules(self, memo=None, prefix="", remove_duplicate=True):
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if self not in memo:
|
||||
@ -110,16 +109,11 @@ class Module(object):
|
||||
for name, module in self._modules.items():
|
||||
if module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
for m in module.named_modules(memo, submodule_prefix,
|
||||
remove_duplicate):
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
|
||||
yield m
|
||||
|
||||
def named_modules_with_parent(self,
|
||||
memo=None,
|
||||
prefix='',
|
||||
parent=None,
|
||||
remove_duplicate=True):
|
||||
def named_modules_with_parent(self, memo=None, prefix="", parent=None, remove_duplicate=True):
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if self not in memo:
|
||||
@ -130,7 +124,7 @@ class Module(object):
|
||||
if parent:
|
||||
# Use the up-to-date module from the parent, to allow replacing
|
||||
# layers while iterating this generator.
|
||||
module_name = prefix.rsplit('.', 1)[-1]
|
||||
module_name = prefix.rsplit(".", 1)[-1]
|
||||
module = getattr(parent, module_name)
|
||||
if module is None:
|
||||
return
|
||||
@ -140,9 +134,10 @@ class Module(object):
|
||||
for child_name, child_module in module._modules.items():
|
||||
if child_module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + child_name
|
||||
submodule_prefix = prefix + ("." if prefix else "") + child_name
|
||||
for m in child_module.named_modules_with_parent(
|
||||
memo, submodule_prefix, module, remove_duplicate):
|
||||
memo, submodule_prefix, module, remove_duplicate
|
||||
):
|
||||
yield m
|
||||
|
||||
def named_children(self):
|
||||
@ -152,27 +147,26 @@ class Module(object):
|
||||
memo.add(module)
|
||||
yield name, module
|
||||
|
||||
def _named_members(self, get_members_fn, prefix='', recurse=True):
|
||||
def _named_members(self, get_members_fn, prefix="", recurse=True):
|
||||
memo = set()
|
||||
modules = self.named_modules(prefix=prefix) if recurse else [(prefix,
|
||||
self)]
|
||||
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
|
||||
for module_prefix, module in modules:
|
||||
members = get_members_fn(module)
|
||||
for k, v in members:
|
||||
if v is None or v in memo:
|
||||
continue
|
||||
memo.add(v)
|
||||
name = module_prefix + ('.' if module_prefix else '') + k
|
||||
name = module_prefix + ("." if module_prefix else "") + k
|
||||
yield name, v
|
||||
|
||||
def parameters(self, recurse=True):
|
||||
for name, param in self.named_parameters():
|
||||
yield param
|
||||
|
||||
def named_parameters(self, prefix='', recurse=True):
|
||||
gen = self._named_members(lambda module: module._parameters.items(),
|
||||
prefix=prefix,
|
||||
recurse=recurse)
|
||||
def named_parameters(self, prefix="", recurse=True):
|
||||
gen = self._named_members(
|
||||
lambda module: module._parameters.items(), prefix=prefix, recurse=recurse
|
||||
)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
@ -201,15 +195,15 @@ class Module(object):
|
||||
def named_network_outputs(self):
|
||||
for name, module in self.named_modules():
|
||||
for n, output in module._network_outputs.items():
|
||||
yield name + ('.' if name else '') + n, output
|
||||
yield name + ("." if name else "") + n, output
|
||||
|
||||
def update_parameters(self, torch_module):
|
||||
m = {k: v for k, v in self.named_parameters()}
|
||||
tm = {k: v for k, v in torch_module.named_parameters()}
|
||||
|
||||
assert sorted(m.keys()) == sorted(
|
||||
tm.keys()
|
||||
), 'The parameter names of the tensorrt-llm module must be the same with the torch module'
|
||||
assert sorted(m.keys()) == sorted(tm.keys()), (
|
||||
"The parameter names of the tensorrt-llm module must be the same with the torch module"
|
||||
)
|
||||
|
||||
for k, v in self.named_parameters():
|
||||
v.value = tm[k].detach().cpu().numpy()
|
||||
@ -223,17 +217,16 @@ class Module(object):
|
||||
for key, module in self._modules.items():
|
||||
mod_str = repr(module)
|
||||
mod_str = _addindent(mod_str, 2)
|
||||
child_lines.append('(' + key + '): ' + mod_str)
|
||||
main_str = self._get_name() + '('
|
||||
child_lines.append("(" + key + "): " + mod_str)
|
||||
main_str = self._get_name() + "("
|
||||
if child_lines:
|
||||
# simple one-liner info, which most builtin Modules will use
|
||||
main_str += '\n ' + '\n '.join(child_lines) + '\n'
|
||||
main_str += ')'
|
||||
main_str += "\n " + "\n ".join(child_lines) + "\n"
|
||||
main_str += ")"
|
||||
return main_str
|
||||
|
||||
|
||||
class ModuleList(Module):
|
||||
|
||||
def __init__(self, modules) -> None:
|
||||
super(ModuleList, self).__init__()
|
||||
offset = len(self)
|
||||
@ -241,10 +234,10 @@ class ModuleList(Module):
|
||||
self._modules[str(offset + i)] = module
|
||||
|
||||
def _get_abs_string_index(self, idx):
|
||||
"""Get the absolute index for the list of modules"""
|
||||
"""Get the absolute index for the list of modules."""
|
||||
idx = operator.index(idx)
|
||||
if not (-len(self) <= idx < len(self)):
|
||||
raise IndexError('index {} is out of range'.format(idx))
|
||||
raise IndexError("index {} is out of range".format(idx))
|
||||
if idx < 0:
|
||||
idx += len(self)
|
||||
return str(idx)
|
||||
|
||||
@ -36,19 +36,21 @@ from tensorrt_llm.logger import logger
|
||||
from ._common import _is_building
|
||||
|
||||
if psutil is None:
|
||||
logger.warning("A required package 'psutil' is not installed. Will not "
|
||||
"monitor the host memory usages. Please install the package "
|
||||
"first, e.g, 'pip install psutil'.")
|
||||
logger.warning(
|
||||
"A required package 'psutil' is not installed. Will not "
|
||||
"monitor the host memory usages. Please install the package "
|
||||
"first, e.g, 'pip install psutil'."
|
||||
)
|
||||
|
||||
if pynvml is None:
|
||||
logger.warning(
|
||||
"A required package 'pynvml' is not installed. Will not "
|
||||
"monitor the device memory usages. Please install the package "
|
||||
"first, e.g, 'pip install nvidia-ml-py>=12'.")
|
||||
"first, e.g, 'pip install nvidia-ml-py>=12'."
|
||||
)
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self):
|
||||
self._start_times = {}
|
||||
self._total_elapsed_times = {}
|
||||
@ -77,9 +79,9 @@ class Timer:
|
||||
self._total_elapsed_times.pop(tag, None)
|
||||
|
||||
def summary(self):
|
||||
logger.info('Profile Results')
|
||||
logger.info("Profile Results")
|
||||
for tag, elapsed_time in self._total_elapsed_times.items():
|
||||
logger.info(f' - {tag.ljust(30, ".")}: {elapsed_time:.6f} (sec)')
|
||||
logger.info(f" - {tag.ljust(30, '.')}: {elapsed_time:.6f} (sec)")
|
||||
|
||||
|
||||
_default_timer = Timer()
|
||||
@ -105,11 +107,10 @@ def summary():
|
||||
_default_timer.summary()
|
||||
|
||||
|
||||
MemUnitType = Literal['GiB', 'MiB', 'KiB']
|
||||
MemUnitType = Literal["GiB", "MiB", "KiB"]
|
||||
|
||||
|
||||
class PyNVMLContext:
|
||||
|
||||
def __enter__(self):
|
||||
if pynvml is not None:
|
||||
pynvml.nvmlInit()
|
||||
@ -141,9 +142,7 @@ def host_memory_info(pid: Optional[int] = None) -> Tuple[int, int, int]:
|
||||
return 0, 0, 0 # used, free, total
|
||||
|
||||
|
||||
def device_memory_info(
|
||||
device: Optional[Union[torch.device,
|
||||
int]] = None) -> Tuple[int, int, int]:
|
||||
def device_memory_info(device: Optional[Union[torch.device, int]] = None) -> Tuple[int, int, int]:
|
||||
if pynvml is not None:
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
@ -156,8 +155,8 @@ def device_memory_info(
|
||||
|
||||
|
||||
def bytes_to_target_unit(mem_bytes: int, unit: MemUnitType) -> float:
|
||||
units = {'GiB': 1 << 30, 'MiB': 1 << 20, 'KiB': 1 << 10}
|
||||
_rename_map = {'GB': 'GiB', 'MB': 'MiB', 'KB': 'KiB'}
|
||||
units = {"GiB": 1 << 30, "MiB": 1 << 20, "KiB": 1 << 10}
|
||||
_rename_map = {"GB": "GiB", "MB": "MiB", "KB": "KiB"}
|
||||
if unit not in units:
|
||||
unit = _rename_map[unit]
|
||||
return float(mem_bytes) / units[unit]
|
||||
@ -165,51 +164,60 @@ def bytes_to_target_unit(mem_bytes: int, unit: MemUnitType) -> float:
|
||||
|
||||
def _format(mem_bytes: int, unit: MemUnitType) -> str:
|
||||
mem_usage = bytes_to_target_unit(mem_bytes, unit)
|
||||
return f'{mem_usage:.4f} ({unit})'
|
||||
return f"{mem_usage:.4f} ({unit})"
|
||||
|
||||
|
||||
def _print_mem_message(msg: str, tag: Optional[str] = None):
|
||||
if tag:
|
||||
msg = f'{tag} - {msg}'
|
||||
logger.info(f'[MemUsage] {msg}')
|
||||
msg = f"{tag} - {msg}"
|
||||
logger.info(f"[MemUsage] {msg}")
|
||||
|
||||
|
||||
def print_host_memory_usage(tag: Optional[str] = None,
|
||||
unit: MemUnitType = 'GiB'):
|
||||
def print_host_memory_usage(tag: Optional[str] = None, unit: MemUnitType = "GiB"):
|
||||
if psutil is None:
|
||||
return
|
||||
alloc_mem, _, _ = host_memory_info()
|
||||
msg = f'Allocated Host Memory {_format(alloc_mem, unit)}'
|
||||
msg = f"Allocated Host Memory {_format(alloc_mem, unit)}"
|
||||
_print_mem_message(msg, tag)
|
||||
|
||||
|
||||
def print_device_memory_usage(
|
||||
tag: Optional[str] = None,
|
||||
unit: MemUnitType = 'GiB',
|
||||
unit: MemUnitType = "GiB",
|
||||
device: Optional[Union[torch.device, int]] = None,
|
||||
):
|
||||
alloc_mem, _, _ = device_memory_info(device)
|
||||
msg = f'Allocated Device Memory {_format(alloc_mem, unit)}'
|
||||
msg = f"Allocated Device Memory {_format(alloc_mem, unit)}"
|
||||
_print_mem_message(msg, tag)
|
||||
|
||||
|
||||
def print_memory_usage(
|
||||
tag: Optional[str] = None,
|
||||
unit: MemUnitType = 'GiB',
|
||||
unit: MemUnitType = "GiB",
|
||||
device: Optional[Union[torch.device, int]] = None,
|
||||
):
|
||||
alloc_host_mem, _, _ = host_memory_info()
|
||||
alloc_device_mem, _, _ = device_memory_info(device=device)
|
||||
msg = f'Allocated Memory: Host {_format(alloc_host_mem, unit)} '\
|
||||
f'Device {_format(alloc_device_mem, unit)}'
|
||||
msg = (
|
||||
f"Allocated Memory: Host {_format(alloc_host_mem, unit)} "
|
||||
f"Device {_format(alloc_device_mem, unit)}"
|
||||
)
|
||||
_print_mem_message(msg, tag)
|
||||
|
||||
|
||||
@_is_building
|
||||
def check_gpt_mem_usage(engine, kv_dtype, use_gpt_attention_plugin,
|
||||
paged_kv_cache, max_batch_size, max_beam_width,
|
||||
max_seq_len, local_num_kv_heads, head_size,
|
||||
num_layers) -> int:
|
||||
def check_gpt_mem_usage(
|
||||
engine,
|
||||
kv_dtype,
|
||||
use_gpt_attention_plugin,
|
||||
paged_kv_cache,
|
||||
max_batch_size,
|
||||
max_beam_width,
|
||||
max_seq_len,
|
||||
local_num_kv_heads,
|
||||
head_size,
|
||||
num_layers,
|
||||
) -> int:
|
||||
# Get the amount of memory
|
||||
runtime = trt.Runtime(logger.trt_logger)
|
||||
# 1. TensorRT engine activation memory
|
||||
@ -220,40 +228,51 @@ def check_gpt_mem_usage(engine, kv_dtype, use_gpt_attention_plugin,
|
||||
activation_size = cuda_engine.device_memory_size_v2 / 1024 / 1024
|
||||
del cuda_engine
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'Exception when deserializing engine: {traceback.format_exc()}')
|
||||
logger.warning(f'Activation memory size will be regarded as 0.')
|
||||
logger.info(f'Activation memory size: {activation_size:.2f} MiB')
|
||||
logger.warning(f"Exception when deserializing engine: {traceback.format_exc()}")
|
||||
logger.warning("Activation memory size will be regarded as 0.")
|
||||
logger.info(f"Activation memory size: {activation_size:.2f} MiB")
|
||||
|
||||
# 2. Weights
|
||||
weights_size = bytes_to_target_unit(engine.nbytes, 'MiB')
|
||||
logger.info(f'Weights memory size: {weights_size:.2f} MiB')
|
||||
weights_size = bytes_to_target_unit(engine.nbytes, "MiB")
|
||||
logger.info(f"Weights memory size: {weights_size:.2f} MiB")
|
||||
|
||||
# 3. Estimated max KV Cache size
|
||||
kv_cache_size = max_batch_size * max_beam_width * 2 * local_num_kv_heads * max_seq_len * head_size * num_layers * kv_dtype.itemsize
|
||||
kv_cache_size = (
|
||||
max_batch_size
|
||||
* max_beam_width
|
||||
* 2
|
||||
* local_num_kv_heads
|
||||
* max_seq_len
|
||||
* head_size
|
||||
* num_layers
|
||||
* kv_dtype.itemsize
|
||||
)
|
||||
# without plugin, we need two set of kv cache buffers,
|
||||
# one for inputs, and the other for outputs.
|
||||
if not use_gpt_attention_plugin:
|
||||
kv_cache_size *= 2
|
||||
kv_cache_size = bytes_to_target_unit(kv_cache_size, 'MiB')
|
||||
logger.info(f'Max KV Cache memory size: {kv_cache_size:.2f} MiB')
|
||||
kv_cache_size = bytes_to_target_unit(kv_cache_size, "MiB")
|
||||
logger.info(f"Max KV Cache memory size: {kv_cache_size:.2f} MiB")
|
||||
|
||||
# Estimated total amount of memory
|
||||
est_memory_size = activation_size + weights_size + kv_cache_size
|
||||
logger.info(
|
||||
f'Estimated max memory usage on runtime: {est_memory_size:.2f} MiB')
|
||||
logger.info(f"Estimated max memory usage on runtime: {est_memory_size:.2f} MiB")
|
||||
_, _, total_mem = device_memory_info(torch.cuda.current_device())
|
||||
total_mem = bytes_to_target_unit(total_mem, 'MiB')
|
||||
total_mem = bytes_to_target_unit(total_mem, "MiB")
|
||||
if est_memory_size > total_mem:
|
||||
logger.warning(
|
||||
f'Engine is successfully built, but GPU Memory ({total_mem:.2f} MB)'
|
||||
' may not be enough when running inference on max shape.')
|
||||
f"Engine is successfully built, but GPU Memory ({total_mem:.2f} MB)"
|
||||
" may not be enough when running inference on max shape."
|
||||
)
|
||||
if paged_kv_cache:
|
||||
logger.warning(f'Since paged_kv_cache is enabled, the max KV Cache '
|
||||
'memory size is a estimate for very extreme cases, '
|
||||
'it\'s possible that most cases won\'t meet OOM.')
|
||||
logger.warning(
|
||||
"Since paged_kv_cache is enabled, the max KV Cache "
|
||||
"memory size is a estimate for very extreme cases, "
|
||||
"it's possible that most cases won't meet OOM."
|
||||
)
|
||||
else:
|
||||
logger.warning(f'Enabling `--paged_kv_cache` could help reduce the '
|
||||
'GPU memory usage on runtime.')
|
||||
logger.warning(
|
||||
"Enabling `--paged_kv_cache` could help reduce the GPU memory usage on runtime."
|
||||
)
|
||||
|
||||
return est_memory_size
|
||||
|
||||
@ -10,15 +10,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PromptAdapterManager:
|
||||
|
||||
def __init__(self):
|
||||
self._uid_counter = 0
|
||||
self._uid_to_weights: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def load_from_ckpt(self,
|
||||
model_dirs: List[str],
|
||||
model_config: 'ModelConfig',
|
||||
uids: Optional[List[str]] = None):
|
||||
def load_from_ckpt(
|
||||
self, model_dirs: List[str], model_config: "ModelConfig", uids: Optional[List[str]] = None
|
||||
):
|
||||
if uids is None:
|
||||
uids = [self._generate_uid() for _ in range(len(model_dirs))]
|
||||
assert len(uids) == len(model_dirs)
|
||||
@ -34,10 +32,10 @@ class PromptAdapterManager:
|
||||
return
|
||||
|
||||
for uid, model_dir in zip(new_uids, new_model_dirs):
|
||||
state_dict = load_state_dict(
|
||||
get_model_path(model_dir, 'adapter_model'))
|
||||
self._uid_to_weights[uid] = state_dict['prompt_embeddings'].to(
|
||||
str_dtype_to_torch(model_config.dtype))
|
||||
state_dict = load_state_dict(get_model_path(model_dir, "adapter_model"))
|
||||
self._uid_to_weights[uid] = state_dict["prompt_embeddings"].to(
|
||||
str_dtype_to_torch(model_config.dtype)
|
||||
)
|
||||
|
||||
@property
|
||||
def uid_to_weights(self):
|
||||
|
||||
@ -25,8 +25,13 @@ import tensorrt as trt
|
||||
import torch
|
||||
|
||||
from ._common import default_trtnet
|
||||
from ._utils import (TensorWrapper, np_dtype_to_trt, str_dtype_to_trt,
|
||||
torch_dtype_to_trt, trt_dtype_to_torch)
|
||||
from ._utils import (
|
||||
TensorWrapper,
|
||||
np_dtype_to_trt,
|
||||
str_dtype_to_trt,
|
||||
torch_dtype_to_trt,
|
||||
trt_dtype_to_torch,
|
||||
)
|
||||
from .functional import Tensor, _create_tensor
|
||||
from .plugin.plugin import TRT_LLM_PLUGIN_NAMESPACE
|
||||
|
||||
@ -42,31 +47,31 @@ class PluginInfo:
|
||||
plugin_num_outputs: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash(
|
||||
(self.plugin_name, self.plugin_namespace, self.plugin_version))
|
||||
return hash((self.plugin_name, self.plugin_namespace, self.plugin_version))
|
||||
|
||||
def __eq__(self, obj):
|
||||
if not isinstance(obj, PluginInfo):
|
||||
return False
|
||||
return (self.plugin_name == obj.plugin_name
|
||||
and self.plugin_namespace == obj.plugin_namespace
|
||||
and self.plugin_version == obj.plugin_version)
|
||||
return (
|
||||
self.plugin_name == obj.plugin_name
|
||||
and self.plugin_namespace == obj.plugin_namespace
|
||||
and self.plugin_version == obj.plugin_version
|
||||
)
|
||||
|
||||
|
||||
def make_expr(
|
||||
exprBuilder: Union[trt.IExprBuilder, Type[None]],
|
||||
dim: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]],
|
||||
) -> Union[trt.IDimensionExpr, Type[None]]:
|
||||
"""
|
||||
Parameters:
|
||||
exprBuilder: Union[trt.IExprBuilder, Type[None]]
|
||||
The trt.exprBuilder object. Using it to check whether dim has the same exprBuilder or to create trt.IDimensionExpr if necessary.
|
||||
"""Make a dimension expression.
|
||||
|
||||
dim : Union["DimensionExpr", int, Type[None]]
|
||||
The input dim
|
||||
Parameters:
|
||||
exprBuilder: The trt.exprBuilder object. Using it to check whether dim has the same exprBuilder
|
||||
or to create trt.IDimensionExpr if necessary.
|
||||
dim: The input dim.
|
||||
|
||||
Returns:
|
||||
A trt.IDimensionExpr object
|
||||
A trt.IDimensionExpr object.
|
||||
"""
|
||||
if isinstance(dim, DimensionExpr):
|
||||
assert exprBuilder == dim.exprBuilder
|
||||
@ -87,9 +92,7 @@ def expr_operation(
|
||||
operation: trt.DimensionOperation,
|
||||
exprBuilder: trt.IExprBuilder,
|
||||
):
|
||||
"""
|
||||
The function to do expr operation with None support
|
||||
"""
|
||||
"""The function to do expr operation with None support."""
|
||||
if exprBuilder is None or a is None or b is None:
|
||||
expr = None
|
||||
else:
|
||||
@ -98,9 +101,7 @@ def expr_operation(
|
||||
|
||||
|
||||
class DimensionExpr:
|
||||
'''
|
||||
The class to wrap `trt.IDimensionExpr` to support more pythonic methods.
|
||||
'''
|
||||
"""The class to wrap `trt.IDimensionExpr` to support more pythonic methods."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -115,100 +116,70 @@ class DimensionExpr:
|
||||
return self._expr
|
||||
|
||||
@expr.setter
|
||||
def expr(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def expr(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
self._expr = make_expr(self.exprBuilder, expr)
|
||||
|
||||
def __add__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __add__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.SUM,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.SUM, self.exprBuilder)
|
||||
|
||||
def __radd__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __radd__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
return self.__add__(expr)
|
||||
|
||||
def __mul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __mul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.PROD,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.PROD, self.exprBuilder)
|
||||
|
||||
def __rmul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __rmul__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
return self.__mul__(expr)
|
||||
|
||||
def __sub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __sub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.SUB,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.SUB, self.exprBuilder)
|
||||
|
||||
def __rsub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __rsub__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(expr, self.expr, trt.DimensionOperation.SUB,
|
||||
self.exprBuilder)
|
||||
return expr_operation(expr, self.expr, trt.DimensionOperation.SUB, self.exprBuilder)
|
||||
|
||||
def __eq__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __eq__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.EQUAL,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.EQUAL, self.exprBuilder)
|
||||
|
||||
def __lt__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __lt__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.LESS,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.LESS, self.exprBuilder)
|
||||
|
||||
def __floordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __floordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.FLOOR_DIV,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.FLOOR_DIV, self.exprBuilder)
|
||||
|
||||
def __rfloordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr,
|
||||
int, Type[None]]):
|
||||
def __rfloordiv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(expr, self.expr, trt.DimensionOperation.FLOOR_DIV,
|
||||
self.exprBuilder)
|
||||
return expr_operation(expr, self.expr, trt.DimensionOperation.FLOOR_DIV, self.exprBuilder)
|
||||
|
||||
def __truediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __truediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.CEIL_DIV,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.CEIL_DIV, self.exprBuilder)
|
||||
|
||||
def __rtruediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def __rtruediv__(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(expr, self.expr, trt.DimensionOperation.CEIL_DIV,
|
||||
self.exprBuilder)
|
||||
return expr_operation(expr, self.expr, trt.DimensionOperation.CEIL_DIV, self.exprBuilder)
|
||||
|
||||
def max(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def max(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.MAX,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.MAX, self.exprBuilder)
|
||||
|
||||
def min(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]):
|
||||
def min(self, expr: Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]):
|
||||
expr = make_expr(self.exprBuilder, expr)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.MIN,
|
||||
self.exprBuilder)
|
||||
return expr_operation(self.expr, expr, trt.DimensionOperation.MIN, self.exprBuilder)
|
||||
|
||||
|
||||
class ShapeExpr:
|
||||
'''
|
||||
The class to Wrap `trt.DimsExprs` to support more pythonic methods.
|
||||
'''
|
||||
"""The class to Wrap `trt.DimsExprs` to support more pythonic methods."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: Union[Sequence[trt.IDimensionExpr], Sequence[int],
|
||||
Sequence[type[None]]],
|
||||
dims: Union[Sequence[trt.IDimensionExpr], Sequence[int], Sequence[type[None]]],
|
||||
exprBuilder: Union[trt.IExprBuilder, type[None]],
|
||||
):
|
||||
self.exprBuilder = exprBuilder
|
||||
@ -221,13 +192,11 @@ class ShapeExpr:
|
||||
@dims.setter
|
||||
def dims(
|
||||
self,
|
||||
dims: Sequence[Union["DimensionExpr", trt.IDimensionExpr, int,
|
||||
Type[None]]],
|
||||
dims: Sequence[Union["DimensionExpr", trt.IDimensionExpr, int, Type[None]]],
|
||||
):
|
||||
if dims is not None:
|
||||
self._dims = [
|
||||
DimensionExpr(make_expr(self.exprBuilder, i), self.exprBuilder)
|
||||
for i in dims
|
||||
DimensionExpr(make_expr(self.exprBuilder, i), self.exprBuilder) for i in dims
|
||||
]
|
||||
else:
|
||||
self._dims = None
|
||||
@ -246,8 +215,7 @@ class ShapeExpr:
|
||||
if self._dims is None:
|
||||
return
|
||||
assert index < len(self._dims)
|
||||
value = DimensionExpr(make_expr(self.exprBuilder, value),
|
||||
self.exprBuilder)
|
||||
value = DimensionExpr(make_expr(self.exprBuilder, value), self.exprBuilder)
|
||||
self._dims[index] = value
|
||||
|
||||
def __len__(self):
|
||||
@ -261,9 +229,10 @@ class ShapeExpr:
|
||||
|
||||
|
||||
class SymTensor:
|
||||
'''
|
||||
The class to represent symbolic tensors. Only contains dtype and shape information for users to write their own shape/dtype inference function.
|
||||
'''
|
||||
"""The class to represent symbolic tensors.
|
||||
|
||||
Only contains dtype and shape information for users to write their own shape/dtype inference function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -290,8 +259,7 @@ class SymTensor:
|
||||
return self._dtype
|
||||
|
||||
@dtype.setter
|
||||
def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType,
|
||||
Type[None]]):
|
||||
def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType, Type[None]]):
|
||||
if isinstance(dtype, torch.dtype):
|
||||
self._dtype = torch_dtype_to_trt(dtype)
|
||||
elif isinstance(dtype, str):
|
||||
@ -313,13 +281,15 @@ def _convert_return_value_to_list(ret):
|
||||
return ret
|
||||
|
||||
|
||||
class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
trt.IPluginV3OneRuntime):
|
||||
'''
|
||||
The base class of TRT-LLM plugin.
|
||||
class PluginBase(
|
||||
trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime
|
||||
):
|
||||
"""The base class of TRT-LLM plugin.
|
||||
|
||||
All TRT-LLM plugin should inherit this class and at least rewrite `forward` and `shape_dtype_inference` function. `forward` defines the plugin's compute flow while `shape_dtype_inference` defines how would the output tensor's shape and dtype be inferenced from the input tensor.
|
||||
'''
|
||||
All TRT-LLM plugin should inherit this class and at least rewrite `forward` and `shape_dtype_inference`
|
||||
function. `forward` defines the plugin's compute flow while `shape_dtype_inference` defines how would
|
||||
the output tensor's shape and dtype be inferenced from the input tensor.
|
||||
"""
|
||||
|
||||
_plugin_creator = None
|
||||
_no_serialize_attr = {"_current_stream", "_workspace"}
|
||||
@ -327,9 +297,9 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
def __init__(self):
|
||||
cls = type(self)
|
||||
# Runtime check for plugin decorator
|
||||
assert (
|
||||
cls._plugin_creator is not None
|
||||
), "Please make sure the plugin is registered through `@trtllm_plugin`"
|
||||
assert cls._plugin_creator is not None, (
|
||||
"Please make sure the plugin is registered through `@trtllm_plugin`"
|
||||
)
|
||||
assert cls != PluginBase
|
||||
|
||||
trt.IPluginV3.__init__(self)
|
||||
@ -381,8 +351,7 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
pass
|
||||
|
||||
def get_output_data_types(self, input_types):
|
||||
ret = self.shape_dtype_inference(
|
||||
[SymTensor(i, ShapeExpr(None, None)) for i in input_types])
|
||||
ret = self.shape_dtype_inference([SymTensor(i, ShapeExpr(None, None)) for i in input_types])
|
||||
|
||||
ret = _convert_return_value_to_list(ret)
|
||||
assert len(ret) == self.num_outputs
|
||||
@ -392,11 +361,11 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
return [i.dtype for i in ret]
|
||||
|
||||
def get_output_shapes(self, inputs, shape_inputs, exprBuilder):
|
||||
assert len(
|
||||
shape_inputs) == 0, "Currently we do not support shape inputs"
|
||||
assert len(shape_inputs) == 0, "Currently we do not support shape inputs"
|
||||
|
||||
ret = self.shape_dtype_inference(
|
||||
[SymTensor(None, ShapeExpr(i, exprBuilder)) for i in inputs])
|
||||
[SymTensor(None, ShapeExpr(i, exprBuilder)) for i in inputs]
|
||||
)
|
||||
|
||||
ret = _convert_return_value_to_list(ret)
|
||||
assert len(ret) == self.num_outputs
|
||||
@ -406,8 +375,9 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
return [i.shape.to_trt() for i in ret]
|
||||
|
||||
def supports_format_combination(self, pos, in_out, num_inputs):
|
||||
"""
|
||||
By default, TRT-LLM plugin supports all dtype and linear format. It is the users responsibility to check the dtype the plugin supported in `forward` function.
|
||||
"""By default, TRT-LLM plugin supports all dtype and linear format.
|
||||
|
||||
It is the users responsibility to check the dtype the plugin supported in `forward` function.
|
||||
"""
|
||||
assert pos < len(in_out)
|
||||
|
||||
@ -422,13 +392,11 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
|
||||
def get_fields_to_serialize(self):
|
||||
buffer = pickle.dumps(self._get_dict_to_serialize())
|
||||
return trt.PluginFieldCollection([
|
||||
trt.PluginField("__plugin_pickle_obj__", buffer,
|
||||
trt.PluginFieldType.UNKNOWN)
|
||||
])
|
||||
return trt.PluginFieldCollection(
|
||||
[trt.PluginField("__plugin_pickle_obj__", buffer, trt.PluginFieldType.UNKNOWN)]
|
||||
)
|
||||
|
||||
def enqueue(self, input_desc, output_desc, inputs, outputs, workspace,
|
||||
stream):
|
||||
def enqueue(self, input_desc, output_desc, inputs, outputs, workspace, stream):
|
||||
torch_stream = torch.cuda.ExternalStream(stream_ptr=stream)
|
||||
self.workspace = workspace
|
||||
self.current_stream = stream
|
||||
@ -437,31 +405,32 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
self.forward(
|
||||
tuple(
|
||||
TensorWrapper.from_trt_desc(input_desc[i], inputs[i])
|
||||
for i in range(len(input_desc))),
|
||||
for i in range(len(input_desc))
|
||||
),
|
||||
tuple(
|
||||
TensorWrapper.from_trt_desc(output_desc[i], outputs[i])
|
||||
for i in range(len(output_desc))),
|
||||
for i in range(len(output_desc))
|
||||
),
|
||||
)
|
||||
|
||||
self.current_stream = -1
|
||||
|
||||
def __call__(self, *args: Union[Sequence[TensorWrapper],
|
||||
Sequence[torch.Tensor]]):
|
||||
def __call__(self, *args: Union[Sequence[TensorWrapper], Sequence[torch.Tensor]]):
|
||||
is_trtllm = True
|
||||
for i in args:
|
||||
is_trtllm &= isinstance(i, Tensor)
|
||||
|
||||
if not is_trtllm:
|
||||
for i in args:
|
||||
assert isinstance(
|
||||
i, torch.Tensor
|
||||
), "Plugin inputs must be `tensorrt_llm.Tensor`s or `torch.Tensor`s"
|
||||
assert isinstance(i, torch.Tensor), (
|
||||
"Plugin inputs must be `tensorrt_llm.Tensor`s or `torch.Tensor`s"
|
||||
)
|
||||
sym_tensors = self.shape_dtype_inference(
|
||||
[SymTensor(i.dtype, [j for j in i.shape]) for i in args])
|
||||
[SymTensor(i.dtype, [j for j in i.shape]) for i in args]
|
||||
)
|
||||
sym_tensors = _convert_return_value_to_list(sym_tensors)
|
||||
ret = [
|
||||
torch.empty(sym_tensor.shape,
|
||||
dtype=trt_dtype_to_torch(sym_tensor.dtype))
|
||||
torch.empty(sym_tensor.shape, dtype=trt_dtype_to_torch(sym_tensor.dtype))
|
||||
for sym_tensor in sym_tensors
|
||||
]
|
||||
self.current_stream = torch.cuda.current_stream().cuda_stream
|
||||
@ -492,20 +461,20 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
"By default TRT should not set tactics since PluginBase do not provide custom tactic."
|
||||
)
|
||||
|
||||
def forward(self, inputs: Sequence[TensorWrapper],
|
||||
outputs: Sequence[TensorWrapper]):
|
||||
'''
|
||||
Expect users to rewrite this function to define the compute flow. There are a few special attributes for users to get access to some resources.
|
||||
def forward(self, inputs: Sequence[TensorWrapper], outputs: Sequence[TensorWrapper]):
|
||||
"""Expect users to rewrite this function to define the compute flow.
|
||||
|
||||
There are a few special attributes for users to get access to some resources.
|
||||
|
||||
`self.workspace`: The workspace address of TRT managed workspace.
|
||||
`self.current_stream`: The CUDA stream this plugin is expected to execute on. By default `PluginBase` set the torch.cuda.current_stream() to this stream. This attribute is for the toolkit that doesn't work with torch's stream.
|
||||
'''
|
||||
`self.current_stream`: The CUDA stream this plugin is expected to execute on. By default
|
||||
`PluginBase` set the torch.cuda.current_stream() to this stream. This attribute is for the
|
||||
toolkit that doesn't work with torch's stream.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def shape_dtype_inference(self, inputs: Sequence[SymTensor]):
|
||||
'''
|
||||
Expect users to rewrite this function to define the shape dtype inference for output tensors.
|
||||
'''
|
||||
"""Expect users to rewrite this function to define the shape dtype inference for output tensors."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dict_to_serialize(self):
|
||||
@ -517,7 +486,6 @@ class PluginBase(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild,
|
||||
|
||||
|
||||
class PluginCreatorBase(trt.IPluginCreatorV3One):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -535,50 +503,53 @@ class PluginCreatorBase(trt.IPluginCreatorV3One):
|
||||
|
||||
|
||||
def trtllm_plugin(
|
||||
plugin_name: str,
|
||||
*,
|
||||
plugin_version: str = "1",
|
||||
plugin_namespace: str = TRT_LLM_PLUGIN_NAMESPACE,
|
||||
plugin_num_outputs: Union[int, Type[None]] = None,
|
||||
deepcopy_clone: bool = True,
|
||||
no_serialize_attr: Sequence[str] = set(),
|
||||
plugin_name: str,
|
||||
*,
|
||||
plugin_version: str = "1",
|
||||
plugin_namespace: str = TRT_LLM_PLUGIN_NAMESPACE,
|
||||
plugin_num_outputs: Union[int, Type[None]] = None,
|
||||
deepcopy_clone: bool = True,
|
||||
no_serialize_attr: Sequence[str] = set(),
|
||||
):
|
||||
|
||||
def plugin_registration(plugin_cls):
|
||||
assert issubclass(plugin_cls, PluginBase)
|
||||
assert hasattr(
|
||||
plugin_cls,
|
||||
"__dict__"), "Plugin wrapper uses `__dict__` to track plugin states"
|
||||
assert hasattr(plugin_cls, "__dict__"), (
|
||||
"Plugin wrapper uses `__dict__` to track plugin states"
|
||||
)
|
||||
nonlocal plugin_num_outputs
|
||||
|
||||
annotation = inspect.signature(
|
||||
plugin_cls.shape_dtype_inference).return_annotation
|
||||
annotation = inspect.signature(plugin_cls.shape_dtype_inference).return_annotation
|
||||
origin_annotation = typing.get_origin(annotation)
|
||||
if origin_annotation is tuple or annotation is SymTensor:
|
||||
if origin_annotation is tuple:
|
||||
element_types = typing.get_args(annotation)
|
||||
for ty in element_types:
|
||||
assert (
|
||||
ty == SymTensor
|
||||
), f"Plugin {plugin_name}'s `shape_dtype_inference` return annotation must be SymTensor or a tuple of SymTensor"
|
||||
assert ty == SymTensor, (
|
||||
f"Plugin {plugin_name}'s `shape_dtype_inference` return annotation must be SymTensor "
|
||||
"or a tuple of SymTensor"
|
||||
)
|
||||
infered_num_outputs = len(element_types)
|
||||
else:
|
||||
infered_num_outputs = 1
|
||||
if plugin_num_outputs is not None:
|
||||
assert (
|
||||
plugin_num_outputs == infered_num_outputs
|
||||
), f"Plugin {plugin_name}'s `_num_outputs` and return annotation mismatch, {plugin_cls._num_outputs} != {infered_num_outputs}"
|
||||
assert plugin_num_outputs == infered_num_outputs, (
|
||||
f"Plugin {plugin_name}'s `_num_outputs` and return annotation mismatch, "
|
||||
f"{plugin_cls._num_outputs} != {infered_num_outputs}"
|
||||
)
|
||||
plugin_num_outputs = infered_num_outputs
|
||||
else:
|
||||
assert (
|
||||
plugin_num_outputs is not None
|
||||
), f"Must specify `num_outputs` or valid `shape_dtype_inference` return annotation for {plugin_name}. The valid types are SymTensor or a tuple of SymTensor, got {annotation}."
|
||||
assert plugin_num_outputs is not None, (
|
||||
"Must specify `num_outputs` or valid `shape_dtype_inference` return annotation for "
|
||||
f"{plugin_name}. The valid types are SymTensor or a tuple of SymTensor, got {annotation}."
|
||||
)
|
||||
|
||||
plugin_info = PluginInfo(3, plugin_namespace, plugin_name,
|
||||
plugin_version, plugin_num_outputs)
|
||||
assert (
|
||||
plugin_info not in _plugin_registered
|
||||
), f"Redefine plugin with info: {plugin_info} which is previously defined as {_plugin_registered[plugin_info]}"
|
||||
plugin_info = PluginInfo(
|
||||
3, plugin_namespace, plugin_name, plugin_version, plugin_num_outputs
|
||||
)
|
||||
assert plugin_info not in _plugin_registered, (
|
||||
f"Redefine plugin with info: {plugin_info} which is previously defined as "
|
||||
f"{_plugin_registered[plugin_info]}"
|
||||
)
|
||||
|
||||
_plugin_registered[plugin_info] = plugin_info
|
||||
plugin_cls._plugin_name = plugin_name
|
||||
@ -598,8 +569,7 @@ def trtllm_plugin(
|
||||
plugin_creator.plugin_cls = plugin_cls
|
||||
|
||||
plugin_cls._plugin_creator = plugin_creator
|
||||
ret = plugin_registry.register_creator(plugin_creator,
|
||||
plugin_cls._plugin_namespace)
|
||||
ret = plugin_registry.register_creator(plugin_creator, plugin_cls._plugin_namespace)
|
||||
|
||||
assert ret, f"Plugin: {plugin_cls} register failed, please check the error log."
|
||||
|
||||
|
||||
@ -12,8 +12,7 @@ from tensorrt_llm.bindings import executor as tllme
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class GuidedDecodingParams:
|
||||
"""
|
||||
Guided decoding parameters for text generation. Only one of the fields could be effective.
|
||||
"""Guided decoding parameters for text generation. Only one of the fields could be effective.
|
||||
|
||||
Args:
|
||||
json (str, pydantic.main.BaseModel, dict, optional): The generated text is amenable to json format with additional user-specified restrictions, namely schema. Defaults to None.
|
||||
@ -21,7 +20,8 @@ class GuidedDecodingParams:
|
||||
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
|
||||
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
|
||||
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
json: Optional[Union[str, BaseModel, dict]] = None
|
||||
regex: Optional[str] = None
|
||||
grammar: Optional[str] = None
|
||||
@ -30,12 +30,10 @@ class GuidedDecodingParams:
|
||||
|
||||
def _validate(self):
|
||||
num_guides = 0
|
||||
for field in fields(self):
|
||||
num_guides += bool(getattr(self, field.name))
|
||||
for _field in fields(self):
|
||||
num_guides += bool(getattr(self, _field.name))
|
||||
if num_guides > 1:
|
||||
raise ValueError(
|
||||
f"Only one guide can be used for a request, but got {num_guides}."
|
||||
)
|
||||
raise ValueError(f"Only one guide can be used for a request, but got {num_guides}.")
|
||||
|
||||
|
||||
class LogprobParams(NamedTuple):
|
||||
@ -57,16 +55,23 @@ class LogitsProcessor(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, req_id: int, logits: torch.Tensor,
|
||||
token_ids: List[List[int]], stream_ptr: Optional[int],
|
||||
client_id: Optional[int]) -> None:
|
||||
def __call__(
|
||||
self,
|
||||
req_id: int,
|
||||
logits: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
stream_ptr: Optional[int],
|
||||
client_id: Optional[int],
|
||||
) -> None:
|
||||
"""Logits processing callback. The callback is expected to inplace modify the logits.
|
||||
|
||||
Args:
|
||||
req_id (int): Request id.
|
||||
logits (torch.Tensor): Logits tensor to be modified.
|
||||
token_ids (List[List[int]]): Token ids produced by the request so far. The shape is beam_width * sequence_length.
|
||||
stream_ptr (int, optional): The operation stream used by the logits tensor. Not required for PyTorch backend.
|
||||
token_ids (List[List[int]]): Token ids produced by the request so far.
|
||||
The shape is beam_width * sequence_length.
|
||||
stream_ptr (int, optional): The operation stream used by the logits tensor.
|
||||
Not required for PyTorch backend.
|
||||
client_id (int, optional): An optional client id.
|
||||
"""
|
||||
pass # noqa
|
||||
@ -82,15 +87,21 @@ class BatchedLogitsProcessor(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, req_ids: List[int], logits: List[torch.Tensor],
|
||||
token_ids: List[List[List[int]]], stream_ptr: int,
|
||||
client_ids: List[Optional[int]]) -> None:
|
||||
def __call__(
|
||||
self,
|
||||
req_ids: List[int],
|
||||
logits: List[torch.Tensor],
|
||||
token_ids: List[List[List[int]]],
|
||||
stream_ptr: int,
|
||||
client_ids: List[Optional[int]],
|
||||
) -> None:
|
||||
"""Batched logits processing callback. The callback is expected to inplace modify the logits.
|
||||
|
||||
Args:
|
||||
req_ids (List[int]): A batch of request ids.
|
||||
logits (List[torch.Tensor]): A batch of the logits tensors.
|
||||
token_ids (List[List[List[int]]]): A batch of the token ids produced by the requests so far. The shape is batch * beam_width * sequence_length.
|
||||
token_ids (List[List[List[int]]]): A batch of the token ids produced by the requests so far.
|
||||
The shape is batch * beam_width * sequence_length.
|
||||
stream_ptr (int): The operation stream used by the logits tensors.
|
||||
client_ids (List[Optional[int]]): A batch of optional client ids.
|
||||
"""
|
||||
@ -99,21 +110,20 @@ class BatchedLogitsProcessor(ABC):
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class AdditionalModelOutput:
|
||||
"""
|
||||
An additional output to gather from the model.
|
||||
"""An additional output to gather from the model.
|
||||
|
||||
Args:
|
||||
name (str): The name of the additional output to gather from the model.
|
||||
gather_context (bool): A value indicating whether or not to gather the additional output from the context too. Defaults to False.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
name: str
|
||||
gather_context: bool
|
||||
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class SamplingParams:
|
||||
"""
|
||||
Sampling parameters for text generation.
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
@ -179,7 +189,8 @@ class SamplingParams:
|
||||
truncate_prompt_tokens (int, optional): If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None.
|
||||
skip_special_tokens (bool): Whether to skip special tokens in the output. Defaults to True.
|
||||
spaces_between_special_tokens (bool): Whether to add spaces between special tokens in the output. Defaults to True.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
# [TO DEVELOPER] This class provides an interface to LLMAPI users.
|
||||
# Internally, it manages and dispatches fields to Python bindings of C++ objects, currently including:
|
||||
# (1) all fields of tllme.SamplingConfig;
|
||||
@ -194,19 +205,14 @@ class SamplingParams:
|
||||
max_tokens: int = 32
|
||||
bad: Optional[Union[str, List[str]]] = None
|
||||
bad_token_ids: Optional[List[int]] = None
|
||||
_bad_word_ids: Optional[List[List[int]]] = field(default=None,
|
||||
init=False,
|
||||
repr=False)
|
||||
_bad_word_ids: Optional[List[List[int]]] = field(default=None, init=False, repr=False)
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
include_stop_str_in_output: bool = False
|
||||
_stop_word_ids: Optional[List[List[int]]] = field(default=None,
|
||||
init=False,
|
||||
repr=False)
|
||||
_stop_word_ids: Optional[List[List[int]]] = field(default=None, init=False, repr=False)
|
||||
|
||||
embedding_bias: Optional[torch.Tensor] = None
|
||||
logits_processor: Optional[Union[LogitsProcessor,
|
||||
List[LogitsProcessor]]] = None
|
||||
logits_processor: Optional[Union[LogitsProcessor, List[LogitsProcessor]]] = None
|
||||
apply_batched_logits_processor: bool = False
|
||||
|
||||
n: int = 1
|
||||
@ -273,25 +279,28 @@ class SamplingParams:
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
''' Verify the sampling parameters.
|
||||
"""Verify the sampling parameters.
|
||||
|
||||
This function verifies the sampling parameters in the LLM API, which
|
||||
may have stricter requirements than the Executor class of C++ runtime.
|
||||
For instance, while the greedy decoding with n > 1 is capable in the
|
||||
Executor class of C++ runtime, the LLM API disallows such combination.
|
||||
'''
|
||||
"""
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(
|
||||
f"best_of ({self.best_of}) cannot be less than n ({self.n})")
|
||||
raise ValueError(f"best_of ({self.best_of}) cannot be less than n ({self.n})")
|
||||
|
||||
if (self.best_of > 1 and self._greedy_decoding
|
||||
and not os.environ.get('TLLM_ALLOW_N_GREEDY_DECODING', None)):
|
||||
if (
|
||||
self.best_of > 1
|
||||
and self._greedy_decoding
|
||||
and not os.environ.get("TLLM_ALLOW_N_GREEDY_DECODING", None)
|
||||
):
|
||||
raise ValueError(
|
||||
f'Greedy decoding in the LLM API does not allow multiple '
|
||||
f'returns. Please set to best_of=1, got best_of={self.best_of}. '
|
||||
f'Please set to best_of=1 or set an environment variable '
|
||||
f'TLLM_ALLOW_N_GREEDY_DECODING=1 to allow best_of > 1 '
|
||||
f'under the greedy decoding.')
|
||||
f"Greedy decoding in the LLM API does not allow multiple "
|
||||
f"returns. Please set to best_of=1, got best_of={self.best_of}. "
|
||||
f"Please set to best_of=1 or set an environment variable "
|
||||
f"TLLM_ALLOW_N_GREEDY_DECODING=1 to allow best_of > 1 "
|
||||
f"under the greedy decoding."
|
||||
)
|
||||
|
||||
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
|
||||
raise ValueError(
|
||||
@ -303,14 +312,15 @@ class SamplingParams:
|
||||
|
||||
# correct types as users might pass in logprob=True for Top-1 logprobs
|
||||
self.logprobs = self.logprobs and int(self.logprobs)
|
||||
self.prompt_logprobs = self.prompt_logprobs and int(
|
||||
self.prompt_logprobs)
|
||||
self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs)
|
||||
|
||||
@property
|
||||
def _greedy_decoding(self) -> bool:
|
||||
return (not self.use_beam_search
|
||||
and (self.top_k is None or self.top_k == 1)
|
||||
and (self.top_p is None or self.top_p == 0.0))
|
||||
return (
|
||||
not self.use_beam_search
|
||||
and (self.top_k is None or self.top_k == 1)
|
||||
and (self.top_p is None or self.top_p == 0.0)
|
||||
)
|
||||
|
||||
@property
|
||||
def _need_return_context_logits(self) -> bool:
|
||||
@ -320,9 +330,7 @@ class SamplingParams:
|
||||
def _need_return_generation_logits(self) -> bool:
|
||||
return self.return_generation_logits and not self._generation_logits_auto_enabled
|
||||
|
||||
def _setup(self,
|
||||
tokenizer,
|
||||
add_special_tokens: bool = False) -> 'SamplingParams':
|
||||
def _setup(self, tokenizer, add_special_tokens: bool = False) -> "SamplingParams":
|
||||
if self.end_id is None:
|
||||
self.end_id = tokenizer.eos_token_id
|
||||
self.pad_id = tokenizer.pad_token_id
|
||||
@ -332,15 +340,13 @@ class SamplingParams:
|
||||
if self.bad is not None:
|
||||
strs = [self.bad] if isinstance(self.bad, str) else self.bad
|
||||
self._bad_word_ids = [
|
||||
tokenizer.encode(s, add_special_tokens=add_special_tokens)
|
||||
for s in strs
|
||||
tokenizer.encode(s, add_special_tokens=add_special_tokens) for s in strs
|
||||
]
|
||||
|
||||
if self.stop is not None:
|
||||
strs = [self.stop] if isinstance(self.stop, str) else self.stop
|
||||
self._stop_word_ids = [
|
||||
tokenizer.encode(s, add_special_tokens=add_special_tokens)
|
||||
for s in strs
|
||||
tokenizer.encode(s, add_special_tokens=add_special_tokens) for s in strs
|
||||
]
|
||||
|
||||
return self
|
||||
@ -356,7 +362,8 @@ class SamplingParams:
|
||||
if self._bad_word_ids is None:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.bad ({self.bad}) is not processed by tokenizer, "
|
||||
"please call the setup method.")
|
||||
"please call the setup method."
|
||||
)
|
||||
return words + self._bad_word_ids
|
||||
|
||||
def _get_stop_words(self) -> List[List[int]]:
|
||||
@ -370,11 +377,11 @@ class SamplingParams:
|
||||
if self._stop_word_ids is None:
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__}.stop ({self.stop}) is not processed by tokenizer, "
|
||||
"please call the setup method.")
|
||||
"please call the setup method."
|
||||
)
|
||||
return words + self._stop_word_ids
|
||||
|
||||
def _get_stop_reasons_and_words(
|
||||
self) -> List[Tuple[Union[str, int], List[List[int]]]]:
|
||||
def _get_stop_reasons_and_words(self) -> List[Tuple[Union[str, int], List[List[int]]]]:
|
||||
stop_reasons = []
|
||||
if self.stop_token_ids is not None:
|
||||
stop_reasons.extend(self.stop_token_ids)
|
||||
@ -400,37 +407,30 @@ class SamplingParams:
|
||||
# | Sampling | use_beam_search | beam_width == 1 |
|
||||
# | Sampling | n | num_return_sequences |
|
||||
# | Sampling | best_of | no corresponding param |
|
||||
fields = {
|
||||
f
|
||||
for f in dir(tllme.SamplingConfig) if not f.startswith('__')
|
||||
}
|
||||
fields = {f for f in dir(tllme.SamplingConfig) if not f.startswith("__")}
|
||||
unmatched_params = [
|
||||
'num_return_sequences',
|
||||
'beam_width',
|
||||
'n',
|
||||
'best_of',
|
||||
'use_beam_search',
|
||||
"num_return_sequences",
|
||||
"beam_width",
|
||||
"n",
|
||||
"best_of",
|
||||
"use_beam_search",
|
||||
]
|
||||
llmapi_to_rt_param_map = {
|
||||
f: getattr(self, f)
|
||||
for f in fields if f not in unmatched_params
|
||||
}
|
||||
llmapi_to_rt_param_map = {f: getattr(self, f) for f in fields if f not in unmatched_params}
|
||||
if self.use_beam_search:
|
||||
llmapi_to_rt_param_map['num_return_sequences'] = self.n
|
||||
llmapi_to_rt_param_map['beam_width'] = self.best_of
|
||||
llmapi_to_rt_param_map["num_return_sequences"] = self.n
|
||||
llmapi_to_rt_param_map["beam_width"] = self.best_of
|
||||
else:
|
||||
llmapi_to_rt_param_map['num_return_sequences'] = self.best_of
|
||||
llmapi_to_rt_param_map['beam_width'] = 1
|
||||
llmapi_to_rt_param_map["num_return_sequences"] = self.best_of
|
||||
llmapi_to_rt_param_map["beam_width"] = 1
|
||||
|
||||
return tllme.SamplingConfig(**llmapi_to_rt_param_map)
|
||||
|
||||
def _get_output_config(self,
|
||||
is_pytorch_backend: bool = False
|
||||
) -> tllme.OutputConfig:
|
||||
def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputConfig:
|
||||
sampling_param_fields = set(dir(SamplingParams))
|
||||
fields = [
|
||||
f for f in dir(tllme.OutputConfig)
|
||||
if not f.startswith('__') and f in sampling_param_fields
|
||||
f
|
||||
for f in dir(tllme.OutputConfig)
|
||||
if not f.startswith("__") and f in sampling_param_fields
|
||||
]
|
||||
|
||||
config_kwargs = {f: getattr(self, f) for f in fields}
|
||||
@ -447,8 +447,7 @@ class SamplingParams:
|
||||
return None
|
||||
|
||||
if self.guided_decoding.json_object:
|
||||
return tllme.GuidedDecodingParams(
|
||||
tllme.GuidedDecodingParams.GuideType.JSON)
|
||||
return tllme.GuidedDecodingParams(tllme.GuidedDecodingParams.GuideType.JSON)
|
||||
elif self.guided_decoding.json is not None:
|
||||
json_schema = self.guided_decoding.json
|
||||
if isinstance(json_schema, BaseModel):
|
||||
@ -456,18 +455,20 @@ class SamplingParams:
|
||||
if isinstance(json_schema, dict):
|
||||
json_schema = json.dumps(json_schema)
|
||||
return tllme.GuidedDecodingParams(
|
||||
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema)
|
||||
tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema
|
||||
)
|
||||
elif self.guided_decoding.regex is not None:
|
||||
return tllme.GuidedDecodingParams(
|
||||
tllme.GuidedDecodingParams.GuideType.REGEX,
|
||||
self.guided_decoding.regex)
|
||||
tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex
|
||||
)
|
||||
elif self.guided_decoding.grammar is not None:
|
||||
return tllme.GuidedDecodingParams(
|
||||
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
|
||||
self.guided_decoding.grammar)
|
||||
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar
|
||||
)
|
||||
elif self.guided_decoding.structural_tag is not None:
|
||||
return tllme.GuidedDecodingParams(
|
||||
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
|
||||
self.guided_decoding.structural_tag)
|
||||
self.guided_decoding.structural_tag,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -21,46 +21,52 @@ from .plugin.plugin import PluginConfig
|
||||
|
||||
|
||||
class TopModelMixin:
|
||||
'''
|
||||
The Module class are reused between building blocks (like Attention, MLP) and the top level model (like LLaMAForCausalLM)
|
||||
While there are some functions, like the loading hf/ft weights, or build/load trt engines are only supported by the top level model, not the building blocks.
|
||||
So top level model class like: LLaMAForCausalLM shall inherit this class.
|
||||
'''
|
||||
"""Top model mixin.
|
||||
|
||||
The Module classes are reused between building blocks (like Attention, MLP) and the top level models
|
||||
(like LLaMAForCausalLM).
|
||||
While there are some functions, like the loading hf/ft weights, or build/load trt engines, that are
|
||||
only supported by the top level model, not the building blocks.
|
||||
|
||||
So top level model class like: LLaMAForCausalLM shall inherit this class.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_hugging_face(cls,
|
||||
hf_model_dir: str,
|
||||
dtype: Optional[str] = 'float16',
|
||||
mapping: Optional[Mapping] = None,
|
||||
**kwargs):
|
||||
'''
|
||||
Create LLM object and load weights from hugging face
|
||||
def from_hugging_face(
|
||||
cls,
|
||||
hf_model_dir: str,
|
||||
dtype: Optional[str] = "float16",
|
||||
mapping: Optional[Mapping] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create LLM object and load weights from hugging face.
|
||||
|
||||
Parameters:
|
||||
hf_model_dir: the hugging face model directory
|
||||
dtype: str, the default weights data type when loading from the hugging face model
|
||||
mapping: Mapping, specify the multi-gpu parallel strategy, when it's None, single GPU is used
|
||||
'''
|
||||
"""
|
||||
raise NotImplementedError("Subclass shall override this")
|
||||
|
||||
def use_lora(self, lora_config: LoraConfig):
|
||||
'''
|
||||
Load lora weights from the give config to the module
|
||||
"""Load lora weights from the give config to the module.
|
||||
|
||||
Parameters:
|
||||
lora_config: the lora config
|
||||
'''
|
||||
"""
|
||||
raise NotImplementedError("Subclass shall override this")
|
||||
|
||||
def use_prompt_tuning(self, max_prompt_embedding_table_size: str,
|
||||
prompt_table_path: str):
|
||||
'''Enable p tuning when build the TRT engine, call this before to_trt
|
||||
'''
|
||||
def use_prompt_tuning(self, max_prompt_embedding_table_size: str, prompt_table_path: str):
|
||||
"""Enable p tuning when build the TRT engine, call this before to_trt."""
|
||||
raise NotImplementedError
|
||||
|
||||
def default_plugin_config(self, **kwargs) -> PluginConfig:
|
||||
'''Return the default plugin config for this model, when the plugin_config value is not given in to_trt() call.
|
||||
If users need to set different plugin configs, they can start from the return object and change it.
|
||||
'''
|
||||
"""Return the default plugin config for this model.
|
||||
|
||||
This is used when the plugin_config value is not given in to_trt() call.
|
||||
If users need to set different plugin configs, they can start from the return object and change it.
|
||||
"""
|
||||
return PluginConfig.from_dict(kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user