TensorRT-LLMs/tests/_torch/pattern_watcher.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

37 lines
1.1 KiB
Python

import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternPrettyPrinter, fwd_only,
gen_pattern)
import tensorrt_llm
import tensorrt_llm._torch
import tensorrt_llm._torch.modules
import tensorrt_llm._torch.modules.rms_norm
norm = tensorrt_llm._torch.modules.rms_norm.RMSNorm(hidden_size=3,
eps=1e-5,
dtype=torch.float16).cuda()
def source_pattern(x: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, eps: float):
at = auto_functionalized(
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default,
input=x,
residual=residual,
weight=weight,
eps=eps)
return at[1], at[2]
p = PatternPrettyPrinter()
x = torch.empty((1, 3)).cuda().half()
res = x.clone()
weight = torch.empty((3, )).cuda().half()
eps = 1e-5
pattern = gen_pattern(source_pattern, [x, res, weight, eps], fwd_only)
print(PatternPrettyPrinter.run(pattern))