TensorRT-LLMs/tensorrt_llm/_torch/compilation/backend.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

80 lines
2.7 KiB
Python

import os
from typing import List, Optional, Union
import torch
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.compile_fx import compile_fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.fx import Graph, GraphModule
import tensorrt_llm
from .patterns.ar_residual_norm import register_ar_residual_norm
from .patterns.residual_add_norm import register_add_norm
from .recover_pass import recover_pass
class Backend:
_custom_pass_instance: Optional[PatternMatcherPass] = None
def __init__(self, enable_inductor=True) -> None:
super().__init__()
self.elapsed_time = 0
self.module_inference_event = []
self.module_inference_time = 0
self.call_count = 0
self.custom_pass = Backend.get_custom_pass()
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.match_count = []
if enable_inductor:
from torch._inductor import config
self.inductor_config = config.get_config_copy()
self.inductor_config["joint_custom_post_pass"] = self.optimize
@classmethod
def get_custom_pass(cls):
world_size = tensorrt_llm.mpi_world_size()
if cls._custom_pass_instance == None:
# Really naive pass manager here
cls._custom_pass_instance = PatternMatcherPass()
if world_size > 1:
# Currently torch compile cannot work properly with lamport fusion kernel
# TO-DO: Fix this issue
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
register_ar_residual_norm(cls._custom_pass_instance)
else:
register_add_norm(cls._custom_pass_instance)
return cls._custom_pass_instance
def optimize(
self,
gm: Union[GraphModule | Graph],
example_inputs: Optional[List[torch.Tensor]] = None,
):
graph = gm.graph if isinstance(gm, GraphModule) else gm
self.match_count.append(self.custom_pass.apply(graph))
graph.eliminate_dead_code()
if isinstance(gm, GraphModule):
gm.recompile()
return gm
def __call__(self, gm: GraphModule,
example_inputs: List[torch.Tensor]) -> callable:
gm = recover_pass(gm)
if self.enable_inductor:
return compile_fx(gm,
example_inputs,
config_patches=self.inductor_config)
else:
return aot_module_simplified(gm,
example_inputs,
fw_compiler=self.optimize)