mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9737][chore] Add rl perf reproduce script and enhance the robustness of Ray tests (#9939)
Signed-off-by: Shuyi Xiong <219646547+shuyixiong@users.noreply.github.com>
This commit is contained in:
parent
534700ecd9
commit
f4f0fe85e9
@ -3054,6 +3054,7 @@ def launchTestJobs(pipeline, testFilter)
|
||||
"DGX_H100-4_GPUs-PyTorch-DeepSeek-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-4_GPUs-PyTorch-GptOss-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-4_GPUs-PyTorch-Others-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
|
||||
"DGX_H100-4_GPUs-PyTorch-Ray-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4],
|
||||
"B300-PyTorch-1": ["b300-single", "l0_b300", 1, 1],
|
||||
"DGX_B200-4_GPUs-PyTorch-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
|
||||
"DGX_B200-4_GPUs-PyTorch-Ray-1": ["b200-x4-lbd", "l0_dgx_b200", 1, 1, 4, 1, true],
|
||||
|
||||
@ -91,6 +91,8 @@ class RayWorkerWrapper:
|
||||
store=self.store,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank)
|
||||
assert torch.distributed.get_world_size(
|
||||
) == self.world_size, "Process group world size must match the expected world size"
|
||||
logger.info(
|
||||
f"[Rank {self.rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {self.local_gpu}"
|
||||
)
|
||||
|
||||
@ -37,7 +37,6 @@ import tqdm
|
||||
import yaml
|
||||
from _pytest.mark import ParameterSet
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings import ipc_nvls_supported
|
||||
from tensorrt_llm.llmapi.mpi_session import get_mpi_world_size
|
||||
|
||||
@ -2362,6 +2361,7 @@ def pytest_configure(config):
|
||||
tqdm.tqdm.monitor_interval = 0
|
||||
if config.getoption("--run-ray"):
|
||||
os.environ["TLLM_DISABLE_MPI"] = "1"
|
||||
os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1"
|
||||
|
||||
# Initialize PeriodicJUnitXML reporter if enabled
|
||||
periodic = config.getoption("--periodic-junit", default=False)
|
||||
@ -2825,15 +2825,3 @@ def torch_empty_cache() -> None:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def ray_cleanup(llm_venv) -> None:
|
||||
yield
|
||||
|
||||
if mpi_disabled():
|
||||
llm_venv.run_cmd([
|
||||
"-m",
|
||||
"ray.scripts.scripts",
|
||||
"stop",
|
||||
])
|
||||
|
||||
@ -21,6 +21,12 @@ import tempfile
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
import tensorrt_llm.ray_stub as ray
|
||||
|
||||
import yaml
|
||||
from defs.common import (revise_disagg_config_file_with_free_ports,
|
||||
wait_for_server)
|
||||
@ -30,7 +36,7 @@ from defs.trt_test_alternative import check_call, check_output, popen
|
||||
from test_common.perf_metrics_utils import (get_timing_metrics,
|
||||
validate_timing_metrics)
|
||||
|
||||
from tensorrt_llm._utils import get_free_port, mpi_disabled
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@ -357,8 +363,6 @@ def run_disaggregated_test(example_dir,
|
||||
|
||||
extra_config_files = []
|
||||
workers_cmds = []
|
||||
subprocess.run(['ray', 'start', '--head', '--disable-usage-stats'],
|
||||
check=True)
|
||||
|
||||
# Generate ctx and gen server worker commands
|
||||
ctx_extra_config_file = get_extra_llm_config(config['context_servers'],
|
||||
@ -415,6 +419,21 @@ def run_disaggregated_test(example_dir,
|
||||
use_ray=False)
|
||||
|
||||
else:
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
|
||||
}
|
||||
}
|
||||
ray.init(address="local",
|
||||
include_dashboard=False,
|
||||
ignore_reinit_error=True,
|
||||
runtime_env=runtime_env)
|
||||
gcs_addr = ray.get_runtime_context().gcs_address
|
||||
ray_port = str(gcs_addr.split(":")[1])
|
||||
run_env.update({
|
||||
"RAY_ADDRESS": f"localhost:{ray_port}",
|
||||
"TLLM_RAY_FORCE_LOCAL_CLUSTER": "0"
|
||||
})
|
||||
workers_proc = []
|
||||
with contextlib.ExitStack() as stack:
|
||||
workers_log = stack.enter_context(
|
||||
@ -470,16 +489,16 @@ def run_disaggregated_test(example_dir,
|
||||
logger.error(f.read())
|
||||
raise
|
||||
finally:
|
||||
if use_ray:
|
||||
subprocess.run(['ray', 'stop', '--force'], check=False)
|
||||
for extra_file in extra_config_files:
|
||||
if os.path.exists(extra_file):
|
||||
os.remove(extra_file)
|
||||
elif 'server_proc' in locals() and 'workers_proc' in locals():
|
||||
if 'server_proc' in locals() and 'workers_proc' in locals():
|
||||
server_proc.terminate()
|
||||
workers_proc.terminate()
|
||||
server_proc.wait()
|
||||
workers_proc.wait()
|
||||
if use_ray:
|
||||
ray.shutdown()
|
||||
for extra_file in extra_config_files:
|
||||
if os.path.exists(extra_file):
|
||||
os.remove(extra_file)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
import tensorrt_llm.ray_stub as ray
|
||||
|
||||
import pytest
|
||||
from defs.common import venv_check_call, wait_for_server
|
||||
from defs.conftest import get_device_count, llm_models_root
|
||||
from defs.trt_test_alternative import popen
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -65,48 +71,57 @@ def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
|
||||
disagg_dir = os.path.join(ray_example_root, "disaggregated")
|
||||
script_path = os.path.join(disagg_dir, "disagg_serving_local.sh")
|
||||
model_dir = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
subprocess.run("ray stop --force", shell=True, check=False)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"bash", script_path, "--executor", "ray", "--model", model_dir,
|
||||
"--tp_size",
|
||||
str(tp_size)
|
||||
],
|
||||
cwd=disagg_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
try:
|
||||
assert wait_for_server("localhost", 8000, timeout_seconds=180), \
|
||||
"Disaggregated server failed to start within 3 minutes"
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
|
||||
}
|
||||
}
|
||||
ray.init(address="local",
|
||||
include_dashboard=False,
|
||||
ignore_reinit_error=True,
|
||||
runtime_env=runtime_env)
|
||||
gcs_addr = ray.get_runtime_context().gcs_address
|
||||
ray_port = str(gcs_addr.split(":")[1])
|
||||
|
||||
result = subprocess.run([
|
||||
"curl", "-sS", "-w", "\n%{http_code}",
|
||||
"http://localhost:8000/v1/completions", "-H",
|
||||
"Content-Type: application/json", "-d",
|
||||
'{"model":"TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}'
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30)
|
||||
env_copy = os.environ.copy()
|
||||
env_copy.update({
|
||||
"RAY_ADDRESS": f"localhost:{ray_port}",
|
||||
"TLLM_RAY_FORCE_LOCAL_CLUSTER": "0"
|
||||
})
|
||||
with popen(
|
||||
[
|
||||
"bash", script_path, "--executor", "ray", "--attach", "--model",
|
||||
model_dir, "--tp_size",
|
||||
str(tp_size)
|
||||
],
|
||||
cwd=disagg_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env=env_copy,
|
||||
):
|
||||
assert wait_for_server("localhost", 8000, timeout_seconds=180), \
|
||||
"Disaggregated server failed to start within 3 minutes"
|
||||
|
||||
*body_lines, status_line = result.stdout.strip().splitlines()
|
||||
body = "\n".join(body_lines)
|
||||
status = int(status_line)
|
||||
result = subprocess.run([
|
||||
"curl", "-sS", "-w", "\n%{http_code}",
|
||||
"http://localhost:8000/v1/completions", "-H",
|
||||
"Content-Type: application/json", "-d",
|
||||
'{"model":"TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}'
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30)
|
||||
|
||||
print("HTTP status:", status)
|
||||
print("Response body:", body)
|
||||
*body_lines, status_line = result.stdout.strip().splitlines()
|
||||
body = "\n".join(body_lines)
|
||||
status = int(status_line)
|
||||
|
||||
assert result.returncode == 0, f"curl exit {result.returncode}"
|
||||
assert status == 200, f"Expected 200, got {status}"
|
||||
print("HTTP status:", status)
|
||||
print("Response body:", body)
|
||||
|
||||
assert result.returncode == 0, f"curl exit {result.returncode}"
|
||||
assert status == 200, f"Expected 200, got {status}"
|
||||
finally:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except Exception:
|
||||
proc.kill()
|
||||
|
||||
subprocess.run("ray stop --force", shell=True, check=False)
|
||||
subprocess.run("pkill -9 -f trtllm-serve", shell=True, check=False)
|
||||
ray.shutdown()
|
||||
|
||||
59
tests/integration/defs/ray_orchestrator/RL/README.md
Normal file
59
tests/integration/defs/ray_orchestrator/RL/README.md
Normal file
@ -0,0 +1,59 @@
|
||||
# RL Framework Integration Tests
|
||||
|
||||
This directory contains integration tests for TensorRT-LLM with [Ray orchestrator](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/ray_orchestrator), specifically designed to cover usage patterns from various RL (Reinforcement Learning) frameworks such as VeRL and NeMo RL.
|
||||
|
||||
## Available Scripts
|
||||
|
||||
| Script | Description |
|
||||
|--------|-------------|
|
||||
| `run_rl_perf_reproduce.py` | Emulates RL workload performance with multiple AsyncLLM instances distributed across GPUs using Ray placement groups |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### RL Performance Reproduction
|
||||
|
||||
The `run_rl_perf_reproduce.py` script creates multiple TensorRT-LLM instances in parallel to simulate RL rollout workloads.
|
||||
|
||||
**TP=4 with 2 instances (8 GPUs total):**
|
||||
|
||||
```bash
|
||||
python run_rl_perf_reproduce.py \
|
||||
--model_dir /path/to/model_dir \
|
||||
--data_path /path/to/prompts.json \
|
||||
--num_instances 2 \
|
||||
--tp_size 4 \
|
||||
--top_p 1 \
|
||||
--logprobs 1 \
|
||||
--max_batch_size 1024 \
|
||||
--enable_cuda_graph_padding
|
||||
```
|
||||
|
||||
**TP=1 with 8 instances (8 GPUs total):**
|
||||
|
||||
```bash
|
||||
python run_rl_perf_reproduce.py \
|
||||
--model_dir /path/to/model_dir \
|
||||
--data_path /path/to/prompts.json \
|
||||
--num_instances 8 \
|
||||
--tp_size 1 \
|
||||
--top_p 1 \
|
||||
--logprobs 1 \
|
||||
--max_batch_size 384 \
|
||||
--enable_cuda_graph_padding
|
||||
```
|
||||
|
||||
## Data Format
|
||||
|
||||
The `--data_path` should point to a JSON file containing a list of prompts, where each prompt is a list of token IDs:
|
||||
|
||||
```json
|
||||
[
|
||||
[1, 2345, 6789, ...],
|
||||
[1, 3456, 7890, ...],
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- RL Perf reproduction scripts support single-node execution only (max 8 GPUs)
|
||||
@ -0,0 +1,336 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import (
|
||||
PlacementGroupSchedulingStrategy,
|
||||
placement_group,
|
||||
remove_placement_group,
|
||||
)
|
||||
from transformers import AutoConfig
|
||||
|
||||
from tensorrt_llm import AsyncLLM
|
||||
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SamplingParams
|
||||
|
||||
|
||||
@ray.remote
|
||||
class TRTLLMInstance:
|
||||
"""Ray actor wrapping an AsyncLLM instance for distributed RL workloads.
|
||||
|
||||
This actor manages a single TensorRT-LLM instance that can be scheduled
|
||||
on specific GPUs using Ray placement groups. Multiple instances can run
|
||||
in parallel for high-throughput RL generation.
|
||||
|
||||
Attributes:
|
||||
async_llm_kwargs: Configuration dict for AsyncLLM initialization
|
||||
sampling_kwargs: Configuration dict for SamplingParams
|
||||
llm: The underlying AsyncLLM instance (initialized via init_llm)
|
||||
sampling_params: SamplingParams object for generation
|
||||
"""
|
||||
|
||||
def __init__(self, async_llm_kwargs: dict, sampling_kwargs: dict):
|
||||
self.async_llm_kwargs = async_llm_kwargs
|
||||
self.sampling_kwargs = sampling_kwargs
|
||||
self.llm = None
|
||||
self.sampling_params = None
|
||||
|
||||
async def init_llm(self):
|
||||
"""Initialize the AsyncLLM instance with configured parameters."""
|
||||
self.llm = await AsyncLLM(
|
||||
model=self.async_llm_kwargs["model"],
|
||||
backend="pytorch",
|
||||
orchestrator_type=self.async_llm_kwargs["orchestrator_type"],
|
||||
ray_worker_extension_cls=self.async_llm_kwargs["ray_worker_extension_cls"],
|
||||
kv_cache_config=KvCacheConfig(**self.async_llm_kwargs["kv_cache_config"]),
|
||||
cuda_graph_config=CudaGraphConfig(**self.async_llm_kwargs["cuda_graph_config"]),
|
||||
max_seq_len=self.async_llm_kwargs["max_seq_len"],
|
||||
max_batch_size=self.async_llm_kwargs["max_batch_size"],
|
||||
max_num_tokens=self.async_llm_kwargs["max_num_tokens"],
|
||||
tensor_parallel_size=self.async_llm_kwargs["tensor_parallel_size"],
|
||||
trust_remote_code=self.async_llm_kwargs["trust_remote_code"],
|
||||
enable_sleep=True,
|
||||
sampler_type=self.async_llm_kwargs["sampler_type"],
|
||||
placement_groups=self.async_llm_kwargs["placement_groups"],
|
||||
placement_bundle_indices=self.async_llm_kwargs["placement_bundle_indices"],
|
||||
per_worker_gpu_share=self.async_llm_kwargs["per_worker_gpu_share"],
|
||||
batch_wait_timeout_iters=32,
|
||||
batch_wait_max_tokens_ratio=0.5,
|
||||
)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=self.sampling_kwargs["temperature"],
|
||||
top_p=self.sampling_kwargs["top_p"],
|
||||
top_k=self.sampling_kwargs["top_k"],
|
||||
max_tokens=self.sampling_kwargs["max_tokens"],
|
||||
logprobs=self.sampling_kwargs["logprobs"],
|
||||
detokenize=self.sampling_kwargs["detokenize"],
|
||||
end_id=self.sampling_kwargs["end_id"],
|
||||
pad_id=self.sampling_kwargs["pad_id"],
|
||||
stop_token_ids=self.sampling_kwargs["stop_token_ids"],
|
||||
include_stop_str_in_output=self.sampling_kwargs["include_stop_str_in_output"],
|
||||
)
|
||||
|
||||
async def generate(self, prompt: list[int]):
|
||||
"""Generate output tokens for a single prompt.
|
||||
|
||||
Args:
|
||||
prompt: List of input token IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (token_ids, log_probs):
|
||||
- token_ids: List of generated token IDs
|
||||
- log_probs: List of log probabilities (if logprobs enabled, else None)
|
||||
"""
|
||||
outputs = await self.llm.generate_async(inputs=prompt, sampling_params=self.sampling_params)
|
||||
token_ids = outputs.outputs[0].token_ids
|
||||
log_probs = None
|
||||
if self.sampling_kwargs["logprobs"] is not None:
|
||||
log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs]
|
||||
return token_ids, log_probs
|
||||
|
||||
|
||||
async def setup_rl_llm(args):
|
||||
"""Main setup and execution function for RL LLM workloads.
|
||||
|
||||
This function:
|
||||
1. Loads prompts from the input JSON file
|
||||
2. Initializes Ray with placement groups for GPU allocation
|
||||
3. Creates multiple TRTLLMInstance actors distributed across GPUs
|
||||
4. Distributes prompts round-robin across instances
|
||||
5. Runs async generation and reports throughput metrics
|
||||
|
||||
Args:
|
||||
args: Parsed command-line arguments
|
||||
"""
|
||||
# Load prompts from JSON file (expected format: list of token ID lists)
|
||||
data_path = Path(args.data_path)
|
||||
with open(data_path, "r") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
num_instances = args.num_instances
|
||||
num_gpus = args.tp_size * num_instances
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 8:
|
||||
raise ValueError(
|
||||
f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node."
|
||||
)
|
||||
if available_gpus < num_gpus:
|
||||
raise ValueError(
|
||||
f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})."
|
||||
)
|
||||
|
||||
# Prevent Ray from setting CUDA_VISIBLE_DEVICES automatically
|
||||
# This allows TensorRT-LLM to manage GPU visibility internally
|
||||
os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1"
|
||||
runtime_env = {"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}
|
||||
pg = None
|
||||
|
||||
try:
|
||||
ray.init(address="local")
|
||||
gcs_addr = ray.get_runtime_context().gcs_address
|
||||
port = int(gcs_addr.split(":")[1])
|
||||
# Force ray.init("auto") to attach to a specific cluster via RAY_ADDRESS
|
||||
os.environ["RAY_ADDRESS"] = f"localhost:{port}"
|
||||
|
||||
# Create placement group with one bundle per GPU
|
||||
# STRICT_PACK ensures all bundles are on the same node
|
||||
pg = placement_group(
|
||||
[{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK"
|
||||
)
|
||||
|
||||
# Wait for placement group to be ready
|
||||
ray.get(pg.ready())
|
||||
|
||||
# Configure placement groups for each instance
|
||||
# Each instance gets a contiguous range of GPU bundles for tensor parallelism
|
||||
# Example with num_instances=2, tp_size=2:
|
||||
# Instance 0: bundles [0, 1] -> GPUs 0, 1
|
||||
# Instance 1: bundles [2, 3] -> GPUs 2, 3
|
||||
tp_size = args.tp_size
|
||||
placement_group_list = [[pg] for _ in range(num_instances)]
|
||||
placement_bundle_indices_list = [
|
||||
[list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances)
|
||||
]
|
||||
|
||||
# Create TRTLLMInstance actors for each parallel instance
|
||||
llm_instances = []
|
||||
for i in range(num_instances):
|
||||
llm_instances.append(
|
||||
TRTLLMInstance.options(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
runtime_env=runtime_env,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
).remote(
|
||||
async_llm_kwargs={
|
||||
"model": args.model_dir,
|
||||
"backend": "pytorch",
|
||||
"orchestrator_type": "ray",
|
||||
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": args.enable_block_reuse,
|
||||
"free_gpu_memory_fraction": args.kv_cache_fraction,
|
||||
},
|
||||
"cuda_graph_config": {
|
||||
"enable_padding": args.enable_cuda_graph_padding,
|
||||
"batch_sizes": args.batch_sizes,
|
||||
"max_batch_size": 0 if args.batch_sizes else args.max_batch_size,
|
||||
},
|
||||
"max_seq_len": args.max_seq_len,
|
||||
"max_batch_size": args.max_batch_size,
|
||||
"max_num_tokens": args.max_num_tokens,
|
||||
"tensor_parallel_size": args.tp_size,
|
||||
"trust_remote_code": args.trust_remote_code,
|
||||
"enable_sleep": True,
|
||||
"sampler_type": args.sampler_type,
|
||||
"placement_groups": placement_group_list[i],
|
||||
"placement_bundle_indices": placement_bundle_indices_list[i],
|
||||
"per_worker_gpu_share": 0.5,
|
||||
},
|
||||
sampling_kwargs={
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"top_k": args.top_k,
|
||||
"max_tokens": args.max_tokens,
|
||||
"logprobs": args.logprobs,
|
||||
"detokenize": False,
|
||||
"end_id": -1,
|
||||
"pad_id": hf_config.pad_token_id,
|
||||
"stop_token_ids": [hf_config.eos_token_id],
|
||||
"include_stop_str_in_output": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
# Wait for all Ray actors to be ready, then initialize LLM instances
|
||||
ray.get([llm.__ray_ready__.remote() for llm in llm_instances])
|
||||
ray.get([llm.init_llm.remote() for llm in llm_instances])
|
||||
|
||||
total_prompts = len(prompts)
|
||||
|
||||
print(
|
||||
f"Starting generation for {total_prompts} prompts across {num_instances} instances..."
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
# Helper function to wrap Ray remote call as async coroutine
|
||||
async def generate_single_prompt(instance, prompt):
|
||||
"""Generate a single prompt asynchronously."""
|
||||
object_ref = instance.generate.remote(prompt=prompt)
|
||||
result = await asyncio.to_thread(ray.get, object_ref)
|
||||
return result
|
||||
|
||||
# Create tasks with round-robin distribution
|
||||
tasks = [
|
||||
generate_single_prompt(llm_instances[idx % num_instances], prompt)
|
||||
for idx, prompt in enumerate(prompts)
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
||||
print(f"Total prompts: {total_prompts}")
|
||||
print(f"Throughput: {total_prompts / (end_time - start_time):.2f} prompts/sec")
|
||||
finally:
|
||||
if pg is not None:
|
||||
remove_placement_group(pg)
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def add_rl_llm_args(parser):
|
||||
"""Add command-line arguments for RL LLM configuration."""
|
||||
# Required arguments
|
||||
parser.add_argument("--model_dir", type=str, required=True, help="Model checkpoint directory.")
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input data file path, expected format: list of token ID lists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_instances", type=int, required=True, help="Number of TRTLLM instances."
|
||||
)
|
||||
|
||||
# AsyncLLM parameters
|
||||
parser.add_argument("--tp_size", type=int, required=True, help="Tensor parallel size.")
|
||||
parser.add_argument("--max_seq_len", type=int, default=2048, help="Maximum sequence length.")
|
||||
parser.add_argument("--max_batch_size", type=int, default=384, help="Maximum batch size.")
|
||||
parser.add_argument(
|
||||
"--max_num_tokens", type=int, default=32768, help="Maximum number of tokens."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler_type",
|
||||
type=str,
|
||||
default="TRTLLMSampler",
|
||||
choices=["TRTLLMSampler", "TorchSampler"],
|
||||
help="Sampler type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust_remote_code",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to trust remote code.",
|
||||
)
|
||||
|
||||
# KV Cache Config parameters
|
||||
parser.add_argument(
|
||||
"--kv_cache_fraction",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="The fraction of GPU memory to be used for KV cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_block_reuse",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to enable block reuse for KV cache.",
|
||||
)
|
||||
|
||||
# Cuda Graph Config parameters
|
||||
parser.add_argument(
|
||||
"--enable_cuda_graph_padding",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to enable padding for CUDA graphs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="The batch sizes to be used for CUDA graphs. Example: --batch_sizes 16 32 64 128 256",
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
parser.add_argument("--max_tokens", type=int, default=1024)
|
||||
parser.add_argument("--temperature", type=float, default=1)
|
||||
parser.add_argument("--top_k", type=int, default=None)
|
||||
parser.add_argument("--top_p", type=float, default=None)
|
||||
parser.add_argument("--logprobs", type=int, default=None)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="RL flow performance reproduction.")
|
||||
parser = add_rl_llm_args(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
asyncio.run(setup_rl_llm(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -0,0 +1,59 @@
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from defs.common import venv_check_call
|
||||
from defs.conftest import integration_path, llm_models_root
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size, num_instances", [(2, 2), (1, 4)], ids=["tp2_2instances", "tp1_4instances"]
|
||||
)
|
||||
def test_rl_perf_reproduce(llm_venv, tp_size, num_instances):
|
||||
script_path = (
|
||||
integration_path() / "defs" / "ray_orchestrator" / "RL" / "run_rl_perf_reproduce.py"
|
||||
)
|
||||
model_dir = f"{llm_models_root()}/Qwen2-7B-Instruct"
|
||||
|
||||
if tp_size == 2:
|
||||
max_batch_size = 512
|
||||
else:
|
||||
max_batch_size = 256
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
prompt_text = "The president of the United States is"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
|
||||
# Replicate to create batch of 1024 prompts
|
||||
batch_size = 1024
|
||||
prompts = [token_ids for _ in range(batch_size)]
|
||||
|
||||
data_path = Path(tmpdir) / "prompts.json"
|
||||
with open(data_path, "w") as f:
|
||||
json.dump(prompts, f)
|
||||
|
||||
venv_check_call(
|
||||
llm_venv,
|
||||
[
|
||||
str(script_path),
|
||||
"--model_dir",
|
||||
model_dir,
|
||||
"--data_path",
|
||||
str(data_path),
|
||||
"--num_instances",
|
||||
str(num_instances),
|
||||
"--tp_size",
|
||||
str(tp_size),
|
||||
"--logprobs",
|
||||
"1",
|
||||
"--max_batch_size",
|
||||
str(max_batch_size),
|
||||
"--enable_block_reuse",
|
||||
"--enable_cuda_graph_padding",
|
||||
],
|
||||
)
|
||||
@ -77,6 +77,7 @@ def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request):
|
||||
test_prefix = "unittest"
|
||||
|
||||
waives_file = request.config.getoption("--waives-file")
|
||||
run_ray = request.config.getoption("--run-ray")
|
||||
|
||||
num_workers = 1
|
||||
|
||||
@ -138,6 +139,9 @@ def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request):
|
||||
waives_file = os.path.abspath(waives_file)
|
||||
command += [f"--waives-file={waives_file}"]
|
||||
|
||||
if run_ray:
|
||||
command += ["--run-ray"]
|
||||
|
||||
command += arg_list
|
||||
|
||||
print(f"Running unit test:\"python {' '.join(command)}\"")
|
||||
|
||||
@ -303,3 +303,5 @@ l0_dgx_h100:
|
||||
tests:
|
||||
- unittest/_torch/ray_orchestrator/multi_gpu -m "gpu4"
|
||||
- unittest/llmapi/test_async_llm.py -m "gpu4"
|
||||
- ray_orchestrator/RL/test_rl_perf_reproduce.py::test_rl_perf_reproduce[tp2_2instances]
|
||||
- ray_orchestrator/RL/test_rl_perf_reproduce.py::test_rl_perf_reproduce[tp1_4instances]
|
||||
|
||||
@ -3,11 +3,6 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
|
||||
|
||||
@ -26,24 +21,3 @@ if not mpi_disabled():
|
||||
pytest.skip(
|
||||
"Ray tests are only tested in Ray CI stage or with --run-ray flag",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def setup_ray_cluster():
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
|
||||
}
|
||||
}
|
||||
ray_init_args = {
|
||||
"include_dashboard": False,
|
||||
"namespace": "test",
|
||||
"ignore_reinit_error": True,
|
||||
"runtime_env": runtime_env
|
||||
}
|
||||
try:
|
||||
ray.init(address="local", **ray_init_args)
|
||||
yield
|
||||
finally:
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
|
||||
@ -25,12 +25,12 @@ def test_worker_extension():
|
||||
|
||||
|
||||
@pytest.mark.gpu4
|
||||
def test_placement_env_vars(monkeypatch):
|
||||
def test_placement_env_vars(setup_ray_cluster, monkeypatch):
|
||||
port = setup_ray_cluster
|
||||
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
|
||||
|
||||
pg = None
|
||||
try:
|
||||
ray.init()
|
||||
pg = placement_group([{"GPU": 1, "CPU": 1}] * 4)
|
||||
ray.get(pg.ready())
|
||||
print(f"Placement group ready with bundles {pg.bundle_specs}")
|
||||
@ -39,7 +39,9 @@ def test_placement_env_vars(monkeypatch):
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"TRTLLM_RAY_PER_WORKER_GPUS": "0.8",
|
||||
"TRTLLM_RAY_BUNDLE_INDICES": ",".join(map(str, bundle_indices))
|
||||
"TRTLLM_RAY_BUNDLE_INDICES": ",".join(map(str, bundle_indices)),
|
||||
"RAY_ADDRESS": f"localhost:{port}",
|
||||
"TLLM_RAY_FORCE_LOCAL_CLUSTER": "0"
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,7 +72,6 @@ def test_placement_env_vars(monkeypatch):
|
||||
finally:
|
||||
if pg is not None:
|
||||
remove_placement_group(pg)
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
@ -79,13 +80,14 @@ def test_placement_env_vars(monkeypatch):
|
||||
(2, [1]),
|
||||
],
|
||||
ids=["gpu2_tp1"])
|
||||
def test_placement_api(monkeypatch, n_gpus, bundle_indices):
|
||||
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
|
||||
def test_placement_api(setup_ray_cluster, monkeypatch, n_gpus, bundle_indices):
|
||||
|
||||
port = setup_ray_cluster
|
||||
monkeypatch.setenv("RAY_ADDRESS", f"localhost:{port}")
|
||||
monkeypatch.setenv("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0")
|
||||
tp_size = n_gpus // 2
|
||||
pg = None
|
||||
try:
|
||||
ray.init()
|
||||
pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus)
|
||||
ray.get(pg.ready())
|
||||
print(f"Placement group ready with bundles {pg.bundle_specs}")
|
||||
@ -116,7 +118,6 @@ def test_placement_api(monkeypatch, n_gpus, bundle_indices):
|
||||
finally:
|
||||
if pg is not None:
|
||||
remove_placement_group(pg)
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
|
||||
@ -0,0 +1,120 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import (
|
||||
PlacementGroupSchedulingStrategy,
|
||||
placement_group,
|
||||
remove_placement_group,
|
||||
)
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import AsyncLLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
|
||||
|
||||
@ray.remote
|
||||
class TRTLLMInstance:
|
||||
def __init__(self, async_llm_kwargs: dict):
|
||||
self.llm = AsyncLLM(
|
||||
model=async_llm_kwargs["model"],
|
||||
backend="pytorch",
|
||||
orchestrator_type=async_llm_kwargs["orchestrator_type"],
|
||||
kv_cache_config=KvCacheConfig(**async_llm_kwargs["kv_cache_config"]),
|
||||
tensor_parallel_size=async_llm_kwargs["tensor_parallel_size"],
|
||||
placement_groups=async_llm_kwargs["placement_groups"],
|
||||
placement_bundle_indices=async_llm_kwargs["placement_bundle_indices"],
|
||||
per_worker_gpu_share=async_llm_kwargs["per_worker_gpu_share"],
|
||||
)
|
||||
|
||||
async def init_llm(self):
|
||||
await self.llm.setup_async()
|
||||
|
||||
def shutdown_llm(self):
|
||||
self.llm.shutdown()
|
||||
self.llm = None
|
||||
|
||||
|
||||
@pytest.mark.gpu4
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size, num_instances", [(2, 2), (1, 4)], ids=["tp2_2instances", "tp1_4instances"]
|
||||
)
|
||||
def test_multi_instance(setup_ray_cluster, tp_size, num_instances):
|
||||
"""Test that multiple TRTLLMInstance actors can be started without port conflicts.
|
||||
|
||||
This test guards against port conflict failures when launching multiple
|
||||
TensorRT-LLM instances concurrently. It runs multiple iterations to ensure
|
||||
reliable instance creation and teardown.
|
||||
"""
|
||||
port = setup_ray_cluster
|
||||
num_gpus = tp_size * num_instances
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 8:
|
||||
raise ValueError(
|
||||
f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node."
|
||||
)
|
||||
if available_gpus < num_gpus:
|
||||
raise ValueError(
|
||||
f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})."
|
||||
)
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update(
|
||||
{
|
||||
"TLLM_RAY_FORCE_LOCAL_CLUSTER": "0",
|
||||
"RAY_ADDRESS": f"localhost:{port}",
|
||||
}
|
||||
)
|
||||
|
||||
# Run multiple iterations to guard against port conflict issues
|
||||
execution_times = 5
|
||||
for iteration in range(execution_times):
|
||||
pg = None
|
||||
llm_instances = []
|
||||
try:
|
||||
pg = placement_group(
|
||||
[{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK"
|
||||
)
|
||||
|
||||
ray.get(pg.ready())
|
||||
|
||||
placement_group_list = [[pg] for _ in range(num_instances)]
|
||||
placement_bundle_indices_list = [
|
||||
[list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances)
|
||||
]
|
||||
|
||||
for i in range(num_instances):
|
||||
llm_instances.append(
|
||||
TRTLLMInstance.options(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
runtime_env=runtime_env,
|
||||
).remote(
|
||||
async_llm_kwargs={
|
||||
"model": os.path.join(
|
||||
llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"
|
||||
),
|
||||
"kv_cache_config": {
|
||||
"free_gpu_memory_fraction": 0.1,
|
||||
},
|
||||
"tensor_parallel_size": tp_size,
|
||||
"orchestrator_type": "ray",
|
||||
"placement_groups": placement_group_list[i],
|
||||
"placement_bundle_indices": placement_bundle_indices_list[i],
|
||||
"per_worker_gpu_share": 0.5,
|
||||
}
|
||||
)
|
||||
)
|
||||
ray.get([llm.__ray_ready__.remote() for llm in llm_instances])
|
||||
ray.get([llm.init_llm.remote() for llm in llm_instances])
|
||||
finally:
|
||||
# Clean up actors before removing placement group
|
||||
for llm in llm_instances:
|
||||
ray.get(llm.shutdown_llm.remote())
|
||||
if pg is not None:
|
||||
remove_placement_group(pg)
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from operator import attrgetter
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -10,7 +11,6 @@ except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
from tensorrt_llm._torch.distributed.communicator import TorchDist
|
||||
from tensorrt_llm._utils import get_free_port
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -22,7 +22,6 @@ class PgOpTest:
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_address = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
|
||||
assert len(ray.get_gpu_ids()) == 1
|
||||
self.gpu = int(ray.get_gpu_ids()[0])
|
||||
@ -30,12 +29,30 @@ class PgOpTest:
|
||||
local_gpu = RayWorkerWrapper.physical_to_local_id(self.gpu)
|
||||
torch.cuda.set_device(local_gpu)
|
||||
|
||||
torch.distributed.init_process_group(
|
||||
backend="cuda:nccl,cpu:gloo",
|
||||
init_method=f"tcp://{self.master_address}:{self.master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
def _create_tcp_store(self,
|
||||
port: Optional[int] = None
|
||||
) -> torch.distributed.TCPStore:
|
||||
actual_port = port if port is not None else 0
|
||||
return torch.distributed.TCPStore(host_name=self.master_address,
|
||||
port=actual_port,
|
||||
world_size=self.world_size,
|
||||
is_master=(self.rank == 0),
|
||||
wait_for_workers=False)
|
||||
|
||||
def setup_tcp_store(self):
|
||||
if self.rank != 0:
|
||||
raise RuntimeError("Only the master worker can setup TCP store")
|
||||
self.store = self._create_tcp_store()
|
||||
return self.store.port
|
||||
|
||||
def setup_distributed_env(self, port: int):
|
||||
if self.rank != 0:
|
||||
self.store = self._create_tcp_store(port)
|
||||
|
||||
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo",
|
||||
store=self.store,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank)
|
||||
self.mapping = Mapping(world_size=self.world_size,
|
||||
gpus_per_node=self.world_size,
|
||||
tp_size=self.world_size,
|
||||
@ -91,13 +108,11 @@ def test_allgather_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len):
|
||||
sizes = None
|
||||
|
||||
remotePGTests = []
|
||||
master_port = get_free_port()
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port)
|
||||
})
|
||||
|
||||
for rank in range(world_size):
|
||||
@ -107,6 +122,12 @@ def test_allgather_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len):
|
||||
ray.get(
|
||||
[remotePGTest.__ray_ready__.remote() for remotePGTest in remotePGTests])
|
||||
|
||||
port = ray.get(remotePGTests[0].setup_tcp_store.remote())
|
||||
ray.get([
|
||||
remotePGTest.setup_distributed_env.remote(port)
|
||||
for remotePGTest in remotePGTests
|
||||
])
|
||||
|
||||
if var_len:
|
||||
results = ray.get([
|
||||
remotePGTest.run.remote("allgather_pg",
|
||||
@ -155,13 +176,11 @@ def test_reducescatter_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len):
|
||||
]
|
||||
sizes = None
|
||||
|
||||
master_port = get_free_port()
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port)
|
||||
})
|
||||
|
||||
remotePGTests = []
|
||||
@ -172,6 +191,12 @@ def test_reducescatter_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len):
|
||||
ray.get(
|
||||
[remotePGTest.__ray_ready__.remote() for remotePGTest in remotePGTests])
|
||||
|
||||
port = ray.get(remotePGTests[0].setup_tcp_store.remote())
|
||||
ray.get([
|
||||
remotePGTest.setup_distributed_env.remote(port)
|
||||
for remotePGTest in remotePGTests
|
||||
])
|
||||
|
||||
results = ray.get([
|
||||
remotePGTest.run.remote("reducescatter_pg",
|
||||
test_tensor,
|
||||
@ -194,13 +219,11 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size):
|
||||
test_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
|
||||
expected_result = test_tensor * world_size
|
||||
|
||||
master_port = get_free_port()
|
||||
runtime_env = ray.runtime_env.RuntimeEnv()
|
||||
runtime_env["env_vars"] = os.environ.copy()
|
||||
runtime_env["env_vars"].update({
|
||||
"TLLM_DISABLE_MPI": "1",
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port)
|
||||
})
|
||||
|
||||
remotePGTests = []
|
||||
@ -211,6 +234,12 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size):
|
||||
ray.get(
|
||||
[remotePGTest.__ray_ready__.remote() for remotePGTest in remotePGTests])
|
||||
|
||||
port = ray.get(remotePGTests[0].setup_tcp_store.remote())
|
||||
ray.get([
|
||||
remotePGTest.setup_distributed_env.remote(port)
|
||||
for remotePGTest in remotePGTests
|
||||
])
|
||||
|
||||
results = ray.get([
|
||||
remotePGTest.run.remote("allreduce_pg",
|
||||
test_tensor,
|
||||
|
||||
@ -18,7 +18,12 @@ import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Generator
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ModuleNotFoundError:
|
||||
from tensorrt_llm import ray_stub as ray
|
||||
|
||||
import _pytest.outcomes
|
||||
import pytest
|
||||
@ -344,3 +349,26 @@ def process_gpu_memory_info_available():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def setup_ray_cluster() -> Generator[int, None, None]:
|
||||
runtime_env = {
|
||||
"env_vars": {
|
||||
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"
|
||||
}
|
||||
}
|
||||
ray_init_args = {
|
||||
"include_dashboard": False,
|
||||
"namespace": "test",
|
||||
"ignore_reinit_error": True,
|
||||
"runtime_env": runtime_env
|
||||
}
|
||||
try:
|
||||
ray.init(address="local", **ray_init_args)
|
||||
gcs_addr = ray.get_runtime_context().gcs_address
|
||||
port = int(gcs_addr.split(":")[1])
|
||||
yield port
|
||||
finally:
|
||||
if ray.is_initialized():
|
||||
ray.shutdown()
|
||||
|
||||
@ -96,9 +96,10 @@ async def test_async_llm_release_resume(process_gpu_memory_info_available, num_c
|
||||
@pytest.mark.ray
|
||||
@pytest.mark.gpu4
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
async def test_async_llm_placement_api(monkeypatch):
|
||||
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
|
||||
async def test_async_llm_placement_api(setup_ray_cluster, monkeypatch):
|
||||
port = setup_ray_cluster
|
||||
monkeypatch.setenv("RAY_ADDRESS", f"localhost:{port}")
|
||||
monkeypatch.setenv("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0")
|
||||
|
||||
n_gpus = 4
|
||||
bundle_indices = [2, 3]
|
||||
@ -106,7 +107,6 @@ async def test_async_llm_placement_api(monkeypatch):
|
||||
|
||||
pg = None
|
||||
try:
|
||||
ray.init()
|
||||
pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus)
|
||||
ray.get(pg.ready())
|
||||
print(f"Placement group ready with bundles {pg.bundle_specs}")
|
||||
@ -132,6 +132,6 @@ async def test_async_llm_placement_api(monkeypatch):
|
||||
)
|
||||
|
||||
finally:
|
||||
llm.shutdown()
|
||||
if pg is not None:
|
||||
remove_placement_group(pg)
|
||||
ray.shutdown()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user