diff --git a/jenkins/L0_Test.groovy b/jenkins/L0_Test.groovy index faeb05c508..0699976f4d 100644 --- a/jenkins/L0_Test.groovy +++ b/jenkins/L0_Test.groovy @@ -3054,6 +3054,7 @@ def launchTestJobs(pipeline, testFilter) "DGX_H100-4_GPUs-PyTorch-DeepSeek-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4], "DGX_H100-4_GPUs-PyTorch-GptOss-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4], "DGX_H100-4_GPUs-PyTorch-Others-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4], + "DGX_H100-4_GPUs-PyTorch-Ray-1": ["dgx-h100-x4-oci", "l0_dgx_h100", 1, 1, 4], "B300-PyTorch-1": ["b300-single", "l0_b300", 1, 1], "DGX_B200-4_GPUs-PyTorch-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4], "DGX_B200-4_GPUs-PyTorch-Ray-1": ["b200-x4-lbd", "l0_dgx_b200", 1, 1, 4, 1, true], diff --git a/tensorrt_llm/executor/ray_gpu_worker.py b/tensorrt_llm/executor/ray_gpu_worker.py index 864d23d3af..b8a22af472 100644 --- a/tensorrt_llm/executor/ray_gpu_worker.py +++ b/tensorrt_llm/executor/ray_gpu_worker.py @@ -91,6 +91,8 @@ class RayWorkerWrapper: store=self.store, world_size=self.world_size, rank=self.rank) + assert torch.distributed.get_world_size( + ) == self.world_size, "Process group world size must match the expected world size" logger.info( f"[Rank {self.rank}] Finished PG init. Global GPU ID: {self.gpu}, local GPU ID: {self.local_gpu}" ) diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index 561d9bf327..92d593bf20 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -37,7 +37,6 @@ import tqdm import yaml from _pytest.mark import ParameterSet -from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.bindings import ipc_nvls_supported from tensorrt_llm.llmapi.mpi_session import get_mpi_world_size @@ -2362,6 +2361,7 @@ def pytest_configure(config): tqdm.tqdm.monitor_interval = 0 if config.getoption("--run-ray"): os.environ["TLLM_DISABLE_MPI"] = "1" + os.environ["TLLM_RAY_FORCE_LOCAL_CLUSTER"] = "1" # Initialize PeriodicJUnitXML reporter if enabled periodic = config.getoption("--periodic-junit", default=False) @@ -2825,15 +2825,3 @@ def torch_empty_cache() -> None: gc.collect() torch.cuda.empty_cache() gc.collect() - - -@pytest.fixture(autouse=True) -def ray_cleanup(llm_venv) -> None: - yield - - if mpi_disabled(): - llm_venv.run_cmd([ - "-m", - "ray.scripts.scripts", - "stop", - ]) diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 72b42cd592..0267965f11 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -21,6 +21,12 @@ import tempfile from typing import Callable import pytest + +try: + import ray +except ImportError: + import tensorrt_llm.ray_stub as ray + import yaml from defs.common import (revise_disagg_config_file_with_free_ports, wait_for_server) @@ -30,7 +36,7 @@ from defs.trt_test_alternative import check_call, check_output, popen from test_common.perf_metrics_utils import (get_timing_metrics, validate_timing_metrics) -from tensorrt_llm._utils import get_free_port, mpi_disabled +from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.logger import logger @@ -357,8 +363,6 @@ def run_disaggregated_test(example_dir, extra_config_files = [] workers_cmds = [] - subprocess.run(['ray', 'start', '--head', '--disable-usage-stats'], - check=True) # Generate ctx and gen server worker commands ctx_extra_config_file = get_extra_llm_config(config['context_servers'], @@ -415,6 +419,21 @@ def run_disaggregated_test(example_dir, use_ray=False) else: + runtime_env = { + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1" + } + } + ray.init(address="local", + include_dashboard=False, + ignore_reinit_error=True, + runtime_env=runtime_env) + gcs_addr = ray.get_runtime_context().gcs_address + ray_port = str(gcs_addr.split(":")[1]) + run_env.update({ + "RAY_ADDRESS": f"localhost:{ray_port}", + "TLLM_RAY_FORCE_LOCAL_CLUSTER": "0" + }) workers_proc = [] with contextlib.ExitStack() as stack: workers_log = stack.enter_context( @@ -470,16 +489,16 @@ def run_disaggregated_test(example_dir, logger.error(f.read()) raise finally: - if use_ray: - subprocess.run(['ray', 'stop', '--force'], check=False) - for extra_file in extra_config_files: - if os.path.exists(extra_file): - os.remove(extra_file) - elif 'server_proc' in locals() and 'workers_proc' in locals(): + if 'server_proc' in locals() and 'workers_proc' in locals(): server_proc.terminate() workers_proc.terminate() server_proc.wait() workers_proc.wait() + if use_ray: + ray.shutdown() + for extra_file in extra_config_files: + if os.path.exists(extra_file): + os.remove(extra_file) @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], diff --git a/tests/integration/defs/examples/test_ray.py b/tests/integration/defs/examples/test_ray.py index 9df844d223..44743f030d 100644 --- a/tests/integration/defs/examples/test_ray.py +++ b/tests/integration/defs/examples/test_ray.py @@ -1,9 +1,15 @@ import os import subprocess +try: + import ray +except ImportError: + import tensorrt_llm.ray_stub as ray + import pytest from defs.common import venv_check_call, wait_for_server from defs.conftest import get_device_count, llm_models_root +from defs.trt_test_alternative import popen @pytest.fixture(scope="module") @@ -65,48 +71,57 @@ def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size): disagg_dir = os.path.join(ray_example_root, "disaggregated") script_path = os.path.join(disagg_dir, "disagg_serving_local.sh") model_dir = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0" - subprocess.run("ray stop --force", shell=True, check=False) - proc = subprocess.Popen( - [ - "bash", script_path, "--executor", "ray", "--model", model_dir, - "--tp_size", - str(tp_size) - ], - cwd=disagg_dir, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) try: - assert wait_for_server("localhost", 8000, timeout_seconds=180), \ - "Disaggregated server failed to start within 3 minutes" + runtime_env = { + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1" + } + } + ray.init(address="local", + include_dashboard=False, + ignore_reinit_error=True, + runtime_env=runtime_env) + gcs_addr = ray.get_runtime_context().gcs_address + ray_port = str(gcs_addr.split(":")[1]) - result = subprocess.run([ - "curl", "-sS", "-w", "\n%{http_code}", - "http://localhost:8000/v1/completions", "-H", - "Content-Type: application/json", "-d", - '{"model":"TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}' - ], - capture_output=True, - text=True, - timeout=30) + env_copy = os.environ.copy() + env_copy.update({ + "RAY_ADDRESS": f"localhost:{ray_port}", + "TLLM_RAY_FORCE_LOCAL_CLUSTER": "0" + }) + with popen( + [ + "bash", script_path, "--executor", "ray", "--attach", "--model", + model_dir, "--tp_size", + str(tp_size) + ], + cwd=disagg_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env_copy, + ): + assert wait_for_server("localhost", 8000, timeout_seconds=180), \ + "Disaggregated server failed to start within 3 minutes" - *body_lines, status_line = result.stdout.strip().splitlines() - body = "\n".join(body_lines) - status = int(status_line) + result = subprocess.run([ + "curl", "-sS", "-w", "\n%{http_code}", + "http://localhost:8000/v1/completions", "-H", + "Content-Type: application/json", "-d", + '{"model":"TinyLlama-1.1B-Chat-v1.0","prompt":"NVIDIA is a great company because","max_tokens":16,"temperature":0}' + ], + capture_output=True, + text=True, + timeout=30) - print("HTTP status:", status) - print("Response body:", body) + *body_lines, status_line = result.stdout.strip().splitlines() + body = "\n".join(body_lines) + status = int(status_line) - assert result.returncode == 0, f"curl exit {result.returncode}" - assert status == 200, f"Expected 200, got {status}" + print("HTTP status:", status) + print("Response body:", body) + assert result.returncode == 0, f"curl exit {result.returncode}" + assert status == 200, f"Expected 200, got {status}" finally: - proc.terminate() - try: - proc.wait(timeout=10) - except Exception: - proc.kill() - - subprocess.run("ray stop --force", shell=True, check=False) - subprocess.run("pkill -9 -f trtllm-serve", shell=True, check=False) + ray.shutdown() diff --git a/tests/integration/defs/ray_orchestrator/RL/README.md b/tests/integration/defs/ray_orchestrator/RL/README.md new file mode 100644 index 0000000000..0ffce77f00 --- /dev/null +++ b/tests/integration/defs/ray_orchestrator/RL/README.md @@ -0,0 +1,59 @@ +# RL Framework Integration Tests + +This directory contains integration tests for TensorRT-LLM with [Ray orchestrator](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/ray_orchestrator), specifically designed to cover usage patterns from various RL (Reinforcement Learning) frameworks such as VeRL and NeMo RL. + +## Available Scripts + +| Script | Description | +|--------|-------------| +| `run_rl_perf_reproduce.py` | Emulates RL workload performance with multiple AsyncLLM instances distributed across GPUs using Ray placement groups | + +## Usage Examples + +### RL Performance Reproduction + +The `run_rl_perf_reproduce.py` script creates multiple TensorRT-LLM instances in parallel to simulate RL rollout workloads. + +**TP=4 with 2 instances (8 GPUs total):** + +```bash +python run_rl_perf_reproduce.py \ + --model_dir /path/to/model_dir \ + --data_path /path/to/prompts.json \ + --num_instances 2 \ + --tp_size 4 \ + --top_p 1 \ + --logprobs 1 \ + --max_batch_size 1024 \ + --enable_cuda_graph_padding +``` + +**TP=1 with 8 instances (8 GPUs total):** + +```bash +python run_rl_perf_reproduce.py \ + --model_dir /path/to/model_dir \ + --data_path /path/to/prompts.json \ + --num_instances 8 \ + --tp_size 1 \ + --top_p 1 \ + --logprobs 1 \ + --max_batch_size 384 \ + --enable_cuda_graph_padding +``` + +## Data Format + +The `--data_path` should point to a JSON file containing a list of prompts, where each prompt is a list of token IDs: + +```json +[ + [1, 2345, 6789, ...], + [1, 3456, 7890, ...], + ... +] +``` + +## Notes + +- RL Perf reproduction scripts support single-node execution only (max 8 GPUs) diff --git a/tests/integration/defs/ray_orchestrator/RL/run_rl_perf_reproduce.py b/tests/integration/defs/ray_orchestrator/RL/run_rl_perf_reproduce.py new file mode 100644 index 0000000000..fddc043861 --- /dev/null +++ b/tests/integration/defs/ray_orchestrator/RL/run_rl_perf_reproduce.py @@ -0,0 +1,336 @@ +import argparse +import asyncio +import json +import os +import time +from pathlib import Path + +import ray +import torch +from ray.util.placement_group import ( + PlacementGroupSchedulingStrategy, + placement_group, + remove_placement_group, +) +from transformers import AutoConfig + +from tensorrt_llm import AsyncLLM +from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SamplingParams + + +@ray.remote +class TRTLLMInstance: + """Ray actor wrapping an AsyncLLM instance for distributed RL workloads. + + This actor manages a single TensorRT-LLM instance that can be scheduled + on specific GPUs using Ray placement groups. Multiple instances can run + in parallel for high-throughput RL generation. + + Attributes: + async_llm_kwargs: Configuration dict for AsyncLLM initialization + sampling_kwargs: Configuration dict for SamplingParams + llm: The underlying AsyncLLM instance (initialized via init_llm) + sampling_params: SamplingParams object for generation + """ + + def __init__(self, async_llm_kwargs: dict, sampling_kwargs: dict): + self.async_llm_kwargs = async_llm_kwargs + self.sampling_kwargs = sampling_kwargs + self.llm = None + self.sampling_params = None + + async def init_llm(self): + """Initialize the AsyncLLM instance with configured parameters.""" + self.llm = await AsyncLLM( + model=self.async_llm_kwargs["model"], + backend="pytorch", + orchestrator_type=self.async_llm_kwargs["orchestrator_type"], + ray_worker_extension_cls=self.async_llm_kwargs["ray_worker_extension_cls"], + kv_cache_config=KvCacheConfig(**self.async_llm_kwargs["kv_cache_config"]), + cuda_graph_config=CudaGraphConfig(**self.async_llm_kwargs["cuda_graph_config"]), + max_seq_len=self.async_llm_kwargs["max_seq_len"], + max_batch_size=self.async_llm_kwargs["max_batch_size"], + max_num_tokens=self.async_llm_kwargs["max_num_tokens"], + tensor_parallel_size=self.async_llm_kwargs["tensor_parallel_size"], + trust_remote_code=self.async_llm_kwargs["trust_remote_code"], + enable_sleep=True, + sampler_type=self.async_llm_kwargs["sampler_type"], + placement_groups=self.async_llm_kwargs["placement_groups"], + placement_bundle_indices=self.async_llm_kwargs["placement_bundle_indices"], + per_worker_gpu_share=self.async_llm_kwargs["per_worker_gpu_share"], + batch_wait_timeout_iters=32, + batch_wait_max_tokens_ratio=0.5, + ) + self.sampling_params = SamplingParams( + temperature=self.sampling_kwargs["temperature"], + top_p=self.sampling_kwargs["top_p"], + top_k=self.sampling_kwargs["top_k"], + max_tokens=self.sampling_kwargs["max_tokens"], + logprobs=self.sampling_kwargs["logprobs"], + detokenize=self.sampling_kwargs["detokenize"], + end_id=self.sampling_kwargs["end_id"], + pad_id=self.sampling_kwargs["pad_id"], + stop_token_ids=self.sampling_kwargs["stop_token_ids"], + include_stop_str_in_output=self.sampling_kwargs["include_stop_str_in_output"], + ) + + async def generate(self, prompt: list[int]): + """Generate output tokens for a single prompt. + + Args: + prompt: List of input token IDs + + Returns: + Tuple of (token_ids, log_probs): + - token_ids: List of generated token IDs + - log_probs: List of log probabilities (if logprobs enabled, else None) + """ + outputs = await self.llm.generate_async(inputs=prompt, sampling_params=self.sampling_params) + token_ids = outputs.outputs[0].token_ids + log_probs = None + if self.sampling_kwargs["logprobs"] is not None: + log_probs = [list(d.values())[0].logprob for d in outputs.outputs[0].logprobs] + return token_ids, log_probs + + +async def setup_rl_llm(args): + """Main setup and execution function for RL LLM workloads. + + This function: + 1. Loads prompts from the input JSON file + 2. Initializes Ray with placement groups for GPU allocation + 3. Creates multiple TRTLLMInstance actors distributed across GPUs + 4. Distributes prompts round-robin across instances + 5. Runs async generation and reports throughput metrics + + Args: + args: Parsed command-line arguments + """ + # Load prompts from JSON file (expected format: list of token ID lists) + data_path = Path(args.data_path) + with open(data_path, "r") as f: + prompts = json.load(f) + + hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) + + num_instances = args.num_instances + num_gpus = args.tp_size * num_instances + available_gpus = torch.cuda.device_count() + if num_gpus > 8: + raise ValueError( + f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node." + ) + if available_gpus < num_gpus: + raise ValueError( + f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})." + ) + + # Prevent Ray from setting CUDA_VISIBLE_DEVICES automatically + # This allows TensorRT-LLM to manage GPU visibility internally + os.environ["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" + runtime_env = {"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}} + pg = None + + try: + ray.init(address="local") + gcs_addr = ray.get_runtime_context().gcs_address + port = int(gcs_addr.split(":")[1]) + # Force ray.init("auto") to attach to a specific cluster via RAY_ADDRESS + os.environ["RAY_ADDRESS"] = f"localhost:{port}" + + # Create placement group with one bundle per GPU + # STRICT_PACK ensures all bundles are on the same node + pg = placement_group( + [{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK" + ) + + # Wait for placement group to be ready + ray.get(pg.ready()) + + # Configure placement groups for each instance + # Each instance gets a contiguous range of GPU bundles for tensor parallelism + # Example with num_instances=2, tp_size=2: + # Instance 0: bundles [0, 1] -> GPUs 0, 1 + # Instance 1: bundles [2, 3] -> GPUs 2, 3 + tp_size = args.tp_size + placement_group_list = [[pg] for _ in range(num_instances)] + placement_bundle_indices_list = [ + [list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances) + ] + + # Create TRTLLMInstance actors for each parallel instance + llm_instances = [] + for i in range(num_instances): + llm_instances.append( + TRTLLMInstance.options( + num_cpus=0, + num_gpus=0, + runtime_env=runtime_env, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + ), + ).remote( + async_llm_kwargs={ + "model": args.model_dir, + "backend": "pytorch", + "orchestrator_type": "ray", + "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", + "kv_cache_config": { + "enable_block_reuse": args.enable_block_reuse, + "free_gpu_memory_fraction": args.kv_cache_fraction, + }, + "cuda_graph_config": { + "enable_padding": args.enable_cuda_graph_padding, + "batch_sizes": args.batch_sizes, + "max_batch_size": 0 if args.batch_sizes else args.max_batch_size, + }, + "max_seq_len": args.max_seq_len, + "max_batch_size": args.max_batch_size, + "max_num_tokens": args.max_num_tokens, + "tensor_parallel_size": args.tp_size, + "trust_remote_code": args.trust_remote_code, + "enable_sleep": True, + "sampler_type": args.sampler_type, + "placement_groups": placement_group_list[i], + "placement_bundle_indices": placement_bundle_indices_list[i], + "per_worker_gpu_share": 0.5, + }, + sampling_kwargs={ + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "max_tokens": args.max_tokens, + "logprobs": args.logprobs, + "detokenize": False, + "end_id": -1, + "pad_id": hf_config.pad_token_id, + "stop_token_ids": [hf_config.eos_token_id], + "include_stop_str_in_output": True, + }, + ) + ) + # Wait for all Ray actors to be ready, then initialize LLM instances + ray.get([llm.__ray_ready__.remote() for llm in llm_instances]) + ray.get([llm.init_llm.remote() for llm in llm_instances]) + + total_prompts = len(prompts) + + print( + f"Starting generation for {total_prompts} prompts across {num_instances} instances..." + ) + start_time = time.time() + + # Helper function to wrap Ray remote call as async coroutine + async def generate_single_prompt(instance, prompt): + """Generate a single prompt asynchronously.""" + object_ref = instance.generate.remote(prompt=prompt) + result = await asyncio.to_thread(ray.get, object_ref) + return result + + # Create tasks with round-robin distribution + tasks = [ + generate_single_prompt(llm_instances[idx % num_instances], prompt) + for idx, prompt in enumerate(prompts) + ] + + await asyncio.gather(*tasks) + end_time = time.time() + + print(f"Time taken: {end_time - start_time:.2f} seconds") + print(f"Total prompts: {total_prompts}") + print(f"Throughput: {total_prompts / (end_time - start_time):.2f} prompts/sec") + finally: + if pg is not None: + remove_placement_group(pg) + ray.shutdown() + + +def add_rl_llm_args(parser): + """Add command-line arguments for RL LLM configuration.""" + # Required arguments + parser.add_argument("--model_dir", type=str, required=True, help="Model checkpoint directory.") + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Input data file path, expected format: list of token ID lists.", + ) + parser.add_argument( + "--num_instances", type=int, required=True, help="Number of TRTLLM instances." + ) + + # AsyncLLM parameters + parser.add_argument("--tp_size", type=int, required=True, help="Tensor parallel size.") + parser.add_argument("--max_seq_len", type=int, default=2048, help="Maximum sequence length.") + parser.add_argument("--max_batch_size", type=int, default=384, help="Maximum batch size.") + parser.add_argument( + "--max_num_tokens", type=int, default=32768, help="Maximum number of tokens." + ) + parser.add_argument( + "--sampler_type", + type=str, + default="TRTLLMSampler", + choices=["TRTLLMSampler", "TorchSampler"], + help="Sampler type.", + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + default=False, + help="Whether to trust remote code.", + ) + + # KV Cache Config parameters + parser.add_argument( + "--kv_cache_fraction", + type=float, + default=0.6, + help="The fraction of GPU memory to be used for KV cache.", + ) + parser.add_argument( + "--enable_block_reuse", + action="store_true", + default=False, + help="Whether to enable block reuse for KV cache.", + ) + + # Cuda Graph Config parameters + parser.add_argument( + "--enable_cuda_graph_padding", + action="store_true", + default=False, + help="Whether to enable padding for CUDA graphs.", + ) + parser.add_argument( + "--batch_sizes", + type=int, + nargs="+", + default=None, + help="The batch sizes to be used for CUDA graphs. Example: --batch_sizes 16 32 64 128 256", + ) + + # Sampling parameters + parser.add_argument("--max_tokens", type=int, default=1024) + parser.add_argument("--temperature", type=float, default=1) + parser.add_argument("--top_k", type=int, default=None) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--logprobs", type=int, default=None) + + return parser + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="RL flow performance reproduction.") + parser = add_rl_llm_args(parser) + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + asyncio.run(setup_rl_llm(args)) + + +if __name__ == "__main__": + main() diff --git a/tests/integration/defs/ray_orchestrator/RL/test_rl_perf_reproduce.py b/tests/integration/defs/ray_orchestrator/RL/test_rl_perf_reproduce.py new file mode 100644 index 0000000000..4cf6a83933 --- /dev/null +++ b/tests/integration/defs/ray_orchestrator/RL/test_rl_perf_reproduce.py @@ -0,0 +1,59 @@ +import json +import tempfile +from pathlib import Path + +import pytest +from defs.common import venv_check_call +from defs.conftest import integration_path, llm_models_root +from transformers import AutoTokenizer + + +@pytest.mark.skip_less_device(4) +@pytest.mark.parametrize( + "tp_size, num_instances", [(2, 2), (1, 4)], ids=["tp2_2instances", "tp1_4instances"] +) +def test_rl_perf_reproduce(llm_venv, tp_size, num_instances): + script_path = ( + integration_path() / "defs" / "ray_orchestrator" / "RL" / "run_rl_perf_reproduce.py" + ) + model_dir = f"{llm_models_root()}/Qwen2-7B-Instruct" + + if tp_size == 2: + max_batch_size = 512 + else: + max_batch_size = 256 + + with tempfile.TemporaryDirectory() as tmpdir: + prompt_text = "The president of the United States is" + + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + token_ids = tokenizer.encode(prompt_text, add_special_tokens=False) + + # Replicate to create batch of 1024 prompts + batch_size = 1024 + prompts = [token_ids for _ in range(batch_size)] + + data_path = Path(tmpdir) / "prompts.json" + with open(data_path, "w") as f: + json.dump(prompts, f) + + venv_check_call( + llm_venv, + [ + str(script_path), + "--model_dir", + model_dir, + "--data_path", + str(data_path), + "--num_instances", + str(num_instances), + "--tp_size", + str(tp_size), + "--logprobs", + "1", + "--max_batch_size", + str(max_batch_size), + "--enable_block_reuse", + "--enable_cuda_graph_padding", + ], + ) diff --git a/tests/integration/defs/test_unittests.py b/tests/integration/defs/test_unittests.py index 190ea5111e..d2cd80692c 100644 --- a/tests/integration/defs/test_unittests.py +++ b/tests/integration/defs/test_unittests.py @@ -77,6 +77,7 @@ def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request): test_prefix = "unittest" waives_file = request.config.getoption("--waives-file") + run_ray = request.config.getoption("--run-ray") num_workers = 1 @@ -138,6 +139,9 @@ def test_unittests_v2(llm_root, llm_venv, case: str, output_dir, request): waives_file = os.path.abspath(waives_file) command += [f"--waives-file={waives_file}"] + if run_ray: + command += ["--run-ray"] + command += arg_list print(f"Running unit test:\"python {' '.join(command)}\"") diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 4e90db0050..bcda6ba3f8 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -303,3 +303,5 @@ l0_dgx_h100: tests: - unittest/_torch/ray_orchestrator/multi_gpu -m "gpu4" - unittest/llmapi/test_async_llm.py -m "gpu4" + - ray_orchestrator/RL/test_rl_perf_reproduce.py::test_rl_perf_reproduce[tp2_2instances] + - ray_orchestrator/RL/test_rl_perf_reproduce.py::test_rl_perf_reproduce[tp1_4instances] diff --git a/tests/unittest/_torch/ray_orchestrator/conftest.py b/tests/unittest/_torch/ray_orchestrator/conftest.py index f6145aa98d..7d663d5fb7 100644 --- a/tests/unittest/_torch/ray_orchestrator/conftest.py +++ b/tests/unittest/_torch/ray_orchestrator/conftest.py @@ -3,11 +3,6 @@ import sys import pytest -try: - import ray -except ModuleNotFoundError: - from tensorrt_llm import ray_stub as ray - from tensorrt_llm._utils import mpi_disabled @@ -26,24 +21,3 @@ if not mpi_disabled(): pytest.skip( "Ray tests are only tested in Ray CI stage or with --run-ray flag", allow_module_level=True) - - -@pytest.fixture(scope="function") -def setup_ray_cluster(): - runtime_env = { - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1" - } - } - ray_init_args = { - "include_dashboard": False, - "namespace": "test", - "ignore_reinit_error": True, - "runtime_env": runtime_env - } - try: - ray.init(address="local", **ray_init_args) - yield - finally: - if ray.is_initialized(): - ray.shutdown() diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py index 578be1f6dd..bd5cb47c0a 100644 --- a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_executor.py @@ -25,12 +25,12 @@ def test_worker_extension(): @pytest.mark.gpu4 -def test_placement_env_vars(monkeypatch): +def test_placement_env_vars(setup_ray_cluster, monkeypatch): + port = setup_ray_cluster monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") pg = None try: - ray.init() pg = placement_group([{"GPU": 1, "CPU": 1}] * 4) ray.get(pg.ready()) print(f"Placement group ready with bundles {pg.bundle_specs}") @@ -39,7 +39,9 @@ def test_placement_env_vars(monkeypatch): runtime_env = { "env_vars": { "TRTLLM_RAY_PER_WORKER_GPUS": "0.8", - "TRTLLM_RAY_BUNDLE_INDICES": ",".join(map(str, bundle_indices)) + "TRTLLM_RAY_BUNDLE_INDICES": ",".join(map(str, bundle_indices)), + "RAY_ADDRESS": f"localhost:{port}", + "TLLM_RAY_FORCE_LOCAL_CLUSTER": "0" } } @@ -70,7 +72,6 @@ def test_placement_env_vars(monkeypatch): finally: if pg is not None: remove_placement_group(pg) - ray.shutdown() @pytest.mark.gpu2 @@ -79,13 +80,14 @@ def test_placement_env_vars(monkeypatch): (2, [1]), ], ids=["gpu2_tp1"]) -def test_placement_api(monkeypatch, n_gpus, bundle_indices): - monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") +def test_placement_api(setup_ray_cluster, monkeypatch, n_gpus, bundle_indices): + port = setup_ray_cluster + monkeypatch.setenv("RAY_ADDRESS", f"localhost:{port}") + monkeypatch.setenv("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") tp_size = n_gpus // 2 pg = None try: - ray.init() pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus) ray.get(pg.ready()) print(f"Placement group ready with bundles {pg.bundle_specs}") @@ -116,7 +118,6 @@ def test_placement_api(monkeypatch, n_gpus, bundle_indices): finally: if pg is not None: remove_placement_group(pg) - ray.shutdown() @pytest.mark.gpu2 diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_multi_instance.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_multi_instance.py new file mode 100644 index 0000000000..7018aca856 --- /dev/null +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_multi_instance.py @@ -0,0 +1,120 @@ +import os + +import pytest +import ray +import torch +from ray.util.placement_group import ( + PlacementGroupSchedulingStrategy, + placement_group, + remove_placement_group, +) +from utils.llm_data import llm_models_root + +from tensorrt_llm import AsyncLLM +from tensorrt_llm.llmapi import KvCacheConfig + + +@ray.remote +class TRTLLMInstance: + def __init__(self, async_llm_kwargs: dict): + self.llm = AsyncLLM( + model=async_llm_kwargs["model"], + backend="pytorch", + orchestrator_type=async_llm_kwargs["orchestrator_type"], + kv_cache_config=KvCacheConfig(**async_llm_kwargs["kv_cache_config"]), + tensor_parallel_size=async_llm_kwargs["tensor_parallel_size"], + placement_groups=async_llm_kwargs["placement_groups"], + placement_bundle_indices=async_llm_kwargs["placement_bundle_indices"], + per_worker_gpu_share=async_llm_kwargs["per_worker_gpu_share"], + ) + + async def init_llm(self): + await self.llm.setup_async() + + def shutdown_llm(self): + self.llm.shutdown() + self.llm = None + + +@pytest.mark.gpu4 +@pytest.mark.parametrize( + "tp_size, num_instances", [(2, 2), (1, 4)], ids=["tp2_2instances", "tp1_4instances"] +) +def test_multi_instance(setup_ray_cluster, tp_size, num_instances): + """Test that multiple TRTLLMInstance actors can be started without port conflicts. + + This test guards against port conflict failures when launching multiple + TensorRT-LLM instances concurrently. It runs multiple iterations to ensure + reliable instance creation and teardown. + """ + port = setup_ray_cluster + num_gpus = tp_size * num_instances + available_gpus = torch.cuda.device_count() + if num_gpus > 8: + raise ValueError( + f"Number of GPUs ({num_gpus}) is greater than 8. This script only supports single node." + ) + if available_gpus < num_gpus: + raise ValueError( + f"Number of GPUs ({available_gpus}) is less than number of GPUs required ({num_gpus})." + ) + runtime_env = ray.runtime_env.RuntimeEnv() + runtime_env["env_vars"] = os.environ.copy() + runtime_env["env_vars"].update( + { + "TLLM_RAY_FORCE_LOCAL_CLUSTER": "0", + "RAY_ADDRESS": f"localhost:{port}", + } + ) + + # Run multiple iterations to guard against port conflict issues + execution_times = 5 + for iteration in range(execution_times): + pg = None + llm_instances = [] + try: + pg = placement_group( + [{"GPU": 1, "CPU": 2} for _ in range(num_gpus)], strategy="STRICT_PACK" + ) + + ray.get(pg.ready()) + + placement_group_list = [[pg] for _ in range(num_instances)] + placement_bundle_indices_list = [ + [list(range(i * tp_size, (i + 1) * tp_size))] for i in range(num_instances) + ] + + for i in range(num_instances): + llm_instances.append( + TRTLLMInstance.options( + num_cpus=0, + num_gpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + ), + runtime_env=runtime_env, + ).remote( + async_llm_kwargs={ + "model": os.path.join( + llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0" + ), + "kv_cache_config": { + "free_gpu_memory_fraction": 0.1, + }, + "tensor_parallel_size": tp_size, + "orchestrator_type": "ray", + "placement_groups": placement_group_list[i], + "placement_bundle_indices": placement_bundle_indices_list[i], + "per_worker_gpu_share": 0.5, + } + ) + ) + ray.get([llm.__ray_ready__.remote() for llm in llm_instances]) + ray.get([llm.init_llm.remote() for llm in llm_instances]) + finally: + # Clean up actors before removing placement group + for llm in llm_instances: + ray.get(llm.shutdown_llm.remote()) + if pg is not None: + remove_placement_group(pg) diff --git a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py index 42fa65ffc2..18ec8c81cf 100644 --- a/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py +++ b/tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py @@ -1,5 +1,6 @@ import os from operator import attrgetter +from typing import Optional import pytest import torch @@ -10,7 +11,6 @@ except ModuleNotFoundError: from tensorrt_llm import ray_stub as ray from tensorrt_llm._torch.distributed.communicator import TorchDist -from tensorrt_llm._utils import get_free_port from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy from tensorrt_llm.mapping import Mapping @@ -22,7 +22,6 @@ class PgOpTest: self.rank = rank self.world_size = world_size self.master_address = os.environ["MASTER_ADDR"] - self.master_port = os.environ["MASTER_PORT"] assert len(ray.get_gpu_ids()) == 1 self.gpu = int(ray.get_gpu_ids()[0]) @@ -30,12 +29,30 @@ class PgOpTest: local_gpu = RayWorkerWrapper.physical_to_local_id(self.gpu) torch.cuda.set_device(local_gpu) - torch.distributed.init_process_group( - backend="cuda:nccl,cpu:gloo", - init_method=f"tcp://{self.master_address}:{self.master_port}", - world_size=world_size, - rank=rank) + def _create_tcp_store(self, + port: Optional[int] = None + ) -> torch.distributed.TCPStore: + actual_port = port if port is not None else 0 + return torch.distributed.TCPStore(host_name=self.master_address, + port=actual_port, + world_size=self.world_size, + is_master=(self.rank == 0), + wait_for_workers=False) + def setup_tcp_store(self): + if self.rank != 0: + raise RuntimeError("Only the master worker can setup TCP store") + self.store = self._create_tcp_store() + return self.store.port + + def setup_distributed_env(self, port: int): + if self.rank != 0: + self.store = self._create_tcp_store(port) + + torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo", + store=self.store, + world_size=self.world_size, + rank=self.rank) self.mapping = Mapping(world_size=self.world_size, gpus_per_node=self.world_size, tp_size=self.world_size, @@ -91,13 +108,11 @@ def test_allgather_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len): sizes = None remotePGTests = [] - master_port = get_free_port() runtime_env = ray.runtime_env.RuntimeEnv() runtime_env["env_vars"] = os.environ.copy() runtime_env["env_vars"].update({ "TLLM_DISABLE_MPI": "1", "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(master_port) }) for rank in range(world_size): @@ -107,6 +122,12 @@ def test_allgather_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len): ray.get( [remotePGTest.__ray_ready__.remote() for remotePGTest in remotePGTests]) + port = ray.get(remotePGTests[0].setup_tcp_store.remote()) + ray.get([ + remotePGTest.setup_distributed_env.remote(port) + for remotePGTest in remotePGTests + ]) + if var_len: results = ray.get([ remotePGTest.run.remote("allgather_pg", @@ -155,13 +176,11 @@ def test_reducescatter_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len): ] sizes = None - master_port = get_free_port() runtime_env = ray.runtime_env.RuntimeEnv() runtime_env["env_vars"] = os.environ.copy() runtime_env["env_vars"].update({ "TLLM_DISABLE_MPI": "1", "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(master_port) }) remotePGTests = [] @@ -172,6 +191,12 @@ def test_reducescatter_pg_op(setup_ray_cluster, seq_len, hidden_size, var_len): ray.get( [remotePGTest.__ray_ready__.remote() for remotePGTest in remotePGTests]) + port = ray.get(remotePGTests[0].setup_tcp_store.remote()) + ray.get([ + remotePGTest.setup_distributed_env.remote(port) + for remotePGTest in remotePGTests + ]) + results = ray.get([ remotePGTest.run.remote("reducescatter_pg", test_tensor, @@ -194,13 +219,11 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size): test_tensor = torch.randn((seq_len, hidden_size), dtype=dtype) expected_result = test_tensor * world_size - master_port = get_free_port() runtime_env = ray.runtime_env.RuntimeEnv() runtime_env["env_vars"] = os.environ.copy() runtime_env["env_vars"].update({ "TLLM_DISABLE_MPI": "1", "MASTER_ADDR": "127.0.0.1", - "MASTER_PORT": str(master_port) }) remotePGTests = [] @@ -211,6 +234,12 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size): ray.get( [remotePGTest.__ray_ready__.remote() for remotePGTest in remotePGTests]) + port = ray.get(remotePGTests[0].setup_tcp_store.remote()) + ray.get([ + remotePGTest.setup_distributed_env.remote(port) + for remotePGTest in remotePGTests + ]) + results = ray.get([ remotePGTest.run.remote("allreduce_pg", test_tensor, diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index 97c48b6b96..2c75d30023 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -18,7 +18,12 @@ import sys import traceback import warnings from functools import partial -from typing import Any +from typing import Any, Generator + +try: + import ray +except ModuleNotFoundError: + from tensorrt_llm import ray_stub as ray import _pytest.outcomes import pytest @@ -344,3 +349,26 @@ def process_gpu_memory_info_available(): return False return True + + +@pytest.fixture(scope="function") +def setup_ray_cluster() -> Generator[int, None, None]: + runtime_env = { + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1" + } + } + ray_init_args = { + "include_dashboard": False, + "namespace": "test", + "ignore_reinit_error": True, + "runtime_env": runtime_env + } + try: + ray.init(address="local", **ray_init_args) + gcs_addr = ray.get_runtime_context().gcs_address + port = int(gcs_addr.split(":")[1]) + yield port + finally: + if ray.is_initialized(): + ray.shutdown() diff --git a/tests/unittest/llmapi/test_async_llm.py b/tests/unittest/llmapi/test_async_llm.py index e0e7dd6d0f..abd84976af 100644 --- a/tests/unittest/llmapi/test_async_llm.py +++ b/tests/unittest/llmapi/test_async_llm.py @@ -96,9 +96,10 @@ async def test_async_llm_release_resume(process_gpu_memory_info_available, num_c @pytest.mark.ray @pytest.mark.gpu4 @pytest.mark.asyncio -@pytest.mark.threadleak(enabled=False) -async def test_async_llm_placement_api(monkeypatch): - monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1") +async def test_async_llm_placement_api(setup_ray_cluster, monkeypatch): + port = setup_ray_cluster + monkeypatch.setenv("RAY_ADDRESS", f"localhost:{port}") + monkeypatch.setenv("TLLM_RAY_FORCE_LOCAL_CLUSTER", "0") n_gpus = 4 bundle_indices = [2, 3] @@ -106,7 +107,6 @@ async def test_async_llm_placement_api(monkeypatch): pg = None try: - ray.init() pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus) ray.get(pg.ready()) print(f"Placement group ready with bundles {pg.bundle_specs}") @@ -132,6 +132,6 @@ async def test_async_llm_placement_api(monkeypatch): ) finally: + llm.shutdown() if pg is not None: remove_placement_group(pg) - ray.shutdown()