[TRTLLM-4279] feat: Multistream initial support for torch compile flow (#5847)

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
liji-nv 2025-07-21 19:10:22 +08:00 committed by GitHub
parent aea91b2541
commit 3e0fb60e50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 764 additions and 117 deletions

View File

@ -12,6 +12,7 @@ from torch.fx import GraphModule
import tensorrt_llm
from tensorrt_llm import logger
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns.ar_residual_norm import register_ar_residual_norm
from .patterns.residual_add_norm import register_add_norm
from .patterns.ub_allreduce import register_ub_patterns
@ -25,12 +26,20 @@ class Backend:
_custom_pass_instances: List[PatternMatcherPass] = None
_graph_pool_handle: tuple[int, int] = None
# Following classes are used to let weakref ref the stream and eventlist objects.
class Streams(list):
pass
class Events(list):
pass
def __init__(
self,
enable_inductor=True,
enable_userbuffers=False,
enable_piecewise_cuda_graph: bool = False,
cuda_graph_batch_sizes: Optional[List[int]] = None,
max_num_streams: int = 1,
) -> None:
super().__init__()
self.elapsed_time = 0
@ -45,6 +54,10 @@ class Backend:
else [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False
# We only need to create aux streams.
self.aux_streams = Backend.Streams(
[torch.cuda.Stream() for i in range(max_num_streams - 1)])
self.events = Backend.Events()
inductor_config.enable_auto_functionalized_v2 = False
if Backend._graph_pool_handle is None:
@ -77,6 +90,12 @@ class Backend:
def enable_optimization(self):
self.no_optimization = False
def generate_events(self, num_events: int):
if num_events > len(self.events):
self.events += [
torch.cuda.Event() for _ in range(num_events - len(self.events))
]
def optimize(
self,
gm: GraphModule,
@ -90,17 +109,30 @@ class Backend:
graph.eliminate_dead_code()
# After this pass, cannot run any dce!!!
remove_copy_for_mutates_args(graph)
# Do not apply multi-stream if enable piecewise cuda graph or inductor
# For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer
# For inductor, we do not control the passes inside inductor.
if len(
self.aux_streams
) > 0 and not self.piecewise_cuda_graph and not self.enable_inductor:
num_events = multi_stream_schedule(gm, len(self.aux_streams) + 1)
self.generate_events(num_events)
gm.recompile()
if self.piecewise_cuda_graph:
return piecewise_optimizer(
gm, num_events = piecewise_optimizer(
gm,
example_inputs,
self.enable_inductor,
self.input_num_tokens,
self.cuda_graph_batch_sizes,
self._graph_pool_handle,
len(self.aux_streams) + 1,
)
self.generate_events(num_events)
return gm
elif self.enable_inductor:
return compile_fx(gm, example_inputs)
else:

View File

@ -0,0 +1,456 @@
import time
from dataclasses import dataclass, field
from operator import getitem
from queue import PriorityQueue
from typing import Dict, List
import torch
from torch.fx import Graph, GraphModule, Node
from tensorrt_llm.logger import logger
from ..utils import inplace_info
def is_symint_node(node: Node) -> bool:
if node is not None and 'val' in node.meta:
# This is a symint call that happens on host. No need to count time on stream.
if isinstance(node.meta['val'], torch.SymInt):
return True
return False
def estimate_time(node: Node) -> int:
if node is None:
return 0
if is_symint_node(node):
# This is a symint call that happens on host. No need to count time on stream.
return 0
# Add cost model for ops that need special handling.
# We can start with rough estimation and refine it later.
no_cost_ops = {
getitem, torch.ops.aten.view.default, torch.ops.aten.view.dtype,
torch.ops.aten.alias.default, torch.ops.aten.empty.memory_format,
torch.ops.aten.permute.default
}
moe_ops = {
torch.ops.trtllm.fp4_block_scale_moe_runner.default,
torch.ops.trtllm.fused_moe.default,
}
gemm_ops = {
torch.ops.aten.mm.default,
torch.ops.trtllm.nvfp4_gemm.default,
torch.ops.trtllm.fp8_batched_gemm_trtllmgen.default,
torch.ops.trtllm.w4a8_mxfp4_fp8_gemm.default,
torch.ops.trtllm.finegrained_mixed_dtype_gemm.default,
torch.ops.trtllm.bmm_out.default,
torch.ops.trtllm.cublas_scaled_mm.default,
torch.ops.trtllm.cublas_mm.default,
torch.ops.trtllm.dsv3_router_gemm_op.default,
torch.ops.trtllm.dsv3_fused_a_gemm_op.default,
torch.ops.trtllm.fp4_gemm.default,
torch.ops.trtllm.fp4_bmm.default,
torch.ops.trtllm.fp8_block_scaling_gemm.default,
torch.ops.trtllm.matmul_to_ub.default,
}
# These ops are not counted in the time estimation.
if node.op == "call_function" and node.target in no_cost_ops:
return 0
# Add estimation below. With accurate estimation, the stream assignment
# can give the best performance. But it is hard to get accurate estimation.
#
# So currently, these estimations are not accurate. They just make sure the key path
# is correctly scheduled. Adjust the estimation or add new ones
# if the stream assignment is not desired.
MOE_OP_COST = 20
GEMM_OP_COST = 10
DEFAULT_OP_COST = 1
# Adjust MOE weight to make the router -> MOE key path
if node.op == "call_function" and node.target in moe_ops:
return MOE_OP_COST
# GEMM ops
if node.op == "call_function" and node.target in gemm_ops:
return GEMM_OP_COST
# Refine the estimation of time for nodes.
return DEFAULT_OP_COST
@dataclass
class Stream:
# Stream id
id: int
# Nodes running on the stream
nodes: List['MultiStreamNode'] = field(init=False, default_factory=list)
# Current elapsed time of the stream
current_time: int = field(init=False, default=0)
class MultiStreamNode:
def __init__(self, node: Node, in_edges: Dict[Node, 'MultiStreamNode']):
# The node in the original graph
self.node = node
# The distance to the exit of DAG
self.distance = 0
# Weight for the node which represents the computation cost
self.weight = estimate_time(node)
# The in edges of the node
self.in_edges = in_edges
# The out edges of the node
self.out_edges = []
# end time of the node
self.end_time = 0
# Assigned stream for the node
self.stream = None
# wait on events
self.wait_on = []
# trigger event
self.event = None
class MultiStreamDAG:
def __init__(self, gm: GraphModule):
self.gm = gm
self.node_to_id = {}
self.node_in_degrees = {}
self.output_nodes = []
self.placeholders = []
self.nodes = {}
self.in_degrees = {}
self.work_list = []
self.entry_node = None
self.exit_node = None
self.create_dag_from_gm(gm)
assert self.entry_node is not None
assert self.exit_node is not None
def create_dag_from_gm(self, gm: GraphModule) -> None:
"""
Create a DAG from the graph module.
"""
# Create node to id mapping
for node in gm.graph.nodes:
self.node_to_id[node] = len(self.node_to_id)
# Fake entry node.
# All nodes without in edges will be connected to this node.
self.entry_node = MultiStreamNode(None, dict())
latest_inplace_stat = {}
inplace_map = inplace_info()
def flatten_args(args):
"""Recursively flatten nested arguments into a flat list."""
args_new = []
stack = list(args)
while stack:
arg = stack.pop()
if isinstance(arg, dict):
stack.extend(arg.values())
elif isinstance(arg, (list, tuple)):
stack.extend(arg)
else:
args_new.append(arg)
return args_new
# Pop all the placeholders from gm
# We know that the node is already in topological order
for node in gm.graph.nodes:
# We assume that all the placeholders are already synced with the base stream
if node.op == "placeholder":
self.placeholders.append(node)
continue
args = flatten_args([a for a in node.args] +
[a for a in node.kwargs.values()])
in_edges = dict()
for arg in args:
if arg in latest_inplace_stat:
in_edges[arg] = latest_inplace_stat[arg]
elif isinstance(arg, torch.fx.Node) and arg.op != "placeholder":
in_edges[arg] = self.nodes[arg]
# For node without in edge, connect it to the entry
if len(in_edges) == 0:
in_edges[None] = self.entry_node
vertex = MultiStreamNode(node, in_edges)
if node.op == "output":
self.exit_node = vertex
vertex.distance = 0
self.nodes[node] = vertex
self.in_degrees[vertex] = len(in_edges)
if node.op == "call_function":
func = node.target
if func in inplace_map:
for inplace_arg in inplace_map[func].values():
# At this stage, all inplace op must be using kwargs for all params
assert inplace_arg in node.kwargs
latest_inplace_stat[node.kwargs[inplace_arg]] = vertex
for edge in in_edges.values():
edge.out_edges.append(vertex)
self.compute_distance()
def compute_distance(self) -> None:
"""
Compute the distance to the exit node for each node.
"""
# Reverse topological sort to compute distance to exit node
work_list = [self.exit_node]
out_degrees = {
node: len(node.out_edges)
for node in self.nodes.values()
}
out_degrees[self.entry_node] = len(self.entry_node.out_edges)
while len(work_list) > 0:
node = work_list.pop()
for in_edge in node.in_edges.values():
out_degrees[in_edge] -= 1
in_edge.distance = max(in_edge.distance,
node.weight + node.distance)
if out_degrees[in_edge] == 0:
work_list.append(in_edge)
def assign_streams(self, max_num_streams: int) -> int:
"""
Assign streams to the nodes in the DAG.
Return the number of events created.
"""
worklist = PriorityQueue()
num_nodes = len(self.node_to_id)
# When accessing node, the distance to the exit node is main priority.
# The node with largest distance means currently this is the bottleneck of the whole graph.
def calc_priority(node_id: int, distance: int) -> int:
# We keep the node order by default.
# It also gives deterministic order for priority queue.
return (-distance) * num_nodes + node_id
streams = [Stream(i) for i in range(max_num_streams)]
def pick_stream(start_time, node) -> Stream:
if node.weight == 0:
# This is a symint node or a getitem node.
# It always assigns to the stream that produce the node.
for n in node.in_edges.values():
if is_symint_node(n.node):
continue
return n.stream
return streams[0]
closest_stream = None
least_time = float('inf')
for st in streams:
if st.current_time <= start_time:
return st
else:
if st.current_time < least_time:
least_time = st.current_time
closest_stream = st
return closest_stream
# We just start from the out_edges of the entry node. Entry node is just a fake node
# For entry, we assign to the primary stream.
self.entry_node.stream = streams[0]
streams[0].nodes.append(self.entry_node)
for out_edge in self.entry_node.out_edges:
worklist.put((calc_priority(self.node_to_id[out_edge.node],
out_edge.distance), out_edge))
sync_event_id = 0
while not worklist.empty():
_, node = worklist.get()
assert node.stream is None
# Get when current node can start.
# Start time is the max of the end time of all the in edges.
start_time = max(
[in_edge.end_time for in_edge in node.in_edges.values()])
node.stream = pick_stream(start_time, node)
node.end_time = max(start_time,
node.stream.current_time) + node.weight
node.stream.current_time = node.end_time
node.stream.nodes.append(node)
for in_edge_tensor, in_edge in node.in_edges.items():
if in_edge.stream != node.stream and not is_symint_node(
in_edge.node):
if in_edge.event is None:
in_edge.event = sync_event_id
sync_event_id += 1
node.wait_on.append((in_edge, in_edge_tensor))
# Now, for any in edge running on different stream, we need to create a sync event.
for out_edge in node.out_edges:
self.in_degrees[out_edge] -= 1
if self.in_degrees[out_edge] == 0:
worklist.put((calc_priority(self.node_to_id[out_edge.node],
out_edge.distance), out_edge))
self.streams = streams
return sync_event_id
def create_new_graph(self) -> Graph:
"""
Create new graph with the nodes assigned to the streams.
"""
# Now each node should have been assigned a stream. We will now create a new graph and insert all nodes
# As torch need to create node for switching stream, need to group nodes as much as possible.
remap = {}
new_graph = Graph()
for st in self.streams:
logger.debug(f"{len(st.nodes)} nodes running on stream {st.id}")
# First, push all placeholders to the new graph.
for placeholder in self.placeholders:
remap[placeholder] = new_graph.node_copy(placeholder,
lambda n: remap[n])
# Then, we will push all the nodes into the new graph.
# Build in_degrees again as we need to check whether a stream is ready to run.
self.in_degrees = {
node: len(node.in_edges)
for node in self.nodes.values()
}
self.in_degrees[self.entry_node] = 0
stream_pos = [0] * len(self.streams)
def has_more_nodes() -> bool:
for st in self.streams:
if len(st.nodes) > stream_pos[st.id]:
return True
return False
last_stream = 0
# The nodes in stream are already in topological order.
while has_more_nodes():
for st in self.streams:
if len(st.nodes) == stream_pos[st.id]:
continue
node = st.nodes[stream_pos[st.id]]
if self.in_degrees[node] != 0:
# This stream is not ready to run now.
continue
# Any time the stream is changed, set the stream.
if node.stream.id != last_stream:
# Change stream
new_graph.create_node("call_function",
torch.ops.trtllm.set_stream,
args=(node.stream.id, ))
last_stream = node.stream.id
for _ in range(stream_pos[st.id], len(st.nodes)):
node = st.nodes[stream_pos[st.id]]
if self.in_degrees[node] != 0:
break
for out_edge in node.out_edges:
self.in_degrees[out_edge] -= 1
stream_pos[st.id] += 1
# It could be the fake entry node.
if node.node is not None:
# Wait on all the events that the node is waiting on.
for wait in node.wait_on:
new_graph.create_node("call_function",
torch.ops.trtllm.wait_event,
args=(wait[0].event, ))
remap[node.node] = new_graph.node_copy(
node.node, lambda n: remap[n])
for wait in node.wait_on:
# wait[1] is the actual tensor that the op is waiting on.
# Need to record stream for that tensor.
if wait[1] is None:
continue
new_graph.create_node(
"call_function",
torch.ops.trtllm.record_stream,
args=(remap[wait[1]], st.id))
if node.event is not None:
new_graph.create_node("call_function",
torch.ops.trtllm.record_event,
args=(node.event, ))
# After each handling, start again to make sure primary stream is pushed first.
break
return new_graph
def optimize(self, max_num_streams: int) -> int:
"""
Run multistream optimize for MultiStreamDAG. The graph module that used to create the DAG will be updated.
Return the number of events created.
"""
num_events = self.assign_streams(max_num_streams)
new_graph = self.create_new_graph()
self.gm.graph = new_graph
return num_events
def multi_stream_schedule(gm: GraphModule, max_num_streams: int) -> int:
"""
Schedule the graph module for multi stream execution.
gm is the graph module to be scheduled. The gm will be updated by this function.
max_num_streams is the maximum number of streams to be used. The scheduler may not use all the streams.
Return the number of events created.
"""
dag = MultiStreamDAG(gm)
return dag.optimize(max_num_streams)
# Following code is for debug purpose. Use print_dag_to_dot to print a MultiStreamDAG to dot file.
def dump_dag_as_dot(dag: MultiStreamDAG, max_num_nodes: int = 500) -> None:
COLORS = [
"red", "chocolate", "cyan", "gold", "coral", "green", "blue", "orange",
"purple", "brown"
]
filename = f"dag_{int(time.time())}.dot"
with open(filename, 'w') as f:
f.write("digraph G {\n")
f.write(
f"id_entry [label=\"node=entry, distance={dag.entry_node.distance}\"]\n"
)
cnt = 0
for node in dag.nodes.values():
color = "white" if node.stream is None else COLORS[node.stream.id]
f.write(
f"id_{dag.node_to_id[node.node]} [label=\"node={node.node}, "
f"distance={node.distance}, weight={node.weight}\", "
f"color={color}, shape=oval]\n")
for in_edge in node.in_edges.values():
id = str(dag.node_to_id[
in_edge.node]) if in_edge.node is not None else "entry"
f.write(f"id_{id} -> id_{dag.node_to_id[node.node]}\n")
if cnt > max_num_nodes:
break
cnt += 1
f.write("}\n")
f.flush()

View File

@ -12,7 +12,9 @@ from torch.fx.passes.split_module import split_module
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.logger import logger
from ..utils import get_piecewise_cuda_graph_flag, make_weak_ref
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
make_weak_ref)
from .multi_stream.auto_multi_stream import multi_stream_schedule
from .utils import (get_enable_piecewise_cuda_graph_capture_flag,
is_call_function)
@ -29,6 +31,7 @@ class PiecewiseInterpreter(Interpreter):
graph_pool_handle: tuple[int, int],
garbage_collect_values: bool = True,
graph=None,
max_num_streams: int = 1,
):
super().__init__(module, garbage_collect_values, graph)
@ -39,6 +42,8 @@ class PiecewiseInterpreter(Interpreter):
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
self.graph_pool_handle = graph_pool_handle
self.enable_inductor = enable_inductor
self.num_events = 0
self.max_num_streams = max_num_streams
def run(self, *args):
fake_args = [
@ -72,6 +77,11 @@ class PiecewiseInterpreter(Interpreter):
found_dynamic_shape = True
break
if self.max_num_streams > 1 and not self.enable_inductor:
num_events = multi_stream_schedule(submod, self.max_num_streams)
self.num_events = max(self.num_events, num_events)
submod.recompile()
self.module.__dict__[target] = PiecewiseRunner(
submod,
target,
@ -179,8 +189,12 @@ class PiecewiseRunner(object):
with patch("gc.collect", lambda: None):
# TODO: consider to use `make_graphed_callables()` when
# it's ready rather than capture it ourselves
# Graph Capture would override the stream. We need to setup the stream correctly.
extra_attrs = get_model_extra_attrs()
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
extra_attrs["global_stream"] = torch.cuda.current_stream()
output = entry.callable(*args)
extra_attrs["global_stream"] = torch.cuda.current_stream()
entry.cuda_graph = graph
# Mark weak ref here. The intermediate activation tensor should be freed properly.
@ -218,7 +232,8 @@ def piecewise_optimizer(
input_num_tokens: Union[int | torch.SymInt],
cuda_graph_batch_sizes: Sequence[int],
graph_pool_handle: tuple[int, int],
) -> GraphModule:
max_num_streams: int = 1,
) -> tuple[GraphModule, int]:
graph_pool_handle = torch.cuda.graph_pool_handle()
graph = gm.graph
@ -253,13 +268,16 @@ def piecewise_optimizer(
lambda node: node_to_graph_id[node],
keep_original_order=True)
PiecewiseInterpreter(
interpreter = PiecewiseInterpreter(
gm,
enable_inductor,
input_num_tokens,
cuda_graph_batch_sizes,
exclude_modules_id,
graph_pool_handle,
).run(*example_inputs)
max_num_streams=max_num_streams,
)
return gm
interpreter.run(*example_inputs)
return gm, interpreter.num_events

View File

@ -5,7 +5,7 @@ from torch._higher_order_ops.auto_functionalize import (auto_functionalized,
auto_functionalized_v2)
from torch.fx import Graph, Node
from .utils import is_call_function
from .utils import inplace_info, is_call_function
aten = torch.ops.aten
@ -46,19 +46,12 @@ def remove_copy_for_mutates_args(graph: Graph):
inplace_func = node.args[0]
if inplace_func == torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default:
remove_functionalize_inner(
node,
{
1: "input",
2: "residual"
},
is_v2=node.target == auto_functionalized_v2,
)
if inplace_func == torch.ops.trtllm.attention_inplace.default:
remove_functionalize_inner(node, {1: "output", 2: "output_sf"})
if inplace_func == torch.ops.trtllm.mla_custom_op_inplace.default:
remove_functionalize_inner(node, {1: "output"})
inplace_map = inplace_info()
if inplace_func not in inplace_map:
# We do not know the inplace op
continue
remove_functionalize_inner(node, inplace_map[inplace_func])
for node in nodes_to_remove:
graph.erase_node(node)

View File

@ -41,3 +41,20 @@ def set_enable_piecewise_cuda_graph_capture_flag(enable: bool):
def get_enable_piecewise_cuda_graph_capture_flag() -> bool:
global _enable_piecewise_cuda_graph_capture
return _enable_piecewise_cuda_graph_capture
def inplace_info():
inplace_map = {
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {
1: "input",
2: "residual"
},
torch.ops.trtllm.attention_inplace.default: {
1: "output",
2: "output_sf"
},
torch.ops.trtllm.mla_custom_op_inplace.default: {
1: "output"
}
}
return inplace_map

View File

@ -501,51 +501,6 @@ def _register_fake():
shape[0] = sizes[local_rank]
return input.new_empty(shape)
@torch.library.register_fake("trtllm::fp4_block_scale_moe_runner")
def _(
routing_logits,
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
tile_tokens_dim,
routing_method_type,
do_finalize,
) -> List[torch.Tensor]:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
if do_finalize:
return [
hidden_states.new_empty((num_tokens, hidden_size),
dtype=torch.bfloat16)
]
expanded_row_count = num_tokens * top_k
max_padding_required = (tile_tokens_dim - 1) * num_experts
max_num_padded_tokens = fp4_utils.pad_up(
expanded_row_count + max_padding_required, tile_tokens_dim)
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
return [
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
dtype=torch.bfloat16),
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
]
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave")
def _(sf: torch.Tensor):
rows = sf.shape[-2]
@ -559,3 +514,12 @@ def _register_fake():
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave_reverse")
def _(sf: torch.Tensor):
return torch.empty_like(sf, dtype=torch.uint8)
@torch.library.register_fake("trtllm::moe_finalize_allreduce")
def _(input, residual, norm_weight, expanded_idx_to_permuted_idx,
shared_expert_output, expert_scale_factor, workspace, rank, nranks,
eps) -> List[torch.Tensor]:
return [
torch.empty_like(residual),
torch.empty_like(residual),
]

View File

@ -1056,3 +1056,45 @@ def _(
output_sf = torch.empty(()) # Create a placeholder, which is not used.
return output_act, output_sf
def get_event(event_idx: int):
from ..utils import get_model_extra_attrs
extra_attrs = get_model_extra_attrs()
assert "events" in extra_attrs, "Missing Event Book"
return extra_attrs["events"]()[event_idx]
def get_stream(stream_id: int):
from ..utils import get_model_extra_attrs
extra_attrs = get_model_extra_attrs()
if stream_id == 0:
return extra_attrs["global_stream"]
assert "aux_streams" in extra_attrs, "Missing Aux Streams"
return extra_attrs["aux_streams"]()[stream_id - 1]
@torch.library.custom_op("trtllm::set_stream", mutates_args=())
def set_stream(stream_id: int) -> None:
stream = get_stream(stream_id)
assert stream is not None
torch.cuda.set_stream(stream)
@torch.library.custom_op("trtllm::record_event", mutates_args=())
def record_event(event_idx: int) -> None:
event = get_event(event_idx)
event.record()
@torch.library.custom_op("trtllm::wait_event", mutates_args=())
def wait_event(event_idx: int) -> None:
event = get_event(event_idx)
event.wait()
@torch.library.custom_op("trtllm::record_stream", mutates_args=())
def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
stream = get_stream(stream_id)
assert stream is not None
tensor.record_stream(stream)

View File

@ -4,13 +4,28 @@ from typing import List, Optional, Tuple
import torch
from tensorrt_llm._torch.utils import (get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
from tensorrt_llm._torch.utils import (fp4_utils,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
next_positive_power_of_2)
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
def calculate_tile_tokens_dim(num_tokens: int, num_experts: int,
top_k: int) -> int:
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = num_tokens * top_k // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
@dataclass(frozen=True)
class FP4BlockScaleMoEInputs:
@ -220,11 +235,14 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
intermediate_size: int, local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
tile_tokens_dim: int, routing_method_type: int,
routing_method_type: int,
do_finalize: bool) -> List[torch.Tensor]:
tuner = AutoTuner.get()
num_tokens = hidden_states.shape[0]
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
kernel_runner = FP4BlockScaleMoERunner(
num_experts, top_k, n_group, topk_group, intermediate_size,
local_expert_offset, local_num_experts, routed_scaling_factor,
@ -254,6 +272,53 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
return kernel_runner(inputs, tactic=best_tactic)
@fp4_block_scale_moe_runner.register_fake
def _(
routing_logits,
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
output1_scale_scalar,
output1_scale_gate_scalar,
output2_scale_scalar,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
routing_method_type,
do_finalize,
) -> List[torch.Tensor]:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
if do_finalize:
return [
hidden_states.new_empty((num_tokens, hidden_size),
dtype=torch.bfloat16)
]
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
expanded_row_count = num_tokens * top_k
max_padding_required = (tile_tokens_dim - 1) * num_experts
max_num_padded_tokens = fp4_utils.pad_up(
expanded_row_count + max_padding_required, tile_tokens_dim)
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
return [
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
dtype=torch.bfloat16),
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
]
@dataclass(frozen=True)
class FP8BlockScaleMoEInputs:
@ -420,23 +485,31 @@ class FP8BlockScaleMoERunner(TunableRunner):
@torch.library.custom_op("trtllm::fp8_block_scale_moe_runner", mutates_args=())
def fp8_block_scale_moe_runner(routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int, top_k: int, n_group: int,
topk_group: int, intermediate_size: int,
local_expert_offset: int, local_num_experts: int,
routed_scaling_factor: float,
tile_tokens_dim: int,
routing_method_type: int) -> torch.Tensor:
def fp8_block_scale_moe_runner(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routing_method_type: int,
) -> torch.Tensor:
tuner = AutoTuner.get()
num_tokens = hidden_states.shape[0]
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
kernel_runner = FP8BlockScaleMoERunner(num_experts, top_k, n_group,
topk_group, intermediate_size,
local_expert_offset,
@ -463,3 +536,30 @@ def fp8_block_scale_moe_runner(routing_logits: torch.Tensor,
)
return kernel_runner(inputs, tactic=best_tactic)
@fp8_block_scale_moe_runner.register_fake
def _(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor,
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: int,
topk_group: int,
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: float,
routing_method_type: int,
) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1] * 2
return hidden_states.new_empty((num_tokens, hidden_size),
dtype=torch.bfloat16)

View File

@ -4,7 +4,7 @@ import torch
from ...distributed.ops import reducescatter
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor, next_positive_power_of_2
from ...utils import Fp4QuantizedTensor
from .interface import MoE, MoEWeightLoadingMode
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
NVFP4TRTLLMGenFusedMoEMethod)
@ -91,19 +91,6 @@ class TRTLLMGenFusedMoE(MoE):
def _check_configs(self):
assert self.has_deepseek_fp8_block_scales or self.has_nvfp4, "TRTLLMGenFusedMoE only supports fp8_block_scaling and nvfp4 dtypes."
def _get_tile_tokens_dim(self, x: torch.Tensor):
top_k = self.routing_method.top_k
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def _get_quant_method(self):
if self.quant_config is not None:
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
@ -204,7 +191,6 @@ class TRTLLMGenFusedMoE(MoE):
slot_start, # local_expert_start; use ep_rank if stride!=1
self.expert_size_per_partition, # local_expert_size
routed_scaling_factor,
self._get_tile_tokens_dim(x),
self.routing_method.routing_method_type,
)
elif self.has_nvfp4:
@ -240,7 +226,6 @@ class TRTLLMGenFusedMoE(MoE):
slot_start, # local_expert_start; use ep_rank if stride!=1
self.expert_size_per_partition, # local_expert_size
routed_scaling_factor,
self._get_tile_tokens_dim(x),
self.routing_method.routing_method_type,
do_finalize=do_finalize,
)

