TensorRT-LLMs/tensorrt_llm/_torch/compilation/backend.py
Jin Li 028235404b
[TRTLLM-6633][feat] Padding for piecewise cudagraph (#6750)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-08-26 18:31:33 -04:00

167 lines
5.9 KiB
Python

import os
from typing import List, Optional
import torch
import torch._inductor.config as inductor_config
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.compile_fx import compile_fx, select_decomp_table
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses import FakeTensor
from torch.fx import GraphModule
import tensorrt_llm
from tensorrt_llm import logger
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns.ar_residual_norm import register_ar_fusions
from .patterns.residual_add_norm import register_add_norm
from .piecewise_optimizer import piecewise_optimizer
from .recover_pass import recover_pass
from .remove_copy_pass import remove_copy_for_mutates_args
class Backend:
_custom_pass_instances: List[PatternMatcherPass] = None
_graph_pool_handle: tuple[int, int] = None
# Following classes are used to let weakref ref the stream and eventlist objects.
class Streams(list):
pass
class Events(list):
pass
def __init__(
self,
enable_inductor=True,
enable_userbuffers=False,
enable_piecewise_cuda_graph: bool = False,
capture_num_tokens: Optional[List[int]] = None,
max_num_streams: int = 1,
) -> None:
super().__init__()
self.elapsed_time = 0
self.module_inference_event = []
self.module_inference_time = 0
self.call_count = 0
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.capture_num_tokens = sorted(capture_num_tokens or [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False
# We only need to create aux streams.
self.aux_streams = Backend.Streams(
[torch.cuda.Stream() for _ in range(max_num_streams - 1)])
self.events = Backend.Events()
inductor_config.enable_auto_functionalized_v2 = False
if Backend._graph_pool_handle is None:
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()
self.match_count = []
@classmethod
def get_custom_pass(cls, enable_userbuffers):
# TODO: add pp + tp support
world_size = tensorrt_llm.mpi_world_size()
if not cls._custom_pass_instances:
# Really naive pass manager here
cls._custom_pass_instances = [PatternMatcherPass()]
if world_size > 1:
# Currently torch compile cannot work properly with lamport fusion kernel
# TO-DO: Fix this issue
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
)
register_ar_fusions(cls._custom_pass_instances, ub_enabled)
else:
register_add_norm(cls._custom_pass_instances[0])
return cls._custom_pass_instances
def bypass_optimization(self):
self.no_optimization = True
def enable_optimization(self):
self.no_optimization = False
def generate_events(self, num_events: int):
if num_events > len(self.events):
self.events += [
torch.cuda.Event() for _ in range(num_events - len(self.events))
]
def optimize(
self,
gm: GraphModule,
example_inputs: List[torch.Tensor],
):
graph = gm.graph
for custom_pass in self.custom_passes:
self.match_count.append(custom_pass.apply(graph))
while self.match_count[-1]:
self.match_count.append(custom_pass.apply(graph))
graph.eliminate_dead_code()
# After this pass, cannot run any dce!!!
remove_copy_for_mutates_args(graph)
# Do not apply multi-stream if enable piecewise cuda graph or inductor
# For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer
# For inductor, we do not control the passes inside inductor.
if len(
self.aux_streams
) > 0 and not self.piecewise_cuda_graph and not self.enable_inductor:
num_events = multi_stream_schedule(gm, len(self.aux_streams) + 1)
self.generate_events(num_events)
gm.recompile()
if self.piecewise_cuda_graph:
gm, num_events = piecewise_optimizer(
gm,
example_inputs,
self.enable_inductor,
self.input_num_tokens,
self.capture_num_tokens,
self._graph_pool_handle,
len(self.aux_streams) + 1,
)
self.generate_events(num_events)
return gm
elif self.enable_inductor:
return compile_fx(gm, example_inputs)
else:
return gm
def __call__(self, gm: GraphModule,
example_inputs: List[torch.Tensor]) -> callable:
if self.no_optimization:
logger.warning(
"Bypassing torch.compile optimization and fallback to eager execution!"
)
return gm
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name == "l_input_ids_":
example_value = node.meta["example_value"]
assert isinstance(example_value, FakeTensor)
self.input_num_tokens = example_value.shape[0]
break
if self.piecewise_cuda_graph:
assert (
self.input_num_tokens is not None
), "Cannot detect input_num_tokens. Cannot use piecewise CUDA graph. What is the name of `input_ids`?"
gm = recover_pass(gm)
return aot_module_simplified(
gm,
example_inputs,
fw_compiler=self.optimize,
decompositions=select_decomp_table(),
)