From ea6cd76c55faadcdae154c815489b60b1860a07f Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:23:43 +0800 Subject: [PATCH 1/7] [None][refactor] simplify get_stats and get_kvcache_events with rpc (#9980) Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/proxy.py | 233 +++++++++++--------- tensorrt_llm/executor/rpc_proxy.py | 112 +++++++++- tensorrt_llm/executor/rpc_proxy_mixin.py | 114 +--------- tensorrt_llm/executor/rpc_worker_mixin.py | 104 +++++---- tensorrt_llm/executor/utils.py | 2 - tensorrt_llm/executor/worker.py | 146 ++---------- tests/unittest/llmapi/test_executor.py | 32 ++- tests/unittest/llmapi/test_llm.py | 9 +- tests/unittest/llmapi/test_llm_multi_gpu.py | 9 +- tests/unittest/llmapi/test_llm_pytorch.py | 57 +++++ 10 files changed, 402 insertions(+), 416 deletions(-) diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 13ff28023e..f9c502f85d 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -1,9 +1,10 @@ import atexit import concurrent.futures +import json +import os import threading -import time import weakref -from typing import Dict, Optional, Union +from typing import Dict, List, Optional import torch import zmq @@ -22,9 +23,11 @@ from .ipc import FusedIpcQueue, IpcQueue from .postproc_worker import PostprocWorker, PostprocWorkerConfig from .request import CancellingRequest, GenerationRequest from .result import GenerationResult, IterationResult -from .utils import (ErrorResponse, IntraProcessQueue, WorkerCommIpcAddrs, - create_mpi_comm_session, get_spawn_proxy_process_env, - is_llm_response, print_alive_threads) +from .rpc import RPCClient +from .rpc.rpc_common import get_unique_ipc_addr +from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session, + get_spawn_proxy_process_env, is_llm_response, + print_alive_threads) from .worker import GenerationExecutorWorker, worker_main __all__ = [ @@ -89,19 +92,27 @@ class GenerationExecutorProxy(GenerationExecutor): "llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get( "llm_args", None) is not None else None + # Generate RPC address and key for stats RPC + self.rpc_addr = get_unique_ipc_addr() + self.hmac_key = os.urandom(32) + worker_kwargs = dict(**worker_kwargs, worker_queues=self._setup_queues(), postproc_worker_config=postproc_worker_config, - is_llm_executor=False) + is_llm_executor=False, + rpc_addr=self.rpc_addr, + hmac_key=self.hmac_key) if "log_level" not in worker_kwargs: worker_kwargs["log_level"] = logger.level self.dispatch_result_thread: Optional[ManagedThread] = None - self.dispatch_stats_thread: Optional[ManagedThread] = None - self.dispatch_kv_cache_events_thread: Optional[ManagedThread] = None + self.rpc_client: Optional[RPCClient] = None self._start_executor_workers(worker_kwargs) + # Create RPC client after workers are started (worker starts RPC server) + self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key) + # MPI registers its joiner using threading._register_atexit if possible. # These functions run before atexit.register, so to avoid deadlock, # we have to notify workers to exit before MPI starts to wait them. @@ -128,19 +139,11 @@ class GenerationExecutorProxy(GenerationExecutor): socket_type=zmq.PULL if self.enable_postprocess_parallel else zmq.PAIR, name="proxy_result_queue") - self.mp_stats_queue = FusedIpcQueue(is_server=True, - fuse_message=False, - name="proxy_stats_queue") - self.kv_cache_events_queue = FusedIpcQueue( - is_server=True, - fuse_message=False, - name="proxy_kv_cache_events_queue") + # Stats and KV events are now fetched via RPC, not IPC queues. return WorkerCommIpcAddrs( request_queue_addr=self.request_queue.address, worker_init_status_queue_addr=self.worker_init_status_queue.address, result_queue_addr=self.result_queue.address, - stats_queue_addr=self.mp_stats_queue.address, - kv_cache_events_queue_addr=self.kv_cache_events_queue.address, ) def abort_request(self, request_id: int) -> None: @@ -204,71 +207,8 @@ class GenerationExecutorProxy(GenerationExecutor): return True # success - def _iteration_result_task(self, - queue: Union[FusedIpcQueue, IntraProcessQueue], - result_singleton: IterationResult, - urgent: bool = False) -> bool: - if not urgent: - time.sleep(0.2) - - try: - data = queue.get() - except: - logger.debug( - "proxy.py: Error in _iteration_result_task: queue.get()") - return False - - if data is None: - logger.debug("proxy.py: _iteration_result_task: data is None") - return False # shutdown the thread - - data = data if isinstance(data, list) else [data] - queue = result_singleton.queue - async_queues = [] - - while queue.full(): - queue.get() - - try: - for d in data: - if d is None: - logger.debug("proxy.py: _iteration_result_task: d is None") - return False - - if isinstance(queue, _SyncQueue): - queue.put_nowait(d) - async_queues.append(queue) - else: - queue.put(d) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - - except AsyncQueue.EventLoopShutdownError: - # This happens in the last loop while the generate workflow is - # stopped, or when get_stats() or aget_stats() are not called by users - # and therefore event loop can already be closed. - logger.debug("proxy.py: EventLoopShutdownError") - except Exception as e: - logger.debug(f"proxy.py: Error in _iteration_result_task: {e}") - raise e - - return True # success - - def dispatch_stats_task(self) -> bool: - if not self._iter_stats_result: - # This can happen temporarily because the WAR in tensorrt_llm/bench/benchmark/throughput.py - # is not synchronized with self.dispatch_stats_thread. - logger.debug( - f"Skipping stats dispatch while self._iter_stats_result=None") - return True # Intended behavior, not an error - return self._iteration_result_task(self.mp_stats_queue, - self._iter_stats_result) - - def dispatch_kv_cache_events_task(self) -> bool: - return self._iteration_result_task(self.kv_cache_events_queue, - self._iter_kv_events_result, - urgent=True) + # NOTE: _iteration_result_task, dispatch_stats_task, and dispatch_kv_cache_events_task + # have been removed as stats and kv_events are now fetched via RPC directly. def _start_dispatch_threads(self): if self.dispatch_result_thread is None: @@ -277,25 +217,9 @@ class GenerationExecutorProxy(GenerationExecutor): weakref.WeakMethod(self.dispatch_result_task), error_queue=self._error_queue, name="proxy_dispatch_result_thread") - self.dispatch_stats_thread = ManagedThread( - weakref.WeakMethod(self.dispatch_stats_task), - error_queue=self._error_queue, - name="proxy_dispatch_stats_thread") - self.dispatch_kv_cache_events_thread = ManagedThread( - weakref.WeakMethod(self.dispatch_kv_cache_events_task), - error_queue=self._error_queue, - name="proxy_dispatch_kv_cache_events_thread") self.dispatch_result_thread.start() - # Only collect stats when submission - # is via LLM API - if self._iter_stats_result: - self.dispatch_stats_thread.start() - - if self._iter_kv_events_result: - self.dispatch_kv_cache_events_thread.start() - self._handle_background_error() def _start_executor_workers(self, worker_kwargs): @@ -387,23 +311,18 @@ class GenerationExecutorProxy(GenerationExecutor): ): self.dispatch_result_thread.stop() self.dispatch_result_thread.join() - if self.dispatch_stats_thread is not None and self.dispatch_stats_thread.is_alive( - ): - self.dispatch_stats_thread.stop() - self.dispatch_stats_thread.join() - if self.dispatch_kv_cache_events_thread is not None and self.dispatch_kv_cache_events_thread.is_alive( - ): - self.dispatch_kv_cache_events_thread.stop() - self.dispatch_kv_cache_events_thread.join() # step3: finish all remaining work + # close the RPC client + if self.rpc_client is not None: + self.rpc_client.close() + self.rpc_client = None + # close all the sockets self.request_queue.close() self.worker_init_status_queue.close() self.result_queue.close() - self.mp_stats_queue.close() - self.kv_cache_events_queue.close() self.workers_started = False self.mpi_session.shutdown() @@ -441,6 +360,104 @@ class GenerationExecutorProxy(GenerationExecutor): return result + def get_stats(self, timeout: float) -> List[dict]: + """Get iteration statistics from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime stats as dict. + """ + if self.rpc_client is None: + logger.warning("RPC client not initialized, cannot get stats") + return [] + + stats = self.rpc_client.fetch_stats_wait_async(timeout=timeout).remote() + return [json.loads(s) if isinstance(s, str) else s for s in stats] + + def aget_stats(self, timeout: float) -> IterationResult: + """Get iteration statistics from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime stats. + """ + # Initialize iteration result if needed + self._maybe_initialize_iteration_results() + + if self._iter_stats_result is None: + logger.warning("Iteration statistics are not available yet.") + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch stats via RPC and populate the result + try: + stats = self.rpc_client.fetch_stats_wait_async( + timeout=timeout).remote() + except Exception as e: + logger.debug(f"Error fetching stats via RPC: {e}") + stats = [] + + for stat in stats: + self._iter_stats_result.queue.put(stat) + + self._iter_stats_result.set_timeout(timeout) + return self._iter_stats_result + + def get_kv_events(self, timeout: float) -> List[dict]: + """Get iteration KV events from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime events as dict. + """ + if self.rpc_client is None: + logger.warning("RPC client not initialized, cannot get kv events") + return [] + + try: + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + return [json.loads(e) if isinstance(e, str) else e for e in events] + except Exception as e: + logger.error(f"Error fetching kv events via RPC: {e}") + return [] + + def aget_kv_events(self, timeout: float) -> IterationResult: + """Get iteration KV events from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime events. + """ + # Initialize iteration result if needed + self._maybe_initialize_iteration_results() + + if self._iter_kv_events_result is None: + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch kv events via RPC and populate the result + try: + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + except Exception as e: + logger.debug(f"Error fetching kv events via RPC: {e}") + events = [] + + for event in events: + self._iter_kv_events_result.queue.put(event) + + self._iter_kv_events_result.set_timeout(timeout) + return self._iter_kv_events_result + def __del__(self): self.shutdown() diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 09f93afb80..722609dea6 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -1,11 +1,13 @@ +import json import threading -from typing import Optional +from typing import List, Optional from ..llmapi.mpi_session import MpiPoolSession, MpiSession -from ..llmapi.utils import logger_debug +from ..llmapi.utils import logger_debug, print_colored from ..logger import logger from .executor import GenerationExecutor from .postproc_worker import PostprocWorkerConfig +from .result import IterationResult from .rpc_proxy_mixin import RpcExecutorMixin from .rpc_worker import RpcWorker from .utils import create_mpi_comm_session, get_spawn_proxy_process_env @@ -69,20 +71,110 @@ class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor): **self.worker_kwargs) def _setup_mainloop_with_tasks(self): - """Setup mainloop with all tasks needed for RpcProxy.""" + """Setup mainloop with tasks needed for RpcProxy. + + Note: Stats and kv_events are now fetched on-demand via direct RPC calls + (get_stats, aget_stats, get_kv_events, aget_kv_events), not via streaming loops. + """ tasks = [ self._fetch_responses_loop_async, - self._fetch_stats_loop_async, ] - # Only add kv_cache_events loop if it's enabled - if self._iter_kv_events_result: - tasks.append(self._fetch_kv_cache_events_loop_async) - # Call mixin's setup_mainloop with custom tasks self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop") - def fetch_stats_remote(self): - return self.rpc_client.fetch_stats().remote() + def get_stats(self, timeout: float) -> List[dict]: + """Get iteration statistics from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime stats as dict. + """ + try: + stats = self.rpc_client.fetch_stats_wait_async( + timeout=timeout).remote() + return [json.loads(s) if isinstance(s, str) else s for s in stats] + except Exception as e: + logger.debug(f"Error fetching stats via RPC: {e}") + return [] + + def aget_stats(self, timeout: float) -> IterationResult: + """Get iteration statistics from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime stats. + """ + self._maybe_initialize_iteration_results() + + if self._iter_stats_result is None: + print_colored("Iteration statistics are not available yet.\n", + "yellow") + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch stats via RPC and populate the result + try: + stats = self.rpc_client.fetch_stats_wait_async( + timeout=timeout).remote() + except Exception: + stats = [] + + for stat in stats: + self._iter_stats_result.queue.put(stat) + + self._iter_stats_result.set_timeout(timeout) + return self._iter_stats_result + + def get_kv_events(self, timeout: float) -> List[dict]: + """Get iteration KV events from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime events as dict. + """ + try: + # Events are already serialized by the worker's fetch_kv_cache_events_wait_async() + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + return [json.loads(e) if isinstance(e, str) else e for e in events] + except Exception as e: + logger.debug(f"Error fetching kv events via RPC: {e}") + return [] + + def aget_kv_events(self, timeout: float) -> IterationResult: + """Get iteration KV events from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime events. + """ + # Initialize iteration result if needed + self._maybe_initialize_iteration_results() + + if self._iter_kv_events_result is None: + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch kv events via RPC and populate the result + try: + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + except Exception: + events = [] + + for event in events: + self._iter_kv_events_result.queue.put(event) + + self._iter_kv_events_result.set_timeout(timeout) + return self._iter_kv_events_result def setup_engine_remote(self): return self.rpc_client.setup_engine().remote(need_response=True) diff --git a/tensorrt_llm/executor/rpc_proxy_mixin.py b/tensorrt_llm/executor/rpc_proxy_mixin.py index c7d7716f4f..ecbb86e25e 100644 --- a/tensorrt_llm/executor/rpc_proxy_mixin.py +++ b/tensorrt_llm/executor/rpc_proxy_mixin.py @@ -1,13 +1,12 @@ import asyncio import atexit -import json import os import threading from typing import Callable, List, Optional from .._utils import nvtx_range_debug from ..llmapi.tracer import global_tracer -from ..llmapi.utils import AsyncQueue, _SyncQueue +from ..llmapi.utils import _SyncQueue from ..logger import logger from .request import GenerationRequest from .result import GenerationResult @@ -47,15 +46,16 @@ class RpcExecutorMixin: Args: tasks: List of async coroutine functions to run. thread_name: Name for the main loop thread + + Note: Stats and kv_events are now fetched on-demand via direct RPC calls + (get_stats, aget_stats, get_kv_events, aget_kv_events), so the default + tasks only include the responses loop. Callers can still provide custom + tasks including stats/kv_events loops if needed for streaming use cases. """ if tasks is None: tasks = [ self._fetch_responses_loop_async, - self._fetch_stats_loop_async, ] - # Only add kv_cache_events loop if it's enabled - if self._iter_kv_events_result: - tasks.append(self._fetch_kv_cache_events_loop_async) async def main_loop_task(): await asyncio.gather(*[task() for task in tasks]) @@ -136,22 +136,6 @@ class RpcExecutorMixin: if async_queues: _SyncQueue.notify_many(event_loop, async_queues) - def handle_stats(self, stats): - """Handle stats received from RPC worker and put them into the stats result queue. - - Args: - stats: Statistics data from the RPC worker (can be dict, str, or list) - """ - self._handle_iteration_data(stats, self._iter_stats_result, "stats") - - def handle_kv_cache_events(self, events): - """Handle KV cache events received from RPC worker and put them into the events result queue. - - Args: - events: KV cache events data from the RPC worker (can be dict, str, or list) - """ - self._handle_iteration_data(events, self._iter_kv_events_result, "kv_cache_events") - async def _generic_fetch_loop_async( self, fetch_method_name: str, handler_method: Callable, method_name: str ): @@ -181,86 +165,6 @@ class RpcExecutorMixin: method_name="_fetch_responses_loop_async", ) - async def _fetch_stats_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_stats_loop_async", - handler_method=self.handle_stats, - method_name="_fetch_stats_loop_async", - ) - - async def _fetch_kv_cache_events_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_kv_cache_events_loop_async", - handler_method=self.handle_kv_cache_events, - method_name="_fetch_kv_cache_events_loop_async", - ) - - def _handle_iteration_data(self, data, result_singleton, data_type: str): - """Generic method to handle iteration data received from RPC worker. - - Args: - data: Data from the RPC worker (can be dict, str, or list) - result_singleton: The iteration result singleton to put data into - data_type: Type of data for logging (e.g., "stats", "kv_cache_events") - """ - # Make sure we have initialized the iteration results - self._maybe_initialize_iteration_results() - - if not result_singleton: - logger.debug(f"Skipping {data_type} handling while result_singleton=None") - return - - # Get the queue from the result singleton - queue = result_singleton.queue - async_queues = [] - - # Clear old data if queue is full (similar to _iteration_result_task) - while queue.full(): - queue.get() - - try: - # Handle different types of data - if isinstance(data, str): - # Already JSON serialized - data_json = data - elif isinstance(data, list): - # Skip empty lists to avoid putting nothing in the queue - if not data: - logger.debug(f"rpc_proxy.py: Skipping empty {data_type} list") - return - - # Handle list of data (multiple iterations) - for item in data: - if isinstance(item, str): - item_json = item - else: - item_json = json.dumps(item) - - if isinstance(queue, _SyncQueue): - queue.put_nowait(item_json) - async_queues.append(queue) - else: - queue.put(item_json) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - return - else: - # Convert dict/other to JSON string as expected by IterationResult - data_json = json.dumps(data) - - if isinstance(queue, _SyncQueue): - queue.put_nowait(data_json) - async_queues.append(queue) - else: - queue.put(data_json) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - - except AsyncQueue.EventLoopShutdownError: - # This happens when the event loop is already closed - logger.debug(f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}") - except Exception as e: - logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") - raise e + # NOTE: _fetch_stats_loop_async and _fetch_kv_cache_events_loop_async have been removed. + # Stats and kv_events are now fetched on-demand via direct RPC calls + # (get_stats, aget_stats, get_kv_events, aget_kv_events) instead of streaming loops. diff --git a/tensorrt_llm/executor/rpc_worker_mixin.py b/tensorrt_llm/executor/rpc_worker_mixin.py index cab53e6b1d..c5c201bd07 100644 --- a/tensorrt_llm/executor/rpc_worker_mixin.py +++ b/tensorrt_llm/executor/rpc_worker_mixin.py @@ -1,4 +1,5 @@ import asyncio +import time from queue import Queue from threading import Event from typing import AsyncGenerator, Optional @@ -50,8 +51,9 @@ class RpcWorkerMixin: """Submits a request to the worker.""" with nvtx_range_debug("RpcWorker.submit", color="blue", category="Worker"): logger_debug(f"[worker] Submitting request {request.id}", color="green") - super().submit(request) + result = super().submit(request) logger_debug(f"[worker] Submitted request {request.id}", color="green") + return result def fetch_responses(self, timeout: Optional[float] = None) -> list: """Fetch responses from the response queue (blocking).""" @@ -99,54 +101,58 @@ class RpcWorkerMixin: f"[worker] RpcWorker {self.rank} quitting fetch_responses_loop_async", color="yellow" ) - async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: - """Async version of fetch_stats using asyncio.to_thread.""" - return await asyncio.to_thread(self.fetch_stats) - - async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list: - """Async version of fetch_kv_cache_events using asyncio.to_thread.""" - return await asyncio.to_thread(self.fetch_kv_cache_events) - - async def fetch_stats_loop_async( - self, timeout: Optional[float] = None - ) -> AsyncGenerator[list, None]: - """Stream stats in a loop until shutdown.""" - async for data in self._generic_fetch_loop_async( - fetch_method=self.fetch_stats_async, - serializer=self._stats_serializer, - method_name="fetch_stats_loop_async", - timeout=timeout, - ): - yield data - - async def fetch_kv_cache_events_loop_async( - self, timeout: Optional[float] = None - ) -> AsyncGenerator[list, None]: - """Stream KV cache events in a loop until shutdown.""" - async for data in self._generic_fetch_loop_async( - fetch_method=self.fetch_kv_cache_events_async, - serializer=self._kv_cache_events_serializer, - method_name="fetch_kv_cache_events_loop_async", - timeout=timeout, - ): - yield data - - async def _generic_fetch_loop_async( - self, fetch_method, serializer, method_name: str, timeout: Optional[float] = None - ) -> AsyncGenerator[list, None]: - """Generic method for fetching data in a loop. + async def fetch_stats_wait_async(self, timeout: Optional[float] = None) -> list: + """Poll for stats until available or timeout. Args: - fetch_method: The async method to call for fetching data - serializer: The serializer function to apply to each item - method_name: Name of the method for logging - timeout: Optional timeout between fetches + timeout: Max wait time in seconds. If None, fetch once without waiting. """ - while not self.shutdown_event.is_set(): - timeout = timeout or 0.1 - await asyncio.sleep(timeout) - data = await fetch_method() - # Always yield data, even if empty, to prevent the client looks like hanging - # TODO: Remove the empty data to reduce the IPC overhead - yield [serializer(item) for item in data] - logger_debug(f"[worker] RpcWorker {self.rank} quitting {method_name}", color="yellow") + logger_debug( + f"[worker] RpcWorker {self.rank} is fetching stats with timeout {timeout}", + color="yellow", + ) + start = time.time() + while True: + stats = await asyncio.to_thread(self.fetch_stats) + if stats or timeout is None: + break + if (time.time() - start) >= timeout: + break + await asyncio.sleep(0.1) + return [self._stats_serializer(s) for s in stats] + + async def fetch_kv_cache_events_wait_async(self, timeout: Optional[float] = None) -> list: + """Poll for KV cache events until available or timeout. + + Args: + timeout: Max wait time in seconds. If None, fetch once without waiting. + """ + start = time.time() + while True: + events = await asyncio.to_thread(self.fetch_kv_cache_events) + if events or timeout is None: + break + if (time.time() - start) >= timeout: + break + await asyncio.sleep(0.1) + return [self._kv_cache_events_serializer(e) for e in events] + + async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: + """Async version of fetch_stats using asyncio.to_thread. + + This method is exposed via RPC and can be called directly by the proxy. + Returns serialized stats (JSON strings) that can be sent over RPC. + """ + stats = await asyncio.to_thread(self.fetch_stats) + # Serialize stats before sending over RPC (IterationStats objects are not picklable) + return [self._stats_serializer(s) for s in stats] + + async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list: + """Async version of fetch_kv_cache_events using asyncio.to_thread. + + This method is exposed via RPC and can be called directly by the proxy. + Returns serialized events (JSON strings) that can be sent over RPC. + """ + events = await asyncio.to_thread(self.fetch_kv_cache_events) + # Serialize events before sending over RPC + return [self._kv_cache_events_serializer(e) for e in events] diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index 8a5f61bc36..e52ea481fb 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -142,8 +142,6 @@ class WorkerCommIpcAddrs(NamedTuple): request_queue_addr: tuple[str, Optional[bytes]] worker_init_status_queue_addr: tuple[str, Optional[bytes]] result_queue_addr: tuple[str, Optional[bytes]] - stats_queue_addr: tuple[str, Optional[bytes]] - kv_cache_events_queue_addr: tuple[str, Optional[bytes]] def is_llm_response(instance): diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 2199bee74a..c4917a86a5 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -1,11 +1,9 @@ import gc import os -import time import traceback from concurrent.futures import ProcessPoolExecutor from pathlib import Path -from queue import Queue -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import zmq @@ -18,25 +16,22 @@ from ..llmapi.llm_args import BaseLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, set_global_tracer -from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug, - print_traceback_on_error) +from ..llmapi.utils import ManagedThread, logger_debug, print_traceback_on_error from ..sampling_params import BatchedLogitsProcessor from .base_worker import BaseWorker, _init_hf_modules -from .executor import IterationResultQueue from .ipc import FusedIpcQueue, IpcQueue from .postproc_worker import (PostprocWorker, PostprocWorkerConfig, postproc_worker_main) from .request import CancellingRequest, GenerationRequest -from .result import IterationResult -from .utils import (ErrorResponse, RequestError, WorkerCommIpcAddrs, - has_event_loop) +from .rpc_worker_mixin import RpcWorkerMixin +from .utils import ErrorResponse, RequestError, WorkerCommIpcAddrs __all__ = [ "GenerationExecutorWorker", ] -class GenerationExecutorWorker(BaseWorker): +class GenerationExecutorWorker(RpcWorkerMixin, BaseWorker): def __init__( self, @@ -48,6 +43,8 @@ class GenerationExecutorWorker(BaseWorker): hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, + rpc_addr: Optional[str] = None, + hmac_key: Optional[bytes] = None, ) -> None: super().__init__( engine=engine, @@ -62,35 +59,18 @@ class GenerationExecutorWorker(BaseWorker): self.setup_engine() + # Setup RPC server for stats (skip init_rpc_worker to keep IPC response queue) + # Only set up if rpc_addr is provided (for stats RPC support) + if rpc_addr is not None: + self.rpc_addr = rpc_addr + self.hmac_key = hmac_key + self.start_rpc_server() # Reuse from RpcWorkerMixin + self.await_response_thread = ManagedThread( self.await_response_task, error_queue=self._error_queue, name="await_response_thread") - self.dispatch_stats_thread = ManagedThread( - self.dispatch_stats_task, - error_queue=self._error_queue, - name="dispatch_stats_thread") - - self.dispatch_kv_cache_events_thread = ManagedThread( - self.dispatch_kv_cache_events_task, - error_queue=self._error_queue, - name="dispatch_kv_cache_events_thread") - - def _create_iteration_result_queue(self, - it_result_queue: IterationResultQueue): - if not it_result_queue.is_initialized: - # not yet initialized - it_result_queue.is_initialized = True - if has_event_loop(): - _queue = AsyncQueue() - it_result_queue.queue = _queue.sync_q - it_result_queue.aqueue = _queue - else: - _queue = Queue() - it_result_queue.queue = _queue - it_result_queue.aqueue = None - def start_thread(self, thread: ManagedThread): if self.engine.can_enqueue_requests() and not thread.is_alive(): thread.start() @@ -98,74 +78,10 @@ class GenerationExecutorWorker(BaseWorker): def await_response_task(self) -> bool: return self._await_response_helper() - def _iteration_result_task(self, it_result_queue: IterationResultQueue, - engine_get_result_api: Callable, - result_singleton: IterationResult, - result_serializer: Callable) -> bool: - time.sleep(0.2) - async_queues = [] - queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue - try: - for results in engine_get_result_api(): - res = result_serializer(results) - if self._is_llm_executor and result_singleton: - # In this case, there's no ExecutorBindingProxy. - # Worker needs to take care of putting to result queue. - while queue.full(): - queue.get() - if isinstance(queue, _SyncQueue): - queue.put_nowait(res) - async_queues.append(queue) - else: - queue.put(res) - else: - # Send to ExecutorBindingProxy via IPC - queue.put(res) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - except AsyncQueue.EventLoopShutdownError: - # This happens in the last results loop while the generate workflow is stopped. - logger.debug("worker.py: EventLoopShutdownError") - except Exception as e: - logger.error(f"worker.py: Error in _iteration_result_task: {e}") - raise e - - return True # success - - def dispatch_stats_task(self) -> bool: - return self._iteration_result_task(self.stats_queues, self.fetch_stats, - self._iter_stats_result, - self._stats_serializer) - - def dispatch_kv_cache_events_task(self) -> bool: - if isinstance(self.engine, tllm.Executor): - # Check if the engine has a kv cache event manager - # If not, return an empty list for the events which will cause the thread to exit early. - event_manager = self.engine.get_kv_cache_event_manager() - if event_manager is None: - events_api = lambda: [None] - else: - events_api = event_manager.get_latest_events - return self._iteration_result_task(self.kv_events_queues, - events_api, - self._iter_kv_events_result, - self._kv_cache_events_serializer) - else: - return self._iteration_result_task( - self.kv_events_queues, self.engine.get_latest_kv_cache_events, - self._iter_kv_events_result, self._kv_cache_events_serializer) - def start(self): - # create iteration result queues - self._create_iteration_result_queue(self.stats_queues) - self._create_iteration_result_queue(self.kv_events_queues) - - # start threads + # Stats and KV events are now fetched on-demand via RPC, + # so we only need to start the response thread self.start_thread(self.await_response_thread) - self.start_thread(self.dispatch_kv_cache_events_thread) - if mpi_rank() == 0: - self.start_thread(self.dispatch_stats_thread) def shutdown(self): @@ -178,16 +94,9 @@ class GenerationExecutorWorker(BaseWorker): if self.engine is not None: if self.engine.can_enqueue_requests(): - if self.await_response_thread.is_alive(): self.await_response_thread.stop() self.await_response_thread.join() - if self.dispatch_stats_thread.is_alive(): - self.dispatch_stats_thread.stop() - self.dispatch_stats_thread.join() - if self.dispatch_kv_cache_events_thread.is_alive(): - self.dispatch_kv_cache_events_thread.stop() - self.dispatch_kv_cache_events_thread.join() self.engine.shutdown() self.engine = None @@ -240,6 +149,8 @@ def worker_main( hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, + rpc_addr: Optional[str] = None, + hmac_key: Optional[bytes] = None, ) -> None: mpi_comm().barrier() @@ -287,15 +198,6 @@ def worker_main( is_server=False, socket_type=zmq.DEALER, name="worker_init_status_queue") - mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr, - is_server=False, - fuse_message=True, - name="worker_stats_queue") - kv_cache_events_queue = FusedIpcQueue( - worker_queues.kv_cache_events_queue_addr, - is_server=False, - fuse_message=False, - name="worker_kv_cache_events_queue") if postproc_worker_config.enabled: # IPC queues for sending inputs to the postprocess parallel @@ -322,9 +224,6 @@ def worker_main( assert result_queues is not None for q in result_queues: q.put(None) - # Signal the stats thread in the proxy to quit - mp_stats_queue.put(None) - kv_cache_events_queue.put(None) postprocess_worker_futures = [] if is_leader and postproc_worker_config.enabled: @@ -370,7 +269,9 @@ def worker_main( is_llm_executor=is_llm_executor, hf_model_dir=hf_model_dir, tokenizer=tokenizer, - llm_args=llm_args) + llm_args=llm_args, + rpc_addr=rpc_addr, + hmac_key=hmac_key) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) @@ -396,11 +297,6 @@ def worker_main( else: worker.set_result_queue(result_queue) - # initialize the iteration result queues - worker._set_iteration_result_queue(worker.stats_queues, - mp_stats_queue) - worker._set_iteration_result_queue(worker.kv_events_queues, - kv_cache_events_queue) # Send ready signal with confirmation ready_msg = (ready_signal, None) if not worker_init_status_queue.notify_with_retry(ready_msg): diff --git a/tests/unittest/llmapi/test_executor.py b/tests/unittest/llmapi/test_executor.py index 2e0ef5f65f..338f6903b7 100644 --- a/tests/unittest/llmapi/test_executor.py +++ b/tests/unittest/llmapi/test_executor.py @@ -213,21 +213,33 @@ def _test_sync_generation_tp_inner(llama_7b_tp2_path: Path): result.outputs[0].token_ids) == ", neural network," try: - stats = await executor.aget_stats() - stats = json.loads(stats) - assert stats["iter"] == 0 - assert stats["cpuMemUsage"] > 0 - assert stats["gpuMemUsage"] > 0 - assert stats["inflightBatchingStats"]["numCtxTokens"] == 3 - assert stats["inflightBatchingStats"]["numGenRequests"] == 0 - assert stats["kvCacheStats"]["usedNumBlocks"] == 1 + stats_result = executor.aget_stats(timeout=2) + # aget_stats now returns IterationResult, iterate to get stats + async for stats_str in stats_result: + stats = json.loads(stats_str) if isinstance(stats_str, + str) else stats_str + assert stats["iter"] >= 0 + assert stats["cpuMemUsage"] > 0 + assert stats["gpuMemUsage"] > 0 + break # Just check first result except AsyncQueue.EventLoopShutdownError: pass asyncio.run(async_stats_task()) - stats = executor.get_stats() - assert json.loads(stats)["iter"] == 1 + # Poll for stats since RPC calls return immediately + import time + stats_list = [] + for _ in range(10): + stats_list = executor.get_stats(timeout=0.5) + if stats_list: + break + time.sleep(0.1) + + assert len(stats_list) > 0 + stats = json.loads(stats_list[0]) if isinstance(stats_list[0], + str) else stats_list[0] + assert stats["iter"] == 1 executor.shutdown() diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index fb6e24b81a..f8ffe8fc7b 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -4,6 +4,7 @@ import gc import json import os import sys +import time # Required for test_generate_with_seed to pass. # See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891 @@ -2193,6 +2194,7 @@ def llm_get_stats_test_harness(tp_size: int = 1, sampling_params=sampling_params): print(output) + time.sleep(2) results = llm.get_stats(2) validate_stats(results=results, @@ -2203,7 +2205,7 @@ def llm_get_stats_test_harness(tp_size: int = 1, enable_chunked_prefill=enable_chunked_prefill, enable_iter_req_stats=enable_iter_req_stats) - assert not llm.get_stats(2) + assert not llm.get_stats(0.5) # test that IterationResult()._done is properly set _ = llm.generate(prompts, sampling_params=sampling_params) @@ -2340,8 +2342,9 @@ def llm_get_stats_async_test_harness(tp_size: int = 1, async def task1(repetition_index: int): results = [] await asyncio.sleep( - 3) # ensure there's stats to collect for the assertion - async for stats in llm.get_stats_async(timeout=2): + 4) # ensure there's stats to collect for the assertion + async for stats in llm.get_stats_async( + 10): # it will return immediately results.append(stats) assert results diff --git a/tests/unittest/llmapi/test_llm_multi_gpu.py b/tests/unittest/llmapi/test_llm_multi_gpu.py index dd175a4809..971f25f11e 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu.py @@ -487,11 +487,12 @@ def test_llm_get_kv_cache_events_tp2(): # created + stored events assert events and len(events) >= 2 for event in events: + print(f"event: {event}") if event: - if event[0]["event_id"] == 0: - assert event[0]["data"]["type"] == "created" - elif event[0]["event_id"] == 1: - assert event[0]["data"]["type"] == "stored" + if event["event_id"] == 0: + assert event["data"]["type"] == "created" + elif event["event_id"] == 1: + assert event["data"]["type"] == "stored" @pytest.fixture(scope="module") diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 04d653b842..d90d51cd49 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,3 +1,4 @@ +import json import random import time from contextlib import contextmanager, nullcontext @@ -976,6 +977,62 @@ async def test_llm_rpc_streaming(): print(f"get result: {outputs}") +@skip_ray +def test_llm_rpc_get_stats(): + """Test that get_stats works with RPC orchestrator.""" + + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + enable_iter_perf_stats=True, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + # Generate some output to produce stats + for output in llm.generate( + prompts, sampling_params=SamplingParams(max_tokens=5)): + print(output) + + stats = llm.get_stats(timeout=5) + + assert len(stats) > 0, "Should have at least one stats entry" + # Stats should be JSON strings that can be parsed + parsed = json.loads(stats[0]) if isinstance(stats[0], str) else stats[0] + assert "iter" in parsed, "Stats should contain 'iter' field" + assert "cpuMemUsage" in parsed, "Stats should contain 'cpuMemUsage' field" + + +@skip_ray +@pytest.mark.asyncio +async def test_llm_rpc_get_stats_async(): + """Test that get_stats_async works with RPC orchestrator.""" + import json + + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + enable_iter_perf_stats=True, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + # Generate some output to produce stats + async for output in llm.generate_async( + prompts[0], sampling_params=SamplingParams(max_tokens=5)): + print(output) + + # Get stats via async API + stats_result = llm.get_stats_async(timeout=2) + + # Should be able to iterate over results + stats_count = 0 + async for stat in stats_result: + parsed = json.loads(stat) if isinstance(stat, str) else stat + assert "iter" in parsed, "Stats should contain 'iter' field" + stats_count += 1 + if stats_count >= 1: + break # Just verify we can get at least one + + assert stats_count > 0, "Should have received at least one stat" + + @pytest.mark.threadleak(enabled=False) @pytest.mark.part0 @skip_ray From 472fe497dc2f9e0d9c4c34305cc1bae5c39f55a6 Mon Sep 17 00:00:00 2001 From: Bo Li <22713281+bobboli@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:57:12 +0800 Subject: [PATCH 2/7] [None][chore] NVLinkOneSided AlltoAll Support zero local_num_tokens. (#9822) Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com> --- .../moeAlltoAllKernels.cu | 195 +++++++++++------- cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp | 1 - .../unittest/_torch/multi_gpu/test_moe_a2a.py | 1 + 3 files changed, 123 insertions(+), 74 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 1ee535bdbd..f9dd3377bd 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ int thread_idx = ThreadingPolicy::offset(); int local_token_idx = ThreadingPolicy::token_idx(); - if (local_token_idx >= local_num_tokens) + if (local_num_tokens == 0) { - return; - } - - // Prepare per-policy shared-memory tiles for this token - extern __shared__ int smem[]; - int* smem_topk_target_ranks; - int* smem_topk_send_indices; - int warps_per_block = blockDim.x / warpSize; - if constexpr (std::is_same::value) - { - int lane_id = threadIdx.x / warpSize; - smem_topk_target_ranks = smem + lane_id * TOP_K; - smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; + // Special case: If local_num_tokens == 0, + // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization. + // Other threads should return. + if (local_token_idx > 0) + return; } else { - smem_topk_target_ranks = smem; - smem_topk_send_indices = smem + TOP_K; - } + // Threads that do not have a token to process should return. + if (local_token_idx >= local_num_tokens) + return; - uint64_t already_copied = 0; - for (int k = 0; k < TOP_K; k++) - { - int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; - // Use contiguous partitioning to determine target rank - int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); - - if (already_copied & (1ULL << target_rank)) + // Prepare per-policy shared-memory tiles for this token + extern __shared__ int smem[]; + int* smem_topk_target_ranks; + int* smem_topk_send_indices; + int warps_per_block = blockDim.x / warpSize; + if constexpr (std::is_same::value) { + int lane_id = threadIdx.x / warpSize; + smem_topk_target_ranks = smem + lane_id * TOP_K; + smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; + } + else + { + smem_topk_target_ranks = smem; + smem_topk_send_indices = smem + TOP_K; + } + + uint64_t already_copied = 0; + for (int k = 0; k < TOP_K; k++) + { + int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; + // Use contiguous partitioning to determine target rank + int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); + + if (already_copied & (1ULL << target_rank)) + { + if (thread_idx == 0) + { + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + // Mirror to shared memory immediately + smem_topk_target_ranks[k] = -1; + smem_topk_send_indices[k] = -1; + } + continue; + } + + // Only one thread per warp should increment the counter + int dst_token_idx; if (thread_idx == 0) { - ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; - ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); + + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; // Mirror to shared memory immediately - smem_topk_target_ranks[k] = -1; - smem_topk_send_indices[k] = -1; + smem_topk_target_ranks[k] = target_rank; + smem_topk_send_indices[k] = dst_token_idx; } - continue; + already_copied |= 1ULL << target_rank; } + // Sync before dispatching data + ThreadingPolicy::sync(); - // Only one thread per warp should increment the counter - int dst_token_idx; - if (thread_idx == 0) - { - dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); - - ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; - ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; - // Mirror to shared memory immediately - smem_topk_target_ranks[k] = target_rank; - smem_topk_send_indices[k] = dst_token_idx; - } - already_copied |= 1ULL << target_rank; - } - // Sync before dispatching data - ThreadingPolicy::sync(); - - // Read staged routing once into registers per thread - int topk_target_ranks[TOP_K]; - int topk_send_indices[TOP_K]; + // Read staged routing once into registers per thread + int topk_target_ranks[TOP_K]; + int topk_send_indices[TOP_K]; #pragma unroll - for (int k = 0; k < TOP_K; ++k) - { - topk_target_ranks[k] = smem_topk_target_ranks[k]; - topk_send_indices[k] = smem_topk_send_indices[k]; + for (int k = 0; k < TOP_K; ++k) + { + topk_target_ranks[k] = smem_topk_target_ranks[k]; + topk_send_indices[k] = smem_topk_send_indices[k]; + } + + // Perform a single source load and TOP_K fanout per payload + for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) + { + uint8_t const* src_data = static_cast(ptrs.src_data_ptrs[payload_idx]); + int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; + uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; + + vectorized_dispatch(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, + payload_idx, ptrs, topk_target_ranks, topk_send_indices); + } + + ThreadingPolicy::sync(); } - // Perform a single source load and TOP_K fanout per payload - for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) - { - uint8_t const* src_data = static_cast(ptrs.src_data_ptrs[payload_idx]); - int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; - uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; - - vectorized_dispatch(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx, - ptrs, topk_target_ranks, topk_send_indices); - } - - ThreadingPolicy::sync(); - bool is_first_warp = threadIdx.x / warpSize == 0; if (is_first_warp) { @@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ bool is_last_token = false; if (lane_id == 0) { - int cnt = atomicAdd(ptrs.local_token_counter, 1); - is_last_token = cnt + 1 == local_num_tokens; + if (local_num_tokens != 0) + { + int cnt = atomicAdd(ptrs.local_token_counter, 1); + is_last_token = cnt + 1 == local_num_tokens; + } + else + { + is_last_token = true; + } } is_last_token = __shfl_sync(0xffffffff, is_last_token, 0); @@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) // Validate parameters TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); - TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.local_num_tokens >= 0); TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads); // Prepare kernel pointers struct @@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) if (params.one_block_per_token) { int grid_size = params.local_num_tokens; + // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. + if (grid_size == 0) + { + grid_size = 1; + } int shared_bytes = 2 * params.top_k * (int) sizeof(int); SWITCH_TOP_K(params.top_k, TOP_K, moeA2ADispatchKernel<<>>( @@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) else { int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. + if (grid_size == 0) + { + grid_size = 1; + } int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int); SWITCH_TOP_K(params.top_k, TOP_K, moeA2ADispatchKernel<<>>( @@ -897,9 +924,19 @@ __global__ void moeA2ACombineKernel( int local_token_idx = ThreadingPolicy::token_idx(); int const size_per_token = elements_per_token * sizeof(T); - if (local_token_idx >= local_num_tokens) + if (local_num_tokens == 0) { - return; + // Special case: If local_num_tokens == 0, + // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization. + // Other threads should return. + if (local_token_idx > 0) + return; + } + else + { + // Threads that do not have a token to process should return. + if (local_token_idx >= local_num_tokens) + return; } #if !DISABLE_SYNC_FOR_PROFILING @@ -951,6 +988,9 @@ __global__ void moeA2ACombineKernel( __syncthreads(); #endif + if (local_num_tokens == 0) + return; + // Get output location for this token (using src_data_ptrs[0] as output) T* token_output = static_cast(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token; @@ -1003,7 +1043,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) // Validate parameters TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); - TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.local_num_tokens >= 0); TLLM_CHECK(params.elements_per_token > 0); // Configure kernel launch @@ -1011,6 +1051,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) int const kWarpsPerBlock = kBlockSize / 32; // warpSize int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock); int grid_size_block = params.local_num_tokens; + // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. + if (grid_size_warp == 0) + { + grid_size_warp = 1; + } + if (grid_size_block == 0) + { + grid_size_block = 1; + } // Prepare kernel pointers struct for combine CombineKernelPointers kernel_ptrs = {}; // Zero-initialize diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index e11135ddfb..af6d7cb37d 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -186,7 +186,6 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c MoeA2ADataOffsets const& offsets = *reinterpret_cast(metainfo.data_ptr()); int64_t localNumTokens = tokenSelectedExperts.size(0); - TORCH_CHECK(localNumTokens > 0, "localNumTokens must be positive"); TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]"); diff --git a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py index 7defc0dae7..28b590d88e 100644 --- a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py +++ b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py @@ -566,6 +566,7 @@ class TestMoEAlltoAll: (4, [32, 32, 32, 32], 4), (4, [1, 1, 1, 1], 2), (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + (4, [32, 0, 16, 0], 2), ], indirect=["mpi_pool_executor"]) def test_combine(self, mpi_pool_executor, all_num_tokens, top_k): From a6a88985cf94a326bd0f8d6208c2ad1f7740199c Mon Sep 17 00:00:00 2001 From: William Zhang <133824995+2ez4bz@users.noreply.github.com> Date: Mon, 22 Dec 2025 03:32:49 -0800 Subject: [PATCH 3/7] [TRTLLM-9409][feat] Pass MRoPE tensors for EPD disagg (#9758) * Why? Certain VLMs like the Qwen family need more than just the multimodal embeddings in the language model, and need MRoPE position IDs and deltas. Prior to this commit, only the embeddings could be communicated from the encoder worker to the prefill worker. * What? This commit extends the `DisaggregatedParams` to include the MRoPE information. It also adjusts several pieces of code required to communicate that between E, P and D workers. Closes TRTLLM-9409. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> --- .../_torch/models/modeling_llava_next.py | 5 +- .../_torch/models/modeling_qwen2vl.py | 115 ++++++++++++++++-- tensorrt_llm/_torch/pyexecutor/llm_request.py | 27 +++- .../_torch/pyexecutor/model_engine.py | 41 +++++-- tensorrt_llm/_torch/pyexecutor/sampler.py | 22 +++- tensorrt_llm/disaggregated_params.py | 2 + tensorrt_llm/executor/result.py | 24 +++- tensorrt_llm/llmapi/llm.py | 40 +++++- tensorrt_llm/llmapi/mm_encoder.py | 8 +- .../multimodal/test_mm_encoder_standalone.py | 84 +++++-------- 10 files changed, 271 insertions(+), 97 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 0fd3a9a510..844c0d8958 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -527,6 +527,8 @@ class LlavaNextModel(PreTrainedModel): return if not DISAGG: self.mm_encoder = LlavaNextVisionModel(model_config) + else: + self.mm_encoder = None llm_model_config = copy.deepcopy(model_config) llm_model_config.pretrained_config = model_config.pretrained_config.text_config @@ -545,7 +547,8 @@ class LlavaNextModel(PreTrainedModel): if isinstance(weight_mapper, LlavaNextHfWeightMapper): weights = weight_mapper.preprocess_weights(weights) - self.mm_encoder.load_weights(weights) + if self.mm_encoder is not None: + self.mm_encoder.load_weights(weights) def filter_weights(weights: Dict): transformed_weights = {} diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 6740188f3d..d421b31de5 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -32,7 +32,8 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder, BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, - register_input_processor) + register_input_processor, + support_multimodal_disaggregated) from ...logger import logger from ...sampling_params import SamplingParams from ..attention_backend import AttentionMetadata @@ -865,6 +866,8 @@ class Qwen2VLModelBase(PreTrainedModel): mm_encoder_config = copy.deepcopy(model_config) self.mm_encoder = Qwen2VisionModelBase( mm_encoder_config, kwargs.get('vision_model_class', None)) + else: + self.mm_encoder = None def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]): config = model_config.pretrained_config @@ -953,24 +956,21 @@ class Qwen2VLModelBase(PreTrainedModel): """ VLM forward logic with inflight batching support. """ - num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations + num_context_requests = attn_metadata.num_contexts multimodal_params = kwargs.get("multimodal_params", []) mm_embeds = [] mrope_config = {} - # NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts. - mm_multimodal_params = [ - multimodal_param for multimodal_param in multimodal_params - if multimodal_param.multimodal_data.get("image", {}).get( - "pixel_values") is not None or multimodal_param.multimodal_data. - get("video", {}).get("pixel_values_videos") is not None - ] + # NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate + # the entries that do have multimodal data from those that correspond to text-only prompts. + mm_multimodal_params = self._get_requests_with_mm_data( + multimodal_params) if len(mm_multimodal_params) > 0: if not _is_disagg(): mm_embeds = get_multimodal_embeddings( encoder_forward_fn=self.mm_encoder.forward, multimodal_params=mm_multimodal_params) - else: + elif not getattr(self, "support_mm_disagg", False): raise NotImplementedError( "Qwen2VLModel does not support disaggregated inference yet. Please unset " f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'." @@ -995,6 +995,21 @@ class Qwen2VLModelBase(PreTrainedModel): logger.debug(f'output shape: {output_prob.shape}') return output_prob + def _get_requests_with_mm_data(self, multimodal_params): + mm_multimodal_params = [] + for multimodal_param in multimodal_params: + data = multimodal_param.multimodal_data + if ( + # The first 2 conditions check whether there is input on which inference should be run. + data.get("image", {}).get("pixel_values") is not None or + data.get("video", {}).get("pixel_values_videos") is not None + # This condition corresponds to when the embeddings are already populated, as is e.g. + # the case in EPD disagg in the prefill worker. + or data.get("multimodal_embedding")): + mm_multimodal_params.append(multimodal_param) + + return mm_multimodal_params + @register_vision_encoder(Qwen2VisionModelBase, vlm_base_model=Qwen2VisionTransformerPretrainedModel) @@ -1032,11 +1047,89 @@ class Qwen2VLModel(Qwen2VLModelBase): self.llm.load_weights(weights, weight_mapper) +class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase): + + def get_prompt_token_ids( + self, inputs: TextPrompt, + mm_handles: List[Dict[str, + Any]]) -> Tuple[List[int], List[int], List[int]]: + """ + Build input token ids with multimodal placeholders expanded to the number of MM tokens. + + Args: + inputs: Text prompt input container. Must contain a non-empty prompt string. + mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. + + Returns: + Tuple[List[int], List[int], List[int]]: + - expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token + - mm_token_length: per-image MM token lengths + - mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids + """ + # TODO: Move this function to the base input processor class when extending for more models + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + if not isinstance(mm_handles, list): + raise TypeError("mm_handles must be a list") + + if len(mm_handles) != 1: + # TODO: only support single multimodal item within a request for now + raise NotImplementedError( + "Only one mm_handle is supported for Qwen2.5 VL for now") + hidden_size = mm_handles[0]['tensor_size'][1] + assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size" + input_ids = self.tokenizer(text_prompt, + return_tensors="pt").input_ids[0] + + image_token_index = self.config.image_token_id + + image_mask = input_ids == image_token_index + image_positions = torch.where(image_mask)[0] + num_images = len(image_positions) + assert num_images == len( + mm_handles), "Number of images must match number of mm_handles" + total_mm_tokens = sum(mm_handle["tensor_size"][0] + for mm_handle in mm_handles) + final_length = len(input_ids) - num_images + total_mm_tokens + # Create output tensor + expanded_ids = torch.empty(final_length, dtype=input_ids.dtype) + placeholder_id = self.tllm_multimodal_token_id + + # Fill the expanded sequence + write_pos = 0 + image_cnt = 0 + mm_token_length = [] + mm_token_offsets = [] + for read_pos in range(len(input_ids)): + if input_ids[read_pos] == image_token_index: + # Replace with placeholder id + mm_token_num = mm_handles[image_cnt]["tensor_size"][0] + expanded_ids[write_pos:write_pos + mm_token_num] = \ + placeholder_id + mm_token_offsets.append(write_pos) + mm_token_length.append(mm_token_num) + write_pos += mm_token_num + image_cnt += 1 + else: + # Copy text token as-is + expanded_ids[write_pos] = input_ids[read_pos] + write_pos += 1 + + assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}" + assert mm_token_length[-1] + mm_token_offsets[ + -1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})" + return expanded_ids.to( + torch.int32).tolist(), mm_token_length, mm_token_offsets + + +@support_multimodal_disaggregated @register_vision_encoder(Qwen2VisionModelBase, vlm_base_model=Qwen2_5_VisionModel) @register_auto_model("Qwen2_5_VLForConditionalGeneration") @register_input_processor( - Qwen2VLInputProcessorBase, + Qwen2_5VLInputProcessorBase, model_type="qwen2_5_vl", placeholder_metadata=MultimodalPlaceholderMetadata( placeholder_map={ diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 5f3c39149c..b11cd11617 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -262,6 +262,8 @@ class PyResult: chunk_size=self._chunk_size) if return_generation_logits else None self._log_probs = LogProbStorage() if return_log_probs else None self._mm_embeddings = None + self._mrope_position_ids = None + self._mrope_position_deltas = None self._additional_context_outputs = { name: [] for name in additional_outputs @@ -293,6 +295,16 @@ class PyResult: self._mm_embeddings = SharedTensorContainer.from_tensor( mm_embeddings).dump_to_dict() + def set_mrope_position( + self, + mrope_position_ids: torch.Tensor, + mrope_position_deltas: torch.Tensor, + ): + self._mrope_position_ids = (SharedTensorContainer.from_tensor( + mrope_position_ids).dump_to_dict()) + self._mrope_position_deltas = (SharedTensorContainer.from_tensor( + mrope_position_deltas).dump_to_dict()) + def transfer_remaining_device_logits(self): """Finalize any remaining generation logits transfers (for chunked mode)""" if self._generation_logits: @@ -352,6 +364,18 @@ class PyResult: def mm_embedding_handle(self) -> Dict[str, Any] | None: return self._mm_embeddings + @property + def mrope_position_ids_handle(self) -> Dict[str, Any] | None: + # NOTE: when populated, the returned `dict` contains the information necessary to rebuild + # the `SharedTensorContainer` using the `from_dict` class method. + return self._mrope_position_ids + + @property + def mrope_position_deltas_handle(self) -> Dict[str, Any] | None: + # NOTE: when populated, the returned `dict` contains the information necessary to rebuild + # the `SharedTensorContainer` using the `from_dict` class method. + return self._mrope_position_deltas + @property def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None: if self._additional_context_outputs is None: @@ -382,7 +406,8 @@ class LlmResult: py_result_properties = frozenset( ('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs', 'mm_embedding_handle', 'additional_context_outputs', - 'additional_generation_outputs')) + 'additional_generation_outputs', 'mrope_position_ids_handle', + 'mrope_position_deltas_handle')) def __init__(self, result: Union[bytes, tensorrt_llm.bindings.executor.Result], diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1a62c5beca..96ade56beb 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2213,13 +2213,14 @@ class PyTorchModelEngine(ModelEngine): mrope_position_deltas).expand( 3, 1, 1) mrope_position_ids.append(gen_mrope_position_ids) - multimodal_params.to_device( - "multimodal_data", - "cuda", - pin_memory=True, - target_keywords=[ - "mrope_config.mrope_position_deltas" - ]) + if mrope_position_deltas.device.type == "cpu": + multimodal_params.to_device( + "multimodal_data", + "cuda", + pin_memory=True, + target_keywords=[ + "mrope_config.mrope_position_deltas" + ]) multimodal_params_list.append(multimodal_params) request.py_batch_idx = request.py_seq_slot @@ -2448,8 +2449,9 @@ class PyTorchModelEngine(ModelEngine): # NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but # `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request # so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope. - mrope_position_ids = torch.cat(mrope_position_ids, - dim=-1).pin_memory() + mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) + if mrope_position_ids.device.type == "cpu": + mrope_position_ids = mrope_position_ids.pin_memory() self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_( mrope_position_ids[:, :, :total_num_tokens], non_blocking=True) final_position_ids = self.mrope_position_ids_cuda[:, :, : @@ -3362,7 +3364,26 @@ class PyTorchModelEngine(ModelEngine): mm_embeddings = list( torch.split(mm_embeddings[0], multimodal_chunks, dim=0)) - return {'mm_embeddings': mm_embeddings, 'logits': None} + # Extract mrope position data from multimodal_params if available + mrope_position_ids_list = [] + mrope_position_deltas_list = [] + for multimodal_param in multimodal_params: + mrope_config = multimodal_param.multimodal_data.get( + 'mrope_config', {}) + mrope_position_ids = mrope_config.get('mrope_position_ids') + mrope_position_deltas = mrope_config.get('mrope_position_deltas') + if mrope_position_ids is not None: + mrope_position_ids_list.append(mrope_position_ids) + if mrope_position_deltas is not None: + mrope_position_deltas_list.append(mrope_position_deltas) + + result = {'mm_embeddings': mm_embeddings, 'logits': None} + if mrope_position_ids_list: + result['mrope_position_ids'] = mrope_position_ids_list + if mrope_position_deltas_list: + result['mrope_position_deltas'] = mrope_position_deltas_list + + return result def _init_userbuffers(self, hidden_size): if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index b9bbb7cbf5..62a43a50be 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -21,7 +21,7 @@ from concurrent import futures from dataclasses import dataclass from functools import cached_property from itertools import repeat -from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast import numpy as np import torch @@ -199,6 +199,8 @@ class EarlyStopSampler(Sampler): @dataclass(kw_only=True) class MultimodalResult: mm_embeddings: List[torch.Tensor] + # Can be used to include e.g. `mrope_position_ids`, etc. + extra_data: Optional[Dict[str, Any]] = None def values(self): return vars(self).values() @@ -262,7 +264,10 @@ class EarlyStopWithMMResult(Sampler): resource_manager: Optional[ResourceManager] = None, ) -> SampleStateWithMMResult: # from model_outputs to MultimodalResult - data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"]) + data = MultimodalResult( + mm_embeddings=model_outputs.pop("mm_embeddings"), + extra_data={**model_outputs}, + ) return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data) @override @@ -276,7 +281,12 @@ class EarlyStopWithMMResult(Sampler): scheduled_requests = state.scheduled_requests assert not scheduled_requests.generation_requests mm_embeddings = state.data.mm_embeddings - for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings): + extra_data = state.data.extra_data or {} + mrope_position_ids = extra_data.get("mrope_position_ids", None) + mrope_position_deltas = extra_data.get("mrope_position_deltas", None) + for i, (request, mm_embedding) in enumerate( + zip(scheduled_requests.context_requests, mm_embeddings) + ): request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) @@ -287,6 +297,12 @@ class EarlyStopWithMMResult(Sampler): request.py_result.append_mm_embeddings(mm_embedding) + # Store mrope data if available + if mrope_position_ids is not None and mrope_position_deltas is not None: + request.py_result.set_mrope_position( + mrope_position_ids[i], mrope_position_deltas[i] + ) + @override def is_generation_model(self) -> bool: return False diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index 028bccbea0..4c0680bc94 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -40,6 +40,8 @@ class DisaggregatedParams: multimodal_hashes: Optional[List[List[int]]] = ( None # user provided mm hashes should be a list of 8 integers ) + mrope_position_ids_handle: Optional[Dict[str, Any]] = None + mrope_position_deltas_handle: Optional[Dict[str, Any]] = None def get_context_phase_params(self) -> tllme.ContextPhaseParams: return tllme.ContextPhaseParams( diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 603c567ed5..8d33d94a7f 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import json import time import weakref @@ -415,12 +416,19 @@ class GenerationResultBase: self.cached_tokens = getattr(response_result, 'cached_tokens', 0) self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter if context_phase_params is not None: - self.disaggregated_params = DisaggregatedParams( + existing_disagg_params = self.disaggregated_params + # Use `replace` to preserve things like `mrope_position_ids_handle` and + # `mrope_position_deltas_handle`. However, explicitly set + # `multimodal_embedding_handles=None` since they should no longer be needed. + self.disaggregated_params = dataclasses.replace( + existing_disagg_params or DisaggregatedParams(), request_type="context_only", first_gen_tokens=context_phase_params.first_gen_tokens, ctx_request_id=context_phase_params.req_id, opaque_state=context_phase_params.opaque_state, - draft_tokens=context_phase_params.draft_tokens) + draft_tokens=context_phase_params.draft_tokens, + multimodal_embedding_handles=None, + ) finish_reasons = response_result.finish_reasons # output_token_ids = (beams, tokens) @@ -440,6 +448,8 @@ class GenerationResultBase: if hasattr(response_result, 'mm_embedding_handle' ) and response_result.mm_embedding_handle is not None: self._mm_embedding_handle = response_result.mm_embedding_handle + mrope_position_ids_handle = response_result.mrope_position_ids_handle + mrope_position_deltas_handle = response_result.mrope_position_deltas_handle if self.disaggregated_params is not None: self.disaggregated_params.multimodal_embedding_handles = [ response_result.mm_embedding_handle @@ -451,6 +461,8 @@ class GenerationResultBase: response_result.mm_embedding_handle ], multimodal_hashes=self._multimodal_hashes) + self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle + self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle # Processing background errors here ASAF during generation. if self._background_error_handler and ( @@ -811,8 +823,12 @@ class GenerationResult(GenerationResultBase): def _repr_fields(self): return [ - 'request_id', 'prompt_token_ids', 'outputs', 'finished', - "context_logits", "mm_embedding_handle" + 'request_id', + 'prompt_token_ids', + 'outputs', + 'finished', + "context_logits", + "mm_embedding_handle", ] def __repr__(self) -> str: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 33774f0ed8..6d3410bf3c 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -89,8 +89,12 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult): def _repr_fields(self): return [ - "request_id", "prompt", "prompt_token_ids", "outputs", "finished", - "mm_embedding_handle" + "request_id", + "prompt", + "prompt_token_ids", + "outputs", + "finished", + "mm_embedding_handle", ] @@ -419,7 +423,7 @@ class BaseLLM: multimodal_params = None if is_mm_disagg: - if not self.input_processor.support_mm_disagg: + if not getattr(self.input_processor, "support_mm_disagg", False): raise ValueError( "Multimodal disaggregated inference is not supported for this model" ) @@ -436,14 +440,42 @@ class BaseLLM: mm_hashes = disaggregated_params.multimodal_hashes multimodal_input = MultimodalInput.from_components( mm_hashes, mm_token_positions, mm_token_length) + multimodal_data = {"multimodal_embedding": mm_handles} + if disaggregated_params.mrope_position_ids_handle is not None: + # NOTE: `PyTorchModelEngine` assumes both are present when using mrope. + assert disaggregated_params.mrope_position_deltas_handle is not None + mrope_config = {} + mrope_config[ + "mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle + mrope_config[ + "mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle + multimodal_data["mrope_config"] = mrope_config multimodal_params = MultimodalParams( multimodal_input=multimodal_input, - multimodal_data={"multimodal_embedding": mm_handles}) + multimodal_data=multimodal_data, + ) elif "prompt_token_ids" in inputs: prompt_token_ids = inputs['prompt_token_ids'] prompt = None query_token_ids = inputs.get("query_token_ids", None) + multimodal_data = {} + # NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit. + if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None: + # It looks like `PyTorchModelEngine` assumes both are present when using mrope? + if disaggregated_params.mrope_position_deltas_handle is None: + raise RuntimeError( + "`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both " + "be provided, or both `None`.") + mrope_config = {} + mrope_config[ + "mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle + mrope_config[ + "mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle + multimodal_data["mrope_config"] = mrope_config + if multimodal_data: + multimodal_params = MultimodalParams( + multimodal_data=multimodal_data) elif "prompt" in inputs: if 'multi_modal_data' in inputs: # TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash) diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index 8553d4678e..3ff85dd42b 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -101,14 +101,8 @@ class MultimodalEncoder(_TorchLLM): inputs = [prompt_inputs(i) for i in inputs] - def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any: - if isinstance(maybe_batched, list): - return maybe_batched[pos] - else: - return maybe_batched - futures = [] - for i, request_inputs in enumerate(inputs): + for request_inputs in inputs: future = self.generate_async(request_inputs) futures.append(future) diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index 4dc0564711..99154dd074 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -19,49 +19,23 @@ example_images = [ str(test_data_root / "61.jpg"), ] - -@pytest.fixture(scope="function") -def multimodal_model_config(): - """Get multimodal model configuration similar to integration tests""" - # You can extend this to support multiple models or get from environment - model_configs = { - 'llava-v1.6-mistral-7b-hf': { - 'model_name': - 'llava-v1.6-mistral-7b-hf', - 'hf_model_dir': - 'llava-hf/llava-v1.6-mistral-7b-hf', - 'model_dir': - llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf", - } - } - - return model_configs['llava-v1.6-mistral-7b-hf'] +_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf" +_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct" # TODO: Add multi-image in single chat test -@pytest.mark.parametrize("model_key", [ - "llava-v1.6-mistral-7b-hf", -]) +@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR]) @pytest.mark.parametrize("pd_disagg", [False, True]) -def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): +def test_single_image_chat(model_dir, pd_disagg): """Test processing single image using encoder (pass mm_embeddings) + LLM API. This test verifies that encoder (pass mm_embeddings) + LLM API produces identical results to standard llm generation (pass raw image) by comparing outputs. """ - # Get model configuration - if model_key != "llava-v1.6-mistral-7b-hf": - #TODO: add more model tests progressively here - pytest.skip( - f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now" - ) - - # Extract model information from config - encoder_model_dir = multimodal_model_config['model_dir'] # Test configuration max_tokens = 64 - free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2 + free_gpu_memory_fraction = 0.2 max_batch_size = 1 # Test data - OpenAI chat completion format @@ -76,15 +50,14 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): ) # Process multimodal data using encoder (pass mm_embeddings) - encoder = MultimodalEncoder(model=encoder_model_dir, - max_batch_size=max_batch_size) + encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size) cache_transceiver_cfg = CacheTransceiverConfig( backend="DEFAULT") if pd_disagg else None disable_overlap_scheduler = pd_disagg - llm = LLM(model=encoder_model_dir, + llm = LLM(model=model_dir, backend='pytorch', kv_cache_config=kv_cache_config, trust_remote_code=True, @@ -93,7 +66,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): llm_decode = None if pd_disagg: - llm_decode = LLM(model=encoder_model_dir, + llm_decode = LLM(model=model_dir, backend='pytorch', kv_cache_config=kv_cache_config, trust_remote_code=True, @@ -141,6 +114,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None" ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only" + outputs = llm.generate(inputs, sampling_params=sampling_params, disaggregated_params=ep_disaggregated_params) @@ -151,10 +125,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): pd_disaggregated_params = outputs[0].disaggregated_params pd_disaggregated_params.request_type = "generation_only" sampling_params = SamplingParams(max_tokens=max_tokens) - inputs[0][ - 'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it - inputs[0]['prompt_token_ids'] = outputs[ - 0].prompt_token_ids # use prompt token ids from encoder output + # remove multimodal data from input as decoder worker doesn't need it + inputs[0]['multi_modal_data'] = None + # use prompt token ids from encoder output + inputs[0]['prompt_token_ids'] = outputs[0].prompt_token_ids outputs = llm_decode.generate( inputs, @@ -199,24 +173,23 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): f"Log probabilities don't match for output {i}, generation {j}" -@pytest.mark.parametrize("model_key", [ - "llava-v1.6-mistral-7b-hf", -]) -def test_multi_request_batch_chat(model_key, multimodal_model_config): +@pytest.mark.parametrize( + "model_dir, encoder_max_batch_size", + [ + (_LLAVA_DIR, 3), + # Qwen2.5 VL's vision encoder seems to output different embeddings based on this value. + # The test only passes with this set to 1. + (_QWEN_2_5_VL_DIR, 1), + ], +) +def test_multi_request_batch_chat(model_dir, encoder_max_batch_size): """Test batching multiple multimodal requests and verify encoder path matches raw path. This mirrors test_single_image_chat but with a batch of size 3. """ - if model_key != "llava-v1.6-mistral-7b-hf": - pytest.skip( - f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now" - ) - - encoder_model_dir = multimodal_model_config['model_dir'] max_tokens = 64 free_gpu_memory_fraction = 0.6 - max_batch_size = 3 prompts = [ "Describe the natural environment in the image.", @@ -232,10 +205,10 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config): free_gpu_memory_fraction=free_gpu_memory_fraction, ) - encoder = MultimodalEncoder(model=encoder_model_dir, - max_batch_size=max_batch_size) + encoder = MultimodalEncoder(model=model_dir, + max_batch_size=encoder_max_batch_size) llm = LLM( - model=encoder_model_dir, + model=model_dir, backend='pytorch', kv_cache_config=kv_cache_config, max_batch_size=1, # fix batch size to reduce non-determinism in tests @@ -305,8 +278,7 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config): "Describe the weather in the image.", ], 2), ]) -def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates, - multimodal_model_config): +def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates): """Test mm_keys in KV cache events with cache reuse scenarios. This test verifies: @@ -316,7 +288,7 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates, - Same media + same prompts: full reuse (0 duplicate offsets) - Same media + different prompts: partial reuse (prefix blocks reused) """ - encoder_model_dir = multimodal_model_config['model_dir'] + encoder_model_dir = _LLAVA_DIR max_tokens = 16 free_gpu_memory_fraction = 0.6 From 0f308e95f9253feb1950cc6720a2b87769139b01 Mon Sep 17 00:00:00 2001 From: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:37:22 +0800 Subject: [PATCH 4/7] [None][chore] Remove logprobs constraint on trtllm-serve pytorch backend (#9911) Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> --- .../apps/_test_trtllm_serve_top_logprobs.py | 84 ++++--------------- 1 file changed, 16 insertions(+), 68 deletions(-) diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py index d287e5e35e..dc95ecf292 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py @@ -1,6 +1,3 @@ -import os -import tempfile - import openai import pytest import yaml @@ -22,34 +19,28 @@ def backend(request): @pytest.fixture(scope="module") -def temp_extra_llm_api_options_file(): - temp_dir = tempfile.gettempdir() - temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") - try: - extra_llm_api_options_dict = { - "enable_chunked_prefill": False, - "gather_generation_logits": True, - "kv_cache_config": { - "enable_block_reuse": False, - } +def temp_extra_llm_api_options_file(tmp_path_factory): + extra_llm_api_options_dict = { + "enable_chunked_prefill": False, + "gather_generation_logits": True, + "kv_cache_config": { + "enable_block_reuse": False, } + } - with open(temp_file_path, 'w') as f: - yaml.dump(extra_llm_api_options_dict, f) - - yield temp_file_path - finally: - if os.path.exists(temp_file_path): - os.remove(temp_file_path) + temp_file_path = tmp_path_factory.mktemp( + "config") / "extra_llm_api_options.yaml" + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + return temp_file_path @pytest.fixture(scope="module") def server(model_name: str, backend: str, temp_extra_llm_api_options_file: str): model_path = get_model_path(model_name) - args = [ - "--backend", f"{backend}", "--extra_llm_api_options", - temp_extra_llm_api_options_file - ] + args = ["--backend", f"{backend}"] + if backend == "trt": + args += ["--extra_llm_api_options", temp_extra_llm_api_options_file] with RemoteOpenAIServer(model_path, args) as remote_server: yield remote_server @@ -61,11 +52,7 @@ def async_client(server: RemoteOpenAIServer): @pytest.mark.asyncio(loop_scope="module") async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI, - model_name: str, backend: str): - # Skip if backend is PyTorch as it does not support topk logprobs when k > 1 - if backend == "pytorch": - pytest.skip("Topk logprobs is not supported") - + model_name: str): messages = [{ "role": "system", "content": "You are a helpful assistant." @@ -94,42 +81,3 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI, assert logprob_content.bytes is not None assert logprob_content.top_logprobs is not None assert len(logprob_content.top_logprobs) == 5 - - -@pytest.mark.asyncio(loop_scope="module") -async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI, - model_name: str, backend: str): - # Skip if backend is TRT because it is tested in test_chat_completion_top5_logprobs - if backend == "trt": - pytest.skip( - "TRT top logprobs is already tested in test_chat_completion_top5_logprobs" - ) - - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" - }] - # Test top_logprobs=1 - chat_completion = await async_client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0, - logprobs=True, - top_logprobs=1, - extra_body={ - "ignore_eos": True, - }) - logprobs = chat_completion.choices[0].logprobs - assert logprobs is not None and logprobs.content is not None - assert len(logprobs.content) == 10 - for logprob_content in logprobs.content: - assert logprob_content.token is not None - assert logprob_content.logprob is not None - assert logprob_content.bytes is not None - assert logprob_content.top_logprobs is not None - # Check that the top_logprobs contains only one entry - assert len(logprob_content.top_logprobs) == 1 From ba14a9308e2a38dd41523fe0e2c16c74c0cbe678 Mon Sep 17 00:00:00 2001 From: Emma Qiao Date: Tue, 23 Dec 2025 00:05:45 +0800 Subject: [PATCH 5/7] [None][infra] Waive failed cases on 12/22 (#10200) Signed-off-by: qqiao Signed-off-by: Yanchao Lu Co-authored-by: Yanchao Lu --- tests/integration/test_lists/waives.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 46137d5d18..7b89d12901 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -480,3 +480,8 @@ triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregat cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (https://nvbugs/5760737) unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_allreduce_pg_op[seqlen:16-hidden:1024] SKIP (https://nvbugs/5760740) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler] SKIP (https://nvbugs/5760747) +unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-strategy:8-dtype:bfloat16-hidden:8192-seqlen:[15]] SKIP (https://nvbugs/5761364) +triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854) +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822) +accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (https://nvbugs/5762852) +accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugs/5762852) From aaa87abf417bf4429c9c1c1abb37b80a41772659 Mon Sep 17 00:00:00 2001 From: JunyiXu-nv <219237550+JunyiXu-nv@users.noreply.github.com> Date: Tue, 23 Dec 2025 00:33:34 +0800 Subject: [PATCH 6/7] [TRTLLM-7906][feat] Support multiple post process for Responses API (#9908) Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> --- tensorrt_llm/serve/openai_server.py | 112 ++++-- tensorrt_llm/serve/postprocess_handlers.py | 47 ++- tensorrt_llm/serve/responses_utils.py | 341 ++++++++++++------ .../llmapi/apps/_test_openai_responses.py | 11 +- 4 files changed, 367 insertions(+), 144 deletions(-) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index c9699bb91f..644ae8e418 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -51,20 +51,22 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, MemoryUpdateRequest, ModelCard, ModelList, PromptTokensDetails, ResponsesRequest, + ResponsesResponse, UpdateWeightsRequest, UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, - chat_harmony_post_processor, chat_harmony_streaming_post_processor, - chat_response_post_processor, chat_stream_post_processor, - completion_response_post_processor, completion_stream_post_processor) + ResponsesAPIPostprocArgs, chat_harmony_post_processor, + chat_harmony_streaming_post_processor, chat_response_post_processor, + chat_stream_post_processor, completion_response_post_processor, + completion_stream_post_processor, responses_api_post_processor, + responses_api_streaming_post_processor) from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore, + ResponsesStreamingProcessor, ServerArrivalTimeMiddleware) from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds -from tensorrt_llm.serve.responses_utils import \ - process_streaming_events as responses_api_process_streaming_events from tensorrt_llm.serve.responses_utils import \ request_preprocess as responses_api_request_preprocess from tensorrt_llm.version import __version__ as VERSION @@ -119,9 +121,8 @@ class OpenAIServer: self.model_config = None # Enable response storage for Responses API - self.enable_store = True - if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0: - self.enable_store = False + self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled + self.conversation_store = ConversationHistoryStore() model_dir = Path(model) @@ -942,19 +943,39 @@ class OpenAIServer: return self.create_error_response(message=str(e), err_type="internal_error") async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response: - async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]: - async for event_data in responses_api_process_streaming_events( - request=request, - sampling_params=sampling_params, - generator=generator, - model_name=self.model, - conversation_store=self.conversation_store, - use_harmony=self.use_harmony, - reasoning_parser=self.llm.args.reasoning_parser, - tool_parser=self.tool_parser, - enable_store=self.enable_store - ): - yield event_data + async def create_response( + promise: RequestOutput, postproc_params: PostprocParams) -> ResponsesResponse: + await promise.aresult() + if self.postproc_worker_enabled: + response = promise.outputs[0]._postprocess_result + else: + args = postproc_params.postproc_args + response = await responses_api_create_response( + generator=promise, + request=request, + sampling_params=args.sampling_params, + model_name=self.model, + conversation_store=self.conversation_store, + generation_result=None, + enable_store=self.enable_store and request.store, + use_harmony=self.use_harmony, + reasoning_parser=args.reasoning_parser, + tool_parser=args.tool_parser, + ) + + return response + + async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + streaming_processor = args.streaming_processor + initial_responses = streaming_processor.get_initial_responses() + for initial_response in initial_responses: + yield initial_response + + async for res in promise: + pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + for pp_res in pp_results: + yield pp_res try: if request.background: @@ -977,38 +998,61 @@ class OpenAIServer: request=request, prev_response=prev_response, conversation_store=self.conversation_store, - enable_store=self.enable_store, + enable_store=self.enable_store and request.store, use_harmony=self.use_harmony, tokenizer=self.tokenizer if not self.use_harmony else None, model_config=self.model_config if not self.use_harmony else None, processor=self.processor if not self.use_harmony else None, ) + streaming_processor = None + if request.stream: + # Per-request streaming processor + streaming_processor = ResponsesStreamingProcessor( + request=request, + sampling_params=sampling_params, + model_name=self.model, + conversation_store=self.conversation_store, + enable_store=self.enable_store and request.store, + use_harmony=self.use_harmony, + reasoning_parser=self.llm.args.reasoning_parser, + tool_parser=self.tool_parser, + ) + + postproc_args = ResponsesAPIPostprocArgs( + model=self.model, + request=request, + sampling_params=sampling_params, + use_harmony=self.use_harmony, + reasoning_parser=self.llm.args.reasoning_parser, + tool_parser=self.tool_parser, + streaming_processor=streaming_processor, + ) + postproc_params = PostprocParams( + post_processor=responses_api_streaming_post_processor + if request.stream else responses_api_post_processor, + postproc_args=postproc_args, + ) promise = self.llm.generate_async( inputs=input_tokens, sampling_params=sampling_params, streaming=request.stream, + _postproc_params=postproc_params if self.postproc_worker_enabled else None, ) + if self.postproc_worker_enabled and request.store: + logger.warning("Postproc workers are enabled, request will not be stored!") + asyncio.create_task(self.await_disconnected(raw_request, promise)) if request.stream: return StreamingResponse( - create_stream_response(promise, request, sampling_params), + content=create_streaming_generator(promise, postproc_params), media_type="text/event-stream" ) else: - return await responses_api_create_response( - generator=promise, - request=request, - sampling_params=sampling_params, - model_name=self.model, - conversation_store=self.conversation_store, - generation_result=None, - enable_store=self.enable_store, - use_harmony=self.use_harmony, - reasoning_parser=self.llm.args.reasoning_parser, - tool_parser=self.tool_parser) + response = await create_response(promise, postproc_params) + return JSONResponse(content=response.model_dump()) except CppExecutorError: logger.error(traceback.format_exc()) # If internal executor error is raised, shutdown the server diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index aa56cc6e5b..01ffb648e2 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -1,11 +1,16 @@ from dataclasses import dataclass, field from typing import Any, List, Literal, Optional, Tuple, Union +from tensorrt_llm.serve.responses_utils import ResponsesStreamingProcessor +from tensorrt_llm.serve.responses_utils import \ + create_response_non_store as responses_api_create_response_non_store + from .._utils import nvtx_range_debug from ..executor import (DetokenizedGenerationResultBase, GenerationResult, GenerationResultBase) from ..executor.postproc_worker import PostprocArgs from ..executor.result import Logprob, TokenLogprobs +from ..llmapi import SamplingParams from ..llmapi.reasoning_parser import (BaseReasoningParser, ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer @@ -26,7 +31,8 @@ from .openai_protocol import (ChatCompletionLogProbs, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, - PromptTokensDetails, StreamOptions, ToolCall, + PromptTokensDetails, ResponsesRequest, + ResponsesResponse, StreamOptions, ToolCall, UsageInfo, to_disaggregated_params) from .tool_parser.base_tool_parser import BaseToolParser from .tool_parser.core_types import ToolCallItem @@ -543,3 +549,42 @@ def chat_harmony_streaming_post_processor( num_prompt_tokens=args.num_prompt_tokens, ) return response + + +@dataclass(kw_only=True) +class ResponsesAPIPostprocArgs(PostprocArgs): + model: str + request: ResponsesRequest + sampling_params: SamplingParams + use_harmony: bool + reasoning_parser: Optional[str] = None + tool_parser: Optional[str] = None + streaming_processor: Optional[ResponsesStreamingProcessor] = None + + +@nvtx_range_debug("responses_api_post_processor") +def responses_api_post_processor( + rsp: GenerationResult, + args: ResponsesAPIPostprocArgs) -> ResponsesResponse: + return responses_api_create_response_non_store( + generation_result=rsp, + request=args.request, + sampling_params=args.sampling_params, + model_name=args.model, + use_harmony=args.use_harmony, + reasoning_parser=args.reasoning_parser, + tool_parser=args.tool_parser, + ) + + +@nvtx_range_debug("responses_api_streaming_post_processor") +def responses_api_streaming_post_processor( + rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> List[str]: + if args.streaming_processor is None: + raise ValueError( + "streaming_processor is required for streaming post-processing") + outputs = args.streaming_processor.process_single_output(rsp) + if rsp._done: + outputs.append( + args.streaming_processor.get_final_response_non_store(rsp)) + return outputs diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index 4f0e4e55a6..9297422c6a 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -10,7 +10,7 @@ import uuid from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from copy import copy -from typing import Any, Literal, Optional, OrderedDict, Tuple, Union +from typing import Any, List, Literal, Optional, OrderedDict, Tuple, Union from openai.types.responses import (ResponseCompletedEvent, ResponseContentPartAddedEvent, @@ -41,6 +41,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent, from transformers import AutoProcessor, PretrainedConfig from tensorrt_llm.bindings import steady_clock_now +from tensorrt_llm.executor import GenerationResult from tensorrt_llm.inputs.utils import apply_chat_template from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.llmapi.llm import RequestOutput @@ -962,7 +963,7 @@ def _apply_tool_parser( return normal_text, calls -async def _create_output_content( +def _create_output_content( final_res: RequestOutput, reasoning_parser: Optional[str] = None, tool_parser: Optional[str] = None, @@ -1040,7 +1041,7 @@ async def _create_output_content( return output_items, output_messages -async def _create_output_content_harmony( +def _create_output_content_harmony( final_res: RequestOutput ) -> Tuple[list[ResponseOutputItem], list[Message]]: output_messages = _parse_output_tokens(final_res.outputs[0].token_ids) @@ -1057,12 +1058,53 @@ async def _create_output_content_harmony( return output_content, output_messages +def _create_response( + final_res: GenerationResult, + use_harmony: bool, + request: ResponsesRequest, + model_name: str, + response_creation_time: int, + sampling_params: SamplingParams, + reasoning_parser: Optional[str] = None, + tool_parser: Optional[str] = None, +) -> tuple[ResponsesResponse, list[Message | ChatCompletionMessageParam]]: + _responses_debug_log("================================================") + _responses_debug_log("RAW MODEL OUTPUT:") + _responses_debug_log(final_res.outputs) + _responses_debug_log("================================================") + + # prepare responses output + output_content = [] + if use_harmony: + output_content, output_messages = _create_output_content_harmony( + final_res) + else: + output_content, output_messages = _create_output_content( + final_res, reasoning_parser, tool_parser, request.tools) + + response = ResponsesResponse.from_request( + request=request, + sampling_params=sampling_params, + model_name=model_name, + created_time=response_creation_time, + output=output_content, + status=finish_reason_mapping(final_res.outputs[0].finish_reason), + ) + + _responses_debug_log("========== Response ===========") + _responses_debug_log(response) + _responses_debug_log("===============================") + + # return output_messages for store_response + return response, output_messages + + async def create_response( - generator, request: ResponsesRequest, sampling_params: SamplingParams, model_name: str, conversation_store: ConversationHistoryStore, + generator: Optional[AsyncGenerator[RequestOutput, None]] = None, generation_result: Optional[RequestOutput] = None, enable_store: bool = False, use_harmony: bool = True, @@ -1078,33 +1120,22 @@ async def create_response( if generation_result is not None: final_res = generation_result - else: + elif generator is not None: final_res = await generator if final_res is None: raise RuntimeError("No output generated or provided") - _responses_debug_log("================================================") - _responses_debug_log("RAW MODEL OUTPUT:") - _responses_debug_log(final_res.outputs) - _responses_debug_log("================================================") - # prepare responses output - output_content = [] - if use_harmony: - output_content, output_messages = await _create_output_content_harmony( - final_res) - else: - output_content, output_messages = await _create_output_content( - final_res, reasoning_parser, tool_parser, request.tools) - - response = ResponsesResponse.from_request( + response, output_messages = _create_response( + final_res=final_res, + use_harmony=use_harmony, request=request, - sampling_params=sampling_params, model_name=model_name, - created_time=response_creation_time, - output=output_content, - status=finish_reason_mapping(final_res.outputs[0].finish_reason), + response_creation_time=response_creation_time, + sampling_params=sampling_params, + reasoning_parser=reasoning_parser, + tool_parser=tool_parser, ) if enable_store and request.store: @@ -1112,9 +1143,34 @@ async def create_response( resp_msgs=output_messages, prev_resp_id=prev_response_id) - _responses_debug_log("========== Response ===========") - _responses_debug_log(response) - _responses_debug_log("===============================") + return response + + +def create_response_non_store( + generation_result: RequestOutput, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + use_harmony: bool = True, + create_time: Optional[int] = None, + reasoning_parser: Optional[str] = None, + tool_parser: Optional[str] = None, +) -> ResponsesResponse: + response_creation_time = create_time if create_time is not None else int( + time.time()) + + # prepare responses output + response, _ = _create_response( + final_res=generation_result, + use_harmony=use_harmony, + request=request, + model_name=model_name, + response_creation_time=response_creation_time, + sampling_params=sampling_params, + reasoning_parser=reasoning_parser, + tool_parser=tool_parser, + ) + return response @@ -1649,6 +1705,143 @@ def _generate_streaming_event_harmony( parser.last_content_delta) +class ResponsesStreamingProcessor: + + def __init__( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + create_time: Optional[int] = None, + conversation_store: Optional[ConversationHistoryStore] = None, + enable_store: bool = False, + use_harmony: bool = True, + reasoning_parser: Optional[str] = None, + tool_parser: Optional[str] = None, + ): + self.model_name = model_name + self.request = request + self.sampling_params = sampling_params + self.sequence_number = 0 + self.streaming_events_helper = ResponsesStreamingEventsHelper() + self.response_creation_time = create_time if create_time is not None else int( + time.time()) + self.final_res: Optional[RequestOutput] = None + self.reasoning_parser_dict: dict[int, BaseReasoningParser] = {} + self.tool_parser_dict: dict[int, BaseToolParser] = {} + self.stream_request_id = f"responses-api-{request.request_id}" + self.conversation_store = conversation_store + self.enable_store = enable_store + self.use_harmony = use_harmony + self.reasoning_parser = reasoning_parser + self.tool_parser = tool_parser + + def _send_event(self, event: OpenAIBaseModel): + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = self.sequence_number + self.sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, 'type', 'unknown') + return (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + + def get_initial_responses(self) -> List[str]: + initial_response = ResponsesResponse.from_request( + request=self.request, + sampling_params=self.sampling_params, + model_name=self.model_name, + created_time=self.response_creation_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + + resp_created = self._send_event( + self.streaming_events_helper.get_response_created_event( + initial_response)) + resp_in_progress = self._send_event( + self.streaming_events_helper.get_response_in_progress_event( + initial_response)) + return [resp_created, resp_in_progress] + + async def get_final_response( + self, + final_res: RequestOutput, + ) -> str: + final_response = await create_response( + generator=None, + request=self.request, + sampling_params=self.sampling_params, + model_name=self.model_name, + conversation_store=self.conversation_store, + generation_result=final_res, + enable_store=self.enable_store, + use_harmony=self.use_harmony, + create_time=self.response_creation_time, + reasoning_parser=self.reasoning_parser, + tool_parser=self.tool_parser, + ) + + return self._send_event( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) + + def get_final_response_non_store( + self, + final_res: RequestOutput, + ) -> str: + final_response = create_response_non_store( + generation_result=final_res, + request=self.request, + sampling_params=self.sampling_params, + model_name=self.model_name, + use_harmony=self.use_harmony, + create_time=self.response_creation_time, + reasoning_parser=self.reasoning_parser, + tool_parser=self.tool_parser, + ) + + return self._send_event( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) + + def process_single_output(self, res: GenerationResult) -> list[str]: + event_generator = None + output = res.outputs[0] + if self.use_harmony: + event_generator = _generate_streaming_event_harmony( + harmony_adapter=get_harmony_adapter(), + stream_request_id=self.stream_request_id, + output=output, + request=self.request, + streaming_events_helper=self.streaming_events_helper, + ) + + else: + event_generator = _generate_streaming_event( + output=output, + request=self.request, + finished_generation=res._done, + streaming_events_helper=self.streaming_events_helper, + reasoning_parser_id=self.reasoning_parser, + tool_parser_id=self.tool_parser, + reasoning_parser_dict=self.reasoning_parser_dict, + tool_parser_dict=self.tool_parser_dict, + ) + + if event_generator is None: + raise RuntimeError("Failed to generate streaming events") + + return [self._send_event(event) for event in event_generator] + + async def process_streaming_events( generator, request: ResponsesRequest, @@ -1661,97 +1854,31 @@ async def process_streaming_events( reasoning_parser: Optional[str] = None, tool_parser: Optional[str] = None, ) -> AsyncGenerator[str, None]: - sequence_number = 0 - response_creation_time = create_time if create_time is not None else int( - time.time()) - final_res: Optional[RequestOutput] = None - reasoning_parser_dict: dict[int, BaseReasoningParser] = {} - tool_parser_dict: dict[int, BaseToolParser] = {} - - def _send_event(event: OpenAIBaseModel): - nonlocal sequence_number - # Set sequence_number if the event has this attribute - if hasattr(event, 'sequence_number'): - event.sequence_number = sequence_number - sequence_number += 1 - # Get event type from the event's type field if it exists - event_type = getattr(event, 'type', 'unknown') - return (f"event: {event_type}\n" - f"data: {event.model_dump_json(indent=None)}\n\n") - - streaming_events_helper = ResponsesStreamingEventsHelper() - - initial_response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=response_creation_time, - output=[], - status="in_progress", - usage=None, - ).model_dump() - - yield _send_event( - streaming_events_helper.get_response_created_event(initial_response)) - yield _send_event( - streaming_events_helper.get_response_in_progress_event( - initial_response)) - - stream_request_id = f"responses-api-{request.request_id}" - harmony_adapter = get_harmony_adapter() - async for res in generator: - final_res = res - # TODO(JunyiXu-nv): handle multiple outputs - output = res.outputs[0] - - event_generator = None - if use_harmony: - event_generator = _generate_streaming_event_harmony( - harmony_adapter=harmony_adapter, - stream_request_id=stream_request_id, - output=output, - request=request, - streaming_events_helper=streaming_events_helper, - ) - - else: - event_generator = _generate_streaming_event( - output=output, - request=request, - finished_generation=res.finished, - streaming_events_helper=streaming_events_helper, - reasoning_parser_id=reasoning_parser, - tool_parser_id=tool_parser, - reasoning_parser_dict=reasoning_parser_dict, - tool_parser_dict=tool_parser_dict, - ) - - if event_generator is None: - raise RuntimeError("Failed to generate streaming events") - - for event in event_generator: - yield _send_event(event) - - final_response = await create_response( - generator=generator, + streaming_processor = ResponsesStreamingProcessor( request=request, sampling_params=sampling_params, model_name=model_name, + create_time=create_time, conversation_store=conversation_store, - generation_result=final_res, enable_store=enable_store, use_harmony=use_harmony, - create_time=response_creation_time, reasoning_parser=reasoning_parser, tool_parser=tool_parser, ) - yield _send_event( - ResponseCompletedEvent( - type="response.completed", - sequence_number=-1, - response=final_response.model_dump(), - )) + initial_responses = streaming_processor.get_initial_responses() + for initial_response in initial_responses: + yield initial_response + + async for res in generator: + final_res = res + events = streaming_processor.process_single_output(res) + for event in events: + yield event + + final_response = await streaming_processor.get_final_response(final_res) + + yield final_response class ServerArrivalTimeMiddleware: diff --git a/tests/unittest/llmapi/apps/_test_openai_responses.py b/tests/unittest/llmapi/apps/_test_openai_responses.py index 18271f6b76..e6902127cb 100644 --- a/tests/unittest/llmapi/apps/_test_openai_responses.py +++ b/tests/unittest/llmapi/apps/_test_openai_responses.py @@ -21,11 +21,18 @@ def model(request): return request.param +@pytest.fixture(scope="module", + params=[0, 2], + ids=["disable_processpool", "enable_processpool"]) +def num_postprocess_workers(request): + return request.param + + @pytest.fixture(scope="module") -def server(model: str): +def server(model: str, num_postprocess_workers: int): model_path = get_model_path(model) - args = [] + args = ["--num_postprocess_workers", f"{num_postprocess_workers}"] if model.startswith("Qwen3"): args.extend(["--reasoning_parser", "qwen3"]) elif model.startswith("DeepSeek-R1"): From 12e1cb8d7e15d8ed11fe2284787fd3a774491e3f Mon Sep 17 00:00:00 2001 From: tcherckez-nvidia <127761168+tcherckez-nvidia@users.noreply.github.com> Date: Mon, 22 Dec 2025 22:14:56 +0200 Subject: [PATCH 7/7] [#9717][chore] Refactor MoE code to use enums (#9910) Signed-off-by: Tal Cherckez <127761168+tcherckez-nvidia@users.noreply.github.com> --- .../custom_ops/fused_moe/torch_moe.py | 187 +++++++++--------- .../custom_ops/fused_moe/triton_moe.py | 24 +-- .../custom_ops/fused_moe/trtllm_moe.py | 114 +++++------ .../models/custom/modeling_nemotron_flash.py | 5 +- .../models/custom/modeling_nemotron_h.py | 5 +- .../transform/library/fused_moe.py | 33 ++-- .../transform/library/multi_stream_moe.py | 22 ++- .../transform/library/quantize_moe.py | 16 +- ...test_trtllm_flashinfer_symbol_collision.py | 5 +- .../library/test_ep_sharding.py | 2 +- .../singlegpu/custom_ops/test_ad_moe_op.py | 5 +- .../singlegpu/custom_ops/test_trtllm_moe.py | 72 +++---- .../triton_kernels/test_triton_moe.py | 14 +- 13 files changed, 252 insertions(+), 252 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index 12b065c5e7..09e2f7ef7d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -1,29 +1,29 @@ -from typing import Callable, List, Optional +from typing import Callable, List import torch import torch.nn.functional as F +from tensorrt_llm._torch.utils import ActivationType -def _resolve_activation(name: Optional[str]) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Returns an elementwise activation callable matching the given name. - Supported: "silu", "relu2". - Defaults to SiLU when name is None or empty. - """ - if not name: - name = "silu" - key = name.lower() - if key == "silu": - return F.silu - elif key == "relu2": +def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Returns an elementwise activation callable matching the given activation function. + Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2 + """ + assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], ( + f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'." + ) + torch_fn = None + if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu: + torch_fn = F.silu + elif act_fn == ActivationType.Relu2: def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) - return relu2 - else: - raise ValueError(f"Unsupported activation '{name}'. Use one of: silu, relu2.") + torch_fn = relu2 + return torch_fn def _template_moe( @@ -94,8 +94,8 @@ def torch_moe( w1_weight: List[torch.Tensor], w2_weight: List[torch.Tensor], w3_weight: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), apply_routing_on_input: bool = False, ) -> torch.Tensor: """ @@ -117,8 +117,8 @@ def torch_moe( - Llama4 MoE: sigmoid activated weights w1_weight: For per-expert lists: - • mlp_style=="gated_mlp": List of W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": List of W_up with shape (I, H) — up projection. + • is_gated_mlp==True: List of W1 with shape (I, H) — "gate" projection. + • is_gated_mlp==False: List of W_up with shape (I, H) — up projection. For stacked tensors (Llama4): • Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format w2_weight: @@ -129,17 +129,17 @@ def torch_moe( w3_weight: For per-expert lists with gated_mlp: • List of W3 with shape (I, H) — "up" (second) projection in gated MLP. - For mlp style or stacked tensors: + For is_gated_mlp==False or stacked tensors: • pass an empty list []; ignored. - mlp_style: + is_gated_mlp: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek/Llama4-style): + • is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style): y = W2( act(W1 x) * (W3 x) ) - • "mlp" (NemotronH-style 2-layer MLP): + • is_gated_mlp==False (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. - Supported: "silu" (default), "relu2" (ReLU then square). + Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square). apply_routing_on_input: If True (Llama4 pattern): multiply routing weights with INPUT before MLP Result: act(input * routing_weight) - routing affects activation @@ -148,55 +148,63 @@ def torch_moe( Returns: torch.Tensor: Output tensor with the same shape as the input x. """ - act_fn = _resolve_activation(act_fn) - style = mlp_style.lower() + torch_act_fn = _resolve_torch_fn(act_fn) # Detect if using stacked tensor format (Llama4) vs per-expert lists (standard) is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3 + # Todo: either change torch_moe to use a single condition, or refactor this code. + # it should be : + # is_gated_mlp: + # stacked: + # ... + # not stacked: + # . + # else: + # assert (not stacked) + # ... + # . if is_stacked: # Llama4 stacked tensor format - only supports gated_mlp - if style != "gated_mlp": - raise ValueError("Stacked tensor format only supports 'gated_mlp' style") + if not is_gated_mlp: + raise ValueError("Stacked tensor format only supports gated MLP style") w3_w1_stacked = w1_weight[0] # (E, 2*I, H) + intermediate_size = w3_w1_stacked.shape[1] // 2 w2_stacked = w2_weight[0] # (E, H, I) - def make_mlp(i: int): - gate_up = w3_w1_stacked[i] # (2*I, H) - intermediate_size = gate_up.shape[0] // 2 + def make_mlp(idx: int): + gate_up = w3_w1_stacked[idx] # (2*I, H) W3 = gate_up[:intermediate_size, :] # (I, H) W1 = gate_up[intermediate_size:, :] # (I, H) - W2 = w2_stacked[i] # (H, I) + W2 = w2_stacked[idx] # (H, I) weight_dtype = W1.dtype return lambda inp: F.linear( - act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3), + torch_act_fn(F.linear(inp.to(weight_dtype), W1)) + * F.linear(inp.to(weight_dtype), W3), W2, ) - mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])] + mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])] - elif style == "gated_mlp": + elif is_gated_mlp: # Standard per-expert list format with gated MLP def make_mlp(i: int): W1 = w1_weight[i] # (I, H) W2 = w2_weight[i] # (H, I) W3 = w3_weight[i] # (I, H) - return lambda inp: F.linear(act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2) - - mlps = [make_mlp(i) for i in range(len(w1_weight))] - - elif style == "mlp": - # Standard per-expert list format with simple MLP - def make_mlp(i: int): - W_up = w1_weight[i] # (I, H) - W_down = w2_weight[i] # (H, I) - return lambda inp: F.linear(act_fn(F.linear(inp, W_up)), W_down) + return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2) mlps = [make_mlp(i) for i in range(len(w1_weight))] else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + # Standard per-expert list format with simple MLP + def make_mlp(i: int): + W_up = w1_weight[i] # (I, H) + W_down = w2_weight[i] # (H, I) + return lambda inp: F.linear(torch_act_fn(F.linear(inp, W_up)), W_down) + + mlps = [make_mlp(i) for i in range(len(w1_weight))] return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input) @@ -209,8 +217,8 @@ def torch_moe_fake( w1_weight: List[torch.Tensor], w2_weight: List[torch.Tensor], w3_weight: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), apply_routing_on_input: bool = False, ) -> torch.Tensor: return torch.empty_like(x) @@ -296,23 +304,20 @@ def torch_quant_fp8_moe( w1_weight_scale: List[torch.Tensor], w2_weight_scale: List[torch.Tensor], w3_weight_scale: List[torch.Tensor], - mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp" - act_fn: str = "silu", # silu or relu2 + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """ - FP8 MoE op using quantized linear operations. - - Computes a Mixture-of-Experts layer similar to the reference auto_deploy::torch_moe op, but uses the - quantized FP8 linear op for expert computations. + FP8 MoE op using quantized linear operations. Computes a Mixture-of-Experts layer similar to the reference + auto_deploy::torch_moe op, but uses the quantized FP8 linear op for expert computations. Args: x: Input tensor of shape (B, H) or (B, S, H). - selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices. - routing_weights: Tensor of normalized routing weights. - w1_weight: - List of per-expert weight tensors: - • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": W_up with shape (I, H) — up projection. + selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) + containing expert indices.routing_weights: Tensor of normalized routing weights. + w1_weight: List of per-expert weight tensors: + • is_gated_mlp==True: W1 with shape (I, H) — "gate" projection. + • is_gated_mlp==False: W_up with shape (I, H) — up projection. w2_weight: List of per-expert weight tensors: • gated_mlp: W2 with shape (H, I) — down projection. @@ -323,21 +328,20 @@ def torch_quant_fp8_moe( • mlp: pass an empty list []; ignored. w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops. w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops. - mlp_style: + is_gated_mlp: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek-style): + • is_gated_mlp==True (default, Mixtral/DeepSeek-style): y = W2( act(W1 x) * (W3 x) ) - • "mlp" (NemotronH-style 2-layer MLP): + • is_gated_mlp==False (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. - Supported: "silu" (default), "relu2" (ReLU then square). + Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square). """ - act_fn = _resolve_activation(act_fn) - style = mlp_style.lower() + torch_act_fn = _resolve_torch_fn(act_fn) - if style == "gated_mlp": + if is_gated_mlp: def make_fp8_mlp(i): def mlp(inp): @@ -355,7 +359,7 @@ def torch_quant_fp8_moe( input_scale=w3_input_scale[i], weight_scale=w3_weight_scale[i], ) - prod = act_fn(gate_out) * up_out + prod = torch_act_fn(gate_out) * up_out return torch.ops.auto_deploy.torch_quant_fp8_linear( prod, w2_weight[i], @@ -368,7 +372,7 @@ def torch_quant_fp8_moe( mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] - elif style == "mlp": + else: def make_fp8_mlp(i): def mlp(inp): @@ -380,7 +384,7 @@ def torch_quant_fp8_moe( weight_scale=w1_weight_scale[i], ) return torch.ops.auto_deploy.torch_quant_fp8_linear( - act_fn(up_out), + torch_act_fn(up_out), w2_weight[i], bias=None, input_scale=w2_input_scale[i], @@ -391,9 +395,6 @@ def torch_quant_fp8_moe( mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - return _template_moe(x, selected_experts, routing_weights, mlps) @@ -411,8 +412,8 @@ def torch_quant_fp8_moe_fake( w1_weight_scale: List[torch.Tensor], w2_weight_scale: List[torch.Tensor], w3_weight_scale: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) @@ -434,8 +435,8 @@ def torch_quant_nvfp4_moe( w1_alpha: List[torch.Tensor], w2_alpha: List[torch.Tensor], w3_alpha: List[torch.Tensor], - mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp" - act_fn: str = "silu", # silu or relu2 + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """ FP4 MoE op using quantized linear operations. @@ -449,8 +450,8 @@ def torch_quant_nvfp4_moe( routing_weights: Tensor of normalized routing weights. w1_weight: List of per-expert weight tensors: - • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": W_up with shape (I, H) — up projection. + • is_gated_mlp==True: W1 with shape (I, H) — "gate" projection. + • is_gated_mlp==False: W_up with shape (I, H) — up projection. w2_weight: List of per-expert weight tensors: • gated_mlp: W2 with shape (H, I) — down projection. @@ -462,21 +463,20 @@ def torch_quant_nvfp4_moe( w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors. w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors. w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization. - mlp_style: + is_gated_mlp: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek-style): + • is_gated_mlp==True (default, Mixtral/DeepSeek-style): y = W2( act(W1 x) * (W3 x) ) - • "mlp" (NemotronH-style 2-layer MLP): + • is_gated_mlp==False (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. - Supported: "silu" (default), "relu2" (ReLU then square). + Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square). """ - act_fn = _resolve_activation(act_fn) - style = mlp_style.lower() + torch_act_fn = _resolve_torch_fn(act_fn) - if style == "gated_mlp": + if is_gated_mlp: def make_fp4_mlp(i): def mlp(inp): @@ -498,7 +498,7 @@ def torch_quant_nvfp4_moe( weight_scale=w3_weight_scale[i], alpha=w3_alpha[i], ) - prod = act_fn(gate_out) * up_out + prod = torch_act_fn(gate_out) * up_out return torch.ops.auto_deploy.torch_quant_nvfp4_linear( prod, w2_weight[i], @@ -512,7 +512,7 @@ def torch_quant_nvfp4_moe( mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] - elif style == "mlp": + else: def make_fp4_mlp(i): def mlp(inp): @@ -527,7 +527,7 @@ def torch_quant_nvfp4_moe( alpha=w1_alpha[i], ) return torch.ops.auto_deploy.torch_quant_nvfp4_linear( - act_fn(up_out), + torch_act_fn(up_out), w2_weight[i], bias=None, input_scale=w2_input_scale[i], @@ -539,9 +539,6 @@ def torch_quant_nvfp4_moe( mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - return _template_moe(x, selected_experts, routing_weights, mlps) @@ -562,8 +559,8 @@ def torch_quant_nvfp4_moe_fake( w1_alpha: List[torch.Tensor], w2_alpha: List[torch.Tensor], w3_alpha: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py index 9dcf544393..d33b752532 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py @@ -14,6 +14,8 @@ import torch.nn.functional as F import triton import triton.language as tl +from tensorrt_llm._torch.utils import ActivationType # noqa: F401 + from ...utils.logger import ad_logger @@ -601,15 +603,13 @@ def triton_fused_moe( routing_weights: torch.Tensor, w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "mlp", - act_fn: str = "relu2", + is_gated_mlp: bool = False, + act_fn: int = int(ActivationType.Relu2), ) -> torch.Tensor: """Triton unquantized MoE with 2-layer MLP and ReLU^2 activation.""" - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() - assert mlp_style == "mlp", "Triton backend only supports mlp style." - assert act_fn == "relu2", "Triton backend only supports relu2 activation." + assert not is_gated_mlp, "Triton backend only supports non gated MLP style." + assert act_fn == ActivationType.Relu2, "Triton backend only supports relu2 activation." x_shape = x.shape x2d = x.view(-1, x_shape[-1]) @@ -661,12 +661,12 @@ def triton_quant_fp8_moe( w1_weight_scale: torch.Tensor, # [E] stacked weight scales w2_weight_scale: torch.Tensor, # [E] stacked weight scales w3_weight_scale: torch.Tensor, # unused - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = False, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation.""" - if mlp_style != "mlp": - raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only") + if is_gated_mlp: + raise NotImplementedError("triton_quant_fp8_moe currently supports mlp only") x_shape = x.shape x2d = x.view(-1, x_shape[-1]) @@ -760,7 +760,7 @@ def triton_quant_fp8_moe( w1_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = False, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 827d47c44a..6fb5e560f3 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -26,8 +26,8 @@ def trtllm_moe_fused( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: x_shape = x.shape x = x.view(-1, x_shape[-1]) @@ -37,24 +37,24 @@ def trtllm_moe_fused( quant_scales = [] # Determine activation type - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() activation_type = ActivationType.Swiglu - if mlp_style == "gated_mlp": + if is_gated_mlp: # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) - if act_fn == "silu": + if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu else: - raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") - elif mlp_style == "mlp": + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + ) + else: # For non-gated MLP with ReLU^2 - if act_fn == "relu2": + if act_fn == ActivationType.Relu2: activation_type = ActivationType.Relu2 else: - raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." + ) return torch.ops.trtllm.fused_moe( x, @@ -77,8 +77,8 @@ def trtllm_moe_fused_fake( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) @@ -93,21 +93,12 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) -def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None: - supported_combinations = { - "gated_mlp": ["silu"], - "mlp": ["relu2"], - } - supported_act_fns = [ - act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list - ] - assert mlp_style in supported_combinations.keys(), ( - f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}." - ) - assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}." - assert act_fn in supported_combinations[mlp_style], ( - f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. " - f"Supported combinations: {supported_combinations}" +def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None: + assert (is_gated_mlp and act_fn == ActivationType.Silu) or ( + not is_gated_mlp and act_fn == ActivationType.Relu2 + ), ( + f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. " + f"Supported combinations: gated mlp with silu or mlp with relu2." ) @@ -128,8 +119,8 @@ def trtllm_quant_fp8_moe_fused( gemm1_dequant: torch.Tensor, # [E] gemm2_act_quant: torch.Tensor, # [E] gemm2_dequant: torch.Tensor, # [E] - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """ TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP. @@ -149,8 +140,8 @@ def trtllm_quant_fp8_moe_fused( gemm1_dequant: Precomputed gemm1 dequant scale [E] gemm2_act_quant: Precomputed gemm2 act quant scale [1] gemm2_dequant: Precomputed gemm2 dequant scale [E] - mlp_style: "gated_mlp" or "mlp" - act_fn: "silu" for gated_mlp, "relu2" for mlp + is_gated_mlp: True for gated_mlp, False for mlp + act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp Non-Gated MLP: activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t() @@ -159,7 +150,7 @@ def trtllm_quant_fp8_moe_fused( activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t() """ - _validate_mlp_style_and_act_fn(mlp_style, act_fn) + _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) # Store original shape and flatten to 2D x_shape = x.shape @@ -190,28 +181,27 @@ def trtllm_quant_fp8_moe_fused( # Todo: refactor this repeating code block # Determine activation type - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() - activation_type = ActivationType.Swiglu - if mlp_style == "gated_mlp": + if is_gated_mlp: # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) # For gated MLP, concatenate w1 and w3 as [w3, w1] w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H] fc1_expert_weights = w3_w1_stacked - if act_fn == "silu": + if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu else: - raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") - elif mlp_style == "mlp": + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + ) + else: # For non-gated MLP with ReLU^2 fc1_expert_weights = w1_weight.contiguous() - if act_fn == "relu2": + if act_fn == ActivationType.Relu2: activation_type = ActivationType.Relu2 else: - raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." + ) # Note! Outputting Float8_e4m3fn directly is not currently supported output = torch.ops.trtllm.fused_moe( @@ -248,10 +238,10 @@ def trtllm_quant_fp8_moe_fused_fake( gemm1_dequant: torch.Tensor, gemm2_act_quant: torch.Tensor, gemm2_dequant: torch.Tensor, - mlp_style: str, - act_fn: str, + is_gated_mlp: bool, + act_fn: int, ) -> torch.Tensor: - _validate_mlp_style_and_act_fn(mlp_style, act_fn) + _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) return torch.empty_like(x) @@ -268,8 +258,8 @@ def trtllm_quant_nvfp4_moe_fused( fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8)) fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8)) - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP. @@ -285,22 +275,22 @@ def trtllm_quant_nvfp4_moe_fused( """ NVFP4_BLOCK_SIZE = 16 - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() activation_type = ActivationType.Swiglu - if mlp_style == "gated_mlp": - if act_fn == "silu": + if is_gated_mlp: + if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu else: - raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") - elif mlp_style == "mlp": - if act_fn == "relu2": + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + ) + else: + if act_fn == ActivationType.Relu2: activation_type = ActivationType.Relu2 else: - raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." + ) # quant_scales is described by this code: # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 @@ -353,7 +343,7 @@ def trtllm_quant_nvfp4_moe_fused_fake( fc2_act_global_scale: torch.Tensor, fc1_alpha: torch.Tensor, fc2_alpha: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py index e4f73cb465..588eb82c33 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py @@ -13,6 +13,7 @@ from transformers.modeling_outputs import CausalLMOutput, MoeModelOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding +from tensorrt_llm._torch.utils import ActivationType from tensorrt_llm.inputs.utils import HF_CHAT_TEMPLATE_EXCEPTIONS from ..nemotron_flash import NemotronFlashForCausalLMFactory @@ -182,6 +183,8 @@ class DeltaNet(nn.Module): self.qk_activation = qk_activation self.qk_norm = qk_norm + # can't use ActivationType enum here, + # because there is no Elu defined in cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h assert self.qk_activation in ["silu", "relu", "elu", "identity"] assert self.qk_norm in ["l2", "sum"] @@ -331,7 +334,7 @@ class NemotronFlashMamba2(nn.Module): self.num_heads = self.d_inner // self.headdim self.rmsnorm = rmsnorm self.dt_limit = dt_limit - self.activation = "silu" + self.activation = ActivationType.Silu self.chunk_size = chunk_size self.layer_idx = layer_idx diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py index 3756c054f7..15178b00f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py @@ -33,6 +33,7 @@ from transformers.utils import ModelOutput from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType class MambaRMSNormGated(torch.nn.Module): @@ -308,8 +309,8 @@ class NemotronHMOE(nn.Module): w1_weight=[e.up_proj.weight for e in self.experts], w2_weight=[e.down_proj.weight for e in self.experts], w3_weight=[], - act_fn="relu2", - mlp_style="mlp", + act_fn=ActivationType.Relu2, + is_gated_mlp=False, ) if has_latent_proj: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index af0865c183..754068f442 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -5,6 +5,8 @@ import torch from pydantic import Field from torch.fx import GraphModule, Node +from tensorrt_llm._torch.utils import ActivationType + from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker @@ -70,20 +72,20 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t except (AttributeError, KeyError): pass + (is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn") + if is_stacked_moe: # Stacked MoE (Llama4 pattern): only supports gated MLP - (act_fn_val,) = extract_op_args(node, "act_fn") _process_llama4_stacked_moe_node( - gm, graph, node, replacement_op, act_fn_val, fused_key_counter + gm, graph, node, replacement_op, act_fn, fused_key_counter ) else: # Standard MoE with per-expert weight lists - (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") - assert backend != "triton" or mlp_style_val == "mlp", ( + assert backend != "triton" or not is_gated_mlp, ( "Triton backend only supports mlp style." ) _process_regular_moe_node( - gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter + gm, graph, node, replacement_op, is_gated_mlp, act_fn, fused_key_counter ) fused_key_counter += 1 @@ -102,8 +104,8 @@ def _process_regular_moe_node( graph: torch.fx.Graph, node: Node, replacement_op, - mlp_style_val: str, - act_fn_val: str, + is_gated_mlp: bool, + act_fn: ActivationType, fused_key_counter: int, ) -> None: """Process a single torch_moe node with per-expert weight lists. @@ -122,7 +124,7 @@ def _process_regular_moe_node( ) # Stack weights based on MLP style - if mlp_style_val == "gated_mlp": + if is_gated_mlp: # For gated MLP, concatenate w3 and w1 then stack across experts fused_w_up_experts = torch.stack( [ @@ -135,12 +137,10 @@ def _process_regular_moe_node( dim=0, ) new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}" - elif mlp_style_val == "mlp": + else: # For regular MLP, just stack w1 fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" - else: - raise ValueError(f"Unknown mlp_style: {mlp_style_val}") # Stack w2/down weights fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) @@ -162,8 +162,8 @@ def _process_regular_moe_node( replacement_op, args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg), kwargs={ - "mlp_style": mlp_style_val, - "act_fn": act_fn_val, + "is_gated_mlp": is_gated_mlp, + "act_fn": act_fn, }, ) @@ -176,7 +176,7 @@ def _process_llama4_stacked_moe_node( graph: torch.fx.Graph, node: Node, replacement_op, - act_fn_val: str, + act_fn: ActivationType, fused_key_counter: int, ) -> None: """Process a single Llama4 MoE node with pre-stacked weight tensors. @@ -301,7 +301,8 @@ def _process_llama4_stacked_moe_node( replacement_op, args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg), kwargs={ - "act_fn": act_fn_val, + "act_fn": act_fn, + "is_gated_mlp": True, }, ) @@ -1240,7 +1241,7 @@ class MatchBmmMoePattern(BaseTransform): w3_list_node, ), kwargs={ - "mlp_style": "gated_mlp", + "is_gated_mlp": True, "apply_routing_on_input": apply_routing_on_input, }, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py index 9dab55102e..f145ac5c5e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Tuple import torch from torch.fx import GraphModule +from tensorrt_llm._torch.utils import ActivationType + from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger @@ -123,8 +125,8 @@ def trtllm_moe_fused_aux( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: device = torch.cuda.current_device() with torch.cuda.stream( @@ -137,7 +139,7 @@ def trtllm_moe_fused_aux( routing_weights, w3_w1_stacked_weight, w2_stacked_weight, - mlp_style, + is_gated_mlp, act_fn, ) torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME) @@ -152,8 +154,8 @@ def trtllm_moe_fused_aux_fake( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) @@ -213,8 +215,8 @@ def trtllm_quant_fp8_moe_fused_aux( gemm1_dequant: torch.Tensor, # [E] gemm2_act_quant: torch.Tensor, # [E] gemm2_dequant: torch.Tensor, # [E] - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: device = torch.cuda.current_device() with torch.cuda.stream( @@ -237,7 +239,7 @@ def trtllm_quant_fp8_moe_fused_aux( gemm1_dequant, gemm2_act_quant, gemm2_dequant, - mlp_style, + is_gated_mlp, act_fn, ) torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME) @@ -262,8 +264,8 @@ def trtllm_quant_fp8_moe_fused_aux_fake( gemm1_dequant: torch.Tensor, gemm2_act_quant: torch.Tensor, gemm2_dequant: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py index a881c72fd7..d05c12825b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn from torch.fx import GraphModule, Node +from tensorrt_llm._torch.utils import ActivationType + from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op @@ -87,15 +89,15 @@ def _quantize_moe_node( s1, s2, s3 = collect_scales(idx) args.extend([s1, s2, s3]) - # Extract mlp_style and act_fn from the original node + # Extract is_gated_mlp and act_fn from the original node # These can be in args[6:] or in kwargs - mlp_style = "gated_mlp" # default - act_fn = "silu" # default + is_gated_mlp = True # default + act_fn = ActivationType.Silu # default if len(node.args) > 6: - mlp_style = node.args[6] - elif "mlp_style" in node.kwargs: - mlp_style = node.kwargs["mlp_style"] + is_gated_mlp = node.args[6] + elif "is_gated_mlp" in node.kwargs: + is_gated_mlp = node.kwargs["is_gated_mlp"] if len(node.args) > 7: act_fn = node.args[7] @@ -104,7 +106,7 @@ def _quantize_moe_node( # Prepare kwargs for the quantized op kwargs = { - "mlp_style": mlp_style, + "is_gated_mlp": is_gated_mlp, "act_fn": act_fn, } diff --git a/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py b/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py index 54cf23d6cb..6e3a6415b9 100644 --- a/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py +++ b/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py @@ -6,6 +6,7 @@ import torch import tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.torch_moe # noqa: F401 import tensorrt_llm._torch.custom_ops.torch_custom_ops as trt_ops # noqa: F401 +from tensorrt_llm._torch.utils import ActivationType def test_flashinfer_fused_moe_matches_torch_moe(): @@ -75,8 +76,8 @@ def test_flashinfer_fused_moe_matches_torch_moe(): w1_weight=w1_list, # gate projection w2_weight=w2_list, # down projection w3_weight=w3_list, # up projection - mlp_style="gated_mlp", - act_fn="silu", + is_gated_mlp=True, + act_fn=int(ActivationType.Silu), ) # Compare outputs diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 2d5e0bd8a5..8c034799ad 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -186,7 +186,7 @@ def test_llama4_stacked_moe_pattern_detection(): moe_node = graph.call_function( torch.ops.auto_deploy.torch_moe, args=(x, selected_experts, routing_weights, w1_list, w2_list, w3_list), - kwargs={"mlp_style": "gated_mlp", "apply_routing_on_input": True}, + kwargs={"is_gated_mlp": True, "apply_routing_on_input": True}, ) graph.output(moe_node) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index 99fccfab30..09b7d65eef 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -7,6 +7,7 @@ from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_availab import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401 +from tensorrt_llm._torch.utils import ActivationType def setup_moe_test(dtype, num_experts): @@ -173,8 +174,8 @@ def test_bmm_based_moe_op_run(dtype): [fused_w3_w1_stacked_weight], # Wrap in list for unified interface [fused_w2_weight], # Wrap in list for unified interface [], # Empty w3_weight list for stacked gated MLP - mlp_style="gated_mlp", - act_fn="silu", + is_gated_mlp=True, + act_fn=ActivationType.Silu, apply_routing_on_input=True, ) output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index c9aea8bc60..e6cf60b157 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -82,7 +82,7 @@ def compute_with_experts( alpha=None, beta=None, limit=None, - activation_func="silu", + activation_func: ActivationType = ActivationType.Silu, ): def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) @@ -110,7 +110,7 @@ def compute_with_experts( inter = x1_scaled * x2 else: - if activation_func == "swiglu" or activation_func == "silu": + if activation_func == ActivationType.Swiglu or activation_func == ActivationType.Silu: inter = F.silu(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) else: inter = relu2(expert_inputs @ w1_expert.t()) @@ -136,10 +136,6 @@ def _get_test_data( return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales -def _activation_type_from_str(activation_func: str) -> ActivationType: - return ActivationType.Swiglu if activation_func in ["swiglu", "silu"] else ActivationType.Relu2 - - def _print_diff_if( condition: Callable[[torch.Tensor], bool], diff: torch.Tensor, @@ -183,7 +179,7 @@ F16_TEST_DTYPES = [ @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES) -@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @skip_pre_hopper def test_trtllm_fused_moe( batch_size, @@ -201,13 +197,13 @@ def test_trtllm_fused_moe( pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") torch.manual_seed(42) - if activation_func in ["swiglu", "silu"]: + if activation_func in [ActivationType.Swiglu, ActivationType.Silu]: X_GEN_SCALE = 1.0 else: X_GEN_SCALE = 0.5 W_GEN_SCALE = 0.1 - x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data( + x, router_logits, w31_weight, w2_weight, _, _ = _get_test_data( otype, wtype, batch_size, @@ -239,19 +235,17 @@ def test_trtllm_fused_moe( "F16 test only supports bfloat16 or float16" ) - activation_type = _activation_type_from_str(activation_func) - def get_fc1_expert_weights( - activation_func: str, w31_weight: torch.Tensor, w1_weight: torch.Tensor + activation_func: ActivationType, w31_weight: torch.Tensor, w1_weight: torch.Tensor ) -> torch.Tensor: - if activation_func == "relu2": + if activation_func == ActivationType.Relu2: return w1_weight.contiguous() else: return w31_weight # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) _, w1_weight = torch.chunk(w31_weight, 2, dim=1) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True torch.cuda.synchronize() ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused( @@ -260,9 +254,13 @@ def test_trtllm_fused_moe( routing_weights, w3_w1_stacked_weight=get_fc1_expert_weights(activation_func, w31_weight, w1_weight), w2_stacked_weight=w2_weight, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) + # Convert ActivationType.Silu to ActivationType.Swiglu for C++ op compatibility + cpp_activation_type = ( + ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func + ) trtllm_test_output = torch.ops.trtllm.fused_moe( x, selected_experts.to(torch.int), @@ -273,11 +271,11 @@ def test_trtllm_fused_moe( fc2_expert_biases=None, output_dtype=otype, quant_scales=[], - activation_type=activation_type, + activation_type=cpp_activation_type, )[0].view(x.shape) torch.cuda.synchronize() - if mlp_style == "mlp": + if not is_gated_mlp: with torch.inference_mode(): output_triton_moe = torch.ops.auto_deploy.triton_moe_fused( x, @@ -285,6 +283,7 @@ def test_trtllm_fused_moe( routing_weights, w1_weight.contiguous(), w2_weight.contiguous(), + is_gated_mlp=False, )[0].view(x.shape) torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-2, atol=1e-2) @@ -308,7 +307,7 @@ FP8_TEST_DTYPES = [ @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES) -@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @pytest.mark.skipif( not fp8_compatible() or not trtllm_ops_available(), reason="Requires fp8 and trtllm support", @@ -336,7 +335,7 @@ def test_trtllm_fused_moe_fp8( ) torch.manual_seed(42) - if activation_func in ["swiglu", "silu"]: + if activation_func in [ActivationType.Swiglu, ActivationType.Silu]: X_GEN_SCALE = 1.0 else: X_GEN_SCALE = 0.5 @@ -399,7 +398,7 @@ def test_trtllm_fused_moe_fp8( # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True # compute quant_scales gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32) @@ -424,13 +423,13 @@ def test_trtllm_fused_moe_fp8( gemm1_dequant=gemm1_dequant, gemm2_act_quant=gemm2_act_quant, gemm2_dequant=gemm2_dequant, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) torch.cuda.synchronize() - if mlp_style == "mlp": + if not is_gated_mlp: with torch.inference_mode(): output_triton_fp8_moe = torch.ops.auto_deploy.triton_quant_fp8_moe( x, @@ -445,7 +444,7 @@ def test_trtllm_fused_moe_fp8( w1_scales, w2_scales, w3_scales, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) torch.testing.assert_close(output_triton_fp8_moe, ref_output, rtol=1e-1, atol=1e-1) @@ -569,7 +568,7 @@ NVFP4_TEST_DTYPES = [ @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) -@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @pytest.mark.skipif( not fp4_compatible() or not trtllm_ops_available(), reason="Requires fp4 and trtllm support", @@ -693,25 +692,23 @@ def test_trtllm_fused_moe_nvfp4( fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" - if mlp_style == "gated_mlp": + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True + if is_gated_mlp: # For gated MLP, concatenate w1 and w3 as [w3, w1] fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) fc1_weight_gs = torch.max(w3_gs, w1_gs) - if activation_func != "silu": + if activation_func != ActivationType.Silu: raise ValueError( f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'." ) - elif mlp_style == "mlp": + else: # For non-gated MLP with ReLU^2 fc1_expert_weights_fp4 = w1_q_fp4 fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) fc1_weight_gs = w1_gs - if activation_func != "relu2": + if activation_func != ActivationType.Relu2: raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long) fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long) @@ -729,7 +726,7 @@ def test_trtllm_fused_moe_nvfp4( fc2_activation_gs, fc1_alpha, fc2_alpha, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) @@ -747,8 +744,7 @@ def test_trtllm_fused_moe_nvfp4( block_size=NVFP4_BLOCK_SIZE, ) - concat_w3_w1 = mlp_style == "gated_mlp" - if concat_w3_w1: + if is_gated_mlp: w1_gs = w3_gs = torch.max(w1_gs, w3_gs) w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype) @@ -782,14 +778,18 @@ def test_trtllm_fused_moe_nvfp4( block_size=NVFP4_BLOCK_SIZE, ) + # Convert ActivationType.Silu to ActivationType.Swiglu for reference op compatibility + resolved_activation_type = ( + ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func + ) ref_output = torch_moe_nvfp4( x_dq, - torch.cat([w3_dq, w1_dq], dim=1) if concat_w3_w1 else w1_dq, + torch.cat([w3_dq, w1_dq], dim=1) if is_gated_mlp else w1_dq, w2_dq, top_k, routing_weights, selected_experts, - _activation_type_from_str(activation_func), + resolved_activation_type, ) return ref_output diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py index c639c355e8..490eb1d742 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py @@ -4,6 +4,7 @@ from utils.util import skip_pre_hopper import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size +from tensorrt_llm._torch.utils import ActivationType # noqa: F401 def _pack_routed_tokens_reference( @@ -131,6 +132,7 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit): routing_weights, w_up_stacked, w_down_stacked, + is_gated_mlp=False, ) # Reference Torch MoE in mlp mode with relu2 activation @@ -141,8 +143,8 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit): w1_weight=w_up_list, w2_weight=w_down_list, w3_weight=[], - mlp_style="mlp", - act_fn="relu2", + is_gated_mlp=False, + act_fn=ActivationType.Relu2, ) torch.testing.assert_close(out_triton, out_torch, rtol=5e-2, atol=5e-2) @@ -364,8 +366,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): w1_weight_scale, w2_weight_scale, w3_weight_scale_tensor, - mlp_style="mlp", - act_fn="relu2", + is_gated_mlp=False, + act_fn=ActivationType.Relu2, ) # Reference: Torch quantized FP8 MoE (uses lists of tensors and scales) @@ -382,8 +384,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): w1_weight_scale=w1_weight_scale_list, w2_weight_scale=w2_weight_scale_list, w3_weight_scale=w3_weight_scale_list, - mlp_style="mlp", - act_fn="relu2", + is_gated_mlp=False, + act_fn=ActivationType.Relu2, ) torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)