View File

@ -73,6 +73,7 @@ class PyTorchConfig:
torch_compile_piecewise_cuda_graph: bool = False
# When torch compile is enabled, userbuffers is enabled by default
torch_compile_enable_userbuffers: bool = True
torch_compile_max_num_streams: int = 1
# Enable autotuner only when torch compile is enabled
# TODO: after it can be work stable in warmup stage

View File

@ -323,7 +323,9 @@ class PyTorchModelEngine(ModelEngine):
enable_piecewise_cuda_graph=pytorch_backend_config.
torch_compile_piecewise_cuda_graph,
cuda_graph_batch_sizes=pytorch_backend_config.
cuda_graph_batch_sizes)
cuda_graph_batch_sizes,
max_num_streams=pytorch_backend_config.
torch_compile_max_num_streams)
if isinstance(self.model, DecoderModelForCausalLM):
self.model.model = torch.compile(
self.model.model,
@ -2093,6 +2095,14 @@ class PyTorchModelEngine(ModelEngine):
attrs["attention_metadata"] = weakref.ref(kwargs['attn_metadata'])
attrs.update(self.model.model_config.extra_attrs)
if self._torch_compile_backend is not None:
# Register aux streams and events to model extra attrs.
# The streams and events are list which could be updated during compilation.
attrs["aux_streams"] = weakref.ref(
self._torch_compile_backend.aux_streams)
attrs["events"] = weakref.ref(self._torch_compile_backend.events)
attrs["global_stream"] = torch.cuda.current_stream()
if is_trace_enabled("TLLM_TRACE_MODEL_FORWARD"):
return trace_func(self.model.forward)(**kwargs)
else:

View File

@ -196,7 +196,17 @@ def next_positive_power_of_2(x: int) -> int:
if x < 1:
return 1
return 1 << (x - 1).bit_length()
# Following code is equivalent to 1 << (x - 1).bit_length()
# But this impl does not contain bit_length() so can be used by torch compile.
# It can correctly handle 64bit number which should be enough for now.
n = x - 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
return n + 1
def last_positive_power_of_2(x: int) -> int:

View File

@ -1792,6 +1792,20 @@ class TorchCompileConfig(BaseModel):
description=
"When torch compile is enabled, userbuffers is enabled by default.")
max_num_streams: int = Field(
default=1,
description=
"The maximum number of CUDA streams to use for torch.compile.")
@field_validator('max_num_streams')
@classmethod
def validate_torch_compile_max_num_streams(cls, v):
"""Validate torch_compile_config.max_num_streams >= 1."""
if v < 1:
raise ValueError(
"torch_compile_config.max_num_streams must be >= 1")
return v
class TorchLlmArgs(BaseLlmArgs):
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
@ -2116,6 +2130,9 @@ class TorchLlmArgs(BaseLlmArgs):
torch_compile_enable_userbuffers=self.torch_compile_config.
enable_userbuffers if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['enable_userbuffers'].default,
torch_compile_max_num_streams=self.torch_compile_config.
max_num_streams if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['max_num_streams'].default,
enable_autotuner=self.enable_autotuner,
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
load_format=self.load_format,

