mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
e0a45f1455
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: hao-aaron <ahao@anyscale.com>
63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import sys
|
|
|
|
import regex as re
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
|
|
# --------------------------------------------------------------------------- #
|
|
_TORCH_CUDA_PATTERNS = [
|
|
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
|
|
r"\btorch\.cuda\.(manual_seed|manual_seed_all)\b",
|
|
r"\bwith\storch\.cuda\.device\b",
|
|
# Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
|
|
r"\bcuda_device_count_stateless\(\)\b",
|
|
]
|
|
|
|
ALLOWED_FILES = {
|
|
"vllm/platforms/",
|
|
"vllm/device_allocator/",
|
|
"vllm/distributed/weight_transfer/ipc_engine.py",
|
|
"tests/distributed/test_packed_tensor.py",
|
|
}
|
|
|
|
|
|
def scan_file(path: str) -> int:
|
|
with open(path, encoding="utf-8") as f:
|
|
content = f.read()
|
|
for pattern in _TORCH_CUDA_PATTERNS:
|
|
for match in re.finditer(pattern, content, re.MULTILINE):
|
|
# Calculate line number from match position
|
|
line_num = content[: match.start() + 1].count("\n") + 1
|
|
matched_text = match.group(0)
|
|
if "manual_seed" in matched_text:
|
|
print(
|
|
f"{path}:{line_num}: "
|
|
"\033[91merror:\033[0m "
|
|
f"Found {matched_text} API call. Use set_random_seed instead."
|
|
)
|
|
return 1
|
|
print(
|
|
f"{path}:{line_num}: "
|
|
"\033[91merror:\033[0m " # red color
|
|
"Found torch.cuda API call. Please refer RFC "
|
|
"https://github.com/vllm-project/vllm/issues/30679, use "
|
|
"torch.accelerator API instead."
|
|
)
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def main():
|
|
returncode = 0
|
|
for filename in sys.argv[1:]:
|
|
if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
|
|
continue
|
|
returncode |= scan_file(filename)
|
|
return returncode
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|