mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
95 lines
2.6 KiB
Python
95 lines
2.6 KiB
Python
import contextlib
|
|
from typing import Callable, List, Union
|
|
|
|
import torch
|
|
from torch.fx import Node
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
|
|
def get_symint_val(i: Union[torch.SymInt | int]):
|
|
if isinstance(i, int):
|
|
return i
|
|
elif isinstance(i, torch.SymInt):
|
|
node = i.node
|
|
expr = node.expr
|
|
shape_env: ShapeEnv = node.shape_env
|
|
var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace(
|
|
shape_env.var_to_val)
|
|
return var_val
|
|
else:
|
|
raise Exception("Only support int or torch.SymInt")
|
|
|
|
|
|
def get_arg(node, idx, arg_name):
|
|
return node.args[idx] if len(node.args) > idx else node.kwargs[arg_name]
|
|
|
|
|
|
def is_call_function(node: Node, target: Union[List[Callable], Callable]):
|
|
if isinstance(target, list):
|
|
return node.op == "call_function" and node.target in target
|
|
else:
|
|
return node.op == "call_function" and node.target == target
|
|
|
|
|
|
_enable_piecewise_cuda_graph_capture = False
|
|
|
|
|
|
def set_capture_piecewise_cuda_graph_flag(enable: bool):
|
|
global _enable_piecewise_cuda_graph_capture
|
|
_enable_piecewise_cuda_graph_capture = enable
|
|
|
|
|
|
def get_capture_piecewise_cuda_graph_flag() -> bool:
|
|
global _enable_piecewise_cuda_graph_capture
|
|
return _enable_piecewise_cuda_graph_capture
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def capture_piecewise_cuda_graph(enable: bool):
|
|
prev_enable = get_capture_piecewise_cuda_graph_flag()
|
|
set_capture_piecewise_cuda_graph_flag(enable)
|
|
try:
|
|
yield
|
|
finally:
|
|
set_capture_piecewise_cuda_graph_flag(prev_enable)
|
|
|
|
|
|
def inplace_info():
|
|
inplace_map = {
|
|
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {
|
|
1: "input",
|
|
2: "residual"
|
|
},
|
|
torch.ops.trtllm.attn_custom_op_inplace.default: {
|
|
1: "output",
|
|
},
|
|
torch.ops.trtllm.mla_custom_op_inplace.default: {
|
|
1: "output"
|
|
},
|
|
torch.ops.trtllm.fused_qk_norm_rope.default: {
|
|
1: "qkv"
|
|
},
|
|
torch.ops.trtllm.flashinfer_apply_rope_with_cos_sin_cache_inplace.default:
|
|
{
|
|
1: "query",
|
|
2: "key"
|
|
},
|
|
torch.ops.trtllm.logits_bitmask.default: {
|
|
1: "logits"
|
|
},
|
|
torch.ops.trtllm.moe_output_memset_inplace.default: {
|
|
1: "input"
|
|
},
|
|
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default:
|
|
{
|
|
6: "output"
|
|
},
|
|
torch.ops.trtllm.pp_recv_tensors.default: {
|
|
1: "tensors"
|
|
},
|
|
torch.ops.trtllm.pp_send_tensors.default: {
|
|
1: "tensors"
|
|
}
|
|
}
|
|
return inplace_map
|