View File

@ -661,7 +661,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -702,8 +703,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph
and not attention_dp) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph and not attention_dp,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -742,7 +743,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -793,8 +795,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
pytest.skip("https://nvbugs/5252559")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = (TorchCompileConfig(
enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph)
if torch_compile else None)
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph,
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
@ -896,8 +899,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph
and not attention_dp) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph and not attention_dp,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -958,8 +961,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
pytest.skip("PP with torch.compile is not supported yet.")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = (TorchCompileConfig(
enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph)
if torch_compile else None)
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph,
max_num_streams=3) if torch_compile else None)
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
@ -1088,7 +1092,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
@ -1141,8 +1146,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=cuda_graph
and not attention_dp) if torch_compile else None
enable_piecewise_cuda_graph=cuda_graph and not attention_dp,
max_num_streams=3) if torch_compile else None
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,

View File

@ -621,7 +621,6 @@ class TestMoeFP8:
padding = 8
routed_scaling = 2.5
routing_method_type = RoutingMethodType.DeepSeekV3
tile_tokens_dim = 8 if num_tokens < 1024 else 32
assert top_k <= num_experts
assert top_k <= 8
@ -670,8 +669,7 @@ class TestMoeFP8:
expert_logits, routing_bias, hidden_states, hidden_states_scale,
gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales,
num_experts, top_k, n_groups, top_k_groups, intermediate_size,
0, num_experts, routed_scaling, tile_tokens_dim,
routing_method_type)
0, num_experts, routed_scaling, routing_method_type)
output_dequant_actual = output.to(torch.float)
#
@ -1033,7 +1031,6 @@ class TestMoeFp4:
0,
num_experts,
routed_scaling,
tile_tokens_dim,
routing_method_type,
do_finalize=True)