TensorRT-LLMs/tensorrt_llm/_torch/compilation/utils.py
Enwei Zhu 1745102e72
[TRTLLM-7027][feat] Fuse d2t to logitsBitmaskKernel and fix a race condition in one-model spec (#7481)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
Co-authored-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-09-04 23:30:14 +08:00

82 lines
2.2 KiB
Python

import contextlib
from typing import Callable, List, Union
import torch
from torch.fx import Node
from torch.fx.experimental.symbolic_shapes import ShapeEnv
def get_symint_val(i: Union[torch.SymInt | int]):
if isinstance(i, int):
return i
elif isinstance(i, torch.SymInt):
node = i.node
expr = node.expr
shape_env: ShapeEnv = node.shape_env
var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace(
shape_env.var_to_val)
return var_val
else:
raise Exception("Only support int or torch.SymInt")
def get_arg(node, idx, arg_name):
return node.args[idx] if len(node.args) > idx else node.kwargs[arg_name]
def is_call_function(node: Node, target: Union[List[Callable], Callable]):
if isinstance(target, list):
return node.op == "call_function" and node.target in target
else:
return node.op == "call_function" and node.target == target
_enable_piecewise_cuda_graph_capture = False
def set_capture_piecewise_cuda_graph_flag(enable: bool):
global _enable_piecewise_cuda_graph_capture
_enable_piecewise_cuda_graph_capture = enable
def get_capture_piecewise_cuda_graph_flag() -> bool:
global _enable_piecewise_cuda_graph_capture
return _enable_piecewise_cuda_graph_capture
@contextlib.contextmanager
def capture_piecewise_cuda_graph(enable: bool):
prev_enable = get_capture_piecewise_cuda_graph_flag()
set_capture_piecewise_cuda_graph_flag(enable)
try:
yield
finally:
set_capture_piecewise_cuda_graph_flag(prev_enable)
def inplace_info():
inplace_map = {
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default: {
1: "input",
2: "residual"
},
torch.ops.trtllm.attn_custom_op_inplace.default: {
1: "output",
},
torch.ops.trtllm.mla_custom_op_inplace.default: {
1: "output"
},
torch.ops.trtllm.fused_qk_norm_rope.default: {
1: "qkv"
},
torch.ops.trtllm.flashinfer_apply_rope_with_cos_sin_cache_inplace.default:
{
1: "query",
2: "key"
},
torch.ops.trtllm.logits_bitmask.default: {
1: "logits"
}
}
return inplace_map