From 3e0fb60e5007c4d6855c0e86c51df7c579728277 Mon Sep 17 00:00:00 2001 From: liji-nv <59594262+liji-nv@users.noreply.github.com> Date: Mon, 21 Jul 2025 19:10:22 +0800 Subject: [PATCH] [TRTLLM-4279] feat: Multistream initial support for torch compile flow (#5847) Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- tensorrt_llm/_torch/compilation/backend.py | 34 +- .../compilation/multi_stream/__init__.py | 0 .../multi_stream/auto_multi_stream.py | 456 ++++++++++++++++++ .../_torch/compilation/piecewise_optimizer.py | 28 +- .../_torch/compilation/remove_copy_pass.py | 21 +- tensorrt_llm/_torch/compilation/utils.py | 17 + .../_torch/custom_ops/cpp_custom_ops.py | 54 +-- .../_torch/custom_ops/torch_custom_ops.py | 42 ++ .../custom_ops/trtllm_gen_custom_ops.py | 134 ++++- .../modules/fused_moe/fused_moe_trtllm_gen.py | 17 +- tensorrt_llm/_torch/pyexecutor/config.py | 1 + .../_torch/pyexecutor/model_engine.py | 12 +- tensorrt_llm/_torch/utils.py | 12 +- tensorrt_llm/llmapi/llm_args.py | 17 + .../defs/accuracy/test_llm_api_pytorch.py | 31 +- tests/unittest/_torch/thop/test_moe.py | 5 +- 16 files changed, 764 insertions(+), 117 deletions(-) create mode 100644 tensorrt_llm/_torch/compilation/multi_stream/__init__.py create mode 100644 tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py diff --git a/tensorrt_llm/_torch/compilation/backend.py b/tensorrt_llm/_torch/compilation/backend.py index 1e06d553dc..ec76ea5238 100644 --- a/tensorrt_llm/_torch/compilation/backend.py +++ b/tensorrt_llm/_torch/compilation/backend.py @@ -12,6 +12,7 @@ from torch.fx import GraphModule import tensorrt_llm from tensorrt_llm import logger +from .multi_stream.auto_multi_stream import multi_stream_schedule from .patterns.ar_residual_norm import register_ar_residual_norm from .patterns.residual_add_norm import register_add_norm from .patterns.ub_allreduce import register_ub_patterns @@ -25,12 +26,20 @@ class Backend: _custom_pass_instances: List[PatternMatcherPass] = None _graph_pool_handle: tuple[int, int] = None + # Following classes are used to let weakref ref the stream and eventlist objects. + class Streams(list): + pass + + class Events(list): + pass + def __init__( self, enable_inductor=True, enable_userbuffers=False, enable_piecewise_cuda_graph: bool = False, cuda_graph_batch_sizes: Optional[List[int]] = None, + max_num_streams: int = 1, ) -> None: super().__init__() self.elapsed_time = 0 @@ -45,6 +54,10 @@ class Backend: else []) self.piecewise_cuda_graph = enable_piecewise_cuda_graph self.no_optimization = False + # We only need to create aux streams. + self.aux_streams = Backend.Streams( + [torch.cuda.Stream() for i in range(max_num_streams - 1)]) + self.events = Backend.Events() inductor_config.enable_auto_functionalized_v2 = False if Backend._graph_pool_handle is None: @@ -77,6 +90,12 @@ class Backend: def enable_optimization(self): self.no_optimization = False + def generate_events(self, num_events: int): + if num_events > len(self.events): + self.events += [ + torch.cuda.Event() for _ in range(num_events - len(self.events)) + ] + def optimize( self, gm: GraphModule, @@ -90,17 +109,30 @@ class Backend: graph.eliminate_dead_code() # After this pass, cannot run any dce!!! remove_copy_for_mutates_args(graph) + + # Do not apply multi-stream if enable piecewise cuda graph or inductor + # For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer + # For inductor, we do not control the passes inside inductor. + if len( + self.aux_streams + ) > 0 and not self.piecewise_cuda_graph and not self.enable_inductor: + num_events = multi_stream_schedule(gm, len(self.aux_streams) + 1) + self.generate_events(num_events) + gm.recompile() if self.piecewise_cuda_graph: - return piecewise_optimizer( + gm, num_events = piecewise_optimizer( gm, example_inputs, self.enable_inductor, self.input_num_tokens, self.cuda_graph_batch_sizes, self._graph_pool_handle, + len(self.aux_streams) + 1, ) + self.generate_events(num_events) + return gm elif self.enable_inductor: return compile_fx(gm, example_inputs) else: diff --git a/tensorrt_llm/_torch/compilation/multi_stream/__init__.py b/tensorrt_llm/_torch/compilation/multi_stream/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py b/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py new file mode 100644 index 0000000000..c2d3cf012a --- /dev/null +++ b/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py @@ -0,0 +1,456 @@ +import time +from dataclasses import dataclass, field +from operator import getitem +from queue import PriorityQueue +from typing import Dict, List + +import torch +from torch.fx import Graph, GraphModule, Node + +from tensorrt_llm.logger import logger + +from ..utils import inplace_info + + +def is_symint_node(node: Node) -> bool: + if node is not None and 'val' in node.meta: + # This is a symint call that happens on host. No need to count time on stream. + if isinstance(node.meta['val'], torch.SymInt): + return True + return False + + +def estimate_time(node: Node) -> int: + if node is None: + return 0 + if is_symint_node(node): + # This is a symint call that happens on host. No need to count time on stream. + return 0 + + # Add cost model for ops that need special handling. + # We can start with rough estimation and refine it later. + + no_cost_ops = { + getitem, torch.ops.aten.view.default, torch.ops.aten.view.dtype, + torch.ops.aten.alias.default, torch.ops.aten.empty.memory_format, + torch.ops.aten.permute.default + } + + moe_ops = { + torch.ops.trtllm.fp4_block_scale_moe_runner.default, + torch.ops.trtllm.fused_moe.default, + } + + gemm_ops = { + torch.ops.aten.mm.default, + torch.ops.trtllm.nvfp4_gemm.default, + torch.ops.trtllm.fp8_batched_gemm_trtllmgen.default, + torch.ops.trtllm.w4a8_mxfp4_fp8_gemm.default, + torch.ops.trtllm.finegrained_mixed_dtype_gemm.default, + torch.ops.trtllm.bmm_out.default, + torch.ops.trtllm.cublas_scaled_mm.default, + torch.ops.trtllm.cublas_mm.default, + torch.ops.trtllm.dsv3_router_gemm_op.default, + torch.ops.trtllm.dsv3_fused_a_gemm_op.default, + torch.ops.trtllm.fp4_gemm.default, + torch.ops.trtllm.fp4_bmm.default, + torch.ops.trtllm.fp8_block_scaling_gemm.default, + torch.ops.trtllm.matmul_to_ub.default, + } + + # These ops are not counted in the time estimation. + if node.op == "call_function" and node.target in no_cost_ops: + return 0 + + # Add estimation below. With accurate estimation, the stream assignment + # can give the best performance. But it is hard to get accurate estimation. + # + # So currently, these estimations are not accurate. They just make sure the key path + # is correctly scheduled. Adjust the estimation or add new ones + # if the stream assignment is not desired. + + MOE_OP_COST = 20 + GEMM_OP_COST = 10 + DEFAULT_OP_COST = 1 + + # Adjust MOE weight to make the router -> MOE key path + if node.op == "call_function" and node.target in moe_ops: + return MOE_OP_COST + + # GEMM ops + if node.op == "call_function" and node.target in gemm_ops: + return GEMM_OP_COST + + # Refine the estimation of time for nodes. + return DEFAULT_OP_COST + + +@dataclass +class Stream: + # Stream id + id: int + + # Nodes running on the stream + nodes: List['MultiStreamNode'] = field(init=False, default_factory=list) + + # Current elapsed time of the stream + current_time: int = field(init=False, default=0) + + +class MultiStreamNode: + + def __init__(self, node: Node, in_edges: Dict[Node, 'MultiStreamNode']): + # The node in the original graph + self.node = node + + # The distance to the exit of DAG + self.distance = 0 + + # Weight for the node which represents the computation cost + self.weight = estimate_time(node) + + # The in edges of the node + self.in_edges = in_edges + + # The out edges of the node + self.out_edges = [] + + # end time of the node + self.end_time = 0 + + # Assigned stream for the node + self.stream = None + + # wait on events + self.wait_on = [] + + # trigger event + self.event = None + + +class MultiStreamDAG: + + def __init__(self, gm: GraphModule): + self.gm = gm + self.node_to_id = {} + self.node_in_degrees = {} + self.output_nodes = [] + self.placeholders = [] + self.nodes = {} + self.in_degrees = {} + self.work_list = [] + self.entry_node = None + self.exit_node = None + + self.create_dag_from_gm(gm) + assert self.entry_node is not None + assert self.exit_node is not None + + def create_dag_from_gm(self, gm: GraphModule) -> None: + """ + Create a DAG from the graph module. + """ + # Create node to id mapping + for node in gm.graph.nodes: + self.node_to_id[node] = len(self.node_to_id) + + # Fake entry node. + # All nodes without in edges will be connected to this node. + self.entry_node = MultiStreamNode(None, dict()) + + latest_inplace_stat = {} + inplace_map = inplace_info() + + def flatten_args(args): + """Recursively flatten nested arguments into a flat list.""" + args_new = [] + stack = list(args) + while stack: + arg = stack.pop() + if isinstance(arg, dict): + stack.extend(arg.values()) + elif isinstance(arg, (list, tuple)): + stack.extend(arg) + else: + args_new.append(arg) + return args_new + + # Pop all the placeholders from gm + # We know that the node is already in topological order + for node in gm.graph.nodes: + # We assume that all the placeholders are already synced with the base stream + if node.op == "placeholder": + self.placeholders.append(node) + continue + + args = flatten_args([a for a in node.args] + + [a for a in node.kwargs.values()]) + + in_edges = dict() + for arg in args: + if arg in latest_inplace_stat: + in_edges[arg] = latest_inplace_stat[arg] + elif isinstance(arg, torch.fx.Node) and arg.op != "placeholder": + in_edges[arg] = self.nodes[arg] + + # For node without in edge, connect it to the entry + if len(in_edges) == 0: + in_edges[None] = self.entry_node + + vertex = MultiStreamNode(node, in_edges) + if node.op == "output": + self.exit_node = vertex + vertex.distance = 0 + self.nodes[node] = vertex + self.in_degrees[vertex] = len(in_edges) + if node.op == "call_function": + func = node.target + if func in inplace_map: + for inplace_arg in inplace_map[func].values(): + # At this stage, all inplace op must be using kwargs for all params + assert inplace_arg in node.kwargs + latest_inplace_stat[node.kwargs[inplace_arg]] = vertex + + for edge in in_edges.values(): + edge.out_edges.append(vertex) + self.compute_distance() + + def compute_distance(self) -> None: + """ + Compute the distance to the exit node for each node. + """ + # Reverse topological sort to compute distance to exit node + work_list = [self.exit_node] + out_degrees = { + node: len(node.out_edges) + for node in self.nodes.values() + } + out_degrees[self.entry_node] = len(self.entry_node.out_edges) + + while len(work_list) > 0: + node = work_list.pop() + for in_edge in node.in_edges.values(): + out_degrees[in_edge] -= 1 + in_edge.distance = max(in_edge.distance, + node.weight + node.distance) + if out_degrees[in_edge] == 0: + work_list.append(in_edge) + + def assign_streams(self, max_num_streams: int) -> int: + """ + Assign streams to the nodes in the DAG. + Return the number of events created. + """ + worklist = PriorityQueue() + num_nodes = len(self.node_to_id) + + # When accessing node, the distance to the exit node is main priority. + # The node with largest distance means currently this is the bottleneck of the whole graph. + def calc_priority(node_id: int, distance: int) -> int: + # We keep the node order by default. + # It also gives deterministic order for priority queue. + return (-distance) * num_nodes + node_id + + streams = [Stream(i) for i in range(max_num_streams)] + + def pick_stream(start_time, node) -> Stream: + if node.weight == 0: + # This is a symint node or a getitem node. + # It always assigns to the stream that produce the node. + for n in node.in_edges.values(): + if is_symint_node(n.node): + continue + return n.stream + return streams[0] + + closest_stream = None + least_time = float('inf') + for st in streams: + if st.current_time <= start_time: + return st + else: + if st.current_time < least_time: + least_time = st.current_time + closest_stream = st + return closest_stream + + # We just start from the out_edges of the entry node. Entry node is just a fake node + # For entry, we assign to the primary stream. + self.entry_node.stream = streams[0] + streams[0].nodes.append(self.entry_node) + for out_edge in self.entry_node.out_edges: + worklist.put((calc_priority(self.node_to_id[out_edge.node], + out_edge.distance), out_edge)) + + sync_event_id = 0 + + while not worklist.empty(): + _, node = worklist.get() + assert node.stream is None + + # Get when current node can start. + # Start time is the max of the end time of all the in edges. + start_time = max( + [in_edge.end_time for in_edge in node.in_edges.values()]) + node.stream = pick_stream(start_time, node) + node.end_time = max(start_time, + node.stream.current_time) + node.weight + node.stream.current_time = node.end_time + node.stream.nodes.append(node) + + for in_edge_tensor, in_edge in node.in_edges.items(): + if in_edge.stream != node.stream and not is_symint_node( + in_edge.node): + if in_edge.event is None: + in_edge.event = sync_event_id + sync_event_id += 1 + node.wait_on.append((in_edge, in_edge_tensor)) + + # Now, for any in edge running on different stream, we need to create a sync event. + for out_edge in node.out_edges: + self.in_degrees[out_edge] -= 1 + if self.in_degrees[out_edge] == 0: + worklist.put((calc_priority(self.node_to_id[out_edge.node], + out_edge.distance), out_edge)) + self.streams = streams + return sync_event_id + + def create_new_graph(self) -> Graph: + """ + Create new graph with the nodes assigned to the streams. + """ + # Now each node should have been assigned a stream. We will now create a new graph and insert all nodes + # As torch need to create node for switching stream, need to group nodes as much as possible. + remap = {} + new_graph = Graph() + + for st in self.streams: + logger.debug(f"{len(st.nodes)} nodes running on stream {st.id}") + + # First, push all placeholders to the new graph. + for placeholder in self.placeholders: + remap[placeholder] = new_graph.node_copy(placeholder, + lambda n: remap[n]) + + # Then, we will push all the nodes into the new graph. + # Build in_degrees again as we need to check whether a stream is ready to run. + self.in_degrees = { + node: len(node.in_edges) + for node in self.nodes.values() + } + self.in_degrees[self.entry_node] = 0 + + stream_pos = [0] * len(self.streams) + + def has_more_nodes() -> bool: + for st in self.streams: + if len(st.nodes) > stream_pos[st.id]: + return True + return False + + last_stream = 0 + + # The nodes in stream are already in topological order. + while has_more_nodes(): + for st in self.streams: + if len(st.nodes) == stream_pos[st.id]: + continue + node = st.nodes[stream_pos[st.id]] + if self.in_degrees[node] != 0: + # This stream is not ready to run now. + continue + + # Any time the stream is changed, set the stream. + if node.stream.id != last_stream: + # Change stream + new_graph.create_node("call_function", + torch.ops.trtllm.set_stream, + args=(node.stream.id, )) + last_stream = node.stream.id + + for _ in range(stream_pos[st.id], len(st.nodes)): + node = st.nodes[stream_pos[st.id]] + if self.in_degrees[node] != 0: + break + for out_edge in node.out_edges: + self.in_degrees[out_edge] -= 1 + stream_pos[st.id] += 1 + # It could be the fake entry node. + if node.node is not None: + # Wait on all the events that the node is waiting on. + for wait in node.wait_on: + new_graph.create_node("call_function", + torch.ops.trtllm.wait_event, + args=(wait[0].event, )) + remap[node.node] = new_graph.node_copy( + node.node, lambda n: remap[n]) + for wait in node.wait_on: + # wait[1] is the actual tensor that the op is waiting on. + # Need to record stream for that tensor. + if wait[1] is None: + continue + new_graph.create_node( + "call_function", + torch.ops.trtllm.record_stream, + args=(remap[wait[1]], st.id)) + if node.event is not None: + new_graph.create_node("call_function", + torch.ops.trtllm.record_event, + args=(node.event, )) + + # After each handling, start again to make sure primary stream is pushed first. + break + return new_graph + + def optimize(self, max_num_streams: int) -> int: + """ + Run multistream optimize for MultiStreamDAG. The graph module that used to create the DAG will be updated. + Return the number of events created. + """ + num_events = self.assign_streams(max_num_streams) + new_graph = self.create_new_graph() + self.gm.graph = new_graph + return num_events + + +def multi_stream_schedule(gm: GraphModule, max_num_streams: int) -> int: + """ + Schedule the graph module for multi stream execution. + gm is the graph module to be scheduled. The gm will be updated by this function. + max_num_streams is the maximum number of streams to be used. The scheduler may not use all the streams. + Return the number of events created. + """ + dag = MultiStreamDAG(gm) + return dag.optimize(max_num_streams) + + +# Following code is for debug purpose. Use print_dag_to_dot to print a MultiStreamDAG to dot file. + + +def dump_dag_as_dot(dag: MultiStreamDAG, max_num_nodes: int = 500) -> None: + COLORS = [ + "red", "chocolate", "cyan", "gold", "coral", "green", "blue", "orange", + "purple", "brown" + ] + filename = f"dag_{int(time.time())}.dot" + with open(filename, 'w') as f: + f.write("digraph G {\n") + f.write( + f"id_entry [label=\"node=entry, distance={dag.entry_node.distance}\"]\n" + ) + cnt = 0 + for node in dag.nodes.values(): + color = "white" if node.stream is None else COLORS[node.stream.id] + f.write( + f"id_{dag.node_to_id[node.node]} [label=\"node={node.node}, " + f"distance={node.distance}, weight={node.weight}\", " + f"color={color}, shape=oval]\n") + for in_edge in node.in_edges.values(): + id = str(dag.node_to_id[ + in_edge.node]) if in_edge.node is not None else "entry" + f.write(f"id_{id} -> id_{dag.node_to_id[node.node]}\n") + if cnt > max_num_nodes: + break + cnt += 1 + f.write("}\n") + f.flush() diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index 75a9aeff8e..8e60b6bd36 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -12,7 +12,9 @@ from torch.fx.passes.split_module import split_module from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.logger import logger -from ..utils import get_piecewise_cuda_graph_flag, make_weak_ref +from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag, + make_weak_ref) +from .multi_stream.auto_multi_stream import multi_stream_schedule from .utils import (get_enable_piecewise_cuda_graph_capture_flag, is_call_function) @@ -29,6 +31,7 @@ class PiecewiseInterpreter(Interpreter): graph_pool_handle: tuple[int, int], garbage_collect_values: bool = True, graph=None, + max_num_streams: int = 1, ): super().__init__(module, garbage_collect_values, graph) @@ -39,6 +42,8 @@ class PiecewiseInterpreter(Interpreter): self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id] self.graph_pool_handle = graph_pool_handle self.enable_inductor = enable_inductor + self.num_events = 0 + self.max_num_streams = max_num_streams def run(self, *args): fake_args = [ @@ -72,6 +77,11 @@ class PiecewiseInterpreter(Interpreter): found_dynamic_shape = True break + if self.max_num_streams > 1 and not self.enable_inductor: + num_events = multi_stream_schedule(submod, self.max_num_streams) + self.num_events = max(self.num_events, num_events) + submod.recompile() + self.module.__dict__[target] = PiecewiseRunner( submod, target, @@ -179,8 +189,12 @@ class PiecewiseRunner(object): with patch("gc.collect", lambda: None): # TODO: consider to use `make_graphed_callables()` when # it's ready rather than capture it ourselves + # Graph Capture would override the stream. We need to setup the stream correctly. + extra_attrs = get_model_extra_attrs() with torch.cuda.graph(graph, pool=self.graph_pool_handle): + extra_attrs["global_stream"] = torch.cuda.current_stream() output = entry.callable(*args) + extra_attrs["global_stream"] = torch.cuda.current_stream() entry.cuda_graph = graph # Mark weak ref here. The intermediate activation tensor should be freed properly. @@ -218,7 +232,8 @@ def piecewise_optimizer( input_num_tokens: Union[int | torch.SymInt], cuda_graph_batch_sizes: Sequence[int], graph_pool_handle: tuple[int, int], -) -> GraphModule: + max_num_streams: int = 1, +) -> tuple[GraphModule, int]: graph_pool_handle = torch.cuda.graph_pool_handle() graph = gm.graph @@ -253,13 +268,16 @@ def piecewise_optimizer( lambda node: node_to_graph_id[node], keep_original_order=True) - PiecewiseInterpreter( + interpreter = PiecewiseInterpreter( gm, enable_inductor, input_num_tokens, cuda_graph_batch_sizes, exclude_modules_id, graph_pool_handle, - ).run(*example_inputs) + max_num_streams=max_num_streams, + ) - return gm + interpreter.run(*example_inputs) + + return gm, interpreter.num_events diff --git a/tensorrt_llm/_torch/compilation/remove_copy_pass.py b/tensorrt_llm/_torch/compilation/remove_copy_pass.py index fe968f020b..8e5eb7a811 100644 --- a/tensorrt_llm/_torch/compilation/remove_copy_pass.py +++ b/tensorrt_llm/_torch/compilation/remove_copy_pass.py @@ -5,7 +5,7 @@ from torch._higher_order_ops.auto_functionalize import (auto_functionalized, auto_functionalized_v2) from torch.fx import Graph, Node -from .utils import is_call_function +from .utils import inplace_info, is_call_function aten = torch.ops.aten @@ -46,19 +46,12 @@ def remove_copy_for_mutates_args(graph: Graph): inplace_func = node.args[0] - if inplace_func == torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: - remove_functionalize_inner( - node, - { - 1: "input", - 2: "residual" - }, - is_v2=node.target == auto_functionalized_v2, - ) - if inplace_func == torch.ops.trtllm.attention_inplace.default: - remove_functionalize_inner(node, {1: "output", 2: "output_sf"}) - if inplace_func == torch.ops.trtllm.mla_custom_op_inplace.default: - remove_functionalize_inner(node, {1: "output"}) + inplace_map = inplace_info() + if inplace_func not in inplace_map: + # We do not know the inplace op + continue + + remove_functionalize_inner(node, inplace_map[inplace_func]) for node in nodes_to_remove: graph.erase_node(node) diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index 6e900b9e3f..f00d689458 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -41,3 +41,20 @@ def set_enable_piecewise_cuda_graph_capture_flag(enable: bool): def get_enable_piecewise_cuda_graph_capture_flag() -> bool: global _enable_piecewise_cuda_graph_capture return _enable_piecewise_cuda_graph_capture + + +def inplace_info(): + inplace_map = { + torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: { + 1: "input", + 2: "residual" + }, + torch.ops.trtllm.attention_inplace.default: { + 1: "output", + 2: "output_sf" + }, + torch.ops.trtllm.mla_custom_op_inplace.default: { + 1: "output" + } + } + return inplace_map diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 35eb19acf5..31fa33d308 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -501,51 +501,6 @@ def _register_fake(): shape[0] = sizes[local_rank] return input.new_empty(shape) - @torch.library.register_fake("trtllm::fp4_block_scale_moe_runner") - def _( - routing_logits, - routing_bias, - hidden_states, - hidden_states_scale, - gemm1_weights, - gemm1_weights_scale, - gemm2_weights, - gemm2_weights_scale, - output1_scale_scalar, - output1_scale_gate_scalar, - output2_scale_scalar, - num_experts, - top_k, - n_group, - topk_group, - intermediate_size, - local_expert_offset, - local_num_experts, - routed_scaling_factor, - tile_tokens_dim, - routing_method_type, - do_finalize, - ) -> List[torch.Tensor]: - num_tokens = hidden_states.shape[0] - hidden_size = hidden_states.shape[1] * 2 - if do_finalize: - return [ - hidden_states.new_empty((num_tokens, hidden_size), - dtype=torch.bfloat16) - ] - - expanded_row_count = num_tokens * top_k - max_padding_required = (tile_tokens_dim - 1) * num_experts - max_num_padded_tokens = fp4_utils.pad_up( - expanded_row_count + max_padding_required, tile_tokens_dim) - wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16 - return [ - hidden_states.new_empty((max_num_padded_tokens, hidden_size), - dtype=torch.bfloat16), - hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype), - hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32) - ] - @torch.library.register_fake("trtllm::nvfp4_block_scale_interleave") def _(sf: torch.Tensor): rows = sf.shape[-2] @@ -559,3 +514,12 @@ def _register_fake(): @torch.library.register_fake("trtllm::nvfp4_block_scale_interleave_reverse") def _(sf: torch.Tensor): return torch.empty_like(sf, dtype=torch.uint8) + + @torch.library.register_fake("trtllm::moe_finalize_allreduce") + def _(input, residual, norm_weight, expanded_idx_to_permuted_idx, + shared_expert_output, expert_scale_factor, workspace, rank, nranks, + eps) -> List[torch.Tensor]: + return [ + torch.empty_like(residual), + torch.empty_like(residual), + ] diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 873f15a3a3..c2ba7f077a 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1056,3 +1056,45 @@ def _( output_sf = torch.empty(()) # Create a placeholder, which is not used. return output_act, output_sf + + +def get_event(event_idx: int): + from ..utils import get_model_extra_attrs + extra_attrs = get_model_extra_attrs() + assert "events" in extra_attrs, "Missing Event Book" + return extra_attrs["events"]()[event_idx] + + +def get_stream(stream_id: int): + from ..utils import get_model_extra_attrs + extra_attrs = get_model_extra_attrs() + if stream_id == 0: + return extra_attrs["global_stream"] + assert "aux_streams" in extra_attrs, "Missing Aux Streams" + return extra_attrs["aux_streams"]()[stream_id - 1] + + +@torch.library.custom_op("trtllm::set_stream", mutates_args=()) +def set_stream(stream_id: int) -> None: + stream = get_stream(stream_id) + assert stream is not None + torch.cuda.set_stream(stream) + + +@torch.library.custom_op("trtllm::record_event", mutates_args=()) +def record_event(event_idx: int) -> None: + event = get_event(event_idx) + event.record() + + +@torch.library.custom_op("trtllm::wait_event", mutates_args=()) +def wait_event(event_idx: int) -> None: + event = get_event(event_idx) + event.wait() + + +@torch.library.custom_op("trtllm::record_stream", mutates_args=()) +def record_stream(tensor: torch.Tensor, stream_id: int) -> None: + stream = get_stream(stream_id) + assert stream is not None + tensor.record_stream(stream) diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index a8d3b7e7ce..622fa12c51 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -4,13 +4,28 @@ from typing import List, Optional, Tuple import torch -from tensorrt_llm._torch.utils import (get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2) +from tensorrt_llm._torch.utils import (fp4_utils, + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + next_positive_power_of_2) from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) +def calculate_tile_tokens_dim(num_tokens: int, num_experts: int, + top_k: int) -> int: + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = num_tokens * top_k // num_experts + + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + + return tile_tokens_dim + + @dataclass(frozen=True) class FP4BlockScaleMoEInputs: @@ -220,11 +235,14 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor, intermediate_size: int, local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int, routing_method_type: int, + routing_method_type: int, do_finalize: bool) -> List[torch.Tensor]: tuner = AutoTuner.get() + num_tokens = hidden_states.shape[0] + tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k) + kernel_runner = FP4BlockScaleMoERunner( num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, @@ -254,6 +272,53 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor, return kernel_runner(inputs, tactic=best_tactic) +@fp4_block_scale_moe_runner.register_fake +def _( + routing_logits, + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + output1_scale_scalar, + output1_scale_gate_scalar, + output2_scale_scalar, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + do_finalize, +) -> List[torch.Tensor]: + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] * 2 + if do_finalize: + return [ + hidden_states.new_empty((num_tokens, hidden_size), + dtype=torch.bfloat16) + ] + + tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k) + + expanded_row_count = num_tokens * top_k + max_padding_required = (tile_tokens_dim - 1) * num_experts + max_num_padded_tokens = fp4_utils.pad_up( + expanded_row_count + max_padding_required, tile_tokens_dim) + wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16 + return [ + hidden_states.new_empty((max_num_padded_tokens, hidden_size), + dtype=torch.bfloat16), + hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype), + hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32) + ] + + @dataclass(frozen=True) class FP8BlockScaleMoEInputs: @@ -420,23 +485,31 @@ class FP8BlockScaleMoERunner(TunableRunner): @torch.library.custom_op("trtllm::fp8_block_scale_moe_runner", mutates_args=()) -def fp8_block_scale_moe_runner(routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - hidden_states: torch.Tensor, - hidden_states_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - num_experts: int, top_k: int, n_group: int, - topk_group: int, intermediate_size: int, - local_expert_offset: int, local_num_experts: int, - routed_scaling_factor: float, - tile_tokens_dim: int, - routing_method_type: int) -> torch.Tensor: +def fp8_block_scale_moe_runner( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: float, + routing_method_type: int, +) -> torch.Tensor: tuner = AutoTuner.get() + num_tokens = hidden_states.shape[0] + tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k) + kernel_runner = FP8BlockScaleMoERunner(num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, @@ -463,3 +536,30 @@ def fp8_block_scale_moe_runner(routing_logits: torch.Tensor, ) return kernel_runner(inputs, tactic=best_tactic) + + +@fp8_block_scale_moe_runner.register_fake +def _( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: int, + topk_group: int, + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: float, + routing_method_type: int, +) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] * 2 + + return hidden_states.new_empty((num_tokens, hidden_size), + dtype=torch.bfloat16) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index b5f93ab250..94e082a667 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -4,7 +4,7 @@ import torch from ...distributed.ops import reducescatter from ...model_config import ModelConfig -from ...utils import Fp4QuantizedTensor, next_positive_power_of_2 +from ...utils import Fp4QuantizedTensor from .interface import MoE, MoEWeightLoadingMode from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEMethod) @@ -91,19 +91,6 @@ class TRTLLMGenFusedMoE(MoE): def _check_configs(self): assert self.has_deepseek_fp8_block_scales or self.has_nvfp4, "TRTLLMGenFusedMoE only supports fp8_block_scaling and nvfp4 dtypes." - def _get_tile_tokens_dim(self, x: torch.Tensor): - top_k = self.routing_method.top_k - # Number of tokens in the input tensor. - num_tokens = x.shape[0] - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // self.num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) - - return tile_tokens_dim - def _get_quant_method(self): if self.quant_config is not None: if self.quant_config.layer_quant_mode.has_fp8_block_scales(): @@ -204,7 +191,6 @@ class TRTLLMGenFusedMoE(MoE): slot_start, # local_expert_start; use ep_rank if stride!=1 self.expert_size_per_partition, # local_expert_size routed_scaling_factor, - self._get_tile_tokens_dim(x), self.routing_method.routing_method_type, ) elif self.has_nvfp4: @@ -240,7 +226,6 @@ class TRTLLMGenFusedMoE(MoE): slot_start, # local_expert_start; use ep_rank if stride!=1 self.expert_size_per_partition, # local_expert_size routed_scaling_factor, - self._get_tile_tokens_dim(x), self.routing_method.routing_method_type, do_finalize=do_finalize, ) diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 181f2b0bdc..483d220c2e 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -73,6 +73,7 @@ class PyTorchConfig: torch_compile_piecewise_cuda_graph: bool = False # When torch compile is enabled, userbuffers is enabled by default torch_compile_enable_userbuffers: bool = True + torch_compile_max_num_streams: int = 1 # Enable autotuner only when torch compile is enabled # TODO: after it can be work stable in warmup stage diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 98eb2e870d..1c8b418ff9 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -323,7 +323,9 @@ class PyTorchModelEngine(ModelEngine): enable_piecewise_cuda_graph=pytorch_backend_config. torch_compile_piecewise_cuda_graph, cuda_graph_batch_sizes=pytorch_backend_config. - cuda_graph_batch_sizes) + cuda_graph_batch_sizes, + max_num_streams=pytorch_backend_config. + torch_compile_max_num_streams) if isinstance(self.model, DecoderModelForCausalLM): self.model.model = torch.compile( self.model.model, @@ -2093,6 +2095,14 @@ class PyTorchModelEngine(ModelEngine): attrs["attention_metadata"] = weakref.ref(kwargs['attn_metadata']) attrs.update(self.model.model_config.extra_attrs) + if self._torch_compile_backend is not None: + # Register aux streams and events to model extra attrs. + # The streams and events are list which could be updated during compilation. + attrs["aux_streams"] = weakref.ref( + self._torch_compile_backend.aux_streams) + attrs["events"] = weakref.ref(self._torch_compile_backend.events) + attrs["global_stream"] = torch.cuda.current_stream() + if is_trace_enabled("TLLM_TRACE_MODEL_FORWARD"): return trace_func(self.model.forward)(**kwargs) else: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 59cbb214f8..5710dbdc6a 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -196,7 +196,17 @@ def next_positive_power_of_2(x: int) -> int: if x < 1: return 1 - return 1 << (x - 1).bit_length() + # Following code is equivalent to 1 << (x - 1).bit_length() + # But this impl does not contain bit_length() so can be used by torch compile. + # It can correctly handle 64bit number which should be enough for now. + n = x - 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + return n + 1 def last_positive_power_of_2(x: int) -> int: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index f8d525c6a0..1636476ccd 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1792,6 +1792,20 @@ class TorchCompileConfig(BaseModel): description= "When torch compile is enabled, userbuffers is enabled by default.") + max_num_streams: int = Field( + default=1, + description= + "The maximum number of CUDA streams to use for torch.compile.") + + @field_validator('max_num_streams') + @classmethod + def validate_torch_compile_max_num_streams(cls, v): + """Validate torch_compile_config.max_num_streams >= 1.""" + if v < 1: + raise ValueError( + "torch_compile_config.max_num_streams must be >= 1") + return v + class TorchLlmArgs(BaseLlmArgs): # Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs @@ -2116,6 +2130,9 @@ class TorchLlmArgs(BaseLlmArgs): torch_compile_enable_userbuffers=self.torch_compile_config. enable_userbuffers if self.torch_compile_config is not None else TorchCompileConfig.model_fields['enable_userbuffers'].default, + torch_compile_max_num_streams=self.torch_compile_config. + max_num_streams if self.torch_compile_config is not None else + TorchCompileConfig.model_fields['max_num_streams'].default, enable_autotuner=self.enable_autotuner, enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker, load_format=self.load_format, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 45c67a6311..f0461ac91c 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -661,7 +661,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, - enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None + enable_piecewise_cuda_graph=cuda_graph, + max_num_streams=3) if torch_compile else None pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -702,8 +703,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, - enable_piecewise_cuda_graph=cuda_graph - and not attention_dp) if torch_compile else None + enable_piecewise_cuda_graph=cuda_graph and not attention_dp, + max_num_streams=3) if torch_compile else None pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -742,7 +743,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, - enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None + enable_piecewise_cuda_graph=cuda_graph, + max_num_streams=3) if torch_compile else None pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -793,8 +795,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): pytest.skip("https://nvbugs/5252559") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = (TorchCompileConfig( - enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph) - if torch_compile else None) + enable_fullgraph=True, + enable_piecewise_cuda_graph=cuda_graph, + max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, use_cuda_graph=cuda_graph, @@ -896,8 +899,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, - enable_piecewise_cuda_graph=cuda_graph - and not attention_dp) if torch_compile else None + enable_piecewise_cuda_graph=cuda_graph and not attention_dp, + max_num_streams=3) if torch_compile else None pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -958,8 +961,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): pytest.skip("PP with torch.compile is not supported yet.") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = (TorchCompileConfig( - enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph) - if torch_compile else None) + enable_fullgraph=True, + enable_piecewise_cuda_graph=cuda_graph, + max_num_streams=3) if torch_compile else None) pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, use_cuda_graph=cuda_graph, @@ -1088,7 +1092,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, - enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None + enable_piecewise_cuda_graph=cuda_graph, + max_num_streams=3) if torch_compile else None pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, @@ -1141,8 +1146,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): # Picewise Cuda Graph cannot be enabled for nvfp4 attention dp. torch_compile_config = TorchCompileConfig( enable_fullgraph=True, - enable_piecewise_cuda_graph=cuda_graph - and not attention_dp) if torch_compile else None + enable_piecewise_cuda_graph=cuda_graph and not attention_dp, + max_num_streams=3) if torch_compile else None pytorch_config = dict( disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, diff --git a/tests/unittest/_torch/thop/test_moe.py b/tests/unittest/_torch/thop/test_moe.py index 953c8cd268..8f70ecebeb 100644 --- a/tests/unittest/_torch/thop/test_moe.py +++ b/tests/unittest/_torch/thop/test_moe.py @@ -621,7 +621,6 @@ class TestMoeFP8: padding = 8 routed_scaling = 2.5 routing_method_type = RoutingMethodType.DeepSeekV3 - tile_tokens_dim = 8 if num_tokens < 1024 else 32 assert top_k <= num_experts assert top_k <= 8 @@ -670,8 +669,7 @@ class TestMoeFP8: expert_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales, num_experts, top_k, n_groups, top_k_groups, intermediate_size, - 0, num_experts, routed_scaling, tile_tokens_dim, - routing_method_type) + 0, num_experts, routed_scaling, routing_method_type) output_dequant_actual = output.to(torch.float) # @@ -1033,7 +1031,6 @@ class TestMoeFp4: 0, num_experts, routed_scaling, - tile_tokens_dim, routing_method_type, do_finalize=True)