mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-4279] feat: Multistream initial support for torch compile flow (#5847)
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
This commit is contained in:
parent
aea91b2541
commit
3e0fb60e50
@ -12,6 +12,7 @@ from torch.fx import GraphModule
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import logger
|
||||
|
||||
from .multi_stream.auto_multi_stream import multi_stream_schedule
|
||||
from .patterns.ar_residual_norm import register_ar_residual_norm
|
||||
from .patterns.residual_add_norm import register_add_norm
|
||||
from .patterns.ub_allreduce import register_ub_patterns
|
||||
@ -25,12 +26,20 @@ class Backend:
|
||||
_custom_pass_instances: List[PatternMatcherPass] = None
|
||||
_graph_pool_handle: tuple[int, int] = None
|
||||
|
||||
# Following classes are used to let weakref ref the stream and eventlist objects.
|
||||
class Streams(list):
|
||||
pass
|
||||
|
||||
class Events(list):
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_inductor=True,
|
||||
enable_userbuffers=False,
|
||||
enable_piecewise_cuda_graph: bool = False,
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = None,
|
||||
max_num_streams: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.elapsed_time = 0
|
||||
@ -45,6 +54,10 @@ class Backend:
|
||||
else [])
|
||||
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
|
||||
self.no_optimization = False
|
||||
# We only need to create aux streams.
|
||||
self.aux_streams = Backend.Streams(
|
||||
[torch.cuda.Stream() for i in range(max_num_streams - 1)])
|
||||
self.events = Backend.Events()
|
||||
inductor_config.enable_auto_functionalized_v2 = False
|
||||
|
||||
if Backend._graph_pool_handle is None:
|
||||
@ -77,6 +90,12 @@ class Backend:
|
||||
def enable_optimization(self):
|
||||
self.no_optimization = False
|
||||
|
||||
def generate_events(self, num_events: int):
|
||||
if num_events > len(self.events):
|
||||
self.events += [
|
||||
torch.cuda.Event() for _ in range(num_events - len(self.events))
|
||||
]
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
@ -90,17 +109,30 @@ class Backend:
|
||||
graph.eliminate_dead_code()
|
||||
# After this pass, cannot run any dce!!!
|
||||
remove_copy_for_mutates_args(graph)
|
||||
|
||||
# Do not apply multi-stream if enable piecewise cuda graph or inductor
|
||||
# For piecewise cuda graph, we will apply the multi-stream optimization in piecewise_optimizer
|
||||
# For inductor, we do not control the passes inside inductor.
|
||||
if len(
|
||||
self.aux_streams
|
||||
) > 0 and not self.piecewise_cuda_graph and not self.enable_inductor:
|
||||
num_events = multi_stream_schedule(gm, len(self.aux_streams) + 1)
|
||||
self.generate_events(num_events)
|
||||
|
||||
gm.recompile()
|
||||
|
||||
if self.piecewise_cuda_graph:
|
||||
return piecewise_optimizer(
|
||||
gm, num_events = piecewise_optimizer(
|
||||
gm,
|
||||
example_inputs,
|
||||
self.enable_inductor,
|
||||
self.input_num_tokens,
|
||||
self.cuda_graph_batch_sizes,
|
||||
self._graph_pool_handle,
|
||||
len(self.aux_streams) + 1,
|
||||
)
|
||||
self.generate_events(num_events)
|
||||
return gm
|
||||
elif self.enable_inductor:
|
||||
return compile_fx(gm, example_inputs)
|
||||
else:
|
||||
|
||||
@ -0,0 +1,456 @@
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from operator import getitem
|
||||
from queue import PriorityQueue
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..utils import inplace_info
|
||||
|
||||
|
||||
def is_symint_node(node: Node) -> bool:
|
||||
if node is not None and 'val' in node.meta:
|
||||
# This is a symint call that happens on host. No need to count time on stream.
|
||||
if isinstance(node.meta['val'], torch.SymInt):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def estimate_time(node: Node) -> int:
|
||||
if node is None:
|
||||
return 0
|
||||
if is_symint_node(node):
|
||||
# This is a symint call that happens on host. No need to count time on stream.
|
||||
return 0
|
||||
|
||||
# Add cost model for ops that need special handling.
|
||||
# We can start with rough estimation and refine it later.
|
||||
|
||||
no_cost_ops = {
|
||||
getitem, torch.ops.aten.view.default, torch.ops.aten.view.dtype,
|
||||
torch.ops.aten.alias.default, torch.ops.aten.empty.memory_format,
|
||||
torch.ops.aten.permute.default
|
||||
}
|
||||
|
||||
moe_ops = {
|
||||
torch.ops.trtllm.fp4_block_scale_moe_runner.default,
|
||||
torch.ops.trtllm.fused_moe.default,
|
||||
}
|
||||
|
||||
gemm_ops = {
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.trtllm.nvfp4_gemm.default,
|
||||
torch.ops.trtllm.fp8_batched_gemm_trtllmgen.default,
|
||||
torch.ops.trtllm.w4a8_mxfp4_fp8_gemm.default,
|
||||
torch.ops.trtllm.finegrained_mixed_dtype_gemm.default,
|
||||
torch.ops.trtllm.bmm_out.default,
|
||||
torch.ops.trtllm.cublas_scaled_mm.default,
|
||||
torch.ops.trtllm.cublas_mm.default,
|
||||
torch.ops.trtllm.dsv3_router_gemm_op.default,
|
||||
torch.ops.trtllm.dsv3_fused_a_gemm_op.default,
|
||||
torch.ops.trtllm.fp4_gemm.default,
|
||||
torch.ops.trtllm.fp4_bmm.default,
|
||||
torch.ops.trtllm.fp8_block_scaling_gemm.default,
|
||||
torch.ops.trtllm.matmul_to_ub.default,
|
||||
}
|
||||
|
||||
# These ops are not counted in the time estimation.
|
||||
if node.op == "call_function" and node.target in no_cost_ops:
|
||||
return 0
|
||||
|
||||
# Add estimation below. With accurate estimation, the stream assignment
|
||||
# can give the best performance. But it is hard to get accurate estimation.
|
||||
#
|
||||
# So currently, these estimations are not accurate. They just make sure the key path
|
||||
# is correctly scheduled. Adjust the estimation or add new ones
|
||||
# if the stream assignment is not desired.
|
||||
|
||||
MOE_OP_COST = 20
|
||||
GEMM_OP_COST = 10
|
||||
DEFAULT_OP_COST = 1
|
||||
|
||||
# Adjust MOE weight to make the router -> MOE key path
|
||||
if node.op == "call_function" and node.target in moe_ops:
|
||||
return MOE_OP_COST
|
||||
|
||||
# GEMM ops
|
||||
if node.op == "call_function" and node.target in gemm_ops:
|
||||
return GEMM_OP_COST
|
||||
|
||||
# Refine the estimation of time for nodes.
|
||||
return DEFAULT_OP_COST
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stream:
|
||||
# Stream id
|
||||
id: int
|
||||
|
||||
# Nodes running on the stream
|
||||
nodes: List['MultiStreamNode'] = field(init=False, default_factory=list)
|
||||
|
||||
# Current elapsed time of the stream
|
||||
current_time: int = field(init=False, default=0)
|
||||
|
||||
|
||||
class MultiStreamNode:
|
||||
|
||||
def __init__(self, node: Node, in_edges: Dict[Node, 'MultiStreamNode']):
|
||||
# The node in the original graph
|
||||
self.node = node
|
||||
|
||||
# The distance to the exit of DAG
|
||||
self.distance = 0
|
||||
|
||||
# Weight for the node which represents the computation cost
|
||||
self.weight = estimate_time(node)
|
||||
|
||||
# The in edges of the node
|
||||
self.in_edges = in_edges
|
||||
|
||||
# The out edges of the node
|
||||
self.out_edges = []
|
||||
|
||||
# end time of the node
|
||||
self.end_time = 0
|
||||
|
||||
# Assigned stream for the node
|
||||
self.stream = None
|
||||
|
||||
# wait on events
|
||||
self.wait_on = []
|
||||
|
||||
# trigger event
|
||||
self.event = None
|
||||
|
||||
|
||||
class MultiStreamDAG:
|
||||
|
||||
def __init__(self, gm: GraphModule):
|
||||
self.gm = gm
|
||||
self.node_to_id = {}
|
||||
self.node_in_degrees = {}
|
||||
self.output_nodes = []
|
||||
self.placeholders = []
|
||||
self.nodes = {}
|
||||
self.in_degrees = {}
|
||||
self.work_list = []
|
||||
self.entry_node = None
|
||||
self.exit_node = None
|
||||
|
||||
self.create_dag_from_gm(gm)
|
||||
assert self.entry_node is not None
|
||||
assert self.exit_node is not None
|
||||
|
||||
def create_dag_from_gm(self, gm: GraphModule) -> None:
|
||||
"""
|
||||
Create a DAG from the graph module.
|
||||
"""
|
||||
# Create node to id mapping
|
||||
for node in gm.graph.nodes:
|
||||
self.node_to_id[node] = len(self.node_to_id)
|
||||
|
||||
# Fake entry node.
|
||||
# All nodes without in edges will be connected to this node.
|
||||
self.entry_node = MultiStreamNode(None, dict())
|
||||
|
||||
latest_inplace_stat = {}
|
||||
inplace_map = inplace_info()
|
||||
|
||||
def flatten_args(args):
|
||||
"""Recursively flatten nested arguments into a flat list."""
|
||||
args_new = []
|
||||
stack = list(args)
|
||||
while stack:
|
||||
arg = stack.pop()
|
||||
if isinstance(arg, dict):
|
||||
stack.extend(arg.values())
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
stack.extend(arg)
|
||||
else:
|
||||
args_new.append(arg)
|
||||
return args_new
|
||||
|
||||
# Pop all the placeholders from gm
|
||||
# We know that the node is already in topological order
|
||||
for node in gm.graph.nodes:
|
||||
# We assume that all the placeholders are already synced with the base stream
|
||||
if node.op == "placeholder":
|
||||
self.placeholders.append(node)
|
||||
continue
|
||||
|
||||
args = flatten_args([a for a in node.args] +
|
||||
[a for a in node.kwargs.values()])
|
||||
|
||||
in_edges = dict()
|
||||
for arg in args:
|
||||
if arg in latest_inplace_stat:
|
||||
in_edges[arg] = latest_inplace_stat[arg]
|
||||
elif isinstance(arg, torch.fx.Node) and arg.op != "placeholder":
|
||||
in_edges[arg] = self.nodes[arg]
|
||||
|
||||
# For node without in edge, connect it to the entry
|
||||
if len(in_edges) == 0:
|
||||
in_edges[None] = self.entry_node
|
||||
|
||||
vertex = MultiStreamNode(node, in_edges)
|
||||
if node.op == "output":
|
||||
self.exit_node = vertex
|
||||
vertex.distance = 0
|
||||
self.nodes[node] = vertex
|
||||
self.in_degrees[vertex] = len(in_edges)
|
||||
if node.op == "call_function":
|
||||
func = node.target
|
||||
if func in inplace_map:
|
||||
for inplace_arg in inplace_map[func].values():
|
||||
# At this stage, all inplace op must be using kwargs for all params
|
||||
assert inplace_arg in node.kwargs
|
||||
latest_inplace_stat[node.kwargs[inplace_arg]] = vertex
|
||||
|
||||
for edge in in_edges.values():
|
||||
edge.out_edges.append(vertex)
|
||||
self.compute_distance()
|
||||
|
||||
def compute_distance(self) -> None:
|
||||
"""
|
||||
Compute the distance to the exit node for each node.
|
||||
"""
|
||||
# Reverse topological sort to compute distance to exit node
|
||||
work_list = [self.exit_node]
|
||||
out_degrees = {
|
||||
node: len(node.out_edges)
|
||||
for node in self.nodes.values()
|
||||
}
|
||||
out_degrees[self.entry_node] = len(self.entry_node.out_edges)
|
||||
|
||||
while len(work_list) > 0:
|
||||
node = work_list.pop()
|
||||
for in_edge in node.in_edges.values():
|
||||
out_degrees[in_edge] -= 1
|
||||
in_edge.distance = max(in_edge.distance,
|
||||
node.weight + node.distance)
|
||||
if out_degrees[in_edge] == 0:
|
||||
work_list.append(in_edge)
|
||||
|
||||
def assign_streams(self, max_num_streams: int) -> int:
|
||||
"""
|
||||
Assign streams to the nodes in the DAG.
|
||||
Return the number of events created.
|
||||
"""
|
||||
worklist = PriorityQueue()
|
||||
num_nodes = len(self.node_to_id)
|
||||
|
||||
# When accessing node, the distance to the exit node is main priority.
|
||||
# The node with largest distance means currently this is the bottleneck of the whole graph.
|
||||
def calc_priority(node_id: int, distance: int) -> int:
|
||||
# We keep the node order by default.
|
||||
# It also gives deterministic order for priority queue.
|
||||
return (-distance) * num_nodes + node_id
|
||||
|
||||
streams = [Stream(i) for i in range(max_num_streams)]
|
||||
|
||||
def pick_stream(start_time, node) -> Stream:
|
||||
if node.weight == 0:
|
||||
# This is a symint node or a getitem node.
|
||||
# It always assigns to the stream that produce the node.
|
||||
for n in node.in_edges.values():
|
||||
if is_symint_node(n.node):
|
||||
continue
|
||||
return n.stream
|
||||
return streams[0]
|
||||
|
||||
closest_stream = None
|
||||
least_time = float('inf')
|
||||
for st in streams:
|
||||
if st.current_time <= start_time:
|
||||
return st
|
||||
else:
|
||||
if st.current_time < least_time:
|
||||
least_time = st.current_time
|
||||
closest_stream = st
|
||||
return closest_stream
|
||||
|
||||
# We just start from the out_edges of the entry node. Entry node is just a fake node
|
||||
# For entry, we assign to the primary stream.
|
||||
self.entry_node.stream = streams[0]
|
||||
streams[0].nodes.append(self.entry_node)
|
||||
for out_edge in self.entry_node.out_edges:
|
||||
worklist.put((calc_priority(self.node_to_id[out_edge.node],
|
||||
out_edge.distance), out_edge))
|
||||
|
||||
sync_event_id = 0
|
||||
|
||||
while not worklist.empty():
|
||||
_, node = worklist.get()
|
||||
assert node.stream is None
|
||||
|
||||
# Get when current node can start.
|
||||
# Start time is the max of the end time of all the in edges.
|
||||
start_time = max(
|
||||
[in_edge.end_time for in_edge in node.in_edges.values()])
|
||||
node.stream = pick_stream(start_time, node)
|
||||
node.end_time = max(start_time,
|
||||
node.stream.current_time) + node.weight
|
||||
node.stream.current_time = node.end_time
|
||||
node.stream.nodes.append(node)
|
||||
|
||||
for in_edge_tensor, in_edge in node.in_edges.items():
|
||||
if in_edge.stream != node.stream and not is_symint_node(
|
||||
in_edge.node):
|
||||
if in_edge.event is None:
|
||||
in_edge.event = sync_event_id
|
||||
sync_event_id += 1
|
||||
node.wait_on.append((in_edge, in_edge_tensor))
|
||||
|
||||
# Now, for any in edge running on different stream, we need to create a sync event.
|
||||
for out_edge in node.out_edges:
|
||||
self.in_degrees[out_edge] -= 1
|
||||
if self.in_degrees[out_edge] == 0:
|
||||
worklist.put((calc_priority(self.node_to_id[out_edge.node],
|
||||
out_edge.distance), out_edge))
|
||||
self.streams = streams
|
||||
return sync_event_id
|
||||
|
||||
def create_new_graph(self) -> Graph:
|
||||
"""
|
||||
Create new graph with the nodes assigned to the streams.
|
||||
"""
|
||||
# Now each node should have been assigned a stream. We will now create a new graph and insert all nodes
|
||||
# As torch need to create node for switching stream, need to group nodes as much as possible.
|
||||
remap = {}
|
||||
new_graph = Graph()
|
||||
|
||||
for st in self.streams:
|
||||
logger.debug(f"{len(st.nodes)} nodes running on stream {st.id}")
|
||||
|
||||
# First, push all placeholders to the new graph.
|
||||
for placeholder in self.placeholders:
|
||||
remap[placeholder] = new_graph.node_copy(placeholder,
|
||||
lambda n: remap[n])
|
||||
|
||||
# Then, we will push all the nodes into the new graph.
|
||||
# Build in_degrees again as we need to check whether a stream is ready to run.
|
||||
self.in_degrees = {
|
||||
node: len(node.in_edges)
|
||||
for node in self.nodes.values()
|
||||
}
|
||||
self.in_degrees[self.entry_node] = 0
|
||||
|
||||
stream_pos = [0] * len(self.streams)
|
||||
|
||||
def has_more_nodes() -> bool:
|
||||
for st in self.streams:
|
||||
if len(st.nodes) > stream_pos[st.id]:
|
||||
return True
|
||||
return False
|
||||
|
||||
last_stream = 0
|
||||
|
||||
# The nodes in stream are already in topological order.
|
||||
while has_more_nodes():
|
||||
for st in self.streams:
|
||||
if len(st.nodes) == stream_pos[st.id]:
|
||||
continue
|
||||
node = st.nodes[stream_pos[st.id]]
|
||||
if self.in_degrees[node] != 0:
|
||||
# This stream is not ready to run now.
|
||||
continue
|
||||
|
||||
# Any time the stream is changed, set the stream.
|
||||
if node.stream.id != last_stream:
|
||||
# Change stream
|
||||
new_graph.create_node("call_function",
|
||||
torch.ops.trtllm.set_stream,
|
||||
args=(node.stream.id, ))
|
||||
last_stream = node.stream.id
|
||||
|
||||
for _ in range(stream_pos[st.id], len(st.nodes)):
|
||||
node = st.nodes[stream_pos[st.id]]
|
||||
if self.in_degrees[node] != 0:
|
||||
break
|
||||
for out_edge in node.out_edges:
|
||||
self.in_degrees[out_edge] -= 1
|
||||
stream_pos[st.id] += 1
|
||||
# It could be the fake entry node.
|
||||
if node.node is not None:
|
||||
# Wait on all the events that the node is waiting on.
|
||||
for wait in node.wait_on:
|
||||
new_graph.create_node("call_function",
|
||||
torch.ops.trtllm.wait_event,
|
||||
args=(wait[0].event, ))
|
||||
remap[node.node] = new_graph.node_copy(
|
||||
node.node, lambda n: remap[n])
|
||||
for wait in node.wait_on:
|
||||
# wait[1] is the actual tensor that the op is waiting on.
|
||||
# Need to record stream for that tensor.
|
||||
if wait[1] is None:
|
||||
continue
|
||||
new_graph.create_node(
|
||||
"call_function",
|
||||
torch.ops.trtllm.record_stream,
|
||||
args=(remap[wait[1]], st.id))
|
||||
if node.event is not None:
|
||||
new_graph.create_node("call_function",
|
||||
torch.ops.trtllm.record_event,
|
||||
args=(node.event, ))
|
||||
|
||||
# After each handling, start again to make sure primary stream is pushed first.
|
||||
break
|
||||
return new_graph
|
||||
|
||||
def optimize(self, max_num_streams: int) -> int:
|
||||
"""
|
||||
Run multistream optimize for MultiStreamDAG. The graph module that used to create the DAG will be updated.
|
||||
Return the number of events created.
|
||||
"""
|
||||
num_events = self.assign_streams(max_num_streams)
|
||||
new_graph = self.create_new_graph()
|
||||
self.gm.graph = new_graph
|
||||
return num_events
|
||||
|
||||
|
||||
def multi_stream_schedule(gm: GraphModule, max_num_streams: int) -> int:
|
||||
"""
|
||||
Schedule the graph module for multi stream execution.
|
||||
gm is the graph module to be scheduled. The gm will be updated by this function.
|
||||
max_num_streams is the maximum number of streams to be used. The scheduler may not use all the streams.
|
||||
Return the number of events created.
|
||||
"""
|
||||
dag = MultiStreamDAG(gm)
|
||||
return dag.optimize(max_num_streams)
|
||||
|
||||
|
||||
# Following code is for debug purpose. Use print_dag_to_dot to print a MultiStreamDAG to dot file.
|
||||
|
||||
|
||||
def dump_dag_as_dot(dag: MultiStreamDAG, max_num_nodes: int = 500) -> None:
|
||||
COLORS = [
|
||||
"red", "chocolate", "cyan", "gold", "coral", "green", "blue", "orange",
|
||||
"purple", "brown"
|
||||
]
|
||||
filename = f"dag_{int(time.time())}.dot"
|
||||
with open(filename, 'w') as f:
|
||||
f.write("digraph G {\n")
|
||||
f.write(
|
||||
f"id_entry [label=\"node=entry, distance={dag.entry_node.distance}\"]\n"
|
||||
)
|
||||
cnt = 0
|
||||
for node in dag.nodes.values():
|
||||
color = "white" if node.stream is None else COLORS[node.stream.id]
|
||||
f.write(
|
||||
f"id_{dag.node_to_id[node.node]} [label=\"node={node.node}, "
|
||||
f"distance={node.distance}, weight={node.weight}\", "
|
||||
f"color={color}, shape=oval]\n")
|
||||
for in_edge in node.in_edges.values():
|
||||
id = str(dag.node_to_id[
|
||||
in_edge.node]) if in_edge.node is not None else "entry"
|
||||
f.write(f"id_{id} -> id_{dag.node_to_id[node.node]}\n")
|
||||
if cnt > max_num_nodes:
|
||||
break
|
||||
cnt += 1
|
||||
f.write("}\n")
|
||||
f.flush()
|
||||
@ -12,7 +12,9 @@ from torch.fx.passes.split_module import split_module
|
||||
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..utils import get_piecewise_cuda_graph_flag, make_weak_ref
|
||||
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
|
||||
make_weak_ref)
|
||||
from .multi_stream.auto_multi_stream import multi_stream_schedule
|
||||
from .utils import (get_enable_piecewise_cuda_graph_capture_flag,
|
||||
is_call_function)
|
||||
|
||||
@ -29,6 +31,7 @@ class PiecewiseInterpreter(Interpreter):
|
||||
graph_pool_handle: tuple[int, int],
|
||||
garbage_collect_values: bool = True,
|
||||
graph=None,
|
||||
max_num_streams: int = 1,
|
||||
):
|
||||
super().__init__(module, garbage_collect_values, graph)
|
||||
|
||||
@ -39,6 +42,8 @@ class PiecewiseInterpreter(Interpreter):
|
||||
self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id]
|
||||
self.graph_pool_handle = graph_pool_handle
|
||||
self.enable_inductor = enable_inductor
|
||||
self.num_events = 0
|
||||
self.max_num_streams = max_num_streams
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
@ -72,6 +77,11 @@ class PiecewiseInterpreter(Interpreter):
|
||||
found_dynamic_shape = True
|
||||
break
|
||||
|
||||
if self.max_num_streams > 1 and not self.enable_inductor:
|
||||
num_events = multi_stream_schedule(submod, self.max_num_streams)
|
||||
self.num_events = max(self.num_events, num_events)
|
||||
submod.recompile()
|
||||
|
||||
self.module.__dict__[target] = PiecewiseRunner(
|
||||
submod,
|
||||
target,
|
||||
@ -179,8 +189,12 @@ class PiecewiseRunner(object):
|
||||
with patch("gc.collect", lambda: None):
|
||||
# TODO: consider to use `make_graphed_callables()` when
|
||||
# it's ready rather than capture it ourselves
|
||||
# Graph Capture would override the stream. We need to setup the stream correctly.
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
with torch.cuda.graph(graph, pool=self.graph_pool_handle):
|
||||
extra_attrs["global_stream"] = torch.cuda.current_stream()
|
||||
output = entry.callable(*args)
|
||||
extra_attrs["global_stream"] = torch.cuda.current_stream()
|
||||
|
||||
entry.cuda_graph = graph
|
||||
# Mark weak ref here. The intermediate activation tensor should be freed properly.
|
||||
@ -218,7 +232,8 @@ def piecewise_optimizer(
|
||||
input_num_tokens: Union[int | torch.SymInt],
|
||||
cuda_graph_batch_sizes: Sequence[int],
|
||||
graph_pool_handle: tuple[int, int],
|
||||
) -> GraphModule:
|
||||
max_num_streams: int = 1,
|
||||
) -> tuple[GraphModule, int]:
|
||||
graph_pool_handle = torch.cuda.graph_pool_handle()
|
||||
graph = gm.graph
|
||||
|
||||
@ -253,13 +268,16 @@ def piecewise_optimizer(
|
||||
lambda node: node_to_graph_id[node],
|
||||
keep_original_order=True)
|
||||
|
||||
PiecewiseInterpreter(
|
||||
interpreter = PiecewiseInterpreter(
|
||||
gm,
|
||||
enable_inductor,
|
||||
input_num_tokens,
|
||||
cuda_graph_batch_sizes,
|
||||
exclude_modules_id,
|
||||
graph_pool_handle,
|
||||
).run(*example_inputs)
|
||||
max_num_streams=max_num_streams,
|
||||
)
|
||||
|
||||
return gm
|
||||
interpreter.run(*example_inputs)
|
||||
|
||||
return gm, interpreter.num_events
|
||||
|
||||
@ -5,7 +5,7 @@ from torch._higher_order_ops.auto_functionalize import (auto_functionalized,
|
||||
auto_functionalized_v2)
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from .utils import is_call_function
|
||||
from .utils import inplace_info, is_call_function
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -46,19 +46,12 @@ def remove_copy_for_mutates_args(graph: Graph):
|
||||
|
||||
inplace_func = node.args[0]
|
||||
|
||||
if inplace_func == torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default:
|
||||
remove_functionalize_inner(
|
||||
node,
|
||||
{
|
||||
1: "input",
|
||||
2: "residual"
|
||||
},
|
||||
is_v2=node.target == auto_functionalized_v2,
|
||||
)
|
||||
if inplace_func == torch.ops.trtllm.attention_inplace.default:
|
||||
remove_functionalize_inner(node, {1: "output", 2: "output_sf"})
|
||||
if inplace_func == torch.ops.trtllm.mla_custom_op_inplace.default:
|
||||
remove_functionalize_inner(node, {1: "output"})
|
||||
inplace_map = inplace_info()
|
||||
if inplace_func not in inplace_map:
|
||||
# We do not know the inplace op
|
||||
continue
|
||||
|
||||
remove_functionalize_inner(node, inplace_map[inplace_func])
|
||||
|
||||
for node in nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
@ -41,3 +41,20 @@ def set_enable_piecewise_cuda_graph_capture_flag(enable: bool):
|
||||
def get_enable_piecewise_cuda_graph_capture_flag() -> bool:
|
||||
global _enable_piecewise_cuda_graph_capture
|
||||
return _enable_piecewise_cuda_graph_capture
|
||||
|
||||
|
||||
def inplace_info():
|
||||
inplace_map = {
|
||||
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {
|
||||
1: "input",
|
||||
2: "residual"
|
||||
},
|
||||
torch.ops.trtllm.attention_inplace.default: {
|
||||
1: "output",
|
||||
2: "output_sf"
|
||||
},
|
||||
torch.ops.trtllm.mla_custom_op_inplace.default: {
|
||||
1: "output"
|
||||
}
|
||||
}
|
||||
return inplace_map
|
||||
|
||||
@ -501,51 +501,6 @@ def _register_fake():
|
||||
shape[0] = sizes[local_rank]
|
||||
return input.new_empty(shape)
|
||||
|
||||
@torch.library.register_fake("trtllm::fp4_block_scale_moe_runner")
|
||||
def _(
|
||||
routing_logits,
|
||||
routing_bias,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
gemm1_weights,
|
||||
gemm1_weights_scale,
|
||||
gemm2_weights,
|
||||
gemm2_weights_scale,
|
||||
output1_scale_scalar,
|
||||
output1_scale_gate_scalar,
|
||||
output2_scale_scalar,
|
||||
num_experts,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
intermediate_size,
|
||||
local_expert_offset,
|
||||
local_num_experts,
|
||||
routed_scaling_factor,
|
||||
tile_tokens_dim,
|
||||
routing_method_type,
|
||||
do_finalize,
|
||||
) -> List[torch.Tensor]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[1] * 2
|
||||
if do_finalize:
|
||||
return [
|
||||
hidden_states.new_empty((num_tokens, hidden_size),
|
||||
dtype=torch.bfloat16)
|
||||
]
|
||||
|
||||
expanded_row_count = num_tokens * top_k
|
||||
max_padding_required = (tile_tokens_dim - 1) * num_experts
|
||||
max_num_padded_tokens = fp4_utils.pad_up(
|
||||
expanded_row_count + max_padding_required, tile_tokens_dim)
|
||||
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
|
||||
return [
|
||||
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
|
||||
dtype=torch.bfloat16),
|
||||
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
|
||||
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
|
||||
]
|
||||
|
||||
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave")
|
||||
def _(sf: torch.Tensor):
|
||||
rows = sf.shape[-2]
|
||||
@ -559,3 +514,12 @@ def _register_fake():
|
||||
@torch.library.register_fake("trtllm::nvfp4_block_scale_interleave_reverse")
|
||||
def _(sf: torch.Tensor):
|
||||
return torch.empty_like(sf, dtype=torch.uint8)
|
||||
|
||||
@torch.library.register_fake("trtllm::moe_finalize_allreduce")
|
||||
def _(input, residual, norm_weight, expanded_idx_to_permuted_idx,
|
||||
shared_expert_output, expert_scale_factor, workspace, rank, nranks,
|
||||
eps) -> List[torch.Tensor]:
|
||||
return [
|
||||
torch.empty_like(residual),
|
||||
torch.empty_like(residual),
|
||||
]
|
||||
|
||||
@ -1056,3 +1056,45 @@ def _(
|
||||
output_sf = torch.empty(()) # Create a placeholder, which is not used.
|
||||
|
||||
return output_act, output_sf
|
||||
|
||||
|
||||
def get_event(event_idx: int):
|
||||
from ..utils import get_model_extra_attrs
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
assert "events" in extra_attrs, "Missing Event Book"
|
||||
return extra_attrs["events"]()[event_idx]
|
||||
|
||||
|
||||
def get_stream(stream_id: int):
|
||||
from ..utils import get_model_extra_attrs
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
if stream_id == 0:
|
||||
return extra_attrs["global_stream"]
|
||||
assert "aux_streams" in extra_attrs, "Missing Aux Streams"
|
||||
return extra_attrs["aux_streams"]()[stream_id - 1]
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::set_stream", mutates_args=())
|
||||
def set_stream(stream_id: int) -> None:
|
||||
stream = get_stream(stream_id)
|
||||
assert stream is not None
|
||||
torch.cuda.set_stream(stream)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::record_event", mutates_args=())
|
||||
def record_event(event_idx: int) -> None:
|
||||
event = get_event(event_idx)
|
||||
event.record()
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::wait_event", mutates_args=())
|
||||
def wait_event(event_idx: int) -> None:
|
||||
event = get_event(event_idx)
|
||||
event.wait()
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::record_stream", mutates_args=())
|
||||
def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
|
||||
stream = get_stream(stream_id)
|
||||
assert stream is not None
|
||||
tensor.record_stream(stream)
|
||||
|
||||
@ -4,13 +4,28 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.utils import (get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
from tensorrt_llm._torch.utils import (fp4_utils,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2,
|
||||
next_positive_power_of_2)
|
||||
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
OptimizationProfile, TunableRunner, TuningConfig)
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens: int, num_experts: int,
|
||||
top_k: int) -> int:
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = num_tokens * top_k // num_experts
|
||||
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FP4BlockScaleMoEInputs:
|
||||
|
||||
@ -220,11 +235,14 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
|
||||
intermediate_size: int, local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: Optional[float],
|
||||
tile_tokens_dim: int, routing_method_type: int,
|
||||
routing_method_type: int,
|
||||
do_finalize: bool) -> List[torch.Tensor]:
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
|
||||
|
||||
kernel_runner = FP4BlockScaleMoERunner(
|
||||
num_experts, top_k, n_group, topk_group, intermediate_size,
|
||||
local_expert_offset, local_num_experts, routed_scaling_factor,
|
||||
@ -254,6 +272,53 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor,
|
||||
return kernel_runner(inputs, tactic=best_tactic)
|
||||
|
||||
|
||||
@fp4_block_scale_moe_runner.register_fake
|
||||
def _(
|
||||
routing_logits,
|
||||
routing_bias,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
gemm1_weights,
|
||||
gemm1_weights_scale,
|
||||
gemm2_weights,
|
||||
gemm2_weights_scale,
|
||||
output1_scale_scalar,
|
||||
output1_scale_gate_scalar,
|
||||
output2_scale_scalar,
|
||||
num_experts,
|
||||
top_k,
|
||||
n_group,
|
||||
topk_group,
|
||||
intermediate_size,
|
||||
local_expert_offset,
|
||||
local_num_experts,
|
||||
routed_scaling_factor,
|
||||
routing_method_type,
|
||||
do_finalize,
|
||||
) -> List[torch.Tensor]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[1] * 2
|
||||
if do_finalize:
|
||||
return [
|
||||
hidden_states.new_empty((num_tokens, hidden_size),
|
||||
dtype=torch.bfloat16)
|
||||
]
|
||||
|
||||
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
|
||||
|
||||
expanded_row_count = num_tokens * top_k
|
||||
max_padding_required = (tile_tokens_dim - 1) * num_experts
|
||||
max_num_padded_tokens = fp4_utils.pad_up(
|
||||
expanded_row_count + max_padding_required, tile_tokens_dim)
|
||||
wt_dtype = routing_bias.dtype if routing_bias is not None else torch.bfloat16
|
||||
return [
|
||||
hidden_states.new_empty((max_num_padded_tokens, hidden_size),
|
||||
dtype=torch.bfloat16),
|
||||
hidden_states.new_empty((num_tokens, top_k), dtype=wt_dtype),
|
||||
hidden_states.new_empty((num_tokens, top_k), dtype=torch.int32)
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FP8BlockScaleMoEInputs:
|
||||
|
||||
@ -420,23 +485,31 @@ class FP8BlockScaleMoERunner(TunableRunner):
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::fp8_block_scale_moe_runner", mutates_args=())
|
||||
def fp8_block_scale_moe_runner(routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
num_experts: int, top_k: int, n_group: int,
|
||||
topk_group: int, intermediate_size: int,
|
||||
local_expert_offset: int, local_num_experts: int,
|
||||
routed_scaling_factor: float,
|
||||
tile_tokens_dim: int,
|
||||
routing_method_type: int) -> torch.Tensor:
|
||||
def fp8_block_scale_moe_runner(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: float,
|
||||
routing_method_type: int,
|
||||
) -> torch.Tensor:
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
tile_tokens_dim = calculate_tile_tokens_dim(num_tokens, num_experts, top_k)
|
||||
|
||||
kernel_runner = FP8BlockScaleMoERunner(num_experts, top_k, n_group,
|
||||
topk_group, intermediate_size,
|
||||
local_expert_offset,
|
||||
@ -463,3 +536,30 @@ def fp8_block_scale_moe_runner(routing_logits: torch.Tensor,
|
||||
)
|
||||
|
||||
return kernel_runner(inputs, tactic=best_tactic)
|
||||
|
||||
|
||||
@fp8_block_scale_moe_runner.register_fake
|
||||
def _(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm1_weights_scale: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
gemm2_weights_scale: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
n_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
routed_scaling_factor: float,
|
||||
routing_method_type: int,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
hidden_size = hidden_states.shape[1] * 2
|
||||
|
||||
return hidden_states.new_empty((num_tokens, hidden_size),
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from ...distributed.ops import reducescatter
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import Fp4QuantizedTensor, next_positive_power_of_2
|
||||
from ...utils import Fp4QuantizedTensor
|
||||
from .interface import MoE, MoEWeightLoadingMode
|
||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
|
||||
NVFP4TRTLLMGenFusedMoEMethod)
|
||||
@ -91,19 +91,6 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
def _check_configs(self):
|
||||
assert self.has_deepseek_fp8_block_scales or self.has_nvfp4, "TRTLLMGenFusedMoE only supports fp8_block_scaling and nvfp4 dtypes."
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor):
|
||||
top_k = self.routing_method.top_k
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Guess tokens per expert assuming perfect expert distribution first.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def _get_quant_method(self):
|
||||
if self.quant_config is not None:
|
||||
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
|
||||
@ -204,7 +191,6 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
slot_start, # local_expert_start; use ep_rank if stride!=1
|
||||
self.expert_size_per_partition, # local_expert_size
|
||||
routed_scaling_factor,
|
||||
self._get_tile_tokens_dim(x),
|
||||
self.routing_method.routing_method_type,
|
||||
)
|
||||
elif self.has_nvfp4:
|
||||
@ -240,7 +226,6 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
slot_start, # local_expert_start; use ep_rank if stride!=1
|
||||
self.expert_size_per_partition, # local_expert_size
|
||||
routed_scaling_factor,
|
||||
self._get_tile_tokens_dim(x),
|
||||
self.routing_method.routing_method_type,
|
||||
do_finalize=do_finalize,
|
||||
)
|
||||
|
||||
@ -73,6 +73,7 @@ class PyTorchConfig:
|
||||
torch_compile_piecewise_cuda_graph: bool = False
|
||||
# When torch compile is enabled, userbuffers is enabled by default
|
||||
torch_compile_enable_userbuffers: bool = True
|
||||
torch_compile_max_num_streams: int = 1
|
||||
|
||||
# Enable autotuner only when torch compile is enabled
|
||||
# TODO: after it can be work stable in warmup stage
|
||||
|
||||
@ -323,7 +323,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
enable_piecewise_cuda_graph=pytorch_backend_config.
|
||||
torch_compile_piecewise_cuda_graph,
|
||||
cuda_graph_batch_sizes=pytorch_backend_config.
|
||||
cuda_graph_batch_sizes)
|
||||
cuda_graph_batch_sizes,
|
||||
max_num_streams=pytorch_backend_config.
|
||||
torch_compile_max_num_streams)
|
||||
if isinstance(self.model, DecoderModelForCausalLM):
|
||||
self.model.model = torch.compile(
|
||||
self.model.model,
|
||||
@ -2093,6 +2095,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attrs["attention_metadata"] = weakref.ref(kwargs['attn_metadata'])
|
||||
attrs.update(self.model.model_config.extra_attrs)
|
||||
|
||||
if self._torch_compile_backend is not None:
|
||||
# Register aux streams and events to model extra attrs.
|
||||
# The streams and events are list which could be updated during compilation.
|
||||
attrs["aux_streams"] = weakref.ref(
|
||||
self._torch_compile_backend.aux_streams)
|
||||
attrs["events"] = weakref.ref(self._torch_compile_backend.events)
|
||||
attrs["global_stream"] = torch.cuda.current_stream()
|
||||
|
||||
if is_trace_enabled("TLLM_TRACE_MODEL_FORWARD"):
|
||||
return trace_func(self.model.forward)(**kwargs)
|
||||
else:
|
||||
|
||||
@ -196,7 +196,17 @@ def next_positive_power_of_2(x: int) -> int:
|
||||
if x < 1:
|
||||
return 1
|
||||
|
||||
return 1 << (x - 1).bit_length()
|
||||
# Following code is equivalent to 1 << (x - 1).bit_length()
|
||||
# But this impl does not contain bit_length() so can be used by torch compile.
|
||||
# It can correctly handle 64bit number which should be enough for now.
|
||||
n = x - 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
return n + 1
|
||||
|
||||
|
||||
def last_positive_power_of_2(x: int) -> int:
|
||||
|
||||
@ -1792,6 +1792,20 @@ class TorchCompileConfig(BaseModel):
|
||||
description=
|
||||
"When torch compile is enabled, userbuffers is enabled by default.")
|
||||
|
||||
max_num_streams: int = Field(
|
||||
default=1,
|
||||
description=
|
||||
"The maximum number of CUDA streams to use for torch.compile.")
|
||||
|
||||
@field_validator('max_num_streams')
|
||||
@classmethod
|
||||
def validate_torch_compile_max_num_streams(cls, v):
|
||||
"""Validate torch_compile_config.max_num_streams >= 1."""
|
||||
if v < 1:
|
||||
raise ValueError(
|
||||
"torch_compile_config.max_num_streams must be >= 1")
|
||||
return v
|
||||
|
||||
|
||||
class TorchLlmArgs(BaseLlmArgs):
|
||||
# Just a dummy BuildConfig to allow code reuse with the TrtLlmArgs
|
||||
@ -2116,6 +2130,9 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
torch_compile_enable_userbuffers=self.torch_compile_config.
|
||||
enable_userbuffers if self.torch_compile_config is not None else
|
||||
TorchCompileConfig.model_fields['enable_userbuffers'].default,
|
||||
torch_compile_max_num_streams=self.torch_compile_config.
|
||||
max_num_streams if self.torch_compile_config is not None else
|
||||
TorchCompileConfig.model_fields['max_num_streams'].default,
|
||||
enable_autotuner=self.enable_autotuner,
|
||||
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
|
||||
load_format=self.load_format,
|
||||
|
||||
@ -661,7 +661,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph,
|
||||
max_num_streams=3) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -702,8 +703,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph
|
||||
and not attention_dp) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph and not attention_dp,
|
||||
max_num_streams=3) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -742,7 +743,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph,
|
||||
max_num_streams=3) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -793,8 +795,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("https://nvbugs/5252559")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = (TorchCompileConfig(
|
||||
enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph)
|
||||
if torch_compile else None)
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph,
|
||||
max_num_streams=3) if torch_compile else None)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
@ -896,8 +899,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph
|
||||
and not attention_dp) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph and not attention_dp,
|
||||
max_num_streams=3) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -958,8 +961,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
pytest.skip("PP with torch.compile is not supported yet.")
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = (TorchCompileConfig(
|
||||
enable_fullgraph=True, enable_piecewise_cuda_graph=cuda_graph)
|
||||
if torch_compile else None)
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph,
|
||||
max_num_streams=3) if torch_compile else None)
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
use_cuda_graph=cuda_graph,
|
||||
@ -1088,7 +1092,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph,
|
||||
max_num_streams=3) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
@ -1141,8 +1146,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.
|
||||
torch_compile_config = TorchCompileConfig(
|
||||
enable_fullgraph=True,
|
||||
enable_piecewise_cuda_graph=cuda_graph
|
||||
and not attention_dp) if torch_compile else None
|
||||
enable_piecewise_cuda_graph=cuda_graph and not attention_dp,
|
||||
max_num_streams=3) if torch_compile else None
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
|
||||
@ -621,7 +621,6 @@ class TestMoeFP8:
|
||||
padding = 8
|
||||
routed_scaling = 2.5
|
||||
routing_method_type = RoutingMethodType.DeepSeekV3
|
||||
tile_tokens_dim = 8 if num_tokens < 1024 else 32
|
||||
|
||||
assert top_k <= num_experts
|
||||
assert top_k <= 8
|
||||
@ -670,8 +669,7 @@ class TestMoeFP8:
|
||||
expert_logits, routing_bias, hidden_states, hidden_states_scale,
|
||||
gemm1_weights, gemm1_scales, gemm2_weights, gemm2_scales,
|
||||
num_experts, top_k, n_groups, top_k_groups, intermediate_size,
|
||||
0, num_experts, routed_scaling, tile_tokens_dim,
|
||||
routing_method_type)
|
||||
0, num_experts, routed_scaling, routing_method_type)
|
||||
|
||||
output_dequant_actual = output.to(torch.float)
|
||||
#
|
||||
@ -1033,7 +1031,6 @@ class TestMoeFp4:
|
||||
0,
|
||||
num_experts,
|
||||
routed_scaling,
|
||||
tile_tokens_dim,
|
||||
routing_method_type,
|
||||
do_finalize=True)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user