mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import torch
|
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|
from torch._inductor.pattern_matcher import (PatternPrettyPrinter, fwd_only,
|
|
gen_pattern)
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm._torch
|
|
import tensorrt_llm._torch.modules
|
|
import tensorrt_llm._torch.modules.rms_norm
|
|
|
|
norm = tensorrt_llm._torch.modules.rms_norm.RMSNorm(hidden_size=3,
|
|
eps=1e-5,
|
|
dtype=torch.float16).cuda()
|
|
|
|
|
|
def source_pattern(x: torch.Tensor, residual: torch.Tensor,
|
|
weight: torch.Tensor, eps: float):
|
|
at = auto_functionalized(
|
|
torch.ops.trtllm.flashinfer_fused_add_rmsnorm.default,
|
|
input=x,
|
|
residual=residual,
|
|
weight=weight,
|
|
eps=eps)
|
|
return at[1], at[2]
|
|
|
|
|
|
p = PatternPrettyPrinter()
|
|
|
|
x = torch.empty((1, 3)).cuda().half()
|
|
res = x.clone()
|
|
weight = torch.empty((3, )).cuda().half()
|
|
eps = 1e-5
|
|
|
|
pattern = gen_pattern(source_pattern, [x, res, weight, eps], fwd_only)
|
|
|
|
print(PatternPrettyPrinter.run(pattern))
|