TensorRT-LLMs/scripts/vendor_triton_kernels.py
Anish Shanbhag 24ac86c485
[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
2026-01-28 19:56:32 -08:00

201 lines
6.6 KiB
Python
Executable File

#!/usr/bin/env python3
"""Script to vendor triton-kernels into TensorRT-LLM.
This script:
1. Clones the Triton repo at a specific tag to a temp directory
2. Copies the triton_kernels module to the repo root as a top-level package
3. Adds attribution headers to all Python files
4. Copies the LICENSE file from Triton
5. Creates a VERSION file to track the vendored version
6. Creates a README.md with clear copyright attribution
To update to a new version:
python scripts/vendor_triton_kernels.py --tag v3.6.0
"""
import argparse
import shutil
import subprocess
import tempfile
from pathlib import Path
REPO_ROOT = Path(__file__).parent.parent.resolve()
TRITON_REPO_URL = "https://github.com/triton-lang/triton.git"
TRITON_KERNELS_MODULE_PATH = "python/triton_kernels/triton_kernels"
DEST_PATH = REPO_ROOT / "triton_kernels"
VENDORED_NOTICE = "# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY."
ATTRIBUTION_HEADER = f"""\
{VENDORED_NOTICE}
# Source: https://github.com/triton-lang/triton/tree/{{tag}}/{{original_file}}
# Triton is licensed under the MIT License.
"""
def clone_triton(tag: str, dest_dir: str) -> tuple[Path, Path]:
"""Clone the Triton repo at the specified tag. Returns (module_path, repo_root)."""
print(f"Cloning Triton repo at tag {tag}...")
subprocess.run(
["git", "clone", "--depth", "1", "--branch", tag, TRITON_REPO_URL, dest_dir],
check=True,
capture_output=True,
text=True,
)
repo_root = Path(dest_dir)
triton_kernels_module_path = repo_root / TRITON_KERNELS_MODULE_PATH
if not triton_kernels_module_path.exists():
raise RuntimeError(f"triton_kernels module not found at {triton_kernels_module_path}")
return triton_kernels_module_path, repo_root
def add_attribution_header(file_path: Path, tag: str, original_rel_path: str) -> None:
content = file_path.read_text()
# Handle shebang and encoding declarations
lines = content.split("\n")
insert_pos = 0
preserved_lines = []
for i, line in enumerate(lines):
if line.startswith("#!") or line.startswith("# -*-") or line.startswith("# coding"):
preserved_lines.append(line)
insert_pos = i + 1
else:
break
header = ATTRIBUTION_HEADER.format(tag=tag, original_file=original_rel_path)
new_content = "\n".join(preserved_lines)
if preserved_lines:
new_content += "\n"
new_content += header
# Add blank line between header and content if file has content
remaining_content = "\n".join(lines[insert_pos:])
if remaining_content.strip():
new_content += "\n"
new_content += remaining_content
file_path.write_text(new_content)
def copy_triton_kernels(src_path: Path, dest_path: Path, tag: str) -> list[str]:
"""Copy triton_kernels module to destination and add attribution headers."""
print(f"Copying triton_kernels to {dest_path}...")
if dest_path.exists():
print(f" Removing existing {dest_path}")
shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path)
# Add attribution headers to all existing Python files
python_files = []
for py_file in dest_path.rglob("*.py"):
rel_path = py_file.relative_to(dest_path)
original_rel_path = f"{TRITON_KERNELS_MODULE_PATH}/{rel_path}"
python_files.append(str(rel_path))
add_attribution_header(py_file, tag, original_rel_path)
# Create __init__.py files in subdirs that don't have them.
# Triton's upstream code relies on implicit namespace packages (PEP 420), but
# setuptools' find_packages() requires __init__.py to discover subpackages.
for subdir in dest_path.rglob("*"):
if subdir.is_dir():
init_file = subdir / "__init__.py"
if not init_file.exists():
print(f" Creating {init_file.relative_to(dest_path)}")
init_file.write_text(f"{VENDORED_NOTICE}\n")
print(f" Copied triton_kernels module to {dest_path}")
return python_files
def copy_license(triton_repo_root: Path, dest_path: Path) -> None:
"""Copy the Triton LICENSE file."""
print("Copying LICENSE file...")
license_src = triton_repo_root / "LICENSE"
license_dest = dest_path / "LICENSE"
shutil.copy2(license_src, license_dest)
print(f" Copied LICENSE to {license_dest}")
def create_version_file(dest_path: Path, tag: str) -> None:
"""Create a VERSION file to track which version was vendored."""
print("Creating VERSION file...")
version_file = dest_path / "VERSION"
version_content = f"""{tag}
# This file tracks the version of triton-kernels that was vendored.
# To update, run: python scripts/vendor_triton_kernels.py --tag <new-tag>
"""
version_file.write_text(version_content)
print(f" Created {version_file}")
def create_readme(dest_path: Path, tag: str) -> None:
"""Create a README.md with clear copyright attribution."""
print("Creating README.md...")
readme_file = dest_path / "README.md"
readme_content = f"""# Vendored triton_kernels
This directory contains code vendored from the [Triton](https://github.com/triton-lang/triton) project.
| | |
|---|---|
| **Copyright** | The Triton Authors |
| **License** | MIT (see [LICENSE](LICENSE) file in this directory) |
| **Source** | https://github.com/triton-lang/triton/tree/{tag}/python/triton_kernels/triton_kernels |
| **Version** | `{tag}` |
## Attribution
This code is the work of the Triton authors and is included here under the MIT License.
Each Python file includes an attribution header indicating its origin.
## Do Not Edit
This code is vendored verbatim and should not be modified directly.
To update to a newer version, run:
```bash
python scripts/vendor_triton_kernels.py --tag <new-tag>
```
"""
readme_file.write_text(readme_content)
print(f" Created {readme_file}")
def main():
parser = argparse.ArgumentParser(description="Vendor triton-kernels into TensorRT-LLM")
parser.add_argument(
"--tag",
required=True,
help="Triton git tag to vendor from. See the list of tags at https://github.com/triton-lang/triton/tags",
)
args = parser.parse_args()
print(f"Vendoring triton-kernels from Triton {args.tag}")
print(f"Destination: {DEST_PATH}")
with tempfile.TemporaryDirectory() as tmp_dir:
triton_kernels_src, triton_repo_root = clone_triton(args.tag, tmp_dir)
copy_triton_kernels(triton_kernels_src, DEST_PATH, args.tag)
copy_license(triton_repo_root, DEST_PATH)
create_version_file(DEST_PATH, args.tag)
create_readme(DEST_PATH, args.tag)
print("SUCCESS: triton-kernels has been vendored.")
if __name__ == "__main__":
main()