TensorRT-LLMs/tensorrt_llm/_torch/utils.py
Yukun He c678774c99
feat: Apply the new torch-flow compatible AutoTuner to both Fused MoE and NVFP4 Linear operators. (#3151)
* Several optimizations and fixings on the Autotuner.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Apply the new Python side Autotuner on current linear for nvFP4 data type.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Apply the new Python side Autotuner on MoE op
* Remove routers from cache key to improve inference perf
* Prevent unnecessary code profiling. Use do_preparation keyword to select which part should be executed during before evaluating any tactic.
* Remove try-catch inside moe profiling process.
* Move default tactic -1 to 0 transforms in cpp runner.
* Revise relavant tests.
* Predefined the bucketizing strategy for fused_moe

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Add specific_profile support for AutoTuner to bypass the standard cache search process for perf optimization
* Add specific_profile for moe
* Add specific profile for linear

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Fixing and revising according to reviewer's suggestions.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Use lru_cache for inference pref optimization.
* Revert gen_custom_cache_key feature

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Replace runner with runner id to achieve a serializable cache.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Code clean up and minor fixings.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Move all tunable runners and custom ops into torch_custom_ops.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

* Treat min_latency_mode as a independent dynamic tensor. Modify get_valid_tactics to suit for it.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>

---------

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-04-08 14:28:36 +08:00

154 lines
4.7 KiB
Python

import os
from dataclasses import dataclass
from enum import Enum
from typing import List
import torch
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from .pipeline_interface import PipelineInterface
is_torch_compiling_flag = False
aux_stream_name_list = ['Attention', 'MoeShared', 'MoeChunkingOverlap']
AuxStreamType = Enum(
'AuxStreamType',
aux_stream_name_list,
)
EventType = Enum(
'EventType',
['Main', *aux_stream_name_list],
start=0,
)
def set_torch_compiling(enable: bool):
global is_torch_compiling_flag
is_torch_compiling_flag = enable
def is_torch_compiling() -> bool:
global is_torch_compiling_flag
return is_torch_compiling_flag
def make_weak_ref(x):
if isinstance(x, torch.Tensor):
return convert_to_torch_tensor(
TensorWrapper(x.data_ptr(), x.dtype, x.shape)) if x.is_cuda else x
elif isinstance(x, tuple):
return tuple(make_weak_ref(i) for i in x)
elif isinstance(x, list):
return [make_weak_ref(i) for i in x]
elif isinstance(x, dict):
return {k: make_weak_ref(v) for k, v in x.items()}
elif isinstance(x, (int, float, bool)):
return x
elif isinstance(x, PipelineInterface):
return tuple(make_weak_ref(tensor) for tensor in x)
else:
raise TypeError(f"Invalid type {type(x)} to make weak ref")
@dataclass
class Fp4QuantizedTensor:
fp4_tensor: torch.Tensor
scaling_factor: torch.Tensor
_disable_fp4_allgather = os.getenv("TLLM_DISABLE_FP4_ALLGATHER", "0") == "1"
def disable_fp4_allgather():
return _disable_fp4_allgather
def swizzle_sf(sf: torch.Tensor,
row: int,
col: int,
scaling_vector_size: int = 16):
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_full = torch.zeros(num_m_tiles * 32 * 4,
num_k_tiles * 4,
dtype=sf.dtype,
device=sf.device)
sf_full[:row, :(col //
scaling_vector_size)] = sf[:row, :(col //
scaling_vector_size)]
sf_full_reshaped = sf_full.view(num_m_tiles, 4, 32, num_k_tiles, 4)
sf_full_swizzle = sf_full_reshaped.transpose(1, 3)
sf_swizzle = sf_full_swizzle.reshape(-1)
return sf_swizzle
def unswizzle_sf(sf: torch.Tensor,
row: int,
col: int,
scaling_vector_size: int = 16):
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
# SF layout [num_m_tiles, num_k_tiles, 32 (m_tile column major), 4 (m_tile column major), 4(k_tile)]
sf_reshaped = sf.view(num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(1, 3)
sf_unswizzle = sf_unswizzle.reshape(num_m_tiles * 32 * 4, num_k_tiles * 4)
sf_unswizzle_sliced = sf_unswizzle[:row, :(col // scaling_vector_size)]
return sf_unswizzle_sliced.contiguous()
def reswizzle_sf(sf: torch.Tensor,
row: int,
col: int,
scaling_vector_size: int = 16):
factor = scaling_vector_size * 4
num_m_tiles = (row + 128 - 1) // 128
num_k_tiles = (col + factor - 1) // factor
partition_size = num_m_tiles * num_k_tiles * 32 * 4 * 4
num_partitions = sf.numel() // partition_size
sf_reshaped = sf.view(num_partitions, num_m_tiles, num_k_tiles, 32, 4, 4)
sf_unswizzle = sf_reshaped.transpose(2, 4)
sf_unswizzle = sf_unswizzle.reshape(num_partitions, num_m_tiles * 32 * 4,
num_k_tiles * 4)
total_rows = num_partitions * row
num_m_tiles_out = (total_rows + 128 - 1) // 128
sf_out = torch.zeros(
num_m_tiles_out,
4,
32,
num_k_tiles,
4,
dtype=sf.dtype,
device=sf.device,
)
sf_out_reshaped = sf_out.view(num_m_tiles_out * 32 * 4, num_k_tiles * 4)
sf_out_reshaped[:total_rows] = sf_unswizzle[:, :row].reshape(total_rows, -1)
sf_out_swizzle = sf_out.transpose(1, 3).reshape(-1)
return sf_out_swizzle
def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
return 1 << (x - 1).bit_length()
def nearest_in_buckets(x: int, buckets: List[int]) -> int:
return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])
def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
max_num_tokens = next_positive_power_of_2(max_num_tokens)
num_token_buckets = []
m = max_num_tokens
while m >= 1:
num_token_buckets.append(m)
m //= 2
return num_token_buckets