mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Instead of allocating UserBuffers at beginning of runtime, UB buffers are now managed with global allocator. The allocator will dynamically assign free UB buffer or allocate new buffer for torch tensor. It makes userbuffers easier to use. * In common usecase, the Userbuffers will be allocated correctly during warm up stage. There is no dynamic allocation during inference. * UB fusion pattern is rewroten using the new UB Allocator. It contains following passes: 1. Fuse Quant with allreduce, replace with UB impl, and insert a copy_to_userbuffers. Currently the normal allreduce still does not support FP8 quant. So this need to be done in UB pass 2. Convert all supported allreduce with UB and insert copy_to_userbuffers. 3. Fuse op before ar with the copy_to_userbuffers. So the op directly writes to the userbuffer 4. Remove userbuffers finalize if the output is connect to another UB allreduce. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
import os
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
from torch._inductor.compile_fx import compile_fx
|
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
|
from torch.fx import Graph, GraphModule
|
|
|
|
import tensorrt_llm
|
|
|
|
from .patterns.ar_residual_norm import register_ar_residual_norm
|
|
from .patterns.residual_add_norm import register_add_norm
|
|
from .patterns.ub_allreduce import register_ub_patterns
|
|
from .recover_pass import recover_pass
|
|
|
|
|
|
class Backend:
|
|
|
|
_custom_pass_instances: List[PatternMatcherPass] = None
|
|
|
|
def __init__(self, enable_inductor=True, enable_userbuffers=False) -> None:
|
|
super().__init__()
|
|
self.elapsed_time = 0
|
|
self.module_inference_event = []
|
|
self.module_inference_time = 0
|
|
self.call_count = 0
|
|
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
|
|
self.rank = tensorrt_llm.mpi_rank()
|
|
self.enable_inductor = enable_inductor
|
|
|
|
self.match_count = []
|
|
|
|
if enable_inductor:
|
|
from torch._inductor import config
|
|
|
|
self.inductor_config = config.get_config_copy()
|
|
self.inductor_config["joint_custom_post_pass"] = self.optimize
|
|
|
|
@classmethod
|
|
def get_custom_pass(cls, enable_userbuffers):
|
|
# TODO: add pp + tp support
|
|
world_size = tensorrt_llm.mpi_world_size()
|
|
if not cls._custom_pass_instances:
|
|
# Really naive pass manager here
|
|
cls._custom_pass_instances = [PatternMatcherPass()]
|
|
if world_size > 1:
|
|
# Currently torch compile cannot work properly with lamport fusion kernel
|
|
# TO-DO: Fix this issue
|
|
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
|
|
register_ar_residual_norm(cls._custom_pass_instances[0])
|
|
if enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
|
|
):
|
|
register_ub_patterns(cls._custom_pass_instances)
|
|
else:
|
|
register_add_norm(cls._custom_pass_instances[0])
|
|
return cls._custom_pass_instances
|
|
|
|
def optimize(
|
|
self,
|
|
gm: Union[GraphModule | Graph],
|
|
example_inputs: Optional[List[torch.Tensor]] = None,
|
|
):
|
|
graph = gm.graph if isinstance(gm, GraphModule) else gm
|
|
for custom_pass in self.custom_passes:
|
|
self.match_count.append(custom_pass.apply(graph))
|
|
while self.match_count[-1]:
|
|
self.match_count.append(custom_pass.apply(graph))
|
|
graph.eliminate_dead_code()
|
|
if isinstance(gm, GraphModule):
|
|
gm.recompile()
|
|
|
|
return gm
|
|
|
|
def __call__(self, gm: GraphModule,
|
|
example_inputs: List[torch.Tensor]) -> callable:
|
|
|
|
gm = recover_pass(gm)
|
|
|
|
if self.enable_inductor:
|
|
return compile_fx(gm,
|
|
example_inputs,
|
|
config_patches=self.inductor_config)
|
|
else:
|
|
return aot_module_simplified(gm,
|
|
example_inputs,
|
|
fw_compiler=self.optimize)
|