TensorRT-LLMs/triton_kernels/specialize.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

140 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/specialize.py
# Triton is licensed under the MIT License.
import inspect
import re
import textwrap
import types
import triton
def cacheable(f):
"""
A decorator that allow you to write something of the form:
@cacheable
def my_kernel(): return (expression dynamically defining a kernel)
such that it interacts gracefully with triton cache and preload.
"""
g = f()
g.fn.__name__ = f.__name__
g.fn.__module__ = f.__module__
g.fn.__qualname__ = f.__qualname__
g.__name__ = f.__name__
g.__module__ = f.__module__
g.__qualname__ = f.__qualname__
g._fn_name = f"{f.__module__}.{f.__qualname__}"
return g
def define_kernel(src, module, attrs=None, **extra_globals):
"""
Dynamically create a Triton function or kernel from a src string,
linking any symbols in the kernel to objects specified by extra_globals.
"""
# create templace function
def _empty_fn():
pass
gdict = dict(**(_empty_fn.__globals__))
gdict.update(extra_globals)
f = types.FunctionType(_empty_fn.__code__, gdict)
f.__module__ = module.__name__
src = textwrap.dedent(src)
src = src[src.find("def "):]
stored_functions = []
function_name = src[4:].split("(")[0].strip()
exec_globals = gdict
exec_globals.update({"stored_functions": stored_functions})
exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
f.__signature__ = inspect.signature(stored_functions[0])
f.__name__ = function_name
f.__doc__ = stored_functions[0].__doc__
if attrs is None:
attrs = dict()
f = triton.JITFunction(f, **attrs)
f._unsafe_update_src(src)
return f
def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
assert isinstance(fn, triton.runtime.jit.JITFunction)
if name is None:
name = f"{fn.__name__}"
# Get original source code
src = inspect.getsource(fn.fn)
src = textwrap.dedent(src)
lines = src.split("\n")
# Skip decorator and def line
def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
# separate header vs body LOC
header_end = def_idx
while not lines[header_end].rstrip().endswith(":"):
header_end += 1
body_lines = lines[header_end + 1:]
header_lines = lines[def_idx:header_end + 1]
# clean-up header
header_clean = [
l.split("#", 1)[0].strip() # keep code, discard comment
for l in header_lines
if l.split("#", 1)[0].strip() # skip blankaftercomment lines
]
# decompose arguments
header_src = " ".join(header_clean) # turn it into a single line
m = re.search(r"\((.*)\)\s*:", header_src)
if not m:
raise ValueError("Could not parse function header")
args_str = m.group(1)
args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
non_specialized_args = []
for arg in args:
arg_key = arg.split(":")[0].split("=")[0].strip()
new_args = tuples.get(arg_key, [arg])
if arg_key not in constants:
non_specialized_args += new_args
# add global symbols
spec_fns = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)}
globals = spec_fns | fn.get_capture_scope()
# build new source code and define kernel dynamically
new_signature = f"def {name}({', '.join(non_specialized_args)}):"
constexpr_lines = [
f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items()
]
tuple_lines = [
f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items()
]
new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines)
# find function parameters
sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
params = list(sig.parameters.values())[2:]
attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
# make a new repr which appends the repr of the specialized functions.
base_repr = attrs["repr"]
def new_repr(specialization):
ret = base_repr(specialization)
for spec_fn in spec_fns.values():
spec_repr = spec_fn.repr(None)
if spec_repr:
spec_repr = spec_repr.strip("_")
if spec_repr:
ret += f"_{spec_repr}"
return ret
attrs["repr"] = new_repr
if do_not_specialize:
attrs["do_not_specialize"] = do_not_specialize
ret = define_kernel(new_src, module, attrs, **globals)
return ret