#!/usr/bin/env python3 """Script to vendor triton-kernels into TensorRT-LLM. This script: 1. Clones the Triton repo at a specific tag to a temp directory 2. Copies the triton_kernels module to the repo root as a top-level package 3. Adds attribution headers to all Python files 4. Copies the LICENSE file from Triton 5. Creates a VERSION file to track the vendored version 6. Creates a README.md with clear copyright attribution To update to a new version: python scripts/vendor_triton_kernels.py --tag v3.6.0 """ import argparse import shutil import subprocess import tempfile from pathlib import Path REPO_ROOT = Path(__file__).parent.parent.resolve() TRITON_REPO_URL = "https://github.com/triton-lang/triton.git" TRITON_KERNELS_MODULE_PATH = "python/triton_kernels/triton_kernels" DEST_PATH = REPO_ROOT / "triton_kernels" VENDORED_NOTICE = "# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY." ATTRIBUTION_HEADER = f"""\ {VENDORED_NOTICE} # Source: https://github.com/triton-lang/triton/tree/{{tag}}/{{original_file}} # Triton is licensed under the MIT License. """ def clone_triton(tag: str, dest_dir: str) -> tuple[Path, Path]: """Clone the Triton repo at the specified tag. Returns (module_path, repo_root).""" print(f"Cloning Triton repo at tag {tag}...") subprocess.run( ["git", "clone", "--depth", "1", "--branch", tag, TRITON_REPO_URL, dest_dir], check=True, capture_output=True, text=True, ) repo_root = Path(dest_dir) triton_kernels_module_path = repo_root / TRITON_KERNELS_MODULE_PATH if not triton_kernels_module_path.exists(): raise RuntimeError(f"triton_kernels module not found at {triton_kernels_module_path}") return triton_kernels_module_path, repo_root def add_attribution_header(file_path: Path, tag: str, original_rel_path: str) -> None: content = file_path.read_text() # Handle shebang and encoding declarations lines = content.split("\n") insert_pos = 0 preserved_lines = [] for i, line in enumerate(lines): if line.startswith("#!") or line.startswith("# -*-") or line.startswith("# coding"): preserved_lines.append(line) insert_pos = i + 1 else: break header = ATTRIBUTION_HEADER.format(tag=tag, original_file=original_rel_path) new_content = "\n".join(preserved_lines) if preserved_lines: new_content += "\n" new_content += header # Add blank line between header and content if file has content remaining_content = "\n".join(lines[insert_pos:]) if remaining_content.strip(): new_content += "\n" new_content += remaining_content file_path.write_text(new_content) def copy_triton_kernels(src_path: Path, dest_path: Path, tag: str) -> list[str]: """Copy triton_kernels module to destination and add attribution headers.""" print(f"Copying triton_kernels to {dest_path}...") if dest_path.exists(): print(f" Removing existing {dest_path}") shutil.rmtree(dest_path) shutil.copytree(src_path, dest_path) # Add attribution headers to all existing Python files python_files = [] for py_file in dest_path.rglob("*.py"): rel_path = py_file.relative_to(dest_path) original_rel_path = f"{TRITON_KERNELS_MODULE_PATH}/{rel_path}" python_files.append(str(rel_path)) add_attribution_header(py_file, tag, original_rel_path) # Create __init__.py files in subdirs that don't have them. # Triton's upstream code relies on implicit namespace packages (PEP 420), but # setuptools' find_packages() requires __init__.py to discover subpackages. for subdir in dest_path.rglob("*"): if subdir.is_dir(): init_file = subdir / "__init__.py" if not init_file.exists(): print(f" Creating {init_file.relative_to(dest_path)}") init_file.write_text(f"{VENDORED_NOTICE}\n") print(f" Copied triton_kernels module to {dest_path}") return python_files def copy_license(triton_repo_root: Path, dest_path: Path) -> None: """Copy the Triton LICENSE file.""" print("Copying LICENSE file...") license_src = triton_repo_root / "LICENSE" license_dest = dest_path / "LICENSE" shutil.copy2(license_src, license_dest) print(f" Copied LICENSE to {license_dest}") def create_version_file(dest_path: Path, tag: str) -> None: """Create a VERSION file to track which version was vendored.""" print("Creating VERSION file...") version_file = dest_path / "VERSION" version_content = f"""{tag} # This file tracks the version of triton-kernels that was vendored. # To update, run: python scripts/vendor_triton_kernels.py --tag """ version_file.write_text(version_content) print(f" Created {version_file}") def create_readme(dest_path: Path, tag: str) -> None: """Create a README.md with clear copyright attribution.""" print("Creating README.md...") readme_file = dest_path / "README.md" readme_content = f"""# Vendored triton_kernels This directory contains code vendored from the [Triton](https://github.com/triton-lang/triton) project. | | | |---|---| | **Copyright** | The Triton Authors | | **License** | MIT (see [LICENSE](LICENSE) file in this directory) | | **Source** | https://github.com/triton-lang/triton/tree/{tag}/python/triton_kernels/triton_kernels | | **Version** | `{tag}` | ## Attribution This code is the work of the Triton authors and is included here under the MIT License. Each Python file includes an attribution header indicating its origin. ## Do Not Edit This code is vendored verbatim and should not be modified directly. To update to a newer version, run: ```bash python scripts/vendor_triton_kernels.py --tag ``` """ readme_file.write_text(readme_content) print(f" Created {readme_file}") def main(): parser = argparse.ArgumentParser(description="Vendor triton-kernels into TensorRT-LLM") parser.add_argument( "--tag", required=True, help="Triton git tag to vendor from. See the list of tags at https://github.com/triton-lang/triton/tags", ) args = parser.parse_args() print(f"Vendoring triton-kernels from Triton {args.tag}") print(f"Destination: {DEST_PATH}") with tempfile.TemporaryDirectory() as tmp_dir: triton_kernels_src, triton_repo_root = clone_triton(args.tag, tmp_dir) copy_triton_kernels(triton_kernels_src, DEST_PATH, args.tag) copy_license(triton_repo_root, DEST_PATH) create_version_file(DEST_PATH, args.tag) create_readme(DEST_PATH, args.tag) print("SUCCESS: triton-kernels has been vendored.") if __name__ == "__main__": main()