mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
from typing import List
|
|
|
|
# Matches a Markdown checklist item in the PR body.
|
|
# Expected format: "- [ ] Task description" or "* [x] Task description"
|
|
# Group 1 captures the checkbox state: ' ' (unchecked), 'x' or 'X' (checked).
|
|
# Group 2 captures the task content (the description of the checklist item).
|
|
TASK_PATTERN = re.compile(r'^\s*[-*]\s+\[( |x|X)\]\s*(.*)')
|
|
|
|
|
|
def find_all_tasks(pr_body: str) -> List[str]:
|
|
"""Return list of all task list items (both resolved and unresolved)."""
|
|
tasks: List[str] = []
|
|
for line in pr_body.splitlines():
|
|
match = TASK_PATTERN.match(line)
|
|
if match:
|
|
tasks.append(match.group(0).strip())
|
|
return tasks
|
|
|
|
|
|
def find_unresolved_tasks(pr_body: str) -> List[str]:
|
|
"""Return list of unresolved task list items.
|
|
|
|
A task is considered resolved if it is checked (``[x]`` or ``[X]``)
|
|
or if its text is struck through using ``~~`` markers.
|
|
"""
|
|
unresolved: List[str] = []
|
|
for line in pr_body.splitlines():
|
|
match = TASK_PATTERN.match(line)
|
|
if not match:
|
|
continue
|
|
state, content = match.groups()
|
|
if state.lower() == 'x':
|
|
continue
|
|
# Check if the entire content is struck through
|
|
if content.strip().startswith('~~') and content.strip().endswith('~~'):
|
|
continue
|
|
unresolved.append(match.group(0).strip())
|
|
return unresolved
|
|
|
|
|
|
def check_pr_checklist_section(pr_body: str) -> tuple[bool, str]:
|
|
"""Check if the PR Checklist section exists with the required final checkbox.
|
|
|
|
Returns:
|
|
tuple: (is_valid, error_message)
|
|
"""
|
|
# Check if "## PR Checklist" header exists
|
|
pr_checklist_pattern = re.compile(r'^##\s+PR\s+Checklist',
|
|
re.IGNORECASE | re.MULTILINE)
|
|
if not pr_checklist_pattern.search(pr_body):
|
|
return False, "Missing '## PR Checklist' header. Please ensure you haven't removed the PR template section."
|
|
|
|
# Check if the final checkbox exists (the one users must check)
|
|
final_checkbox_pattern = re.compile(
|
|
r'^\s*[-*]\s+\[( |x|X)\]\s+Please check this after reviewing the above items',
|
|
re.MULTILINE)
|
|
if not final_checkbox_pattern.search(pr_body):
|
|
return False, "Missing the required final checkbox '- [ ] Please check this after reviewing the above items as appropriate for this PR.' Please ensure you haven't removed this from the PR template."
|
|
|
|
return True, ""
|
|
|
|
|
|
def main() -> None:
|
|
pr_body = os.environ.get("PR_BODY", "")
|
|
enforce_checklist = os.environ.get("ENFORCE_PR_HAS_CHECKLIST",
|
|
"false").lower() == "true"
|
|
|
|
# Always check for PR Checklist section when enforcement is enabled
|
|
if enforce_checklist:
|
|
is_valid, error_msg = check_pr_checklist_section(pr_body)
|
|
if not is_valid:
|
|
print(f"Error: {error_msg}")
|
|
sys.exit(1)
|
|
|
|
all_tasks = find_all_tasks(pr_body)
|
|
unresolved = find_unresolved_tasks(pr_body)
|
|
|
|
# Check if we need to enforce the presence of at least one checklist item
|
|
if enforce_checklist and not all_tasks:
|
|
print(
|
|
"Error: PR body must contain at least one checklist item when ENFORCE_PR_HAS_CHECKLIST is enabled."
|
|
)
|
|
print(
|
|
"Expected format: - [ ] Task description or * [ ] Task description")
|
|
sys.exit(1)
|
|
|
|
# If we have tasks, check if any are unresolved
|
|
if unresolved:
|
|
print("Unresolved checklist items found:")
|
|
for item in unresolved:
|
|
print(f"{item}")
|
|
sys.exit(1)
|
|
|
|
if all_tasks:
|
|
print("All checklist items resolved.")
|
|
else:
|
|
print("No checklist items found in PR body.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|