mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com> Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
373 lines
12 KiB
Python
373 lines
12 KiB
Python
import contextlib
|
|
import os
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from enum import Enum, IntEnum
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.math_utils import ceil_div, pad_up
|
|
from tensorrt_llm.quantization.utils import fp4_utils
|
|
|
|
is_torch_compiling_flag = False
|
|
is_piecewise_running_flag = False
|
|
|
|
aux_stream_name_list = [
|
|
'Attention',
|
|
'MoeShared',
|
|
'MoeChunkingOverlap',
|
|
'MoeBalancer',
|
|
]
|
|
AuxStreamType = Enum(
|
|
'AuxStreamType',
|
|
aux_stream_name_list,
|
|
)
|
|
EventType = Enum(
|
|
'EventType',
|
|
['Main', *aux_stream_name_list],
|
|
start=0,
|
|
)
|
|
|
|
|
|
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
|
|
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
|
|
class ActivationType(IntEnum):
|
|
InvalidType = 0
|
|
Identity = 1
|
|
Gelu = 2
|
|
Relu = 3
|
|
Silu = 4
|
|
Swiglu = 5
|
|
Geglu = 6
|
|
SwigluBias = 7
|
|
Relu2 = 8
|
|
|
|
|
|
def set_torch_compiling(enable: bool):
|
|
global is_torch_compiling_flag
|
|
is_torch_compiling_flag = enable
|
|
|
|
|
|
def is_torch_compiling() -> bool:
|
|
global is_torch_compiling_flag
|
|
return is_torch_compiling_flag
|
|
|
|
|
|
def set_piecewise_running(enable: bool):
|
|
global is_piecewise_running_flag
|
|
is_piecewise_running_flag = enable
|
|
|
|
|
|
def is_piecewise_running() -> bool:
|
|
global is_piecewise_running_flag
|
|
return is_piecewise_running_flag
|
|
|
|
|
|
_global_attrs = threading.local()
|
|
|
|
|
|
def get_global_attrs():
|
|
return _global_attrs
|
|
|
|
|
|
_model_extra_attrs = threading.local()
|
|
|
|
|
|
def get_model_extra_attrs():
|
|
return getattr(_model_extra_attrs, 'attrs', None)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def model_extra_attrs(attrs: Dict):
|
|
old_attrs = getattr(_model_extra_attrs, 'attrs', None)
|
|
_model_extra_attrs.attrs = attrs
|
|
try:
|
|
yield
|
|
finally:
|
|
_model_extra_attrs.attrs = old_attrs
|
|
|
|
|
|
def with_model_extra_attrs(get_attrs):
|
|
|
|
def decorator(func):
|
|
|
|
def wrapper(self, *args, **kwargs):
|
|
with model_extra_attrs(get_attrs(self)):
|
|
return func(self, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def make_weak_ref(x):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
return convert_to_torch_tensor(
|
|
TensorWrapper(x.data_ptr(), x.dtype, x.shape,
|
|
x.stride())) if x.is_cuda else x
|
|
elif isinstance(x, tuple):
|
|
return tuple(make_weak_ref(i) for i in x)
|
|
elif isinstance(x, list):
|
|
return [make_weak_ref(i) for i in x]
|
|
elif isinstance(x, dict):
|
|
return {k: make_weak_ref(v) for k, v in x.items()}
|
|
elif isinstance(x, (int, float, bool)):
|
|
return x
|
|
else:
|
|
raise TypeError(f"Invalid type {type(x)} to make weak ref")
|
|
|
|
|
|
@dataclass
|
|
class Fp4QuantizedTensor:
|
|
fp4_tensor: torch.Tensor
|
|
scaling_factor: torch.Tensor
|
|
is_sf_swizzled: bool = True
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.fp4_tensor.shape
|
|
|
|
|
|
def compute_swizzled_sf_shape(row: int, col: int):
|
|
padded_row = pad_up(row, 128)
|
|
padded_col = pad_up(col, 4)
|
|
return padded_row, padded_col
|
|
|
|
|
|
def swizzle_sf(sf: torch.Tensor,
|
|
rows: int,
|
|
cols: int,
|
|
scaling_vector_size: int = 16):
|
|
"""Swizzle FP4 scaling factors using C++ torch op implementation
|
|
Args:
|
|
sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors.
|
|
rows: rows of the original unquantized tensor
|
|
cols_sf: ceil_div(cols, scaling_vector_size) where cols is the number of columns of the original unquantized tensor
|
|
scaling_vector_size: the size of the scaling vector
|
|
Returns:
|
|
[b * pad_up(rows, 128) * pad_up(cols_sf, 4), ] 1D swizzled scaling factors, possibly with rows and cols padded.
|
|
"""
|
|
sf_cols = ceil_div(cols, scaling_vector_size)
|
|
sf = sf.view(-1, rows, sf_cols)
|
|
return torch.ops.trtllm.block_scale_interleave(sf)
|
|
|
|
|
|
def unswizzle_sf(sf: torch.Tensor,
|
|
rows: int,
|
|
cols: int,
|
|
scaling_vector_size: int = 16):
|
|
"""Swizzle FP4 scaling factors using C++ torch op implementation
|
|
Args:
|
|
sf: The (padded and) swizzled scaling factors.
|
|
rows: rows of the original unquantized tensor
|
|
cols: cols of the original unquantized tensor
|
|
scaling_vector_size: the size of the scaling vector
|
|
Returns:
|
|
2D unswizzled scaling factors
|
|
"""
|
|
sf_cols = ceil_div(cols, scaling_vector_size)
|
|
sf = sf.view(-1, rows, sf_cols)
|
|
return torch.ops.trtllm.block_scale_interleave_reverse(sf).view(-1, sf_cols)
|
|
|
|
|
|
@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=())
|
|
def reswizzle_sf(sf: torch.Tensor,
|
|
rows: int,
|
|
cols: int,
|
|
scaling_vector_size: int = 16) -> torch.Tensor:
|
|
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
|
|
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
|
|
Args:
|
|
sf: The (padded and) swizzled scaling factors.
|
|
rows: rows of the original unquantized tensor
|
|
cols: cols of the original unquantized tensor
|
|
scaling_vector_size: the size of the scaling vector
|
|
Returns:
|
|
1D reswizzled scaling factors
|
|
"""
|
|
sf_cols = ceil_div(cols, scaling_vector_size)
|
|
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
|
|
padded_cols = padded_sf_cols * scaling_vector_size
|
|
|
|
assert sf.numel() % (padded_rows * padded_sf_cols) == 0
|
|
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
|
|
|
|
sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols)
|
|
|
|
# Unswizzle each partition
|
|
sf_unswizzled = unswizzle_sf(sf_reshaped, padded_rows, padded_cols,
|
|
scaling_vector_size)
|
|
|
|
# Brings the unswizzled scaling factors in each partition together
|
|
total_rows = num_partitions * rows
|
|
sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows,
|
|
padded_sf_cols)
|
|
sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous(
|
|
) # TODO: This will incur a elementwise kernel
|
|
sf_concatenated = sf_concatenated.view(total_rows, sf_cols)
|
|
|
|
# Finally swizzle the concatenated scaling factors
|
|
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
|
|
|
|
|
|
@torch.library.register_fake("trtllm::reswizzle_sf")
|
|
def _(sf, rows, cols, scaling_vector_size=16):
|
|
sf_cols = ceil_div(cols, scaling_vector_size)
|
|
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
|
|
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
|
|
total_rows = num_partitions * rows
|
|
sz = pad_up(total_rows, 128) * pad_up(cols, 4)
|
|
return sf.new_empty(sz)
|
|
|
|
|
|
def next_positive_power_of_2(x: int) -> int:
|
|
if x < 1:
|
|
return 1
|
|
|
|
# Following code is equivalent to 1 << (x - 1).bit_length()
|
|
# But this impl does not contain bit_length() so can be used by torch compile.
|
|
# It can correctly handle 64bit number which should be enough for now.
|
|
n = x - 1
|
|
n |= n >> 1
|
|
n |= n >> 2
|
|
n |= n >> 4
|
|
n |= n >> 8
|
|
n |= n >> 16
|
|
n |= n >> 32
|
|
return n + 1
|
|
|
|
|
|
def last_positive_power_of_2(x: int) -> int:
|
|
next = next_positive_power_of_2(x)
|
|
if next == x:
|
|
return next
|
|
|
|
return next // 2
|
|
|
|
|
|
def nearest_in_buckets(x: int, buckets: List[int]) -> int:
|
|
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
|
|
|
|
|
|
def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
|
|
max_num_tokens = next_positive_power_of_2(max_num_tokens)
|
|
num_token_buckets = []
|
|
m = max_num_tokens
|
|
while m >= 1:
|
|
num_token_buckets.append(m)
|
|
m //= 2
|
|
|
|
return tuple(num_token_buckets[::-1])
|
|
|
|
|
|
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
|
|
max_num_tokens = last_positive_power_of_2(max_num_tokens)
|
|
num_token_buckets = []
|
|
m = max_num_tokens
|
|
while m >= 1:
|
|
num_token_buckets.append(m)
|
|
m //= 2
|
|
return tuple(num_token_buckets[::-1])
|
|
|
|
|
|
def fp4_scale_infer_shape(input_shapes: List[List[int]]):
|
|
"""Calculate the dimensions of the fp4 scale tensor.
|
|
"""
|
|
out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
|
|
sf_vec_size=16)
|
|
return scale_shape * 2
|
|
|
|
|
|
_enable_piecewise_cuda_graph = True
|
|
|
|
|
|
def set_piecewise_cuda_graph_flag(enable: bool):
|
|
global _enable_piecewise_cuda_graph
|
|
_enable_piecewise_cuda_graph = enable
|
|
|
|
|
|
def get_piecewise_cuda_graph_flag() -> bool:
|
|
global _enable_piecewise_cuda_graph
|
|
return _enable_piecewise_cuda_graph
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def piecewise_cuda_graph(enable: bool):
|
|
prev_enable = get_piecewise_cuda_graph_flag()
|
|
set_piecewise_cuda_graph_flag(enable)
|
|
try:
|
|
yield
|
|
finally:
|
|
set_piecewise_cuda_graph_flag(prev_enable)
|
|
|
|
|
|
def set_per_request_piecewise_cuda_graph_flag(enable: bool):
|
|
_global_attrs.per_request_piecewise_cuda_graph_flag = enable
|
|
|
|
|
|
def get_per_request_piecewise_cuda_graph_flag() -> bool:
|
|
return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True)
|
|
|
|
|
|
def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
|
|
# We use heuristic to determine the lm_head_tp_size
|
|
# Since token_count=256 will hit the boundary of math-bound problem
|
|
# We use 256 // token_count to determine the lm_head_tp_size
|
|
# For more details, refer to the blog: https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.md#mtp-lm-head-tensor-parallelism
|
|
lm_head_tp_size_raw = 256 // token_count
|
|
# TODO: On platforms like GB200, setting lm_head_tp_size_upper_bound to world_size could be more efficient when world_size > gpus_per_node, we need to do further investigation.
|
|
lm_head_tp_size_upper_bound = min(mapping.world_size, mapping.gpus_per_node)
|
|
lm_head_tp_size = int(
|
|
os.getenv(
|
|
'LM_HEAD_TP_SIZE',
|
|
nearest_in_buckets(lm_head_tp_size_raw,
|
|
[1, lm_head_tp_size_upper_bound])))
|
|
assert mapping.tp_size % lm_head_tp_size == 0, f"mapping.tp_size: {mapping.tp_size}, lm_head_tp_size: {lm_head_tp_size}"
|
|
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size
|
|
|
|
return Mapping(
|
|
world_size=lm_head_tp_size * lm_head_pp_size,
|
|
rank=mapping.rank,
|
|
gpus_per_node=mapping.gpus_per_node,
|
|
tp_size=lm_head_tp_size,
|
|
pp_size=lm_head_pp_size,
|
|
enable_attention_dp=mapping.enable_attention_dp,
|
|
enable_lm_head_tp_in_adp=mapping.enable_lm_head_tp_in_adp,
|
|
)
|
|
|
|
|
|
def get_device_uuid(device_idx: int) -> str:
|
|
"""Get the UUID of a CUDA device using torch cuda api"""
|
|
|
|
property = torch.cuda.get_device_properties(device_idx)
|
|
uuid = "GPU-" + str(property.uuid)
|
|
return uuid
|
|
|
|
|
|
def maybe_compile(func=None, **compile_kwargs):
|
|
"""
|
|
Conditionally compile a function with torch.compile.
|
|
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
|
|
Args:
|
|
func: The function to decorate (optional, for direct decoration).
|
|
**compile_kwargs: Keyword arguments for torch.compile.
|
|
Returns:
|
|
The conditionally compiled function..
|
|
"""
|
|
|
|
def decorator(f):
|
|
compiled_func = torch.compile(f, **compile_kwargs)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
if is_piecewise_running():
|
|
return f(*args, **kwargs)
|
|
return compiled_func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator(func) if func else decorator
|