[TRTLLM-9736][feat] AsyncLLM and verl integ (#9353)

Signed-off-by: Liwei Ma <liweim@nvidia.com>
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Co-authored-by: Liwei Ma <liweim@nvidia.com>
Co-authored-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
Co-authored-by: Superjomn <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
Erin 2025-12-11 09:33:25 -08:00 committed by GitHub
parent 02edb19f43
commit 89dabf5aa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 629 additions and 70 deletions

View File

@ -84,7 +84,7 @@ from ._utils import (default_gpus_per_node, local_mpi_rank, local_mpi_size,
from .builder import BuildConfig, Builder, BuilderConfig, build
from .disaggregated_params import DisaggregatedParams
from .functional import Tensor, constant
from .llmapi import LLM, MultimodalEncoder
from .llmapi import LLM, AsyncLLM, MultimodalEncoder
from .llmapi.llm_args import LlmArgs, TorchLlmArgs, TrtLlmArgs
from .logger import logger
from .mapping import Mapping
@ -136,6 +136,7 @@ __all__ = [
'quantization',
'tools',
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'LlmArgs',
'TorchLlmArgs',

View File

@ -0,0 +1,106 @@
from typing import Any, List, Optional
from ..llmapi.llm import LLM
from ..llmapi.llm_args import RayPlacementConfig
class AsyncLLM(LLM):
"""AsyncLLM is a subclass of LLM that supports asynchronous setup, release and
resume operations that are necessary for RL or agentic scenarios.
Currently, RL APIs are only supported with Ray orchestrator.
"""
def __init__(
self,
placement_groups: Optional[List[Any]] = None,
placement_bundle_indices: Optional[List[List[int]]] = None,
per_worker_gpu_share: Optional[float] = None,
*args,
**kwargs,
):
kwargs["orchestrator_type"] = "ray"
kwargs["ray_placement_config"] = RayPlacementConfig(
defer_workers_init=True,
placement_groups=placement_groups,
placement_bundle_indices=placement_bundle_indices,
per_worker_gpu_share=per_worker_gpu_share,
)
# WAR: RL integration needs to use NCCL AllReduce for TP>1 due to a bug in TRTLLM's AllReduce
# which will cause convergence issue when using multiple rollout instances.
kwargs["allreduce_strategy"] = "NCCL"
if "ray_worker_extension_cls" not in kwargs:
kwargs["ray_worker_extension_cls"] = "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension"
super().__init__(*args, **kwargs)
self._async_initialized = False
async def setup_async(self):
"""Setup the LLM asynchronously."""
if not self._async_initialized:
await self._executor.init_workers_async()
await self._executor.setup_engine_remote_async()
self._async_initialized = True
return self
async def release(self, tags: list[str]):
"""Release the GPU memory used by the LLM asynchronously.
Args:
tags: List of memory tag strings to release (e.g., ["model", "kv_cache"]).
"""
await self.collective_rpc("sleep", args=(tags,))
async def resume(self, tags: list[str]):
"""Resume the GPU memory used by the LLM asynchronously.
Args:
tags: List of memory tag strings to resume (e.g., ["model", "kv_cache"]).
"""
await self.collective_rpc("wakeup", args=(tags,))
async def update_weights(self, weights: dict[str, str]):
"""Update the weights of the LLM asynchronously.
Args:
weights: Dictionary mapping device UUIDs to IPC handles for weight tensors.
"""
await self.collective_rpc("update_weights", args=(weights,))
async def collective_rpc(
self,
method: str,
args: tuple[Any, ...] = (),
kwargs: Optional[dict] = None,
unique_reply_rank: Optional[int] = None,
) -> list[Any]:
"""Execute an asynchronous RPC call on all GPU workers. Currently, this is only supported for RayExecutor.
Args:
method (str): The name of the worker method to execute.
args (tuple[Any, ...]): Positional arguments to pass to the worker method. Defaults to ().
kwargs (dict, optional): Keyword arguments to pass to the worker method. Defaults to None.
unique_reply_rank (int, optional): The rank of the worker that will be used to send the reply.
Returns:
list[Any]: A list of results from each worker.
"""
return await self._executor.collective_rpc_async(
method, args, kwargs, unique_reply_rank=unique_reply_rank
)
def __await__(self):
return self.setup_async().__await__()
def __enter__(self):
raise RuntimeError("Please use 'async with AsyncLLM' instead")
async def __aenter__(self):
await self.setup_async()
return super().__enter__()
async def __aexit__(self, exc_type, exc_val, exc_tb):
return super().__exit__(exc_type, exc_val, exc_tb)

View File

@ -3010,7 +3010,7 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
new_tokens_host = state.host.new_tokens.flatten().tolist()
sequence_lengths_host_data = state.host.sequence_lengths.flatten().tolist()
finish_reasons = state.host.finish_reasons.flatten().tolist()
log_probs_host = state.host.log_probs.tolist() if state.host.log_probs is not None else None
log_probs_host_tensor = state.host.log_probs
cum_log_probs_host = (
state.host.cum_log_probs.tolist() if state.host.cum_log_probs is not None else None
)
@ -3032,24 +3032,31 @@ class TRTLLMSampler(Sampler, AsyncWorkerMixin):
add_new_tokens_to_requests(reqs_with_new_tokens, new_tokens, 0)
# Log probs
for request in reqs_with_new_tokens:
if request.py_return_log_probs:
seq_slot = request.py_seq_slot
seq_len = sequence_lengths_host_data[seq_slot]
begin_log_probs_offset = request.prompt_len
current_token = seq_len - request.prompt_len - 1
log_probs = [
{
new_tokens_host[seq_slot]: Logprob(
logprob=log_probs_host[seq_slot][0][
begin_log_probs_offset + current_token
],
rank=1,
)
}
]
cum_log_probs = [cum_log_probs_host[seq_slot]]
request.py_result.append_log_probs([log_probs], cum_log_probs)
if log_probs_host_tensor is not None:
# Log probs
seq_slots = []
seq_lens = []
for request in reqs_with_new_tokens:
if request.py_return_log_probs:
seq_slot = request.py_seq_slot
seq_slots.append(seq_slot)
seq_lens.append(sequence_lengths_host_data[seq_slot] - 1)
log_probs_host = log_probs_host_tensor[seq_slots, 0, seq_lens].tolist()
idx = 0
for request in reqs_with_new_tokens:
if request.py_return_log_probs:
log_probs = [
{
new_tokens_host[seq_slot]: Logprob(
logprob=log_probs_host[idx],
rank=1,
)
}
]
cum_log_probs = [cum_log_probs_host[seq_slot]]
request.py_result.append_log_probs([log_probs], cum_log_probs)
idx += 1
for request in reqs:
request.py_decoding_iter += 1

View File

@ -74,7 +74,8 @@ class ExecutorMemoryType(StrEnum):
SPEC_RESOURCES = "spec_resource_manager"
INIT_KV_CACHE = "_no_capture_init_kv_cache"
INIT_EXTRA_RESOURCES = "_no_capture_init_extra_resources"
MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
# MODEL_EXTRA = "_no_capture_model_extra" # TODO: remove _no_capture after torch fix crash on torch.cuda.empty_cache()
MODEL_EXTRA = "model_extra"
EXTRA_RESOURCES = "executor_extra"
KV_CACHE = "kv_cache"
MODEL_ENGINE_MAIN = "model"

View File

@ -1,3 +1,4 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple
@ -7,8 +8,7 @@ except ModuleNotFoundError as e:
e.msg = """Cannot import Ray. Please install 'ray' package to use ray orchestrator"""
raise
from ray.util.placement_group import (PlacementGroup,
PlacementGroupSchedulingStrategy,
from ray.util.placement_group import (PlacementGroupSchedulingStrategy,
get_current_placement_group,
placement_group)
@ -23,6 +23,7 @@ from .ray_gpu_worker import RayGPUWorker, RayWorkerWrapper
from .request import GenerationRequest
from .result import GenerationResult
from .rpc_proxy_mixin import RpcExecutorMixin
from .utils import has_event_loop
__all__ = [
"RayExecutor",
@ -77,19 +78,30 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
self.master_address = ray.util.get_node_ip_address()
self.master_port = get_free_port()
worker_kwargs = dict(**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
self.worker_kwargs = dict(
**worker_kwargs,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor)
self.init_rpc_executor()
# Inject the generated HMAC key into worker_kwargs for workers
worker_kwargs['hmac_key'] = self.hmac_key
worker_kwargs['rpc_addr'] = self.rpc_addr
self.create_workers(RayGPUWorker, worker_kwargs)
self.setup_engine_remote()
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
logger.info(f"Connecting to RPC server at {self.rpc_addr}")
self.worker_kwargs['hmac_key'] = self.hmac_key
self.worker_kwargs['rpc_addr'] = self.rpc_addr
placement_config = getattr(self.worker_kwargs['llm_args'],
'ray_placement_config', None)
defer_workers_init = placement_config.defer_workers_init if placement_config else False
if defer_workers_init:
self.workers = [
] # Placeholder, will be initialized in setup_async
self._mainloop_started = False # DO NOT start mainloop until after setup_engine_remote_async is called
else:
if not has_event_loop():
self.init_workers_sync()
self.setup_engine_remote()
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
except Exception as e:
self.shutdown()
@ -97,9 +109,16 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
raise e
def create_workers(self, worker_cls, worker_kwargs):
llm_args = worker_kwargs.get("llm_args")
placement_config = getattr(llm_args, 'ray_placement_config',
None) if llm_args else None
# When set to be a fraction, it allows Ray to schedule
# multiple actors on a single GPU for colocate use cases.
num_gpus = float(os.getenv("TRTLLM_RAY_PER_WORKER_GPUS", "1.0"))
if placement_config and placement_config.per_worker_gpu_share is not None:
num_gpus = placement_config.per_worker_gpu_share
logger.debug(f"{num_gpus=} for each worker.")
runtime_env = ray.runtime_env.RuntimeEnv()
@ -110,28 +129,40 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
"MASTER_PORT": str(self.master_port)
})
self.placement_group, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size)
placement_groups, self.bundle_indices = self._get_placement_group(
tp_size=self.tp_size, worker_kwargs=worker_kwargs)
self.workers = [
RayWorkerWrapper.options(
if isinstance(placement_groups, list):
self.placement_group = None
else:
self.placement_group = placement_groups
self.workers = []
for rank in range(self.world_size):
pg = placement_groups[rank] if isinstance(
placement_groups, list) else placement_groups
worker = RayWorkerWrapper.options(
num_gpus=num_gpus,
runtime_env=runtime_env, # per-actor env
runtime_env=runtime_env,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=self.placement_group,
placement_group=pg,
placement_group_bundle_index=self.bundle_indices[rank],
)).remote(worker_cls, worker_kwargs, self.world_size, rank)
for rank in range(self.world_size)
]
self.workers.append(worker)
def init_workers_sync(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
try:
ray.get([worker.__ray_ready__.remote() for worker in self.workers])
ray.get(self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
if "The actor died because of an error raised in its creation task" in str(
e):
raise RuntimeError(
"RayGPUWorker died during initialization") from e
raise
raise RuntimeError("RayGPUWorker died during initialization") from e
async def init_workers_async(self):
self.create_workers(RayGPUWorker, self.worker_kwargs)
try:
await asyncio.gather(*self._get_worker_ready_futures())
except ray.exceptions.ActorDiedError as e:
raise RuntimeError("RayGPUWorker died during initialization") from e
@unwrap_ray_errors()
def call_all_ray_workers(self, func: str, leader_only: bool,
@ -171,6 +202,20 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
**kwargs))
return refs if non_block else ray.get(refs)
@unwrap_ray_errors()
async def collective_rpc_async(
self,
method: str,
args: tuple = (),
kwargs: Optional[dict] = None,
unique_reply_rank: Optional[int] = None) -> list[Any]:
refs = self.collective_rpc(method,
args,
kwargs,
non_block=True,
unique_reply_rank=unique_reply_rank)
return await asyncio.gather(*refs)
def submit(self, request: "GenerationRequest") -> "GenerationResult":
"""
Low-level API to the executor. Return a "future" GenerationResult
@ -198,6 +243,26 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
def setup_engine_remote(self):
return self.collective_rpc("setup_engine", non_block=False)
async def setup_engine_remote_async(self):
"""Async version of setup_engine_remote for use after async worker initialization."""
if not self.workers or len(self.workers) == 0:
raise RuntimeError(
"Workers must be initialized before calling setup_engine_remote_async"
)
# Setup engine on all workers
result = await self.collective_rpc_async("setup_engine")
logger.info("setup_engine_remote_async finished")
# Now that engine is set up, start the mainloop for fetching responses
if hasattr(self, '_mainloop_started') and not self._mainloop_started:
logger.info("Starting mainloop after engine setup")
self.setup_mainloop(tasks=[self._fetch_responses_loop_async],
thread_name="ray_executor_main_loop")
self._mainloop_started = True
return result
def report_device_ids(self) -> list[str]:
gpu_ids = self.call_all_ray_workers("report_device_id",
leader_only=False,
@ -265,15 +330,52 @@ class RayExecutor(RpcExecutorMixin, GenerationExecutor):
logger.debug("Shutting down Ray cluster")
ray.shutdown()
def _get_placement_group(self,
tp_size: int) -> Tuple[PlacementGroup, List[int]]:
def _get_worker_ready_futures(self):
return [worker.__ray_ready__.remote() for worker in self.workers]
def _get_placement_group(
self,
tp_size: int,
worker_kwargs: Dict = None) -> Tuple[Any, List[int]]:
"""
Either use the existing placement group from driver script (e.g., in the case of RL FW integration),
or create a default PACK placement group where each bundle has tp_size GPUs.
- When tp_size GPUs per node, keep one TP group per node.
- When tp_size > GPUs per node, allow a TP group span nodes.
- rank 0 must be put on the driver node
Returns:
Tuple of (placement_group(s), bundle_indices)
- placement_group(s) can be a single PlacementGroup or a List[PlacementGroup]
- bundle_indices is always a List[int]
"""
llm_args = worker_kwargs.get("llm_args") if worker_kwargs else None
placement_config = getattr(llm_args, 'ray_placement_config',
None) if llm_args else None
if placement_config and placement_config.placement_groups is not None:
total_workers = sum(
len(indices)
for indices in placement_config.placement_bundle_indices)
if total_workers != self.world_size:
raise ValueError(
f"Total bundle indices ({total_workers}) must equal world_size ({self.world_size})"
)
logger.info(
f"Creating {self.world_size} workers with external placement groups"
)
flat_pgs = []
flat_indices = []
for pg, indices in zip(placement_config.placement_groups,
placement_config.placement_bundle_indices):
for idx in indices:
flat_pgs.append(pg)
flat_indices.append(idx)
return flat_pgs, flat_indices
bundle_indices = os.getenv("TRTLLM_RAY_BUNDLE_INDICES", None)
if bundle_indices:

View File

@ -1,3 +1,4 @@
import gc
import importlib
import os
from pathlib import Path
@ -43,7 +44,6 @@ class RayWorkerWrapper:
def __init__(self, worker_cls, worker_kwargs, world_size, rank):
self.master_address = os.environ["MASTER_ADDR"]
self.master_port = os.environ["MASTER_PORT"]
# Ray can't pickle TensorRT logger
global logger
from tensorrt_llm.logger import logger
@ -218,6 +218,8 @@ class RayGPUWorker(RpcWorkerMixin, BaseWorker):
torch.cuda.synchronize()
release_with_tag(*tags)
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
except Exception as e:
logger.error(f"Encountered an error in sleep: {e}")
raise e

View File

@ -1,3 +1,4 @@
from .._torch.async_llm import AsyncLLM
from ..disaggregated_params import DisaggregatedParams
from ..executor import CompletionOutput, LoRARequest, RequestError
from ..sampling_params import GuidedDecodingParams, SamplingParams
@ -23,6 +24,7 @@ from .mpi_session import MpiCommSession
__all__ = [
'LLM',
'AsyncLLM',
'MultimodalEncoder',
'CompletionOutput',
'RequestOutput',

View File

@ -193,7 +193,7 @@ class BaseLLM:
self.mpi_session = self.args.mpi_session
if self.args.parallel_config.is_multi_gpu:
if get_device_count(
if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count(
) < self.args.parallel_config.world_size_per_node:
raise RuntimeError(
f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required."
@ -229,7 +229,6 @@ class BaseLLM:
self.runtime_context: Optional[_ModelRuntimeContext] = None
self.llm_build_stats = LlmBuildStats()
self._build_model()
except Exception:

View File

@ -19,6 +19,11 @@ from pydantic import PrivateAttr, field_validator, model_validator
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase
try:
from ray.util.placement_group import PlacementGroup
except ImportError:
PlacementGroup = None
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
@ -1086,6 +1091,65 @@ class AutoDecodingConfig(DecodingBaseConfig):
return backend == "pytorch"
class RayPlacementConfig(StrictBaseModel):
"""
Configuration for Ray GPU workers placement.
This config is only used with AsyncLLM for RL scenarios.
"""
defer_workers_init: bool = Field(
default=False,
description="Defer Ray worker initialization until async setup.")
placement_groups: Optional[List[Any]] = Field(
default=None,
description="List of Ray placement groups, one per node. "
"Each element must be a ray.util.placement_group.PlacementGroup instance."
)
placement_bundle_indices: Optional[List[List[int]]] = Field(
default=None,
description="List of bundle indices for each placement group. "
"Outer list corresponds to placement_groups, inner list contains bundle indices for that group."
)
per_worker_gpu_share: Optional[float] = Field(
default=None,
description="GPU fraction per worker for colocation scenarios. "
"Example: 0.1 means 10 actors can share one GPU. Defaults to 1.0 (one actor per GPU)."
)
@model_validator(mode='after')
def validate_ray_placement(self) -> 'RayPlacementConfig':
has_pgs = self.placement_groups is not None
has_indices = self.placement_bundle_indices is not None
if has_pgs != has_indices:
raise ValueError(
"placement_groups and placement_bundle_indices must be provided together"
)
if has_pgs:
if len(self.placement_groups) != len(self.placement_bundle_indices):
raise ValueError(
f"placement_groups length ({len(self.placement_groups)}) must equal "
f"placement_bundle_indices length ({len(self.placement_bundle_indices)})"
)
if PlacementGroup is not None:
for i, pg in enumerate(self.placement_groups):
if not isinstance(pg, PlacementGroup):
raise TypeError(
f"placement_groups[{i}] must be a Ray PlacementGroup, "
f"got {type(pg).__name__}")
if self.per_worker_gpu_share is not None:
if not (0 < self.per_worker_gpu_share <= 1.0):
raise ValueError(
f"per_worker_gpu_share must be between 0 and 1.0, "
f"got {self.per_worker_gpu_share}")
return self
class PybindMirror(ABC):
''' A class containing the utilities for mirroring Python classes to
pybinding classes.
@ -2032,6 +2096,8 @@ class BaseLlmArgs(StrictBaseModel):
@field_validator("gpus_per_node", mode='before')
@classmethod
def validate_gpus_per_node(cls, v, info):
if os.getenv("RAY_LOCAL_WORLD_SIZE") is not None:
return info.data.get("tensor_parallel_size")
if v is None:
logger.warning(
f"Using default gpus_per_node: {torch.cuda.device_count()}")
@ -2750,6 +2816,13 @@ class TorchLlmArgs(BaseLlmArgs):
"Allows users to extend the functions of the RayGPUWorker class.",
status="prototype")
ray_placement_config: Optional[RayPlacementConfig] = Field(
default=None,
description=
"Placement config for RayGPUWorker. Only used with AsyncLLM and orchestrator_type='ray'.",
exclude=True,
status="prototype")
enable_sleep: bool = Field(
default=False,
description=
@ -3059,6 +3132,14 @@ class TorchLlmArgs(BaseLlmArgs):
)
return self
@model_validator(mode='after')
def validate_ray_placement_config(self) -> 'TorchLlmArgs':
if self.ray_placement_config is not None and self.orchestrator_type != "ray":
raise ValueError(
"ray_placement_config is only supported with orchestrator_type='ray'"
)
return self
def get_executor_config(
self,
_hf_model_dir: Optional[Path] = None,

View File

@ -1,3 +1,5 @@
import base64
import pickle # nosec B403
from typing import Optional
import torch
@ -56,12 +58,20 @@ class WorkerExtension:
raise ValueError(f"Device UUID {device_uuid} not found in ipc_handles")
weights = {}
all_handles = ipc_handles[device_uuid]
serialized_handles = ipc_handles[device_uuid]
if isinstance(serialized_handles, str):
# Data is base64-encoded pickled bytes - deserialize it
logger.info("Deserializing base64-encoded weight handles")
all_handles = pickle.loads(base64.b64decode(serialized_handles)) # nosec B301
else:
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles
for param_name, tensor_handle in all_handles:
func, args = tensor_handle
list_args = list(args)
list_args[6] = self.device_id # Set target device
list_args[6] = self.device_id
tensor = func(*list_args)
weights[param_name] = tensor
@ -88,7 +98,7 @@ class WorkerExtension:
logger.error("Encountered an error in update_weights")
raise e
def check_weights_updated(self):
def check_weights_updated(self) -> bool:
"""Check if the weights are updated to 0."""
weights_updated = True
for name, p in self.engine.model_engine.model.named_parameters():

View File

@ -968,6 +968,16 @@ class ResponsesStreamResponse(OpenAIBaseModel):
"response.incomplete"]
class MemoryUpdateRequest(OpenAIBaseModel):
tags: List[str] = Field(default=["model", "kv_cache"])
class UpdateWeightsRequest(OpenAIBaseModel):
weights: Optional[Dict[str, str]] = Field(
default=None,
description="Weight handles dict, or None to finalize update")
def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]:
if opaque_state is None:
return None

View File

@ -21,6 +21,7 @@ from starlette.routing import Mount
from transformers import AutoProcessor
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._torch.async_llm import AsyncLLM
# yapf: disable
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.executor.postproc_worker import PostprocParams
@ -46,9 +47,11 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
ChatMessage, CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
ErrorResponse, ModelCard,
ErrorResponse,
MemoryUpdateRequest, ModelCard,
ModelList, PromptTokensDetails,
ResponsesRequest, UsageInfo,
ResponsesRequest,
UpdateWeightsRequest, UsageInfo,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
@ -262,6 +265,16 @@ class OpenAIServer:
self.app.add_api_route("/v1/responses",
self.openai_responses,
methods=["POST"])
# RL-only endpoints
self.app.add_api_route("/release_memory",
self.release_memory,
methods=["POST"])
self.app.add_api_route("/resume_memory",
self.resume_memory,
methods=["POST"])
self.app.add_api_route("/update_weights",
self.update_weights,
methods=["POST"])
if self.llm.args.return_perf_metrics:
# register /prometheus/metrics
self.mount_metrics()
@ -298,6 +311,16 @@ class OpenAIServer:
self.app.add_api_route("/v1/chat/completions",
self.openai_mm_encoder,
methods=["POST"])
# RL-only endpoints
self.app.add_api_route("/release_memory",
self.release_memory,
methods=["POST"])
self.app.add_api_route("/resume_memory",
self.resume_memory,
methods=["POST"])
self.app.add_api_route("/update_weights",
self.update_weights,
methods=["POST"])
async def health(self) -> Response:
if self._check_health():
@ -990,6 +1013,20 @@ class OpenAIServer:
return JSONResponse(content={"detail": "None"})
async def release_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
assert isinstance(self.llm, AsyncLLM), "/release_memory endpoint is only supported with AsyncLLM()"
await self.llm.collective_rpc('sleep', args=(request.tags,))
return JSONResponse(content={"status": "success"})
async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse:
assert isinstance(self.llm, AsyncLLM), "/resume_memory endpoint is only supported with AsyncLLM()"
await self.llm.collective_rpc('wakeup', args=(request.tags,))
return JSONResponse(content={"status": "success"})
async def update_weights(self, request: UpdateWeightsRequest) -> JSONResponse:
assert isinstance(self.llm, AsyncLLM), "/update_weights endpoint is only supported with AsyncLLM()"
await self.llm.collective_rpc('update_weights', args=(request.weights,))
return JSONResponse(content={"status": "success"})
async def __call__(self, host, port, sockets: list[socket.socket] | None = None):
# Store the binding address for server registration

View File

@ -279,8 +279,25 @@ l0_dgx_h100:
tests:
- unittest/_torch/ray_orchestrator/multi_gpu -m "gpu2"
- unittest/llmapi/test_llm_multi_gpu_pytorch.py -m "gpu2"
- unittest/llmapi/test_async_llm.py -m "gpu2"
- accuracy/test_llm_api_pytorch_ray.py::TestLlama3_1_8BInstruct::test_pp2_ray
- examples/test_ray.py::test_llm_inference_distributed_ray[tp2]
- examples/test_ray.py::test_llm_inference_distributed_ray[pp2]
- examples/test_ray.py::test_llm_inference_distributed_ray[tep2]
- examples/test_ray.py::test_ray_disaggregated_serving[tp1]
- condition:
ranges:
system_gpu_count:
gte: 4
lte: 4
wildcards:
gpu:
- '*h100*'
linux_distribution_name: ubuntu*
terms:
stage: pre_merge
backend: pytorch
orchestrator: ray
tests:
- unittest/_torch/ray_orchestrator/multi_gpu -m "gpu4"
- unittest/llmapi/test_async_llm.py -m "gpu4"

View File

@ -144,6 +144,7 @@ l0_h100:
- unittest/_torch/executor
- unittest/_torch/ray_orchestrator/single_gpu
- unittest/llmapi/test_llm_pytorch.py
- unittest/llmapi/test_async_llm.py -m "not (gpu2 or gpu4)"
- examples/test_ray.py::test_llm_inference_async_ray
- condition:
ranges:

View File

@ -9,27 +9,23 @@ from utils.llm_data import llm_models_root
from tensorrt_llm import LLM
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import RayPlacementConfig
class DummyWorkerExtension:
def additional_method(self):
return "SUCCESS"
@pytest.mark.gpu2
def test_worker_extension():
llm = LLM(model=llm_models_root() /
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
ray_worker_extension_cls="test_executor.DummyWorkerExtension",
orchestrator_type="ray")
result = llm._collective_rpc("additional_method")
assert result[0] == "SUCCESS"
ray_worker_extension_cls=
"tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
orchestrator_type="ray",
tensor_parallel_size=2)
result = llm._collective_rpc("check_weights_updated")
assert isinstance(result[0], bool)
@pytest.mark.gpu4
def test_bundle_indices(monkeypatch):
"""Placement via bundle indices"""
def test_placement_env_vars(monkeypatch):
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
pg = None
@ -77,6 +73,52 @@ def test_bundle_indices(monkeypatch):
ray.shutdown()
@pytest.mark.gpu2
@pytest.mark.threadleak(enabled=False)
@pytest.mark.parametrize("n_gpus,bundle_indices", [
(2, [1]),
],
ids=["gpu2_tp1"])
def test_placement_api(monkeypatch, n_gpus, bundle_indices):
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
tp_size = n_gpus // 2
pg = None
try:
ray.init()
pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus)
ray.get(pg.ready())
print(f"Placement group ready with bundles {pg.bundle_specs}")
llm = LLM(
model=os.path.join(llm_models_root(), "llama-models-v2",
"TinyLlama-1.1B-Chat-v1.0"),
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
tensor_parallel_size=tp_size,
orchestrator_type="ray",
ray_placement_config=RayPlacementConfig(
placement_groups=[pg],
placement_bundle_indices=[bundle_indices],
per_worker_gpu_share=0.8,
),
)
inference_actor_uuids = llm._collective_rpc("report_device_id")
expected_uuids = [get_device_uuid(idx) for idx in bundle_indices]
print(
f"{inference_actor_uuids=}, all_uuids={[get_device_uuid(i) for i in range(n_gpus)]}"
)
assert sorted(inference_actor_uuids) == sorted(expected_uuids), \
f"Workers not placed on expected GPUs. Expected: {expected_uuids}, Got: {inference_actor_uuids}"
finally:
if pg is not None:
remove_placement_group(pg)
ray.shutdown()
@pytest.mark.gpu2
def test_cuda_visible_device(monkeypatch):
"""Placement via cuda_visible_device"""

View File

@ -211,6 +211,10 @@ methods:
annotation: Optional[str]
default: null
status: prototype
ray_placement_config:
annotation: Optional[tensorrt_llm.llmapi.llm_args.RayPlacementConfig]
default: null
status: prototype
enable_sleep:
annotation: bool
default: False

View File

@ -0,0 +1,137 @@
import os
import pytest
import ray
from ray.util.placement_group import placement_group, remove_placement_group
from utils.llm_data import llm_models_root
from utils.util import get_current_process_gpu_memory
from tensorrt_llm import AsyncLLM
from tensorrt_llm._torch.utils import get_device_uuid
from tensorrt_llm._torch.virtual_memory import ExecutorMemoryType
from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams
@pytest.mark.ray
@pytest.mark.asyncio
async def test_async_llm_awaitable():
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
prompt = "The future of AI is"
sampling_params = SamplingParams(temperature=0, max_tokens=12)
llm = await AsyncLLM(
model=llama_model_path,
enable_sleep=True,
cuda_graph_config=None,
kv_cache_config=kv_cache_config,
)
output = await llm.generate_async(prompt, sampling_params)
assert output.outputs[0].text
print("Output text:", output.outputs[0].text)
del llm
@pytest.mark.ray
@pytest.mark.gpu2
@pytest.mark.asyncio
@pytest.mark.parametrize("num_cycles", [3], ids=lambda x: f"{x}_cycle")
async def test_async_llm_release_resume(process_gpu_memory_info_available, num_cycles):
llama_model_path = str(llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=4096)
prompt = "The future of AI is"
sampling_params = SamplingParams(temperature=0, max_tokens=12)
tags = [tag.value for tag in ExecutorMemoryType]
async with AsyncLLM(
model=llama_model_path,
enable_sleep=True,
cuda_graph_config=None,
kv_cache_config=kv_cache_config,
tensor_parallel_size=2,
) as llm:
# Generate baseline
output_before = await llm.generate_async(prompt, sampling_params)
baseline_text = output_before.outputs[0].text
for cycle in range(num_cycles):
memory_usage_active = get_current_process_gpu_memory(True) / 1024**3
print(f"[Cycle {cycle + 1}] Memory usage before release: {memory_usage_active:.2f} GB")
await llm.release(tags)
memory_usage_released = get_current_process_gpu_memory(True) / 1024**3
if process_gpu_memory_info_available:
print(
f"[Cycle {cycle + 1}] Memory usage after release: {memory_usage_released:.2f} GB"
)
assert memory_usage_released < memory_usage_active, (
f"Released memory ({memory_usage_released:.2f} GB) should be < "
f"active memory ({memory_usage_active:.2f} GB)"
)
await llm.resume(tags)
memory_usage_resumed = get_current_process_gpu_memory(True) / 1024**3
print(f"[Cycle {cycle + 1}] Memory usage after resume: {memory_usage_resumed:.2f} GB")
if process_gpu_memory_info_available:
assert memory_usage_resumed > memory_usage_released, (
f"Resumed memory ({memory_usage_resumed:.2f} GB) should be > "
f"released memory ({memory_usage_released:.2f} GB)"
)
output_after = await llm.generate_async(prompt, sampling_params)
text_after = output_after.outputs[0].text
print(f"[Cycle {num_cycles}] Generated text after release/resume: {text_after}")
assert baseline_text == text_after, (
f"Generated text mismatch after {num_cycles} cycle(s): "
f"'{baseline_text}' != '{text_after}'"
)
@pytest.mark.ray
@pytest.mark.gpu4
@pytest.mark.asyncio
@pytest.mark.threadleak(enabled=False)
async def test_async_llm_placement_api(monkeypatch):
monkeypatch.setenv("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", "1")
n_gpus = 4
bundle_indices = [2, 3]
tp_size = len(bundle_indices)
pg = None
try:
ray.init()
pg = placement_group([{"GPU": 1, "CPU": 1}] * n_gpus)
ray.get(pg.ready())
print(f"Placement group ready with bundles {pg.bundle_specs}")
llm = await AsyncLLM(
model=os.path.join(
str(llm_models_root()), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"
),
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
tensor_parallel_size=tp_size,
placement_groups=[pg],
placement_bundle_indices=[bundle_indices],
per_worker_gpu_share=0.8,
)
inference_actor_uuids = await llm.collective_rpc("report_device_id")
expected_uuids = [get_device_uuid(idx) for idx in bundle_indices]
print(f"{inference_actor_uuids=}, all_uuids={[get_device_uuid(i) for i in range(n_gpus)]}")
assert sorted(inference_actor_uuids) == sorted(expected_uuids), (
f"Workers not placed on expected GPUs. Expected: {expected_uuids}, Got: {inference_actor_uuids}"
)
finally:
if pg is not None:
remove_placement_group(pg)
ray.shutdown()