mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
201 lines
6.6 KiB
Python
Executable File
201 lines
6.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Script to vendor triton-kernels into TensorRT-LLM.
|
|
|
|
This script:
|
|
1. Clones the Triton repo at a specific tag to a temp directory
|
|
2. Copies the triton_kernels module to the repo root as a top-level package
|
|
3. Adds attribution headers to all Python files
|
|
4. Copies the LICENSE file from Triton
|
|
5. Creates a VERSION file to track the vendored version
|
|
6. Creates a README.md with clear copyright attribution
|
|
|
|
To update to a new version:
|
|
python scripts/vendor_triton_kernels.py --tag v3.6.0
|
|
"""
|
|
|
|
import argparse
|
|
import shutil
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
REPO_ROOT = Path(__file__).parent.parent.resolve()
|
|
TRITON_REPO_URL = "https://github.com/triton-lang/triton.git"
|
|
TRITON_KERNELS_MODULE_PATH = "python/triton_kernels/triton_kernels"
|
|
DEST_PATH = REPO_ROOT / "triton_kernels"
|
|
|
|
VENDORED_NOTICE = "# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY."
|
|
ATTRIBUTION_HEADER = f"""\
|
|
{VENDORED_NOTICE}
|
|
# Source: https://github.com/triton-lang/triton/tree/{{tag}}/{{original_file}}
|
|
# Triton is licensed under the MIT License.
|
|
"""
|
|
|
|
|
|
def clone_triton(tag: str, dest_dir: str) -> tuple[Path, Path]:
|
|
"""Clone the Triton repo at the specified tag. Returns (module_path, repo_root)."""
|
|
print(f"Cloning Triton repo at tag {tag}...")
|
|
|
|
subprocess.run(
|
|
["git", "clone", "--depth", "1", "--branch", tag, TRITON_REPO_URL, dest_dir],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
|
|
repo_root = Path(dest_dir)
|
|
triton_kernels_module_path = repo_root / TRITON_KERNELS_MODULE_PATH
|
|
if not triton_kernels_module_path.exists():
|
|
raise RuntimeError(f"triton_kernels module not found at {triton_kernels_module_path}")
|
|
|
|
return triton_kernels_module_path, repo_root
|
|
|
|
|
|
def add_attribution_header(file_path: Path, tag: str, original_rel_path: str) -> None:
|
|
content = file_path.read_text()
|
|
|
|
# Handle shebang and encoding declarations
|
|
lines = content.split("\n")
|
|
insert_pos = 0
|
|
preserved_lines = []
|
|
|
|
for i, line in enumerate(lines):
|
|
if line.startswith("#!") or line.startswith("# -*-") or line.startswith("# coding"):
|
|
preserved_lines.append(line)
|
|
insert_pos = i + 1
|
|
else:
|
|
break
|
|
|
|
header = ATTRIBUTION_HEADER.format(tag=tag, original_file=original_rel_path)
|
|
|
|
new_content = "\n".join(preserved_lines)
|
|
if preserved_lines:
|
|
new_content += "\n"
|
|
new_content += header
|
|
|
|
# Add blank line between header and content if file has content
|
|
remaining_content = "\n".join(lines[insert_pos:])
|
|
if remaining_content.strip():
|
|
new_content += "\n"
|
|
new_content += remaining_content
|
|
|
|
file_path.write_text(new_content)
|
|
|
|
|
|
def copy_triton_kernels(src_path: Path, dest_path: Path, tag: str) -> list[str]:
|
|
"""Copy triton_kernels module to destination and add attribution headers."""
|
|
print(f"Copying triton_kernels to {dest_path}...")
|
|
|
|
if dest_path.exists():
|
|
print(f" Removing existing {dest_path}")
|
|
shutil.rmtree(dest_path)
|
|
|
|
shutil.copytree(src_path, dest_path)
|
|
|
|
# Add attribution headers to all existing Python files
|
|
python_files = []
|
|
for py_file in dest_path.rglob("*.py"):
|
|
rel_path = py_file.relative_to(dest_path)
|
|
original_rel_path = f"{TRITON_KERNELS_MODULE_PATH}/{rel_path}"
|
|
python_files.append(str(rel_path))
|
|
add_attribution_header(py_file, tag, original_rel_path)
|
|
|
|
# Create __init__.py files in subdirs that don't have them.
|
|
# Triton's upstream code relies on implicit namespace packages (PEP 420), but
|
|
# setuptools' find_packages() requires __init__.py to discover subpackages.
|
|
for subdir in dest_path.rglob("*"):
|
|
if subdir.is_dir():
|
|
init_file = subdir / "__init__.py"
|
|
if not init_file.exists():
|
|
print(f" Creating {init_file.relative_to(dest_path)}")
|
|
init_file.write_text(f"{VENDORED_NOTICE}\n")
|
|
|
|
print(f" Copied triton_kernels module to {dest_path}")
|
|
return python_files
|
|
|
|
|
|
def copy_license(triton_repo_root: Path, dest_path: Path) -> None:
|
|
"""Copy the Triton LICENSE file."""
|
|
print("Copying LICENSE file...")
|
|
|
|
license_src = triton_repo_root / "LICENSE"
|
|
license_dest = dest_path / "LICENSE"
|
|
|
|
shutil.copy2(license_src, license_dest)
|
|
print(f" Copied LICENSE to {license_dest}")
|
|
|
|
|
|
def create_version_file(dest_path: Path, tag: str) -> None:
|
|
"""Create a VERSION file to track which version was vendored."""
|
|
print("Creating VERSION file...")
|
|
|
|
version_file = dest_path / "VERSION"
|
|
version_content = f"""{tag}
|
|
# This file tracks the version of triton-kernels that was vendored.
|
|
# To update, run: python scripts/vendor_triton_kernels.py --tag <new-tag>
|
|
"""
|
|
|
|
version_file.write_text(version_content)
|
|
print(f" Created {version_file}")
|
|
|
|
|
|
def create_readme(dest_path: Path, tag: str) -> None:
|
|
"""Create a README.md with clear copyright attribution."""
|
|
print("Creating README.md...")
|
|
|
|
readme_file = dest_path / "README.md"
|
|
readme_content = f"""# Vendored triton_kernels
|
|
|
|
This directory contains code vendored from the [Triton](https://github.com/triton-lang/triton) project.
|
|
|
|
| | |
|
|
|---|---|
|
|
| **Copyright** | The Triton Authors |
|
|
| **License** | MIT (see [LICENSE](LICENSE) file in this directory) |
|
|
| **Source** | https://github.com/triton-lang/triton/tree/{tag}/python/triton_kernels/triton_kernels |
|
|
| **Version** | `{tag}` |
|
|
|
|
## Attribution
|
|
|
|
This code is the work of the Triton authors and is included here under the MIT License.
|
|
Each Python file includes an attribution header indicating its origin.
|
|
|
|
## Do Not Edit
|
|
|
|
This code is vendored verbatim and should not be modified directly.
|
|
To update to a newer version, run:
|
|
|
|
```bash
|
|
python scripts/vendor_triton_kernels.py --tag <new-tag>
|
|
```
|
|
"""
|
|
|
|
readme_file.write_text(readme_content)
|
|
print(f" Created {readme_file}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Vendor triton-kernels into TensorRT-LLM")
|
|
parser.add_argument(
|
|
"--tag",
|
|
required=True,
|
|
help="Triton git tag to vendor from. See the list of tags at https://github.com/triton-lang/triton/tags",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
print(f"Vendoring triton-kernels from Triton {args.tag}")
|
|
print(f"Destination: {DEST_PATH}")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
triton_kernels_src, triton_repo_root = clone_triton(args.tag, tmp_dir)
|
|
copy_triton_kernels(triton_kernels_src, DEST_PATH, args.tag)
|
|
copy_license(triton_repo_root, DEST_PATH)
|
|
create_version_file(DEST_PATH, args.tag)
|
|
create_readme(DEST_PATH, args.tag)
|
|
|
|
print("SUCCESS: triton-kernels has been vendored.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|