TensorRT-LLMs/tests/integration/defs/examples/test_ray.py
Jonas Yang CN 88ea2c4ee9
[TRTLLM-7349][feat] Adding new orchestrator type -- ray (#7520)
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
2025-10-04 08:12:24 +08:00

103 lines
3.3 KiB
Python

import os
import subprocess
import pytest
from defs.common import venv_check_call, wait_for_server
from defs.conftest import get_device_count, llm_models_root
@pytest.fixture(scope="module")
def ray_example_root(llm_root):
example_root = os.path.join(llm_root, "examples", "ray_orchestrator")
return example_root
def test_llm_inference_async_ray(ray_example_root, llm_venv):
script_path = os.path.join(ray_example_root, "llm_inference_async_ray.py")
venv_check_call(llm_venv, [script_path])
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("tp_size,pp_size,ep_size", [
(2, 1, -1),
(1, 2, -1),
(2, 2, -1),
(2, 1, 2),
],
ids=["tp2", "pp2", "tp2pp2", "tep2"])
def test_llm_inference_distributed_ray(ray_example_root, llm_venv, tp_size,
pp_size, ep_size):
world_size = tp_size * pp_size
if get_device_count() < world_size:
pytest.skip(f"Need {world_size} GPUs.")
script_path = os.path.join(ray_example_root,
"llm_inference_distributed_ray.py")
cmd = [
script_path, "--tp_size",
str(tp_size), "--pp_size",
str(pp_size), "--moe_ep_size",
str(ep_size)
]
if ep_size != -1:
model_dir = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
cmd.extend(["--model_dir", model_dir])
venv_check_call(llm_venv, cmd)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("tp_size", [1, 2], ids=["tp1", "tp2"])
def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
if get_device_count() < tp_size * 2:
pytest.skip(f"Need {tp_size * 2} GPUs.")
disagg_dir = os.path.join(ray_example_root, "disaggregated")
script_path = os.path.join(disagg_dir, "disagg_serving_local.sh")
subprocess.run("ray stop --force", shell=True, check=False)
proc = subprocess.Popen(
["bash", script_path, "--executor", "ray", "--tp_size",
str(tp_size)],
cwd=disagg_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
assert wait_for_server("localhost", 8000, timeout_seconds=180), \
"Disaggregated server failed to start within 3 minutes"
result = subprocess.run([
"curl", "-sS", "-w", "\n%{http_code}",
"http://localhost:8000/v1/completions", "-H",
"Content-Type: application/json", "-d",
'{"model":"TinyLlama/TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}'
],
capture_output=True,
text=True,
timeout=30)
*body_lines, status_line = result.stdout.strip().splitlines()
body = "\n".join(body_lines)
status = int(status_line)
print("HTTP status:", status)
print("Response body:", body)
assert result.returncode == 0, f"curl exit {result.returncode}"
assert status == 200, f"Expected 200, got {status}"
finally:
proc.terminate()
try:
proc.wait(timeout=10)
except Exception:
proc.kill()
subprocess.run("ray stop --force", shell=True, check=False)
subprocess.run("pkill -9 -f trtllm-serve", shell=True, check=False)