TensorRT-LLMs/tensorrt_llm/_torch/utils.py
Enwei Zhu 21a93fbf9d
[TRTLLM-9992][perf] Enable PDL for CuteDSL kernels and overlap MoeOutputMemset (#10043)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-12-20 03:12:41 -05:00

407 lines
13 KiB
Python

import contextlib
import os
import threading
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Dict, List
import torch
from torch.nn import functional as F
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils
is_torch_compiling_flag = False
is_piecewise_running_flag = False
aux_stream_name_list = [
'Attention',
'MoeShared',
'MoeChunkingOverlap',
'MoeBalancer',
'MoeOutputMemset',
]
AuxStreamType = Enum(
'AuxStreamType',
aux_stream_name_list,
)
EventType = Enum(
'EventType',
['Main', *aux_stream_name_list],
start=0,
)
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
class ActivationType(IntEnum):
InvalidType = 0
Identity = 1
Gelu = 2
Relu = 3
Silu = 4
Swiglu = 5
Geglu = 6
SwigluBias = 7
Relu2 = 8
# IMPORTANT: when adding a new activation type, please update this function.
# And make sure it aligned with cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h::isGatedActivation function.
def is_gated_activation(activation_type: ActivationType) -> bool:
return activation_type in [
ActivationType.Swiglu, ActivationType.SwigluBias, ActivationType.Geglu
]
def set_torch_compiling(enable: bool):
global is_torch_compiling_flag
is_torch_compiling_flag = enable
def is_torch_compiling() -> bool:
global is_torch_compiling_flag
return is_torch_compiling_flag
def set_piecewise_running(enable: bool):
global is_piecewise_running_flag
is_piecewise_running_flag = enable
def is_piecewise_running() -> bool:
global is_piecewise_running_flag
return is_piecewise_running_flag
_global_attrs = threading.local()
def get_global_attrs():
return _global_attrs
_model_extra_attrs = threading.local()
def get_model_extra_attrs():
return getattr(_model_extra_attrs, 'attrs', None)
@contextlib.contextmanager
def model_extra_attrs(attrs: Dict):
old_attrs = getattr(_model_extra_attrs, 'attrs', None)
_model_extra_attrs.attrs = attrs
try:
yield
finally:
_model_extra_attrs.attrs = old_attrs
def with_model_extra_attrs(get_attrs):
def decorator(func):
def wrapper(self, *args, **kwargs):
with model_extra_attrs(get_attrs(self)):
return func(self, *args, **kwargs)
return wrapper
return decorator
def make_weak_ref(x):
if isinstance(x, torch.Tensor):
return convert_to_torch_tensor(
TensorWrapper(x.data_ptr(), x.dtype, x.shape,
x.stride())) if x.is_cuda else x
elif isinstance(x, tuple):
return tuple(make_weak_ref(i) for i in x)
elif isinstance(x, list):
return [make_weak_ref(i) for i in x]
elif isinstance(x, dict):
return {k: make_weak_ref(v) for k, v in x.items()}
elif isinstance(x, (int, float, bool)):
return x
else:
raise TypeError(f"Invalid type {type(x)} to make weak ref")
@dataclass
class Fp4QuantizedTensor:
fp4_tensor: torch.Tensor
scaling_factor: torch.Tensor
is_sf_swizzled: bool = True
@property
def shape(self):
return self.fp4_tensor.shape
def compute_swizzled_sf_shape(row: int, col: int):
padded_row = pad_up(row, 128)
padded_col = pad_up(col, 4)
return padded_row, padded_col
def swizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16):
"""Swizzle FP4 scaling factors using C++ torch op implementation
Args:
sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors.
rows: rows of the original unquantized tensor
cols_sf: ceil_div(cols, scaling_vector_size) where cols is the number of columns of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
[b * pad_up(rows, 128) * pad_up(cols_sf, 4), ] 1D swizzled scaling factors, possibly with rows and cols padded.
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.trtllm.block_scale_interleave(sf)
def unswizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16):
"""Swizzle FP4 scaling factors using C++ torch op implementation
Args:
sf: The (padded and) swizzled scaling factors.
rows: rows of the original unquantized tensor
cols: cols of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
2D unswizzled scaling factors
"""
sf_cols = ceil_div(cols, scaling_vector_size)
sf = sf.view(-1, rows, sf_cols)
return torch.ops.trtllm.block_scale_interleave_reverse(sf).view(-1, sf_cols)
@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=())
def reswizzle_sf(sf: torch.Tensor,
rows: int,
cols: int,
scaling_vector_size: int = 16) -> torch.Tensor:
"""Reswizzle FP4 scaling factors using C++ torch op implementation.
It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
Args:
sf: The (padded and) swizzled scaling factors.
rows: rows of the original unquantized tensor
cols: cols of the original unquantized tensor
scaling_vector_size: the size of the scaling vector
Returns:
1D reswizzled scaling factors
"""
sf_cols = ceil_div(cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
padded_cols = padded_sf_cols * scaling_vector_size
assert sf.numel() % (padded_rows * padded_sf_cols) == 0
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols)
# Unswizzle each partition
sf_unswizzled = unswizzle_sf(sf_reshaped, padded_rows, padded_cols,
scaling_vector_size)
# Brings the unswizzled scaling factors in each partition together
total_rows = num_partitions * rows
sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows,
padded_sf_cols)
sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous(
) # TODO: This will incur a elementwise kernel
sf_concatenated = sf_concatenated.view(total_rows, sf_cols)
# Finally swizzle the concatenated scaling factors
return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)
@torch.library.register_fake("trtllm::reswizzle_sf")
def _(sf, rows, cols, scaling_vector_size=16):
sf_cols = ceil_div(cols, scaling_vector_size)
padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
total_rows = num_partitions * rows
sz = pad_up(total_rows, 128) * pad_up(cols, 4)
return sf.new_empty(sz)
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
# Following code is equivalent to 1 << (x - 1).bit_length()
# But this impl does not contain bit_length() so can be used by torch compile.
# It can correctly handle 64bit number which should be enough for now.
n = x - 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
return n + 1
def last_positive_power_of_2(x: int) -> int:
next = next_positive_power_of_2(x)
if next == x:
return next
return next // 2
def nearest_in_buckets(x: int, buckets: List[int]) -> int:
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
max_num_tokens = next_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= 1:
num_token_buckets.append(m)
m //= 2
return tuple(num_token_buckets[::-1])
def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
max_num_tokens = last_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= 1:
num_token_buckets.append(m)
m //= 2
return tuple(num_token_buckets[::-1])
def fp4_scale_infer_shape(input_shapes: List[List[int]]):
"""Calculate the dimensions of the fp4 scale tensor.
"""
out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
sf_vec_size=16)
return scale_shape * 2
def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
"""Calculate the dimensions of the fp4 scale tensor.
"""
out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
sf_vec_size=16,
is_swizzled_layout=False)
return scale_shape * 2
_enable_piecewise_cuda_graph = True
def set_piecewise_cuda_graph_flag(enable: bool):
global _enable_piecewise_cuda_graph
_enable_piecewise_cuda_graph = enable
def get_piecewise_cuda_graph_flag() -> bool:
global _enable_piecewise_cuda_graph
return _enable_piecewise_cuda_graph
@contextlib.contextmanager
def piecewise_cuda_graph(enable: bool):
prev_enable = get_piecewise_cuda_graph_flag()
set_piecewise_cuda_graph_flag(enable)
try:
yield
finally:
set_piecewise_cuda_graph_flag(prev_enable)
def set_per_request_piecewise_cuda_graph_flag(enable: bool):
_global_attrs.per_request_piecewise_cuda_graph_flag = enable
def get_per_request_piecewise_cuda_graph_flag() -> bool:
return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True)
def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
# We use heuristic to determine the lm_head_tp_size
# Since token_count=256 will hit the boundary of math-bound problem
# We use 256 // token_count to determine the lm_head_tp_size
# For more details, refer to the blog: https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.md#mtp-lm-head-tensor-parallelism
lm_head_tp_size_raw = 256 // token_count
# TODO: On platforms like GB200, setting lm_head_tp_size_upper_bound to world_size could be more efficient when world_size > gpus_per_node, we need to do further investigation.
lm_head_tp_size_upper_bound = min(mapping.world_size, mapping.gpus_per_node)
lm_head_tp_size = int(
os.getenv(
'LM_HEAD_TP_SIZE',
nearest_in_buckets(lm_head_tp_size_raw,
[1, lm_head_tp_size_upper_bound])))
assert mapping.tp_size % lm_head_tp_size == 0, f"mapping.tp_size: {mapping.tp_size}, lm_head_tp_size: {lm_head_tp_size}"
lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size
return Mapping(
world_size=lm_head_tp_size * lm_head_pp_size,
rank=mapping.rank,
gpus_per_node=mapping.gpus_per_node,
tp_size=lm_head_tp_size,
pp_size=lm_head_pp_size,
enable_attention_dp=mapping.enable_attention_dp,
enable_lm_head_tp_in_adp=mapping.enable_lm_head_tp_in_adp,
)
def get_device_uuid(device_idx: int) -> str:
"""Get the UUID of a CUDA device using torch cuda api"""
property = torch.cuda.get_device_properties(device_idx)
uuid = "GPU-" + str(property.uuid)
return uuid
def maybe_compile(func=None, **compile_kwargs):
"""
Conditionally compile a function with torch.compile.
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
Args:
func: The function to decorate (optional, for direct decoration).
**compile_kwargs: Keyword arguments for torch.compile.
Returns:
The conditionally compiled function..
"""
def decorator(f):
compiled_func = torch.compile(f, **compile_kwargs)
def wrapper(*args, **kwargs):
if is_piecewise_running():
return f(*args, **kwargs)
return compiled_func(*args, **kwargs)
return wrapper
return decorator(func) if func else decorator
def split(x: torch.Tensor,
tp_size: int,
idx: int,
dim: int = 0) -> torch.Tensor:
assert x.shape[dim] % tp_size == 0
split_size = x.shape[dim] // tp_size
if tp_size == 1:
return x
return torch.split(x, split_size, dim=dim)[idx]
def relu2(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))