diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 1ee535bdbd..f9dd3377bd 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ int thread_idx = ThreadingPolicy::offset(); int local_token_idx = ThreadingPolicy::token_idx(); - if (local_token_idx >= local_num_tokens) + if (local_num_tokens == 0) { - return; - } - - // Prepare per-policy shared-memory tiles for this token - extern __shared__ int smem[]; - int* smem_topk_target_ranks; - int* smem_topk_send_indices; - int warps_per_block = blockDim.x / warpSize; - if constexpr (std::is_same::value) - { - int lane_id = threadIdx.x / warpSize; - smem_topk_target_ranks = smem + lane_id * TOP_K; - smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; + // Special case: If local_num_tokens == 0, + // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization. + // Other threads should return. + if (local_token_idx > 0) + return; } else { - smem_topk_target_ranks = smem; - smem_topk_send_indices = smem + TOP_K; - } + // Threads that do not have a token to process should return. + if (local_token_idx >= local_num_tokens) + return; - uint64_t already_copied = 0; - for (int k = 0; k < TOP_K; k++) - { - int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; - // Use contiguous partitioning to determine target rank - int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); - - if (already_copied & (1ULL << target_rank)) + // Prepare per-policy shared-memory tiles for this token + extern __shared__ int smem[]; + int* smem_topk_target_ranks; + int* smem_topk_send_indices; + int warps_per_block = blockDim.x / warpSize; + if constexpr (std::is_same::value) { + int lane_id = threadIdx.x / warpSize; + smem_topk_target_ranks = smem + lane_id * TOP_K; + smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; + } + else + { + smem_topk_target_ranks = smem; + smem_topk_send_indices = smem + TOP_K; + } + + uint64_t already_copied = 0; + for (int k = 0; k < TOP_K; k++) + { + int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; + // Use contiguous partitioning to determine target rank + int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); + + if (already_copied & (1ULL << target_rank)) + { + if (thread_idx == 0) + { + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + // Mirror to shared memory immediately + smem_topk_target_ranks[k] = -1; + smem_topk_send_indices[k] = -1; + } + continue; + } + + // Only one thread per warp should increment the counter + int dst_token_idx; if (thread_idx == 0) { - ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; - ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); + + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; // Mirror to shared memory immediately - smem_topk_target_ranks[k] = -1; - smem_topk_send_indices[k] = -1; + smem_topk_target_ranks[k] = target_rank; + smem_topk_send_indices[k] = dst_token_idx; } - continue; + already_copied |= 1ULL << target_rank; } + // Sync before dispatching data + ThreadingPolicy::sync(); - // Only one thread per warp should increment the counter - int dst_token_idx; - if (thread_idx == 0) - { - dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); - - ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; - ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; - // Mirror to shared memory immediately - smem_topk_target_ranks[k] = target_rank; - smem_topk_send_indices[k] = dst_token_idx; - } - already_copied |= 1ULL << target_rank; - } - // Sync before dispatching data - ThreadingPolicy::sync(); - - // Read staged routing once into registers per thread - int topk_target_ranks[TOP_K]; - int topk_send_indices[TOP_K]; + // Read staged routing once into registers per thread + int topk_target_ranks[TOP_K]; + int topk_send_indices[TOP_K]; #pragma unroll - for (int k = 0; k < TOP_K; ++k) - { - topk_target_ranks[k] = smem_topk_target_ranks[k]; - topk_send_indices[k] = smem_topk_send_indices[k]; + for (int k = 0; k < TOP_K; ++k) + { + topk_target_ranks[k] = smem_topk_target_ranks[k]; + topk_send_indices[k] = smem_topk_send_indices[k]; + } + + // Perform a single source load and TOP_K fanout per payload + for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) + { + uint8_t const* src_data = static_cast(ptrs.src_data_ptrs[payload_idx]); + int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; + uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; + + vectorized_dispatch(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, + payload_idx, ptrs, topk_target_ranks, topk_send_indices); + } + + ThreadingPolicy::sync(); } - // Perform a single source load and TOP_K fanout per payload - for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) - { - uint8_t const* src_data = static_cast(ptrs.src_data_ptrs[payload_idx]); - int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; - uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; - - vectorized_dispatch(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx, - ptrs, topk_target_ranks, topk_send_indices); - } - - ThreadingPolicy::sync(); - bool is_first_warp = threadIdx.x / warpSize == 0; if (is_first_warp) { @@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [ bool is_last_token = false; if (lane_id == 0) { - int cnt = atomicAdd(ptrs.local_token_counter, 1); - is_last_token = cnt + 1 == local_num_tokens; + if (local_num_tokens != 0) + { + int cnt = atomicAdd(ptrs.local_token_counter, 1); + is_last_token = cnt + 1 == local_num_tokens; + } + else + { + is_last_token = true; + } } is_last_token = __shfl_sync(0xffffffff, is_last_token, 0); @@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) // Validate parameters TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); - TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.local_num_tokens >= 0); TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads); // Prepare kernel pointers struct @@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) if (params.one_block_per_token) { int grid_size = params.local_num_tokens; + // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. + if (grid_size == 0) + { + grid_size = 1; + } int shared_bytes = 2 * params.top_k * (int) sizeof(int); SWITCH_TOP_K(params.top_k, TOP_K, moeA2ADispatchKernel<<>>( @@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) else { int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. + if (grid_size == 0) + { + grid_size = 1; + } int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int); SWITCH_TOP_K(params.top_k, TOP_K, moeA2ADispatchKernel<<>>( @@ -897,9 +924,19 @@ __global__ void moeA2ACombineKernel( int local_token_idx = ThreadingPolicy::token_idx(); int const size_per_token = elements_per_token * sizeof(T); - if (local_token_idx >= local_num_tokens) + if (local_num_tokens == 0) { - return; + // Special case: If local_num_tokens == 0, + // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization. + // Other threads should return. + if (local_token_idx > 0) + return; + } + else + { + // Threads that do not have a token to process should return. + if (local_token_idx >= local_num_tokens) + return; } #if !DISABLE_SYNC_FOR_PROFILING @@ -951,6 +988,9 @@ __global__ void moeA2ACombineKernel( __syncthreads(); #endif + if (local_num_tokens == 0) + return; + // Get output location for this token (using src_data_ptrs[0] as output) T* token_output = static_cast(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token; @@ -1003,7 +1043,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) // Validate parameters TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); - TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.local_num_tokens >= 0); TLLM_CHECK(params.elements_per_token > 0); // Configure kernel launch @@ -1011,6 +1051,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params) int const kWarpsPerBlock = kBlockSize / 32; // warpSize int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock); int grid_size_block = params.local_num_tokens; + // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization. + if (grid_size_warp == 0) + { + grid_size_warp = 1; + } + if (grid_size_block == 0) + { + grid_size_block = 1; + } // Prepare kernel pointers struct for combine CombineKernelPointers kernel_ptrs = {}; // Zero-initialize diff --git a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp index e11135ddfb..af6d7cb37d 100644 --- a/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp +++ b/cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp @@ -186,7 +186,6 @@ std::tuple, int64_t> moeA2ADispatchOp(torch::Tensor c MoeA2ADataOffsets const& offsets = *reinterpret_cast(metainfo.data_ptr()); int64_t localNumTokens = tokenSelectedExperts.size(0); - TORCH_CHECK(localNumTokens > 0, "localNumTokens must be positive"); TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive"); TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)"); TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]"); diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index 12b065c5e7..09e2f7ef7d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -1,29 +1,29 @@ -from typing import Callable, List, Optional +from typing import Callable, List import torch import torch.nn.functional as F +from tensorrt_llm._torch.utils import ActivationType -def _resolve_activation(name: Optional[str]) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Returns an elementwise activation callable matching the given name. - Supported: "silu", "relu2". - Defaults to SiLU when name is None or empty. - """ - if not name: - name = "silu" - key = name.lower() - if key == "silu": - return F.silu - elif key == "relu2": +def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]: + """ + Returns an elementwise activation callable matching the given activation function. + Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2 + """ + assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], ( + f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'." + ) + torch_fn = None + if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu: + torch_fn = F.silu + elif act_fn == ActivationType.Relu2: def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) - return relu2 - else: - raise ValueError(f"Unsupported activation '{name}'. Use one of: silu, relu2.") + torch_fn = relu2 + return torch_fn def _template_moe( @@ -94,8 +94,8 @@ def torch_moe( w1_weight: List[torch.Tensor], w2_weight: List[torch.Tensor], w3_weight: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), apply_routing_on_input: bool = False, ) -> torch.Tensor: """ @@ -117,8 +117,8 @@ def torch_moe( - Llama4 MoE: sigmoid activated weights w1_weight: For per-expert lists: - • mlp_style=="gated_mlp": List of W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": List of W_up with shape (I, H) — up projection. + • is_gated_mlp==True: List of W1 with shape (I, H) — "gate" projection. + • is_gated_mlp==False: List of W_up with shape (I, H) — up projection. For stacked tensors (Llama4): • Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format w2_weight: @@ -129,17 +129,17 @@ def torch_moe( w3_weight: For per-expert lists with gated_mlp: • List of W3 with shape (I, H) — "up" (second) projection in gated MLP. - For mlp style or stacked tensors: + For is_gated_mlp==False or stacked tensors: • pass an empty list []; ignored. - mlp_style: + is_gated_mlp: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek/Llama4-style): + • is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style): y = W2( act(W1 x) * (W3 x) ) - • "mlp" (NemotronH-style 2-layer MLP): + • is_gated_mlp==False (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. - Supported: "silu" (default), "relu2" (ReLU then square). + Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square). apply_routing_on_input: If True (Llama4 pattern): multiply routing weights with INPUT before MLP Result: act(input * routing_weight) - routing affects activation @@ -148,55 +148,63 @@ def torch_moe( Returns: torch.Tensor: Output tensor with the same shape as the input x. """ - act_fn = _resolve_activation(act_fn) - style = mlp_style.lower() + torch_act_fn = _resolve_torch_fn(act_fn) # Detect if using stacked tensor format (Llama4) vs per-expert lists (standard) is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3 + # Todo: either change torch_moe to use a single condition, or refactor this code. + # it should be : + # is_gated_mlp: + # stacked: + # ... + # not stacked: + # . + # else: + # assert (not stacked) + # ... + # . if is_stacked: # Llama4 stacked tensor format - only supports gated_mlp - if style != "gated_mlp": - raise ValueError("Stacked tensor format only supports 'gated_mlp' style") + if not is_gated_mlp: + raise ValueError("Stacked tensor format only supports gated MLP style") w3_w1_stacked = w1_weight[0] # (E, 2*I, H) + intermediate_size = w3_w1_stacked.shape[1] // 2 w2_stacked = w2_weight[0] # (E, H, I) - def make_mlp(i: int): - gate_up = w3_w1_stacked[i] # (2*I, H) - intermediate_size = gate_up.shape[0] // 2 + def make_mlp(idx: int): + gate_up = w3_w1_stacked[idx] # (2*I, H) W3 = gate_up[:intermediate_size, :] # (I, H) W1 = gate_up[intermediate_size:, :] # (I, H) - W2 = w2_stacked[i] # (H, I) + W2 = w2_stacked[idx] # (H, I) weight_dtype = W1.dtype return lambda inp: F.linear( - act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3), + torch_act_fn(F.linear(inp.to(weight_dtype), W1)) + * F.linear(inp.to(weight_dtype), W3), W2, ) - mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])] + mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])] - elif style == "gated_mlp": + elif is_gated_mlp: # Standard per-expert list format with gated MLP def make_mlp(i: int): W1 = w1_weight[i] # (I, H) W2 = w2_weight[i] # (H, I) W3 = w3_weight[i] # (I, H) - return lambda inp: F.linear(act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2) - - mlps = [make_mlp(i) for i in range(len(w1_weight))] - - elif style == "mlp": - # Standard per-expert list format with simple MLP - def make_mlp(i: int): - W_up = w1_weight[i] # (I, H) - W_down = w2_weight[i] # (H, I) - return lambda inp: F.linear(act_fn(F.linear(inp, W_up)), W_down) + return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2) mlps = [make_mlp(i) for i in range(len(w1_weight))] else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + # Standard per-expert list format with simple MLP + def make_mlp(i: int): + W_up = w1_weight[i] # (I, H) + W_down = w2_weight[i] # (H, I) + return lambda inp: F.linear(torch_act_fn(F.linear(inp, W_up)), W_down) + + mlps = [make_mlp(i) for i in range(len(w1_weight))] return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input) @@ -209,8 +217,8 @@ def torch_moe_fake( w1_weight: List[torch.Tensor], w2_weight: List[torch.Tensor], w3_weight: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), apply_routing_on_input: bool = False, ) -> torch.Tensor: return torch.empty_like(x) @@ -296,23 +304,20 @@ def torch_quant_fp8_moe( w1_weight_scale: List[torch.Tensor], w2_weight_scale: List[torch.Tensor], w3_weight_scale: List[torch.Tensor], - mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp" - act_fn: str = "silu", # silu or relu2 + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """ - FP8 MoE op using quantized linear operations. - - Computes a Mixture-of-Experts layer similar to the reference auto_deploy::torch_moe op, but uses the - quantized FP8 linear op for expert computations. + FP8 MoE op using quantized linear operations. Computes a Mixture-of-Experts layer similar to the reference + auto_deploy::torch_moe op, but uses the quantized FP8 linear op for expert computations. Args: x: Input tensor of shape (B, H) or (B, S, H). - selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices. - routing_weights: Tensor of normalized routing weights. - w1_weight: - List of per-expert weight tensors: - • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": W_up with shape (I, H) — up projection. + selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) + containing expert indices.routing_weights: Tensor of normalized routing weights. + w1_weight: List of per-expert weight tensors: + • is_gated_mlp==True: W1 with shape (I, H) — "gate" projection. + • is_gated_mlp==False: W_up with shape (I, H) — up projection. w2_weight: List of per-expert weight tensors: • gated_mlp: W2 with shape (H, I) — down projection. @@ -323,21 +328,20 @@ def torch_quant_fp8_moe( • mlp: pass an empty list []; ignored. w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops. w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops. - mlp_style: + is_gated_mlp: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek-style): + • is_gated_mlp==True (default, Mixtral/DeepSeek-style): y = W2( act(W1 x) * (W3 x) ) - • "mlp" (NemotronH-style 2-layer MLP): + • is_gated_mlp==False (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. - Supported: "silu" (default), "relu2" (ReLU then square). + Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square). """ - act_fn = _resolve_activation(act_fn) - style = mlp_style.lower() + torch_act_fn = _resolve_torch_fn(act_fn) - if style == "gated_mlp": + if is_gated_mlp: def make_fp8_mlp(i): def mlp(inp): @@ -355,7 +359,7 @@ def torch_quant_fp8_moe( input_scale=w3_input_scale[i], weight_scale=w3_weight_scale[i], ) - prod = act_fn(gate_out) * up_out + prod = torch_act_fn(gate_out) * up_out return torch.ops.auto_deploy.torch_quant_fp8_linear( prod, w2_weight[i], @@ -368,7 +372,7 @@ def torch_quant_fp8_moe( mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] - elif style == "mlp": + else: def make_fp8_mlp(i): def mlp(inp): @@ -380,7 +384,7 @@ def torch_quant_fp8_moe( weight_scale=w1_weight_scale[i], ) return torch.ops.auto_deploy.torch_quant_fp8_linear( - act_fn(up_out), + torch_act_fn(up_out), w2_weight[i], bias=None, input_scale=w2_input_scale[i], @@ -391,9 +395,6 @@ def torch_quant_fp8_moe( mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))] - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - return _template_moe(x, selected_experts, routing_weights, mlps) @@ -411,8 +412,8 @@ def torch_quant_fp8_moe_fake( w1_weight_scale: List[torch.Tensor], w2_weight_scale: List[torch.Tensor], w3_weight_scale: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) @@ -434,8 +435,8 @@ def torch_quant_nvfp4_moe( w1_alpha: List[torch.Tensor], w2_alpha: List[torch.Tensor], w3_alpha: List[torch.Tensor], - mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp" - act_fn: str = "silu", # silu or relu2 + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """ FP4 MoE op using quantized linear operations. @@ -449,8 +450,8 @@ def torch_quant_nvfp4_moe( routing_weights: Tensor of normalized routing weights. w1_weight: List of per-expert weight tensors: - • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection. - • mlp_style=="mlp": W_up with shape (I, H) — up projection. + • is_gated_mlp==True: W1 with shape (I, H) — "gate" projection. + • is_gated_mlp==False: W_up with shape (I, H) — up projection. w2_weight: List of per-expert weight tensors: • gated_mlp: W2 with shape (H, I) — down projection. @@ -462,21 +463,20 @@ def torch_quant_nvfp4_moe( w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors. w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors. w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization. - mlp_style: + is_gated_mlp: Selects the per-expert MLP computation: - • "gated_mlp" (default, Mixtral/DeepSeek-style): + • is_gated_mlp==True (default, Mixtral/DeepSeek-style): y = W2( act(W1 x) * (W3 x) ) - • "mlp" (NemotronH-style 2-layer MLP): + • is_gated_mlp==False (NemotronH-style 2-layer MLP): y = W_down( act(W_up x) ) act_fn: Elementwise activation applied inside the expert MLP. - Supported: "silu" (default), "relu2" (ReLU then square). + Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square). """ - act_fn = _resolve_activation(act_fn) - style = mlp_style.lower() + torch_act_fn = _resolve_torch_fn(act_fn) - if style == "gated_mlp": + if is_gated_mlp: def make_fp4_mlp(i): def mlp(inp): @@ -498,7 +498,7 @@ def torch_quant_nvfp4_moe( weight_scale=w3_weight_scale[i], alpha=w3_alpha[i], ) - prod = act_fn(gate_out) * up_out + prod = torch_act_fn(gate_out) * up_out return torch.ops.auto_deploy.torch_quant_nvfp4_linear( prod, w2_weight[i], @@ -512,7 +512,7 @@ def torch_quant_nvfp4_moe( mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] - elif style == "mlp": + else: def make_fp4_mlp(i): def mlp(inp): @@ -527,7 +527,7 @@ def torch_quant_nvfp4_moe( alpha=w1_alpha[i], ) return torch.ops.auto_deploy.torch_quant_nvfp4_linear( - act_fn(up_out), + torch_act_fn(up_out), w2_weight[i], bias=None, input_scale=w2_input_scale[i], @@ -539,9 +539,6 @@ def torch_quant_nvfp4_moe( mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))] - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") - return _template_moe(x, selected_experts, routing_weights, mlps) @@ -562,8 +559,8 @@ def torch_quant_nvfp4_moe_fake( w1_alpha: List[torch.Tensor], w2_alpha: List[torch.Tensor], w3_alpha: List[torch.Tensor], - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py index 9dcf544393..d33b752532 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py @@ -14,6 +14,8 @@ import torch.nn.functional as F import triton import triton.language as tl +from tensorrt_llm._torch.utils import ActivationType # noqa: F401 + from ...utils.logger import ad_logger @@ -601,15 +603,13 @@ def triton_fused_moe( routing_weights: torch.Tensor, w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "mlp", - act_fn: str = "relu2", + is_gated_mlp: bool = False, + act_fn: int = int(ActivationType.Relu2), ) -> torch.Tensor: """Triton unquantized MoE with 2-layer MLP and ReLU^2 activation.""" - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() - assert mlp_style == "mlp", "Triton backend only supports mlp style." - assert act_fn == "relu2", "Triton backend only supports relu2 activation." + assert not is_gated_mlp, "Triton backend only supports non gated MLP style." + assert act_fn == ActivationType.Relu2, "Triton backend only supports relu2 activation." x_shape = x.shape x2d = x.view(-1, x_shape[-1]) @@ -661,12 +661,12 @@ def triton_quant_fp8_moe( w1_weight_scale: torch.Tensor, # [E] stacked weight scales w2_weight_scale: torch.Tensor, # [E] stacked weight scales w3_weight_scale: torch.Tensor, # unused - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = False, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation.""" - if mlp_style != "mlp": - raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only") + if is_gated_mlp: + raise NotImplementedError("triton_quant_fp8_moe currently supports mlp only") x_shape = x.shape x2d = x.view(-1, x_shape[-1]) @@ -760,7 +760,7 @@ def triton_quant_fp8_moe( w1_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, w3_weight_scale: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = False, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index 827d47c44a..6fb5e560f3 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -26,8 +26,8 @@ def trtllm_moe_fused( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: x_shape = x.shape x = x.view(-1, x_shape[-1]) @@ -37,24 +37,24 @@ def trtllm_moe_fused( quant_scales = [] # Determine activation type - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() activation_type = ActivationType.Swiglu - if mlp_style == "gated_mlp": + if is_gated_mlp: # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) - if act_fn == "silu": + if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu else: - raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") - elif mlp_style == "mlp": + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + ) + else: # For non-gated MLP with ReLU^2 - if act_fn == "relu2": + if act_fn == ActivationType.Relu2: activation_type = ActivationType.Relu2 else: - raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." + ) return torch.ops.trtllm.fused_moe( x, @@ -77,8 +77,8 @@ def trtllm_moe_fused_fake( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) @@ -93,21 +93,12 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) -def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None: - supported_combinations = { - "gated_mlp": ["silu"], - "mlp": ["relu2"], - } - supported_act_fns = [ - act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list - ] - assert mlp_style in supported_combinations.keys(), ( - f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}." - ) - assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}." - assert act_fn in supported_combinations[mlp_style], ( - f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. " - f"Supported combinations: {supported_combinations}" +def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None: + assert (is_gated_mlp and act_fn == ActivationType.Silu) or ( + not is_gated_mlp and act_fn == ActivationType.Relu2 + ), ( + f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. " + f"Supported combinations: gated mlp with silu or mlp with relu2." ) @@ -128,8 +119,8 @@ def trtllm_quant_fp8_moe_fused( gemm1_dequant: torch.Tensor, # [E] gemm2_act_quant: torch.Tensor, # [E] gemm2_dequant: torch.Tensor, # [E] - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """ TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP. @@ -149,8 +140,8 @@ def trtllm_quant_fp8_moe_fused( gemm1_dequant: Precomputed gemm1 dequant scale [E] gemm2_act_quant: Precomputed gemm2 act quant scale [1] gemm2_dequant: Precomputed gemm2 dequant scale [E] - mlp_style: "gated_mlp" or "mlp" - act_fn: "silu" for gated_mlp, "relu2" for mlp + is_gated_mlp: True for gated_mlp, False for mlp + act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp Non-Gated MLP: activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t() @@ -159,7 +150,7 @@ def trtllm_quant_fp8_moe_fused( activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t() """ - _validate_mlp_style_and_act_fn(mlp_style, act_fn) + _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) # Store original shape and flatten to 2D x_shape = x.shape @@ -190,28 +181,27 @@ def trtllm_quant_fp8_moe_fused( # Todo: refactor this repeating code block # Determine activation type - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() - activation_type = ActivationType.Swiglu - if mlp_style == "gated_mlp": + if is_gated_mlp: # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) # For gated MLP, concatenate w1 and w3 as [w3, w1] w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H] fc1_expert_weights = w3_w1_stacked - if act_fn == "silu": + if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu else: - raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") - elif mlp_style == "mlp": + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + ) + else: # For non-gated MLP with ReLU^2 fc1_expert_weights = w1_weight.contiguous() - if act_fn == "relu2": + if act_fn == ActivationType.Relu2: activation_type = ActivationType.Relu2 else: - raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." + ) # Note! Outputting Float8_e4m3fn directly is not currently supported output = torch.ops.trtllm.fused_moe( @@ -248,10 +238,10 @@ def trtllm_quant_fp8_moe_fused_fake( gemm1_dequant: torch.Tensor, gemm2_act_quant: torch.Tensor, gemm2_dequant: torch.Tensor, - mlp_style: str, - act_fn: str, + is_gated_mlp: bool, + act_fn: int, ) -> torch.Tensor: - _validate_mlp_style_and_act_fn(mlp_style, act_fn) + _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) return torch.empty_like(x) @@ -268,8 +258,8 @@ def trtllm_quant_nvfp4_moe_fused( fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8)) fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8)) - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: """TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP. @@ -285,22 +275,22 @@ def trtllm_quant_nvfp4_moe_fused( """ NVFP4_BLOCK_SIZE = 16 - mlp_style = mlp_style.lower() - act_fn = act_fn.lower() activation_type = ActivationType.Swiglu - if mlp_style == "gated_mlp": - if act_fn == "silu": + if is_gated_mlp: + if act_fn in [ActivationType.Silu, ActivationType.Swiglu]: activation_type = ActivationType.Swiglu else: - raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") - elif mlp_style == "mlp": - if act_fn == "relu2": + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'." + ) + else: + if act_fn == ActivationType.Relu2: activation_type = ActivationType.Relu2 else: - raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + raise ValueError( + f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'." + ) # quant_scales is described by this code: # https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015 @@ -353,7 +343,7 @@ def trtllm_quant_nvfp4_moe_fused_fake( fc2_act_global_scale: torch.Tensor, fc1_alpha: torch.Tensor, fc2_alpha: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py index e4f73cb465..588eb82c33 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_flash.py @@ -13,6 +13,7 @@ from transformers.modeling_outputs import CausalLMOutput, MoeModelOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import BatchEncoding +from tensorrt_llm._torch.utils import ActivationType from tensorrt_llm.inputs.utils import HF_CHAT_TEMPLATE_EXCEPTIONS from ..nemotron_flash import NemotronFlashForCausalLMFactory @@ -182,6 +183,8 @@ class DeltaNet(nn.Module): self.qk_activation = qk_activation self.qk_norm = qk_norm + # can't use ActivationType enum here, + # because there is no Elu defined in cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h assert self.qk_activation in ["silu", "relu", "elu", "identity"] assert self.qk_norm in ["l2", "sum"] @@ -331,7 +334,7 @@ class NemotronFlashMamba2(nn.Module): self.num_heads = self.d_inner // self.headdim self.rmsnorm = rmsnorm self.dt_limit = dt_limit - self.activation = "silu" + self.activation = ActivationType.Silu self.chunk_size = chunk_size self.layer_idx = layer_idx diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py index 3756c054f7..15178b00f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py @@ -33,6 +33,7 @@ from transformers.utils import ModelOutput from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +from tensorrt_llm._torch.utils import ActivationType class MambaRMSNormGated(torch.nn.Module): @@ -308,8 +309,8 @@ class NemotronHMOE(nn.Module): w1_weight=[e.up_proj.weight for e in self.experts], w2_weight=[e.down_proj.weight for e in self.experts], w3_weight=[], - act_fn="relu2", - mlp_style="mlp", + act_fn=ActivationType.Relu2, + is_gated_mlp=False, ) if has_latent_proj: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index af0865c183..754068f442 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -5,6 +5,8 @@ import torch from pydantic import Field from torch.fx import GraphModule, Node +from tensorrt_llm._torch.utils import ActivationType + from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker @@ -70,20 +72,20 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t except (AttributeError, KeyError): pass + (is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn") + if is_stacked_moe: # Stacked MoE (Llama4 pattern): only supports gated MLP - (act_fn_val,) = extract_op_args(node, "act_fn") _process_llama4_stacked_moe_node( - gm, graph, node, replacement_op, act_fn_val, fused_key_counter + gm, graph, node, replacement_op, act_fn, fused_key_counter ) else: # Standard MoE with per-expert weight lists - (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") - assert backend != "triton" or mlp_style_val == "mlp", ( + assert backend != "triton" or not is_gated_mlp, ( "Triton backend only supports mlp style." ) _process_regular_moe_node( - gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter + gm, graph, node, replacement_op, is_gated_mlp, act_fn, fused_key_counter ) fused_key_counter += 1 @@ -102,8 +104,8 @@ def _process_regular_moe_node( graph: torch.fx.Graph, node: Node, replacement_op, - mlp_style_val: str, - act_fn_val: str, + is_gated_mlp: bool, + act_fn: ActivationType, fused_key_counter: int, ) -> None: """Process a single torch_moe node with per-expert weight lists. @@ -122,7 +124,7 @@ def _process_regular_moe_node( ) # Stack weights based on MLP style - if mlp_style_val == "gated_mlp": + if is_gated_mlp: # For gated MLP, concatenate w3 and w1 then stack across experts fused_w_up_experts = torch.stack( [ @@ -135,12 +137,10 @@ def _process_regular_moe_node( dim=0, ) new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}" - elif mlp_style_val == "mlp": + else: # For regular MLP, just stack w1 fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" - else: - raise ValueError(f"Unknown mlp_style: {mlp_style_val}") # Stack w2/down weights fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0) @@ -162,8 +162,8 @@ def _process_regular_moe_node( replacement_op, args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg), kwargs={ - "mlp_style": mlp_style_val, - "act_fn": act_fn_val, + "is_gated_mlp": is_gated_mlp, + "act_fn": act_fn, }, ) @@ -176,7 +176,7 @@ def _process_llama4_stacked_moe_node( graph: torch.fx.Graph, node: Node, replacement_op, - act_fn_val: str, + act_fn: ActivationType, fused_key_counter: int, ) -> None: """Process a single Llama4 MoE node with pre-stacked weight tensors. @@ -301,7 +301,8 @@ def _process_llama4_stacked_moe_node( replacement_op, args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg), kwargs={ - "act_fn": act_fn_val, + "act_fn": act_fn, + "is_gated_mlp": True, }, ) @@ -1240,7 +1241,7 @@ class MatchBmmMoePattern(BaseTransform): w3_list_node, ), kwargs={ - "mlp_style": "gated_mlp", + "is_gated_mlp": True, "apply_routing_on_input": apply_routing_on_input, }, ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py index 9dab55102e..f145ac5c5e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Tuple import torch from torch.fx import GraphModule +from tensorrt_llm._torch.utils import ActivationType + from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.logger import ad_logger @@ -123,8 +125,8 @@ def trtllm_moe_fused_aux( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: device = torch.cuda.current_device() with torch.cuda.stream( @@ -137,7 +139,7 @@ def trtllm_moe_fused_aux( routing_weights, w3_w1_stacked_weight, w2_stacked_weight, - mlp_style, + is_gated_mlp, act_fn, ) torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME) @@ -152,8 +154,8 @@ def trtllm_moe_fused_aux_fake( routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) @@ -213,8 +215,8 @@ def trtllm_quant_fp8_moe_fused_aux( gemm1_dequant: torch.Tensor, # [E] gemm2_act_quant: torch.Tensor, # [E] gemm2_dequant: torch.Tensor, # [E] - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: device = torch.cuda.current_device() with torch.cuda.stream( @@ -237,7 +239,7 @@ def trtllm_quant_fp8_moe_fused_aux( gemm1_dequant, gemm2_act_quant, gemm2_dequant, - mlp_style, + is_gated_mlp, act_fn, ) torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME) @@ -262,8 +264,8 @@ def trtllm_quant_fp8_moe_fused_aux_fake( gemm1_dequant: torch.Tensor, gemm2_act_quant: torch.Tensor, gemm2_dequant: torch.Tensor, - mlp_style: str = "gated_mlp", - act_fn: str = "silu", + is_gated_mlp: bool = True, + act_fn: int = int(ActivationType.Silu), ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py index a881c72fd7..d05c12825b 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn from torch.fx import GraphModule, Node +from tensorrt_llm._torch.utils import ActivationType + from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op @@ -87,15 +89,15 @@ def _quantize_moe_node( s1, s2, s3 = collect_scales(idx) args.extend([s1, s2, s3]) - # Extract mlp_style and act_fn from the original node + # Extract is_gated_mlp and act_fn from the original node # These can be in args[6:] or in kwargs - mlp_style = "gated_mlp" # default - act_fn = "silu" # default + is_gated_mlp = True # default + act_fn = ActivationType.Silu # default if len(node.args) > 6: - mlp_style = node.args[6] - elif "mlp_style" in node.kwargs: - mlp_style = node.kwargs["mlp_style"] + is_gated_mlp = node.args[6] + elif "is_gated_mlp" in node.kwargs: + is_gated_mlp = node.kwargs["is_gated_mlp"] if len(node.args) > 7: act_fn = node.args[7] @@ -104,7 +106,7 @@ def _quantize_moe_node( # Prepare kwargs for the quantized op kwargs = { - "mlp_style": mlp_style, + "is_gated_mlp": is_gated_mlp, "act_fn": act_fn, } diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index 0fd3a9a510..844c0d8958 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -527,6 +527,8 @@ class LlavaNextModel(PreTrainedModel): return if not DISAGG: self.mm_encoder = LlavaNextVisionModel(model_config) + else: + self.mm_encoder = None llm_model_config = copy.deepcopy(model_config) llm_model_config.pretrained_config = model_config.pretrained_config.text_config @@ -545,7 +547,8 @@ class LlavaNextModel(PreTrainedModel): if isinstance(weight_mapper, LlavaNextHfWeightMapper): weights = weight_mapper.preprocess_weights(weights) - self.mm_encoder.load_weights(weights) + if self.mm_encoder is not None: + self.mm_encoder.load_weights(weights) def filter_weights(weights: Dict): transformed_weights = {} diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 6740188f3d..d421b31de5 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -32,7 +32,8 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder, BaseMultimodalInputProcessor, ExtraProcessedInputs, MultimodalPlaceholderMetadata, MultimodalPlaceholderPlacement, TextPrompt, - register_input_processor) + register_input_processor, + support_multimodal_disaggregated) from ...logger import logger from ...sampling_params import SamplingParams from ..attention_backend import AttentionMetadata @@ -865,6 +866,8 @@ class Qwen2VLModelBase(PreTrainedModel): mm_encoder_config = copy.deepcopy(model_config) self.mm_encoder = Qwen2VisionModelBase( mm_encoder_config, kwargs.get('vision_model_class', None)) + else: + self.mm_encoder = None def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]): config = model_config.pretrained_config @@ -953,24 +956,21 @@ class Qwen2VLModelBase(PreTrainedModel): """ VLM forward logic with inflight batching support. """ - num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations + num_context_requests = attn_metadata.num_contexts multimodal_params = kwargs.get("multimodal_params", []) mm_embeds = [] mrope_config = {} - # NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts. - mm_multimodal_params = [ - multimodal_param for multimodal_param in multimodal_params - if multimodal_param.multimodal_data.get("image", {}).get( - "pixel_values") is not None or multimodal_param.multimodal_data. - get("video", {}).get("pixel_values_videos") is not None - ] + # NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate + # the entries that do have multimodal data from those that correspond to text-only prompts. + mm_multimodal_params = self._get_requests_with_mm_data( + multimodal_params) if len(mm_multimodal_params) > 0: if not _is_disagg(): mm_embeds = get_multimodal_embeddings( encoder_forward_fn=self.mm_encoder.forward, multimodal_params=mm_multimodal_params) - else: + elif not getattr(self, "support_mm_disagg", False): raise NotImplementedError( "Qwen2VLModel does not support disaggregated inference yet. Please unset " f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'." @@ -995,6 +995,21 @@ class Qwen2VLModelBase(PreTrainedModel): logger.debug(f'output shape: {output_prob.shape}') return output_prob + def _get_requests_with_mm_data(self, multimodal_params): + mm_multimodal_params = [] + for multimodal_param in multimodal_params: + data = multimodal_param.multimodal_data + if ( + # The first 2 conditions check whether there is input on which inference should be run. + data.get("image", {}).get("pixel_values") is not None or + data.get("video", {}).get("pixel_values_videos") is not None + # This condition corresponds to when the embeddings are already populated, as is e.g. + # the case in EPD disagg in the prefill worker. + or data.get("multimodal_embedding")): + mm_multimodal_params.append(multimodal_param) + + return mm_multimodal_params + @register_vision_encoder(Qwen2VisionModelBase, vlm_base_model=Qwen2VisionTransformerPretrainedModel) @@ -1032,11 +1047,89 @@ class Qwen2VLModel(Qwen2VLModelBase): self.llm.load_weights(weights, weight_mapper) +class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase): + + def get_prompt_token_ids( + self, inputs: TextPrompt, + mm_handles: List[Dict[str, + Any]]) -> Tuple[List[int], List[int], List[int]]: + """ + Build input token ids with multimodal placeholders expanded to the number of MM tokens. + + Args: + inputs: Text prompt input container. Must contain a non-empty prompt string. + mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. + + Returns: + Tuple[List[int], List[int], List[int]]: + - expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token + - mm_token_length: per-image MM token lengths + - mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids + """ + # TODO: Move this function to the base input processor class when extending for more models + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + if not isinstance(mm_handles, list): + raise TypeError("mm_handles must be a list") + + if len(mm_handles) != 1: + # TODO: only support single multimodal item within a request for now + raise NotImplementedError( + "Only one mm_handle is supported for Qwen2.5 VL for now") + hidden_size = mm_handles[0]['tensor_size'][1] + assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size" + input_ids = self.tokenizer(text_prompt, + return_tensors="pt").input_ids[0] + + image_token_index = self.config.image_token_id + + image_mask = input_ids == image_token_index + image_positions = torch.where(image_mask)[0] + num_images = len(image_positions) + assert num_images == len( + mm_handles), "Number of images must match number of mm_handles" + total_mm_tokens = sum(mm_handle["tensor_size"][0] + for mm_handle in mm_handles) + final_length = len(input_ids) - num_images + total_mm_tokens + # Create output tensor + expanded_ids = torch.empty(final_length, dtype=input_ids.dtype) + placeholder_id = self.tllm_multimodal_token_id + + # Fill the expanded sequence + write_pos = 0 + image_cnt = 0 + mm_token_length = [] + mm_token_offsets = [] + for read_pos in range(len(input_ids)): + if input_ids[read_pos] == image_token_index: + # Replace with placeholder id + mm_token_num = mm_handles[image_cnt]["tensor_size"][0] + expanded_ids[write_pos:write_pos + mm_token_num] = \ + placeholder_id + mm_token_offsets.append(write_pos) + mm_token_length.append(mm_token_num) + write_pos += mm_token_num + image_cnt += 1 + else: + # Copy text token as-is + expanded_ids[write_pos] = input_ids[read_pos] + write_pos += 1 + + assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}" + assert mm_token_length[-1] + mm_token_offsets[ + -1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})" + return expanded_ids.to( + torch.int32).tolist(), mm_token_length, mm_token_offsets + + +@support_multimodal_disaggregated @register_vision_encoder(Qwen2VisionModelBase, vlm_base_model=Qwen2_5_VisionModel) @register_auto_model("Qwen2_5_VLForConditionalGeneration") @register_input_processor( - Qwen2VLInputProcessorBase, + Qwen2_5VLInputProcessorBase, model_type="qwen2_5_vl", placeholder_metadata=MultimodalPlaceholderMetadata( placeholder_map={ diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 5f3c39149c..b11cd11617 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -262,6 +262,8 @@ class PyResult: chunk_size=self._chunk_size) if return_generation_logits else None self._log_probs = LogProbStorage() if return_log_probs else None self._mm_embeddings = None + self._mrope_position_ids = None + self._mrope_position_deltas = None self._additional_context_outputs = { name: [] for name in additional_outputs @@ -293,6 +295,16 @@ class PyResult: self._mm_embeddings = SharedTensorContainer.from_tensor( mm_embeddings).dump_to_dict() + def set_mrope_position( + self, + mrope_position_ids: torch.Tensor, + mrope_position_deltas: torch.Tensor, + ): + self._mrope_position_ids = (SharedTensorContainer.from_tensor( + mrope_position_ids).dump_to_dict()) + self._mrope_position_deltas = (SharedTensorContainer.from_tensor( + mrope_position_deltas).dump_to_dict()) + def transfer_remaining_device_logits(self): """Finalize any remaining generation logits transfers (for chunked mode)""" if self._generation_logits: @@ -352,6 +364,18 @@ class PyResult: def mm_embedding_handle(self) -> Dict[str, Any] | None: return self._mm_embeddings + @property + def mrope_position_ids_handle(self) -> Dict[str, Any] | None: + # NOTE: when populated, the returned `dict` contains the information necessary to rebuild + # the `SharedTensorContainer` using the `from_dict` class method. + return self._mrope_position_ids + + @property + def mrope_position_deltas_handle(self) -> Dict[str, Any] | None: + # NOTE: when populated, the returned `dict` contains the information necessary to rebuild + # the `SharedTensorContainer` using the `from_dict` class method. + return self._mrope_position_deltas + @property def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None: if self._additional_context_outputs is None: @@ -382,7 +406,8 @@ class LlmResult: py_result_properties = frozenset( ('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs', 'mm_embedding_handle', 'additional_context_outputs', - 'additional_generation_outputs')) + 'additional_generation_outputs', 'mrope_position_ids_handle', + 'mrope_position_deltas_handle')) def __init__(self, result: Union[bytes, tensorrt_llm.bindings.executor.Result], diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1a62c5beca..96ade56beb 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2213,13 +2213,14 @@ class PyTorchModelEngine(ModelEngine): mrope_position_deltas).expand( 3, 1, 1) mrope_position_ids.append(gen_mrope_position_ids) - multimodal_params.to_device( - "multimodal_data", - "cuda", - pin_memory=True, - target_keywords=[ - "mrope_config.mrope_position_deltas" - ]) + if mrope_position_deltas.device.type == "cpu": + multimodal_params.to_device( + "multimodal_data", + "cuda", + pin_memory=True, + target_keywords=[ + "mrope_config.mrope_position_deltas" + ]) multimodal_params_list.append(multimodal_params) request.py_batch_idx = request.py_seq_slot @@ -2448,8 +2449,9 @@ class PyTorchModelEngine(ModelEngine): # NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but # `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request # so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope. - mrope_position_ids = torch.cat(mrope_position_ids, - dim=-1).pin_memory() + mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) + if mrope_position_ids.device.type == "cpu": + mrope_position_ids = mrope_position_ids.pin_memory() self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_( mrope_position_ids[:, :, :total_num_tokens], non_blocking=True) final_position_ids = self.mrope_position_ids_cuda[:, :, : @@ -3362,7 +3364,26 @@ class PyTorchModelEngine(ModelEngine): mm_embeddings = list( torch.split(mm_embeddings[0], multimodal_chunks, dim=0)) - return {'mm_embeddings': mm_embeddings, 'logits': None} + # Extract mrope position data from multimodal_params if available + mrope_position_ids_list = [] + mrope_position_deltas_list = [] + for multimodal_param in multimodal_params: + mrope_config = multimodal_param.multimodal_data.get( + 'mrope_config', {}) + mrope_position_ids = mrope_config.get('mrope_position_ids') + mrope_position_deltas = mrope_config.get('mrope_position_deltas') + if mrope_position_ids is not None: + mrope_position_ids_list.append(mrope_position_ids) + if mrope_position_deltas is not None: + mrope_position_deltas_list.append(mrope_position_deltas) + + result = {'mm_embeddings': mm_embeddings, 'logits': None} + if mrope_position_ids_list: + result['mrope_position_ids'] = mrope_position_ids_list + if mrope_position_deltas_list: + result['mrope_position_deltas'] = mrope_position_deltas_list + + return result def _init_userbuffers(self, hidden_size): if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index b9bbb7cbf5..62a43a50be 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -21,7 +21,7 @@ from concurrent import futures from dataclasses import dataclass from functools import cached_property from itertools import repeat -from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast +from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast import numpy as np import torch @@ -199,6 +199,8 @@ class EarlyStopSampler(Sampler): @dataclass(kw_only=True) class MultimodalResult: mm_embeddings: List[torch.Tensor] + # Can be used to include e.g. `mrope_position_ids`, etc. + extra_data: Optional[Dict[str, Any]] = None def values(self): return vars(self).values() @@ -262,7 +264,10 @@ class EarlyStopWithMMResult(Sampler): resource_manager: Optional[ResourceManager] = None, ) -> SampleStateWithMMResult: # from model_outputs to MultimodalResult - data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"]) + data = MultimodalResult( + mm_embeddings=model_outputs.pop("mm_embeddings"), + extra_data={**model_outputs}, + ) return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data) @override @@ -276,7 +281,12 @@ class EarlyStopWithMMResult(Sampler): scheduled_requests = state.scheduled_requests assert not scheduled_requests.generation_requests mm_embeddings = state.data.mm_embeddings - for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings): + extra_data = state.data.extra_data or {} + mrope_position_ids = extra_data.get("mrope_position_ids", None) + mrope_position_deltas = extra_data.get("mrope_position_deltas", None) + for i, (request, mm_embedding) in enumerate( + zip(scheduled_requests.context_requests, mm_embeddings) + ): request.state = LlmRequestState.GENERATION_COMPLETE # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) @@ -287,6 +297,12 @@ class EarlyStopWithMMResult(Sampler): request.py_result.append_mm_embeddings(mm_embedding) + # Store mrope data if available + if mrope_position_ids is not None and mrope_position_deltas is not None: + request.py_result.set_mrope_position( + mrope_position_ids[i], mrope_position_deltas[i] + ) + @override def is_generation_model(self) -> bool: return False diff --git a/tensorrt_llm/disaggregated_params.py b/tensorrt_llm/disaggregated_params.py index 028bccbea0..4c0680bc94 100644 --- a/tensorrt_llm/disaggregated_params.py +++ b/tensorrt_llm/disaggregated_params.py @@ -40,6 +40,8 @@ class DisaggregatedParams: multimodal_hashes: Optional[List[List[int]]] = ( None # user provided mm hashes should be a list of 8 integers ) + mrope_position_ids_handle: Optional[Dict[str, Any]] = None + mrope_position_deltas_handle: Optional[Dict[str, Any]] = None def get_context_phase_params(self) -> tllme.ContextPhaseParams: return tllme.ContextPhaseParams( diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 13ff28023e..f9c502f85d 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -1,9 +1,10 @@ import atexit import concurrent.futures +import json +import os import threading -import time import weakref -from typing import Dict, Optional, Union +from typing import Dict, List, Optional import torch import zmq @@ -22,9 +23,11 @@ from .ipc import FusedIpcQueue, IpcQueue from .postproc_worker import PostprocWorker, PostprocWorkerConfig from .request import CancellingRequest, GenerationRequest from .result import GenerationResult, IterationResult -from .utils import (ErrorResponse, IntraProcessQueue, WorkerCommIpcAddrs, - create_mpi_comm_session, get_spawn_proxy_process_env, - is_llm_response, print_alive_threads) +from .rpc import RPCClient +from .rpc.rpc_common import get_unique_ipc_addr +from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session, + get_spawn_proxy_process_env, is_llm_response, + print_alive_threads) from .worker import GenerationExecutorWorker, worker_main __all__ = [ @@ -89,19 +92,27 @@ class GenerationExecutorProxy(GenerationExecutor): "llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get( "llm_args", None) is not None else None + # Generate RPC address and key for stats RPC + self.rpc_addr = get_unique_ipc_addr() + self.hmac_key = os.urandom(32) + worker_kwargs = dict(**worker_kwargs, worker_queues=self._setup_queues(), postproc_worker_config=postproc_worker_config, - is_llm_executor=False) + is_llm_executor=False, + rpc_addr=self.rpc_addr, + hmac_key=self.hmac_key) if "log_level" not in worker_kwargs: worker_kwargs["log_level"] = logger.level self.dispatch_result_thread: Optional[ManagedThread] = None - self.dispatch_stats_thread: Optional[ManagedThread] = None - self.dispatch_kv_cache_events_thread: Optional[ManagedThread] = None + self.rpc_client: Optional[RPCClient] = None self._start_executor_workers(worker_kwargs) + # Create RPC client after workers are started (worker starts RPC server) + self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key) + # MPI registers its joiner using threading._register_atexit if possible. # These functions run before atexit.register, so to avoid deadlock, # we have to notify workers to exit before MPI starts to wait them. @@ -128,19 +139,11 @@ class GenerationExecutorProxy(GenerationExecutor): socket_type=zmq.PULL if self.enable_postprocess_parallel else zmq.PAIR, name="proxy_result_queue") - self.mp_stats_queue = FusedIpcQueue(is_server=True, - fuse_message=False, - name="proxy_stats_queue") - self.kv_cache_events_queue = FusedIpcQueue( - is_server=True, - fuse_message=False, - name="proxy_kv_cache_events_queue") + # Stats and KV events are now fetched via RPC, not IPC queues. return WorkerCommIpcAddrs( request_queue_addr=self.request_queue.address, worker_init_status_queue_addr=self.worker_init_status_queue.address, result_queue_addr=self.result_queue.address, - stats_queue_addr=self.mp_stats_queue.address, - kv_cache_events_queue_addr=self.kv_cache_events_queue.address, ) def abort_request(self, request_id: int) -> None: @@ -204,71 +207,8 @@ class GenerationExecutorProxy(GenerationExecutor): return True # success - def _iteration_result_task(self, - queue: Union[FusedIpcQueue, IntraProcessQueue], - result_singleton: IterationResult, - urgent: bool = False) -> bool: - if not urgent: - time.sleep(0.2) - - try: - data = queue.get() - except: - logger.debug( - "proxy.py: Error in _iteration_result_task: queue.get()") - return False - - if data is None: - logger.debug("proxy.py: _iteration_result_task: data is None") - return False # shutdown the thread - - data = data if isinstance(data, list) else [data] - queue = result_singleton.queue - async_queues = [] - - while queue.full(): - queue.get() - - try: - for d in data: - if d is None: - logger.debug("proxy.py: _iteration_result_task: d is None") - return False - - if isinstance(queue, _SyncQueue): - queue.put_nowait(d) - async_queues.append(queue) - else: - queue.put(d) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - - except AsyncQueue.EventLoopShutdownError: - # This happens in the last loop while the generate workflow is - # stopped, or when get_stats() or aget_stats() are not called by users - # and therefore event loop can already be closed. - logger.debug("proxy.py: EventLoopShutdownError") - except Exception as e: - logger.debug(f"proxy.py: Error in _iteration_result_task: {e}") - raise e - - return True # success - - def dispatch_stats_task(self) -> bool: - if not self._iter_stats_result: - # This can happen temporarily because the WAR in tensorrt_llm/bench/benchmark/throughput.py - # is not synchronized with self.dispatch_stats_thread. - logger.debug( - f"Skipping stats dispatch while self._iter_stats_result=None") - return True # Intended behavior, not an error - return self._iteration_result_task(self.mp_stats_queue, - self._iter_stats_result) - - def dispatch_kv_cache_events_task(self) -> bool: - return self._iteration_result_task(self.kv_cache_events_queue, - self._iter_kv_events_result, - urgent=True) + # NOTE: _iteration_result_task, dispatch_stats_task, and dispatch_kv_cache_events_task + # have been removed as stats and kv_events are now fetched via RPC directly. def _start_dispatch_threads(self): if self.dispatch_result_thread is None: @@ -277,25 +217,9 @@ class GenerationExecutorProxy(GenerationExecutor): weakref.WeakMethod(self.dispatch_result_task), error_queue=self._error_queue, name="proxy_dispatch_result_thread") - self.dispatch_stats_thread = ManagedThread( - weakref.WeakMethod(self.dispatch_stats_task), - error_queue=self._error_queue, - name="proxy_dispatch_stats_thread") - self.dispatch_kv_cache_events_thread = ManagedThread( - weakref.WeakMethod(self.dispatch_kv_cache_events_task), - error_queue=self._error_queue, - name="proxy_dispatch_kv_cache_events_thread") self.dispatch_result_thread.start() - # Only collect stats when submission - # is via LLM API - if self._iter_stats_result: - self.dispatch_stats_thread.start() - - if self._iter_kv_events_result: - self.dispatch_kv_cache_events_thread.start() - self._handle_background_error() def _start_executor_workers(self, worker_kwargs): @@ -387,23 +311,18 @@ class GenerationExecutorProxy(GenerationExecutor): ): self.dispatch_result_thread.stop() self.dispatch_result_thread.join() - if self.dispatch_stats_thread is not None and self.dispatch_stats_thread.is_alive( - ): - self.dispatch_stats_thread.stop() - self.dispatch_stats_thread.join() - if self.dispatch_kv_cache_events_thread is not None and self.dispatch_kv_cache_events_thread.is_alive( - ): - self.dispatch_kv_cache_events_thread.stop() - self.dispatch_kv_cache_events_thread.join() # step3: finish all remaining work + # close the RPC client + if self.rpc_client is not None: + self.rpc_client.close() + self.rpc_client = None + # close all the sockets self.request_queue.close() self.worker_init_status_queue.close() self.result_queue.close() - self.mp_stats_queue.close() - self.kv_cache_events_queue.close() self.workers_started = False self.mpi_session.shutdown() @@ -441,6 +360,104 @@ class GenerationExecutorProxy(GenerationExecutor): return result + def get_stats(self, timeout: float) -> List[dict]: + """Get iteration statistics from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime stats as dict. + """ + if self.rpc_client is None: + logger.warning("RPC client not initialized, cannot get stats") + return [] + + stats = self.rpc_client.fetch_stats_wait_async(timeout=timeout).remote() + return [json.loads(s) if isinstance(s, str) else s for s in stats] + + def aget_stats(self, timeout: float) -> IterationResult: + """Get iteration statistics from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime stats. + """ + # Initialize iteration result if needed + self._maybe_initialize_iteration_results() + + if self._iter_stats_result is None: + logger.warning("Iteration statistics are not available yet.") + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch stats via RPC and populate the result + try: + stats = self.rpc_client.fetch_stats_wait_async( + timeout=timeout).remote() + except Exception as e: + logger.debug(f"Error fetching stats via RPC: {e}") + stats = [] + + for stat in stats: + self._iter_stats_result.queue.put(stat) + + self._iter_stats_result.set_timeout(timeout) + return self._iter_stats_result + + def get_kv_events(self, timeout: float) -> List[dict]: + """Get iteration KV events from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime events as dict. + """ + if self.rpc_client is None: + logger.warning("RPC client not initialized, cannot get kv events") + return [] + + try: + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + return [json.loads(e) if isinstance(e, str) else e for e in events] + except Exception as e: + logger.error(f"Error fetching kv events via RPC: {e}") + return [] + + def aget_kv_events(self, timeout: float) -> IterationResult: + """Get iteration KV events from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime events. + """ + # Initialize iteration result if needed + self._maybe_initialize_iteration_results() + + if self._iter_kv_events_result is None: + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch kv events via RPC and populate the result + try: + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + except Exception as e: + logger.debug(f"Error fetching kv events via RPC: {e}") + events = [] + + for event in events: + self._iter_kv_events_result.queue.put(event) + + self._iter_kv_events_result.set_timeout(timeout) + return self._iter_kv_events_result + def __del__(self): self.shutdown() diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 603c567ed5..8d33d94a7f 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses import json import time import weakref @@ -415,12 +416,19 @@ class GenerationResultBase: self.cached_tokens = getattr(response_result, 'cached_tokens', 0) self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter if context_phase_params is not None: - self.disaggregated_params = DisaggregatedParams( + existing_disagg_params = self.disaggregated_params + # Use `replace` to preserve things like `mrope_position_ids_handle` and + # `mrope_position_deltas_handle`. However, explicitly set + # `multimodal_embedding_handles=None` since they should no longer be needed. + self.disaggregated_params = dataclasses.replace( + existing_disagg_params or DisaggregatedParams(), request_type="context_only", first_gen_tokens=context_phase_params.first_gen_tokens, ctx_request_id=context_phase_params.req_id, opaque_state=context_phase_params.opaque_state, - draft_tokens=context_phase_params.draft_tokens) + draft_tokens=context_phase_params.draft_tokens, + multimodal_embedding_handles=None, + ) finish_reasons = response_result.finish_reasons # output_token_ids = (beams, tokens) @@ -440,6 +448,8 @@ class GenerationResultBase: if hasattr(response_result, 'mm_embedding_handle' ) and response_result.mm_embedding_handle is not None: self._mm_embedding_handle = response_result.mm_embedding_handle + mrope_position_ids_handle = response_result.mrope_position_ids_handle + mrope_position_deltas_handle = response_result.mrope_position_deltas_handle if self.disaggregated_params is not None: self.disaggregated_params.multimodal_embedding_handles = [ response_result.mm_embedding_handle @@ -451,6 +461,8 @@ class GenerationResultBase: response_result.mm_embedding_handle ], multimodal_hashes=self._multimodal_hashes) + self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle + self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle # Processing background errors here ASAF during generation. if self._background_error_handler and ( @@ -811,8 +823,12 @@ class GenerationResult(GenerationResultBase): def _repr_fields(self): return [ - 'request_id', 'prompt_token_ids', 'outputs', 'finished', - "context_logits", "mm_embedding_handle" + 'request_id', + 'prompt_token_ids', + 'outputs', + 'finished', + "context_logits", + "mm_embedding_handle", ] def __repr__(self) -> str: diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 09f93afb80..722609dea6 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -1,11 +1,13 @@ +import json import threading -from typing import Optional +from typing import List, Optional from ..llmapi.mpi_session import MpiPoolSession, MpiSession -from ..llmapi.utils import logger_debug +from ..llmapi.utils import logger_debug, print_colored from ..logger import logger from .executor import GenerationExecutor from .postproc_worker import PostprocWorkerConfig +from .result import IterationResult from .rpc_proxy_mixin import RpcExecutorMixin from .rpc_worker import RpcWorker from .utils import create_mpi_comm_session, get_spawn_proxy_process_env @@ -69,20 +71,110 @@ class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor): **self.worker_kwargs) def _setup_mainloop_with_tasks(self): - """Setup mainloop with all tasks needed for RpcProxy.""" + """Setup mainloop with tasks needed for RpcProxy. + + Note: Stats and kv_events are now fetched on-demand via direct RPC calls + (get_stats, aget_stats, get_kv_events, aget_kv_events), not via streaming loops. + """ tasks = [ self._fetch_responses_loop_async, - self._fetch_stats_loop_async, ] - # Only add kv_cache_events loop if it's enabled - if self._iter_kv_events_result: - tasks.append(self._fetch_kv_cache_events_loop_async) - # Call mixin's setup_mainloop with custom tasks self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop") - def fetch_stats_remote(self): - return self.rpc_client.fetch_stats().remote() + def get_stats(self, timeout: float) -> List[dict]: + """Get iteration statistics from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime stats as dict. + """ + try: + stats = self.rpc_client.fetch_stats_wait_async( + timeout=timeout).remote() + return [json.loads(s) if isinstance(s, str) else s for s in stats] + except Exception as e: + logger.debug(f"Error fetching stats via RPC: {e}") + return [] + + def aget_stats(self, timeout: float) -> IterationResult: + """Get iteration statistics from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime stats. + """ + self._maybe_initialize_iteration_results() + + if self._iter_stats_result is None: + print_colored("Iteration statistics are not available yet.\n", + "yellow") + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch stats via RPC and populate the result + try: + stats = self.rpc_client.fetch_stats_wait_async( + timeout=timeout).remote() + except Exception: + stats = [] + + for stat in stats: + self._iter_stats_result.queue.put(stat) + + self._iter_stats_result.set_timeout(timeout) + return self._iter_stats_result + + def get_kv_events(self, timeout: float) -> List[dict]: + """Get iteration KV events from the runtime via RPC. + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + List[dict]: A list of runtime events as dict. + """ + try: + # Events are already serialized by the worker's fetch_kv_cache_events_wait_async() + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + return [json.loads(e) if isinstance(e, str) else e for e in events] + except Exception as e: + logger.debug(f"Error fetching kv events via RPC: {e}") + return [] + + def aget_kv_events(self, timeout: float) -> IterationResult: + """Get iteration KV events from the runtime via RPC (async). + + Args: + timeout (float): Max wait time in seconds for the RPC call. + + Returns: + IterationResult: An async iterable object containing runtime events. + """ + # Initialize iteration result if needed + self._maybe_initialize_iteration_results() + + if self._iter_kv_events_result is None: + from .executor import empty_async_iterable + return empty_async_iterable() + + # Fetch kv events via RPC and populate the result + try: + events = self.rpc_client.fetch_kv_cache_events_wait_async( + timeout=timeout).remote() + except Exception: + events = [] + + for event in events: + self._iter_kv_events_result.queue.put(event) + + self._iter_kv_events_result.set_timeout(timeout) + return self._iter_kv_events_result def setup_engine_remote(self): return self.rpc_client.setup_engine().remote(need_response=True) diff --git a/tensorrt_llm/executor/rpc_proxy_mixin.py b/tensorrt_llm/executor/rpc_proxy_mixin.py index c7d7716f4f..ecbb86e25e 100644 --- a/tensorrt_llm/executor/rpc_proxy_mixin.py +++ b/tensorrt_llm/executor/rpc_proxy_mixin.py @@ -1,13 +1,12 @@ import asyncio import atexit -import json import os import threading from typing import Callable, List, Optional from .._utils import nvtx_range_debug from ..llmapi.tracer import global_tracer -from ..llmapi.utils import AsyncQueue, _SyncQueue +from ..llmapi.utils import _SyncQueue from ..logger import logger from .request import GenerationRequest from .result import GenerationResult @@ -47,15 +46,16 @@ class RpcExecutorMixin: Args: tasks: List of async coroutine functions to run. thread_name: Name for the main loop thread + + Note: Stats and kv_events are now fetched on-demand via direct RPC calls + (get_stats, aget_stats, get_kv_events, aget_kv_events), so the default + tasks only include the responses loop. Callers can still provide custom + tasks including stats/kv_events loops if needed for streaming use cases. """ if tasks is None: tasks = [ self._fetch_responses_loop_async, - self._fetch_stats_loop_async, ] - # Only add kv_cache_events loop if it's enabled - if self._iter_kv_events_result: - tasks.append(self._fetch_kv_cache_events_loop_async) async def main_loop_task(): await asyncio.gather(*[task() for task in tasks]) @@ -136,22 +136,6 @@ class RpcExecutorMixin: if async_queues: _SyncQueue.notify_many(event_loop, async_queues) - def handle_stats(self, stats): - """Handle stats received from RPC worker and put them into the stats result queue. - - Args: - stats: Statistics data from the RPC worker (can be dict, str, or list) - """ - self._handle_iteration_data(stats, self._iter_stats_result, "stats") - - def handle_kv_cache_events(self, events): - """Handle KV cache events received from RPC worker and put them into the events result queue. - - Args: - events: KV cache events data from the RPC worker (can be dict, str, or list) - """ - self._handle_iteration_data(events, self._iter_kv_events_result, "kv_cache_events") - async def _generic_fetch_loop_async( self, fetch_method_name: str, handler_method: Callable, method_name: str ): @@ -181,86 +165,6 @@ class RpcExecutorMixin: method_name="_fetch_responses_loop_async", ) - async def _fetch_stats_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_stats_loop_async", - handler_method=self.handle_stats, - method_name="_fetch_stats_loop_async", - ) - - async def _fetch_kv_cache_events_loop_async(self): - await self._generic_fetch_loop_async( - fetch_method_name="fetch_kv_cache_events_loop_async", - handler_method=self.handle_kv_cache_events, - method_name="_fetch_kv_cache_events_loop_async", - ) - - def _handle_iteration_data(self, data, result_singleton, data_type: str): - """Generic method to handle iteration data received from RPC worker. - - Args: - data: Data from the RPC worker (can be dict, str, or list) - result_singleton: The iteration result singleton to put data into - data_type: Type of data for logging (e.g., "stats", "kv_cache_events") - """ - # Make sure we have initialized the iteration results - self._maybe_initialize_iteration_results() - - if not result_singleton: - logger.debug(f"Skipping {data_type} handling while result_singleton=None") - return - - # Get the queue from the result singleton - queue = result_singleton.queue - async_queues = [] - - # Clear old data if queue is full (similar to _iteration_result_task) - while queue.full(): - queue.get() - - try: - # Handle different types of data - if isinstance(data, str): - # Already JSON serialized - data_json = data - elif isinstance(data, list): - # Skip empty lists to avoid putting nothing in the queue - if not data: - logger.debug(f"rpc_proxy.py: Skipping empty {data_type} list") - return - - # Handle list of data (multiple iterations) - for item in data: - if isinstance(item, str): - item_json = item - else: - item_json = json.dumps(item) - - if isinstance(queue, _SyncQueue): - queue.put_nowait(item_json) - async_queues.append(queue) - else: - queue.put(item_json) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - return - else: - # Convert dict/other to JSON string as expected by IterationResult - data_json = json.dumps(data) - - if isinstance(queue, _SyncQueue): - queue.put_nowait(data_json) - async_queues.append(queue) - else: - queue.put(data_json) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - - except AsyncQueue.EventLoopShutdownError: - # This happens when the event loop is already closed - logger.debug(f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}") - except Exception as e: - logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") - raise e + # NOTE: _fetch_stats_loop_async and _fetch_kv_cache_events_loop_async have been removed. + # Stats and kv_events are now fetched on-demand via direct RPC calls + # (get_stats, aget_stats, get_kv_events, aget_kv_events) instead of streaming loops. diff --git a/tensorrt_llm/executor/rpc_worker_mixin.py b/tensorrt_llm/executor/rpc_worker_mixin.py index cab53e6b1d..c5c201bd07 100644 --- a/tensorrt_llm/executor/rpc_worker_mixin.py +++ b/tensorrt_llm/executor/rpc_worker_mixin.py @@ -1,4 +1,5 @@ import asyncio +import time from queue import Queue from threading import Event from typing import AsyncGenerator, Optional @@ -50,8 +51,9 @@ class RpcWorkerMixin: """Submits a request to the worker.""" with nvtx_range_debug("RpcWorker.submit", color="blue", category="Worker"): logger_debug(f"[worker] Submitting request {request.id}", color="green") - super().submit(request) + result = super().submit(request) logger_debug(f"[worker] Submitted request {request.id}", color="green") + return result def fetch_responses(self, timeout: Optional[float] = None) -> list: """Fetch responses from the response queue (blocking).""" @@ -99,54 +101,58 @@ class RpcWorkerMixin: f"[worker] RpcWorker {self.rank} quitting fetch_responses_loop_async", color="yellow" ) - async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: - """Async version of fetch_stats using asyncio.to_thread.""" - return await asyncio.to_thread(self.fetch_stats) - - async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list: - """Async version of fetch_kv_cache_events using asyncio.to_thread.""" - return await asyncio.to_thread(self.fetch_kv_cache_events) - - async def fetch_stats_loop_async( - self, timeout: Optional[float] = None - ) -> AsyncGenerator[list, None]: - """Stream stats in a loop until shutdown.""" - async for data in self._generic_fetch_loop_async( - fetch_method=self.fetch_stats_async, - serializer=self._stats_serializer, - method_name="fetch_stats_loop_async", - timeout=timeout, - ): - yield data - - async def fetch_kv_cache_events_loop_async( - self, timeout: Optional[float] = None - ) -> AsyncGenerator[list, None]: - """Stream KV cache events in a loop until shutdown.""" - async for data in self._generic_fetch_loop_async( - fetch_method=self.fetch_kv_cache_events_async, - serializer=self._kv_cache_events_serializer, - method_name="fetch_kv_cache_events_loop_async", - timeout=timeout, - ): - yield data - - async def _generic_fetch_loop_async( - self, fetch_method, serializer, method_name: str, timeout: Optional[float] = None - ) -> AsyncGenerator[list, None]: - """Generic method for fetching data in a loop. + async def fetch_stats_wait_async(self, timeout: Optional[float] = None) -> list: + """Poll for stats until available or timeout. Args: - fetch_method: The async method to call for fetching data - serializer: The serializer function to apply to each item - method_name: Name of the method for logging - timeout: Optional timeout between fetches + timeout: Max wait time in seconds. If None, fetch once without waiting. """ - while not self.shutdown_event.is_set(): - timeout = timeout or 0.1 - await asyncio.sleep(timeout) - data = await fetch_method() - # Always yield data, even if empty, to prevent the client looks like hanging - # TODO: Remove the empty data to reduce the IPC overhead - yield [serializer(item) for item in data] - logger_debug(f"[worker] RpcWorker {self.rank} quitting {method_name}", color="yellow") + logger_debug( + f"[worker] RpcWorker {self.rank} is fetching stats with timeout {timeout}", + color="yellow", + ) + start = time.time() + while True: + stats = await asyncio.to_thread(self.fetch_stats) + if stats or timeout is None: + break + if (time.time() - start) >= timeout: + break + await asyncio.sleep(0.1) + return [self._stats_serializer(s) for s in stats] + + async def fetch_kv_cache_events_wait_async(self, timeout: Optional[float] = None) -> list: + """Poll for KV cache events until available or timeout. + + Args: + timeout: Max wait time in seconds. If None, fetch once without waiting. + """ + start = time.time() + while True: + events = await asyncio.to_thread(self.fetch_kv_cache_events) + if events or timeout is None: + break + if (time.time() - start) >= timeout: + break + await asyncio.sleep(0.1) + return [self._kv_cache_events_serializer(e) for e in events] + + async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: + """Async version of fetch_stats using asyncio.to_thread. + + This method is exposed via RPC and can be called directly by the proxy. + Returns serialized stats (JSON strings) that can be sent over RPC. + """ + stats = await asyncio.to_thread(self.fetch_stats) + # Serialize stats before sending over RPC (IterationStats objects are not picklable) + return [self._stats_serializer(s) for s in stats] + + async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list: + """Async version of fetch_kv_cache_events using asyncio.to_thread. + + This method is exposed via RPC and can be called directly by the proxy. + Returns serialized events (JSON strings) that can be sent over RPC. + """ + events = await asyncio.to_thread(self.fetch_kv_cache_events) + # Serialize events before sending over RPC + return [self._kv_cache_events_serializer(e) for e in events] diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index 8a5f61bc36..e52ea481fb 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -142,8 +142,6 @@ class WorkerCommIpcAddrs(NamedTuple): request_queue_addr: tuple[str, Optional[bytes]] worker_init_status_queue_addr: tuple[str, Optional[bytes]] result_queue_addr: tuple[str, Optional[bytes]] - stats_queue_addr: tuple[str, Optional[bytes]] - kv_cache_events_queue_addr: tuple[str, Optional[bytes]] def is_llm_response(instance): diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 2199bee74a..c4917a86a5 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -1,11 +1,9 @@ import gc import os -import time import traceback from concurrent.futures import ProcessPoolExecutor from pathlib import Path -from queue import Queue -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import zmq @@ -18,25 +16,22 @@ from ..llmapi.llm_args import BaseLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, set_global_tracer -from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug, - print_traceback_on_error) +from ..llmapi.utils import ManagedThread, logger_debug, print_traceback_on_error from ..sampling_params import BatchedLogitsProcessor from .base_worker import BaseWorker, _init_hf_modules -from .executor import IterationResultQueue from .ipc import FusedIpcQueue, IpcQueue from .postproc_worker import (PostprocWorker, PostprocWorkerConfig, postproc_worker_main) from .request import CancellingRequest, GenerationRequest -from .result import IterationResult -from .utils import (ErrorResponse, RequestError, WorkerCommIpcAddrs, - has_event_loop) +from .rpc_worker_mixin import RpcWorkerMixin +from .utils import ErrorResponse, RequestError, WorkerCommIpcAddrs __all__ = [ "GenerationExecutorWorker", ] -class GenerationExecutorWorker(BaseWorker): +class GenerationExecutorWorker(RpcWorkerMixin, BaseWorker): def __init__( self, @@ -48,6 +43,8 @@ class GenerationExecutorWorker(BaseWorker): hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, + rpc_addr: Optional[str] = None, + hmac_key: Optional[bytes] = None, ) -> None: super().__init__( engine=engine, @@ -62,35 +59,18 @@ class GenerationExecutorWorker(BaseWorker): self.setup_engine() + # Setup RPC server for stats (skip init_rpc_worker to keep IPC response queue) + # Only set up if rpc_addr is provided (for stats RPC support) + if rpc_addr is not None: + self.rpc_addr = rpc_addr + self.hmac_key = hmac_key + self.start_rpc_server() # Reuse from RpcWorkerMixin + self.await_response_thread = ManagedThread( self.await_response_task, error_queue=self._error_queue, name="await_response_thread") - self.dispatch_stats_thread = ManagedThread( - self.dispatch_stats_task, - error_queue=self._error_queue, - name="dispatch_stats_thread") - - self.dispatch_kv_cache_events_thread = ManagedThread( - self.dispatch_kv_cache_events_task, - error_queue=self._error_queue, - name="dispatch_kv_cache_events_thread") - - def _create_iteration_result_queue(self, - it_result_queue: IterationResultQueue): - if not it_result_queue.is_initialized: - # not yet initialized - it_result_queue.is_initialized = True - if has_event_loop(): - _queue = AsyncQueue() - it_result_queue.queue = _queue.sync_q - it_result_queue.aqueue = _queue - else: - _queue = Queue() - it_result_queue.queue = _queue - it_result_queue.aqueue = None - def start_thread(self, thread: ManagedThread): if self.engine.can_enqueue_requests() and not thread.is_alive(): thread.start() @@ -98,74 +78,10 @@ class GenerationExecutorWorker(BaseWorker): def await_response_task(self) -> bool: return self._await_response_helper() - def _iteration_result_task(self, it_result_queue: IterationResultQueue, - engine_get_result_api: Callable, - result_singleton: IterationResult, - result_serializer: Callable) -> bool: - time.sleep(0.2) - async_queues = [] - queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue - try: - for results in engine_get_result_api(): - res = result_serializer(results) - if self._is_llm_executor and result_singleton: - # In this case, there's no ExecutorBindingProxy. - # Worker needs to take care of putting to result queue. - while queue.full(): - queue.get() - if isinstance(queue, _SyncQueue): - queue.put_nowait(res) - async_queues.append(queue) - else: - queue.put(res) - else: - # Send to ExecutorBindingProxy via IPC - queue.put(res) - - if async_queues: - _SyncQueue.notify_many(queue.loop, async_queues) - except AsyncQueue.EventLoopShutdownError: - # This happens in the last results loop while the generate workflow is stopped. - logger.debug("worker.py: EventLoopShutdownError") - except Exception as e: - logger.error(f"worker.py: Error in _iteration_result_task: {e}") - raise e - - return True # success - - def dispatch_stats_task(self) -> bool: - return self._iteration_result_task(self.stats_queues, self.fetch_stats, - self._iter_stats_result, - self._stats_serializer) - - def dispatch_kv_cache_events_task(self) -> bool: - if isinstance(self.engine, tllm.Executor): - # Check if the engine has a kv cache event manager - # If not, return an empty list for the events which will cause the thread to exit early. - event_manager = self.engine.get_kv_cache_event_manager() - if event_manager is None: - events_api = lambda: [None] - else: - events_api = event_manager.get_latest_events - return self._iteration_result_task(self.kv_events_queues, - events_api, - self._iter_kv_events_result, - self._kv_cache_events_serializer) - else: - return self._iteration_result_task( - self.kv_events_queues, self.engine.get_latest_kv_cache_events, - self._iter_kv_events_result, self._kv_cache_events_serializer) - def start(self): - # create iteration result queues - self._create_iteration_result_queue(self.stats_queues) - self._create_iteration_result_queue(self.kv_events_queues) - - # start threads + # Stats and KV events are now fetched on-demand via RPC, + # so we only need to start the response thread self.start_thread(self.await_response_thread) - self.start_thread(self.dispatch_kv_cache_events_thread) - if mpi_rank() == 0: - self.start_thread(self.dispatch_stats_thread) def shutdown(self): @@ -178,16 +94,9 @@ class GenerationExecutorWorker(BaseWorker): if self.engine is not None: if self.engine.can_enqueue_requests(): - if self.await_response_thread.is_alive(): self.await_response_thread.stop() self.await_response_thread.join() - if self.dispatch_stats_thread.is_alive(): - self.dispatch_stats_thread.stop() - self.dispatch_stats_thread.join() - if self.dispatch_kv_cache_events_thread.is_alive(): - self.dispatch_kv_cache_events_thread.stop() - self.dispatch_kv_cache_events_thread.join() self.engine.shutdown() self.engine = None @@ -240,6 +149,8 @@ def worker_main( hf_model_dir: Optional[Path] = None, tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[BaseLlmArgs] = None, + rpc_addr: Optional[str] = None, + hmac_key: Optional[bytes] = None, ) -> None: mpi_comm().barrier() @@ -287,15 +198,6 @@ def worker_main( is_server=False, socket_type=zmq.DEALER, name="worker_init_status_queue") - mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr, - is_server=False, - fuse_message=True, - name="worker_stats_queue") - kv_cache_events_queue = FusedIpcQueue( - worker_queues.kv_cache_events_queue_addr, - is_server=False, - fuse_message=False, - name="worker_kv_cache_events_queue") if postproc_worker_config.enabled: # IPC queues for sending inputs to the postprocess parallel @@ -322,9 +224,6 @@ def worker_main( assert result_queues is not None for q in result_queues: q.put(None) - # Signal the stats thread in the proxy to quit - mp_stats_queue.put(None) - kv_cache_events_queue.put(None) postprocess_worker_futures = [] if is_leader and postproc_worker_config.enabled: @@ -370,7 +269,9 @@ def worker_main( is_llm_executor=is_llm_executor, hf_model_dir=hf_model_dir, tokenizer=tokenizer, - llm_args=llm_args) + llm_args=llm_args, + rpc_addr=rpc_addr, + hmac_key=hmac_key) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) @@ -396,11 +297,6 @@ def worker_main( else: worker.set_result_queue(result_queue) - # initialize the iteration result queues - worker._set_iteration_result_queue(worker.stats_queues, - mp_stats_queue) - worker._set_iteration_result_queue(worker.kv_events_queues, - kv_cache_events_queue) # Send ready signal with confirmation ready_msg = (ready_signal, None) if not worker_init_status_queue.notify_with_retry(ready_msg): diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 33774f0ed8..6d3410bf3c 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -89,8 +89,12 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult): def _repr_fields(self): return [ - "request_id", "prompt", "prompt_token_ids", "outputs", "finished", - "mm_embedding_handle" + "request_id", + "prompt", + "prompt_token_ids", + "outputs", + "finished", + "mm_embedding_handle", ] @@ -419,7 +423,7 @@ class BaseLLM: multimodal_params = None if is_mm_disagg: - if not self.input_processor.support_mm_disagg: + if not getattr(self.input_processor, "support_mm_disagg", False): raise ValueError( "Multimodal disaggregated inference is not supported for this model" ) @@ -436,14 +440,42 @@ class BaseLLM: mm_hashes = disaggregated_params.multimodal_hashes multimodal_input = MultimodalInput.from_components( mm_hashes, mm_token_positions, mm_token_length) + multimodal_data = {"multimodal_embedding": mm_handles} + if disaggregated_params.mrope_position_ids_handle is not None: + # NOTE: `PyTorchModelEngine` assumes both are present when using mrope. + assert disaggregated_params.mrope_position_deltas_handle is not None + mrope_config = {} + mrope_config[ + "mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle + mrope_config[ + "mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle + multimodal_data["mrope_config"] = mrope_config multimodal_params = MultimodalParams( multimodal_input=multimodal_input, - multimodal_data={"multimodal_embedding": mm_handles}) + multimodal_data=multimodal_data, + ) elif "prompt_token_ids" in inputs: prompt_token_ids = inputs['prompt_token_ids'] prompt = None query_token_ids = inputs.get("query_token_ids", None) + multimodal_data = {} + # NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit. + if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None: + # It looks like `PyTorchModelEngine` assumes both are present when using mrope? + if disaggregated_params.mrope_position_deltas_handle is None: + raise RuntimeError( + "`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both " + "be provided, or both `None`.") + mrope_config = {} + mrope_config[ + "mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle + mrope_config[ + "mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle + multimodal_data["mrope_config"] = mrope_config + if multimodal_data: + multimodal_params = MultimodalParams( + multimodal_data=multimodal_data) elif "prompt" in inputs: if 'multi_modal_data' in inputs: # TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash) diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index 8553d4678e..3ff85dd42b 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -101,14 +101,8 @@ class MultimodalEncoder(_TorchLLM): inputs = [prompt_inputs(i) for i in inputs] - def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any: - if isinstance(maybe_batched, list): - return maybe_batched[pos] - else: - return maybe_batched - futures = [] - for i, request_inputs in enumerate(inputs): + for request_inputs in inputs: future = self.generate_async(request_inputs) futures.append(future) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index c9699bb91f..644ae8e418 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -51,20 +51,22 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, MemoryUpdateRequest, ModelCard, ModelList, PromptTokensDetails, ResponsesRequest, + ResponsesResponse, UpdateWeightsRequest, UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, - chat_harmony_post_processor, chat_harmony_streaming_post_processor, - chat_response_post_processor, chat_stream_post_processor, - completion_response_post_processor, completion_stream_post_processor) + ResponsesAPIPostprocArgs, chat_harmony_post_processor, + chat_harmony_streaming_post_processor, chat_response_post_processor, + chat_stream_post_processor, completion_response_post_processor, + completion_stream_post_processor, responses_api_post_processor, + responses_api_streaming_post_processor) from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore, + ResponsesStreamingProcessor, ServerArrivalTimeMiddleware) from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds -from tensorrt_llm.serve.responses_utils import \ - process_streaming_events as responses_api_process_streaming_events from tensorrt_llm.serve.responses_utils import \ request_preprocess as responses_api_request_preprocess from tensorrt_llm.version import __version__ as VERSION @@ -119,9 +121,8 @@ class OpenAIServer: self.model_config = None # Enable response storage for Responses API - self.enable_store = True - if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0: - self.enable_store = False + self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled + self.conversation_store = ConversationHistoryStore() model_dir = Path(model) @@ -942,19 +943,39 @@ class OpenAIServer: return self.create_error_response(message=str(e), err_type="internal_error") async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response: - async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]: - async for event_data in responses_api_process_streaming_events( - request=request, - sampling_params=sampling_params, - generator=generator, - model_name=self.model, - conversation_store=self.conversation_store, - use_harmony=self.use_harmony, - reasoning_parser=self.llm.args.reasoning_parser, - tool_parser=self.tool_parser, - enable_store=self.enable_store - ): - yield event_data + async def create_response( + promise: RequestOutput, postproc_params: PostprocParams) -> ResponsesResponse: + await promise.aresult() + if self.postproc_worker_enabled: + response = promise.outputs[0]._postprocess_result + else: + args = postproc_params.postproc_args + response = await responses_api_create_response( + generator=promise, + request=request, + sampling_params=args.sampling_params, + model_name=self.model, + conversation_store=self.conversation_store, + generation_result=None, + enable_store=self.enable_store and request.store, + use_harmony=self.use_harmony, + reasoning_parser=args.reasoning_parser, + tool_parser=args.tool_parser, + ) + + return response + + async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + streaming_processor = args.streaming_processor + initial_responses = streaming_processor.get_initial_responses() + for initial_response in initial_responses: + yield initial_response + + async for res in promise: + pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + for pp_res in pp_results: + yield pp_res try: if request.background: @@ -977,38 +998,61 @@ class OpenAIServer: request=request, prev_response=prev_response, conversation_store=self.conversation_store, - enable_store=self.enable_store, + enable_store=self.enable_store and request.store, use_harmony=self.use_harmony, tokenizer=self.tokenizer if not self.use_harmony else None, model_config=self.model_config if not self.use_harmony else None, processor=self.processor if not self.use_harmony else None, ) + streaming_processor = None + if request.stream: + # Per-request streaming processor + streaming_processor = ResponsesStreamingProcessor( + request=request, + sampling_params=sampling_params, + model_name=self.model, + conversation_store=self.conversation_store, + enable_store=self.enable_store and request.store, + use_harmony=self.use_harmony, + reasoning_parser=self.llm.args.reasoning_parser, + tool_parser=self.tool_parser, + ) + + postproc_args = ResponsesAPIPostprocArgs( + model=self.model, + request=request, + sampling_params=sampling_params, + use_harmony=self.use_harmony, + reasoning_parser=self.llm.args.reasoning_parser, + tool_parser=self.tool_parser, + streaming_processor=streaming_processor, + ) + postproc_params = PostprocParams( + post_processor=responses_api_streaming_post_processor + if request.stream else responses_api_post_processor, + postproc_args=postproc_args, + ) promise = self.llm.generate_async( inputs=input_tokens, sampling_params=sampling_params, streaming=request.stream, + _postproc_params=postproc_params if self.postproc_worker_enabled else None, ) + if self.postproc_worker_enabled and request.store: + logger.warning("Postproc workers are enabled, request will not be stored!") + asyncio.create_task(self.await_disconnected(raw_request, promise)) if request.stream: return StreamingResponse( - create_stream_response(promise, request, sampling_params), + content=create_streaming_generator(promise, postproc_params), media_type="text/event-stream" ) else: - return await responses_api_create_response( - generator=promise, - request=request, - sampling_params=sampling_params, - model_name=self.model, - conversation_store=self.conversation_store, - generation_result=None, - enable_store=self.enable_store, - use_harmony=self.use_harmony, - reasoning_parser=self.llm.args.reasoning_parser, - tool_parser=self.tool_parser) + response = await create_response(promise, postproc_params) + return JSONResponse(content=response.model_dump()) except CppExecutorError: logger.error(traceback.format_exc()) # If internal executor error is raised, shutdown the server diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index aa56cc6e5b..01ffb648e2 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -1,11 +1,16 @@ from dataclasses import dataclass, field from typing import Any, List, Literal, Optional, Tuple, Union +from tensorrt_llm.serve.responses_utils import ResponsesStreamingProcessor +from tensorrt_llm.serve.responses_utils import \ + create_response_non_store as responses_api_create_response_non_store + from .._utils import nvtx_range_debug from ..executor import (DetokenizedGenerationResultBase, GenerationResult, GenerationResultBase) from ..executor.postproc_worker import PostprocArgs from ..executor.result import Logprob, TokenLogprobs +from ..llmapi import SamplingParams from ..llmapi.reasoning_parser import (BaseReasoningParser, ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer @@ -26,7 +31,8 @@ from .openai_protocol import (ChatCompletionLogProbs, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall, - PromptTokensDetails, StreamOptions, ToolCall, + PromptTokensDetails, ResponsesRequest, + ResponsesResponse, StreamOptions, ToolCall, UsageInfo, to_disaggregated_params) from .tool_parser.base_tool_parser import BaseToolParser from .tool_parser.core_types import ToolCallItem @@ -543,3 +549,42 @@ def chat_harmony_streaming_post_processor( num_prompt_tokens=args.num_prompt_tokens, ) return response + + +@dataclass(kw_only=True) +class ResponsesAPIPostprocArgs(PostprocArgs): + model: str + request: ResponsesRequest + sampling_params: SamplingParams + use_harmony: bool + reasoning_parser: Optional[str] = None + tool_parser: Optional[str] = None + streaming_processor: Optional[ResponsesStreamingProcessor] = None + + +@nvtx_range_debug("responses_api_post_processor") +def responses_api_post_processor( + rsp: GenerationResult, + args: ResponsesAPIPostprocArgs) -> ResponsesResponse: + return responses_api_create_response_non_store( + generation_result=rsp, + request=args.request, + sampling_params=args.sampling_params, + model_name=args.model, + use_harmony=args.use_harmony, + reasoning_parser=args.reasoning_parser, + tool_parser=args.tool_parser, + ) + + +@nvtx_range_debug("responses_api_streaming_post_processor") +def responses_api_streaming_post_processor( + rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> List[str]: + if args.streaming_processor is None: + raise ValueError( + "streaming_processor is required for streaming post-processing") + outputs = args.streaming_processor.process_single_output(rsp) + if rsp._done: + outputs.append( + args.streaming_processor.get_final_response_non_store(rsp)) + return outputs diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index 4f0e4e55a6..9297422c6a 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -10,7 +10,7 @@ import uuid from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from copy import copy -from typing import Any, Literal, Optional, OrderedDict, Tuple, Union +from typing import Any, List, Literal, Optional, OrderedDict, Tuple, Union from openai.types.responses import (ResponseCompletedEvent, ResponseContentPartAddedEvent, @@ -41,6 +41,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent, from transformers import AutoProcessor, PretrainedConfig from tensorrt_llm.bindings import steady_clock_now +from tensorrt_llm.executor import GenerationResult from tensorrt_llm.inputs.utils import apply_chat_template from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.llmapi.llm import RequestOutput @@ -962,7 +963,7 @@ def _apply_tool_parser( return normal_text, calls -async def _create_output_content( +def _create_output_content( final_res: RequestOutput, reasoning_parser: Optional[str] = None, tool_parser: Optional[str] = None, @@ -1040,7 +1041,7 @@ async def _create_output_content( return output_items, output_messages -async def _create_output_content_harmony( +def _create_output_content_harmony( final_res: RequestOutput ) -> Tuple[list[ResponseOutputItem], list[Message]]: output_messages = _parse_output_tokens(final_res.outputs[0].token_ids) @@ -1057,12 +1058,53 @@ async def _create_output_content_harmony( return output_content, output_messages +def _create_response( + final_res: GenerationResult, + use_harmony: bool, + request: ResponsesRequest, + model_name: str, + response_creation_time: int, + sampling_params: SamplingParams, + reasoning_parser: Optional[str] = None, + tool_parser: Optional[str] = None, +) -> tuple[ResponsesResponse, list[Message | ChatCompletionMessageParam]]: + _responses_debug_log("================================================") + _responses_debug_log("RAW MODEL OUTPUT:") + _responses_debug_log(final_res.outputs) + _responses_debug_log("================================================") + + # prepare responses output + output_content = [] + if use_harmony: + output_content, output_messages = _create_output_content_harmony( + final_res) + else: + output_content, output_messages = _create_output_content( + final_res, reasoning_parser, tool_parser, request.tools) + + response = ResponsesResponse.from_request( + request=request, + sampling_params=sampling_params, + model_name=model_name, + created_time=response_creation_time, + output=output_content, + status=finish_reason_mapping(final_res.outputs[0].finish_reason), + ) + + _responses_debug_log("========== Response ===========") + _responses_debug_log(response) + _responses_debug_log("===============================") + + # return output_messages for store_response + return response, output_messages + + async def create_response( - generator, request: ResponsesRequest, sampling_params: SamplingParams, model_name: str, conversation_store: ConversationHistoryStore, + generator: Optional[AsyncGenerator[RequestOutput, None]] = None, generation_result: Optional[RequestOutput] = None, enable_store: bool = False, use_harmony: bool = True, @@ -1078,33 +1120,22 @@ async def create_response( if generation_result is not None: final_res = generation_result - else: + elif generator is not None: final_res = await generator if final_res is None: raise RuntimeError("No output generated or provided") - _responses_debug_log("================================================") - _responses_debug_log("RAW MODEL OUTPUT:") - _responses_debug_log(final_res.outputs) - _responses_debug_log("================================================") - # prepare responses output - output_content = [] - if use_harmony: - output_content, output_messages = await _create_output_content_harmony( - final_res) - else: - output_content, output_messages = await _create_output_content( - final_res, reasoning_parser, tool_parser, request.tools) - - response = ResponsesResponse.from_request( + response, output_messages = _create_response( + final_res=final_res, + use_harmony=use_harmony, request=request, - sampling_params=sampling_params, model_name=model_name, - created_time=response_creation_time, - output=output_content, - status=finish_reason_mapping(final_res.outputs[0].finish_reason), + response_creation_time=response_creation_time, + sampling_params=sampling_params, + reasoning_parser=reasoning_parser, + tool_parser=tool_parser, ) if enable_store and request.store: @@ -1112,9 +1143,34 @@ async def create_response( resp_msgs=output_messages, prev_resp_id=prev_response_id) - _responses_debug_log("========== Response ===========") - _responses_debug_log(response) - _responses_debug_log("===============================") + return response + + +def create_response_non_store( + generation_result: RequestOutput, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + use_harmony: bool = True, + create_time: Optional[int] = None, + reasoning_parser: Optional[str] = None, + tool_parser: Optional[str] = None, +) -> ResponsesResponse: + response_creation_time = create_time if create_time is not None else int( + time.time()) + + # prepare responses output + response, _ = _create_response( + final_res=generation_result, + use_harmony=use_harmony, + request=request, + model_name=model_name, + response_creation_time=response_creation_time, + sampling_params=sampling_params, + reasoning_parser=reasoning_parser, + tool_parser=tool_parser, + ) + return response @@ -1649,6 +1705,143 @@ def _generate_streaming_event_harmony( parser.last_content_delta) +class ResponsesStreamingProcessor: + + def __init__( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + create_time: Optional[int] = None, + conversation_store: Optional[ConversationHistoryStore] = None, + enable_store: bool = False, + use_harmony: bool = True, + reasoning_parser: Optional[str] = None, + tool_parser: Optional[str] = None, + ): + self.model_name = model_name + self.request = request + self.sampling_params = sampling_params + self.sequence_number = 0 + self.streaming_events_helper = ResponsesStreamingEventsHelper() + self.response_creation_time = create_time if create_time is not None else int( + time.time()) + self.final_res: Optional[RequestOutput] = None + self.reasoning_parser_dict: dict[int, BaseReasoningParser] = {} + self.tool_parser_dict: dict[int, BaseToolParser] = {} + self.stream_request_id = f"responses-api-{request.request_id}" + self.conversation_store = conversation_store + self.enable_store = enable_store + self.use_harmony = use_harmony + self.reasoning_parser = reasoning_parser + self.tool_parser = tool_parser + + def _send_event(self, event: OpenAIBaseModel): + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = self.sequence_number + self.sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, 'type', 'unknown') + return (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + + def get_initial_responses(self) -> List[str]: + initial_response = ResponsesResponse.from_request( + request=self.request, + sampling_params=self.sampling_params, + model_name=self.model_name, + created_time=self.response_creation_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + + resp_created = self._send_event( + self.streaming_events_helper.get_response_created_event( + initial_response)) + resp_in_progress = self._send_event( + self.streaming_events_helper.get_response_in_progress_event( + initial_response)) + return [resp_created, resp_in_progress] + + async def get_final_response( + self, + final_res: RequestOutput, + ) -> str: + final_response = await create_response( + generator=None, + request=self.request, + sampling_params=self.sampling_params, + model_name=self.model_name, + conversation_store=self.conversation_store, + generation_result=final_res, + enable_store=self.enable_store, + use_harmony=self.use_harmony, + create_time=self.response_creation_time, + reasoning_parser=self.reasoning_parser, + tool_parser=self.tool_parser, + ) + + return self._send_event( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) + + def get_final_response_non_store( + self, + final_res: RequestOutput, + ) -> str: + final_response = create_response_non_store( + generation_result=final_res, + request=self.request, + sampling_params=self.sampling_params, + model_name=self.model_name, + use_harmony=self.use_harmony, + create_time=self.response_creation_time, + reasoning_parser=self.reasoning_parser, + tool_parser=self.tool_parser, + ) + + return self._send_event( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) + + def process_single_output(self, res: GenerationResult) -> list[str]: + event_generator = None + output = res.outputs[0] + if self.use_harmony: + event_generator = _generate_streaming_event_harmony( + harmony_adapter=get_harmony_adapter(), + stream_request_id=self.stream_request_id, + output=output, + request=self.request, + streaming_events_helper=self.streaming_events_helper, + ) + + else: + event_generator = _generate_streaming_event( + output=output, + request=self.request, + finished_generation=res._done, + streaming_events_helper=self.streaming_events_helper, + reasoning_parser_id=self.reasoning_parser, + tool_parser_id=self.tool_parser, + reasoning_parser_dict=self.reasoning_parser_dict, + tool_parser_dict=self.tool_parser_dict, + ) + + if event_generator is None: + raise RuntimeError("Failed to generate streaming events") + + return [self._send_event(event) for event in event_generator] + + async def process_streaming_events( generator, request: ResponsesRequest, @@ -1661,97 +1854,31 @@ async def process_streaming_events( reasoning_parser: Optional[str] = None, tool_parser: Optional[str] = None, ) -> AsyncGenerator[str, None]: - sequence_number = 0 - response_creation_time = create_time if create_time is not None else int( - time.time()) - final_res: Optional[RequestOutput] = None - reasoning_parser_dict: dict[int, BaseReasoningParser] = {} - tool_parser_dict: dict[int, BaseToolParser] = {} - - def _send_event(event: OpenAIBaseModel): - nonlocal sequence_number - # Set sequence_number if the event has this attribute - if hasattr(event, 'sequence_number'): - event.sequence_number = sequence_number - sequence_number += 1 - # Get event type from the event's type field if it exists - event_type = getattr(event, 'type', 'unknown') - return (f"event: {event_type}\n" - f"data: {event.model_dump_json(indent=None)}\n\n") - - streaming_events_helper = ResponsesStreamingEventsHelper() - - initial_response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=response_creation_time, - output=[], - status="in_progress", - usage=None, - ).model_dump() - - yield _send_event( - streaming_events_helper.get_response_created_event(initial_response)) - yield _send_event( - streaming_events_helper.get_response_in_progress_event( - initial_response)) - - stream_request_id = f"responses-api-{request.request_id}" - harmony_adapter = get_harmony_adapter() - async for res in generator: - final_res = res - # TODO(JunyiXu-nv): handle multiple outputs - output = res.outputs[0] - - event_generator = None - if use_harmony: - event_generator = _generate_streaming_event_harmony( - harmony_adapter=harmony_adapter, - stream_request_id=stream_request_id, - output=output, - request=request, - streaming_events_helper=streaming_events_helper, - ) - - else: - event_generator = _generate_streaming_event( - output=output, - request=request, - finished_generation=res.finished, - streaming_events_helper=streaming_events_helper, - reasoning_parser_id=reasoning_parser, - tool_parser_id=tool_parser, - reasoning_parser_dict=reasoning_parser_dict, - tool_parser_dict=tool_parser_dict, - ) - - if event_generator is None: - raise RuntimeError("Failed to generate streaming events") - - for event in event_generator: - yield _send_event(event) - - final_response = await create_response( - generator=generator, + streaming_processor = ResponsesStreamingProcessor( request=request, sampling_params=sampling_params, model_name=model_name, + create_time=create_time, conversation_store=conversation_store, - generation_result=final_res, enable_store=enable_store, use_harmony=use_harmony, - create_time=response_creation_time, reasoning_parser=reasoning_parser, tool_parser=tool_parser, ) - yield _send_event( - ResponseCompletedEvent( - type="response.completed", - sequence_number=-1, - response=final_response.model_dump(), - )) + initial_responses = streaming_processor.get_initial_responses() + for initial_response in initial_responses: + yield initial_response + + async for res in generator: + final_res = res + events = streaming_processor.process_single_output(res) + for event in events: + yield event + + final_response = await streaming_processor.get_final_response(final_res) + + yield final_response class ServerArrivalTimeMiddleware: diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 813bd40533..ee5fec7cd5 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -479,3 +479,8 @@ triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregat cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (https://nvbugs/5760737) unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_allreduce_pg_op[seqlen:16-hidden:1024] SKIP (https://nvbugs/5760740) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler] SKIP (https://nvbugs/5760747) +unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-strategy:8-dtype:bfloat16-hidden:8192-seqlen:[15]] SKIP (https://nvbugs/5761364) +triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854) +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822) +accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (https://nvbugs/5762852) +accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugs/5762852) diff --git a/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py b/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py index 54cf23d6cb..6e3a6415b9 100644 --- a/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py +++ b/tests/unittest/_torch/attention/test_trtllm_flashinfer_symbol_collision.py @@ -6,6 +6,7 @@ import torch import tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.torch_moe # noqa: F401 import tensorrt_llm._torch.custom_ops.torch_custom_ops as trt_ops # noqa: F401 +from tensorrt_llm._torch.utils import ActivationType def test_flashinfer_fused_moe_matches_torch_moe(): @@ -75,8 +76,8 @@ def test_flashinfer_fused_moe_matches_torch_moe(): w1_weight=w1_list, # gate projection w2_weight=w2_list, # down projection w3_weight=w3_list, # up projection - mlp_style="gated_mlp", - act_fn="silu", + is_gated_mlp=True, + act_fn=int(ActivationType.Silu), ) # Compare outputs diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 2d5e0bd8a5..8c034799ad 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -186,7 +186,7 @@ def test_llama4_stacked_moe_pattern_detection(): moe_node = graph.call_function( torch.ops.auto_deploy.torch_moe, args=(x, selected_experts, routing_weights, w1_list, w2_list, w3_list), - kwargs={"mlp_style": "gated_mlp", "apply_routing_on_input": True}, + kwargs={"is_gated_mlp": True, "apply_routing_on_input": True}, ) graph.output(moe_node) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index 99fccfab30..09b7d65eef 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -7,6 +7,7 @@ from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_availab import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401 +from tensorrt_llm._torch.utils import ActivationType def setup_moe_test(dtype, num_experts): @@ -173,8 +174,8 @@ def test_bmm_based_moe_op_run(dtype): [fused_w3_w1_stacked_weight], # Wrap in list for unified interface [fused_w2_weight], # Wrap in list for unified interface [], # Empty w3_weight list for stacked gated MLP - mlp_style="gated_mlp", - act_fn="silu", + is_gated_mlp=True, + act_fn=ActivationType.Silu, apply_routing_on_input=True, ) output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index c9aea8bc60..e6cf60b157 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -82,7 +82,7 @@ def compute_with_experts( alpha=None, beta=None, limit=None, - activation_func="silu", + activation_func: ActivationType = ActivationType.Silu, ): def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) @@ -110,7 +110,7 @@ def compute_with_experts( inter = x1_scaled * x2 else: - if activation_func == "swiglu" or activation_func == "silu": + if activation_func == ActivationType.Swiglu or activation_func == ActivationType.Silu: inter = F.silu(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) else: inter = relu2(expert_inputs @ w1_expert.t()) @@ -136,10 +136,6 @@ def _get_test_data( return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales -def _activation_type_from_str(activation_func: str) -> ActivationType: - return ActivationType.Swiglu if activation_func in ["swiglu", "silu"] else ActivationType.Relu2 - - def _print_diff_if( condition: Callable[[torch.Tensor], bool], diff: torch.Tensor, @@ -183,7 +179,7 @@ F16_TEST_DTYPES = [ @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES) -@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @skip_pre_hopper def test_trtllm_fused_moe( batch_size, @@ -201,13 +197,13 @@ def test_trtllm_fused_moe( pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") torch.manual_seed(42) - if activation_func in ["swiglu", "silu"]: + if activation_func in [ActivationType.Swiglu, ActivationType.Silu]: X_GEN_SCALE = 1.0 else: X_GEN_SCALE = 0.5 W_GEN_SCALE = 0.1 - x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data( + x, router_logits, w31_weight, w2_weight, _, _ = _get_test_data( otype, wtype, batch_size, @@ -239,19 +235,17 @@ def test_trtllm_fused_moe( "F16 test only supports bfloat16 or float16" ) - activation_type = _activation_type_from_str(activation_func) - def get_fc1_expert_weights( - activation_func: str, w31_weight: torch.Tensor, w1_weight: torch.Tensor + activation_func: ActivationType, w31_weight: torch.Tensor, w1_weight: torch.Tensor ) -> torch.Tensor: - if activation_func == "relu2": + if activation_func == ActivationType.Relu2: return w1_weight.contiguous() else: return w31_weight # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) _, w1_weight = torch.chunk(w31_weight, 2, dim=1) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True torch.cuda.synchronize() ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused( @@ -260,9 +254,13 @@ def test_trtllm_fused_moe( routing_weights, w3_w1_stacked_weight=get_fc1_expert_weights(activation_func, w31_weight, w1_weight), w2_stacked_weight=w2_weight, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) + # Convert ActivationType.Silu to ActivationType.Swiglu for C++ op compatibility + cpp_activation_type = ( + ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func + ) trtllm_test_output = torch.ops.trtllm.fused_moe( x, selected_experts.to(torch.int), @@ -273,11 +271,11 @@ def test_trtllm_fused_moe( fc2_expert_biases=None, output_dtype=otype, quant_scales=[], - activation_type=activation_type, + activation_type=cpp_activation_type, )[0].view(x.shape) torch.cuda.synchronize() - if mlp_style == "mlp": + if not is_gated_mlp: with torch.inference_mode(): output_triton_moe = torch.ops.auto_deploy.triton_moe_fused( x, @@ -285,6 +283,7 @@ def test_trtllm_fused_moe( routing_weights, w1_weight.contiguous(), w2_weight.contiguous(), + is_gated_mlp=False, )[0].view(x.shape) torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-2, atol=1e-2) @@ -308,7 +307,7 @@ FP8_TEST_DTYPES = [ @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES) -@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @pytest.mark.skipif( not fp8_compatible() or not trtllm_ops_available(), reason="Requires fp8 and trtllm support", @@ -336,7 +335,7 @@ def test_trtllm_fused_moe_fp8( ) torch.manual_seed(42) - if activation_func in ["swiglu", "silu"]: + if activation_func in [ActivationType.Swiglu, ActivationType.Silu]: X_GEN_SCALE = 1.0 else: X_GEN_SCALE = 0.5 @@ -399,7 +398,7 @@ def test_trtllm_fused_moe_fp8( # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True # compute quant_scales gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32) @@ -424,13 +423,13 @@ def test_trtllm_fused_moe_fp8( gemm1_dequant=gemm1_dequant, gemm2_act_quant=gemm2_act_quant, gemm2_dequant=gemm2_dequant, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) torch.cuda.synchronize() - if mlp_style == "mlp": + if not is_gated_mlp: with torch.inference_mode(): output_triton_fp8_moe = torch.ops.auto_deploy.triton_quant_fp8_moe( x, @@ -445,7 +444,7 @@ def test_trtllm_fused_moe_fp8( w1_scales, w2_scales, w3_scales, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) torch.testing.assert_close(output_triton_fp8_moe, ref_output, rtol=1e-1, atol=1e-1) @@ -569,7 +568,7 @@ NVFP4_TEST_DTYPES = [ @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) @pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES) -@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2]) @pytest.mark.skipif( not fp4_compatible() or not trtllm_ops_available(), reason="Requires fp4 and trtllm support", @@ -693,25 +692,23 @@ def test_trtllm_fused_moe_nvfp4( fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs) fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs) - mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp" - if mlp_style == "gated_mlp": + is_gated_mlp = False if activation_func == ActivationType.Relu2 else True + if is_gated_mlp: # For gated MLP, concatenate w1 and w3 as [w3, w1] fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous() fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1) fc1_weight_gs = torch.max(w3_gs, w1_gs) - if activation_func != "silu": + if activation_func != ActivationType.Silu: raise ValueError( f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'." ) - elif mlp_style == "mlp": + else: # For non-gated MLP with ReLU^2 fc1_expert_weights_fp4 = w1_q_fp4 fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long) fc1_weight_gs = w1_gs - if activation_func != "relu2": + if activation_func != ActivationType.Relu2: raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.") - else: - raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long) fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long) @@ -729,7 +726,7 @@ def test_trtllm_fused_moe_nvfp4( fc2_activation_gs, fc1_alpha, fc2_alpha, - mlp_style=mlp_style, + is_gated_mlp=is_gated_mlp, act_fn=activation_func, ) @@ -747,8 +744,7 @@ def test_trtllm_fused_moe_nvfp4( block_size=NVFP4_BLOCK_SIZE, ) - concat_w3_w1 = mlp_style == "gated_mlp" - if concat_w3_w1: + if is_gated_mlp: w1_gs = w3_gs = torch.max(w1_gs, w3_gs) w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype) @@ -782,14 +778,18 @@ def test_trtllm_fused_moe_nvfp4( block_size=NVFP4_BLOCK_SIZE, ) + # Convert ActivationType.Silu to ActivationType.Swiglu for reference op compatibility + resolved_activation_type = ( + ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func + ) ref_output = torch_moe_nvfp4( x_dq, - torch.cat([w3_dq, w1_dq], dim=1) if concat_w3_w1 else w1_dq, + torch.cat([w3_dq, w1_dq], dim=1) if is_gated_mlp else w1_dq, w2_dq, top_k, routing_weights, selected_experts, - _activation_type_from_str(activation_func), + resolved_activation_type, ) return ref_output diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py index c639c355e8..490eb1d742 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/triton_kernels/test_triton_moe.py @@ -4,6 +4,7 @@ from utils.util import skip_pre_hopper import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size +from tensorrt_llm._torch.utils import ActivationType # noqa: F401 def _pack_routed_tokens_reference( @@ -131,6 +132,7 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit): routing_weights, w_up_stacked, w_down_stacked, + is_gated_mlp=False, ) # Reference Torch MoE in mlp mode with relu2 activation @@ -141,8 +143,8 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit): w1_weight=w_up_list, w2_weight=w_down_list, w3_weight=[], - mlp_style="mlp", - act_fn="relu2", + is_gated_mlp=False, + act_fn=ActivationType.Relu2, ) torch.testing.assert_close(out_triton, out_torch, rtol=5e-2, atol=5e-2) @@ -364,8 +366,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): w1_weight_scale, w2_weight_scale, w3_weight_scale_tensor, - mlp_style="mlp", - act_fn="relu2", + is_gated_mlp=False, + act_fn=ActivationType.Relu2, ) # Reference: Torch quantized FP8 MoE (uses lists of tensors and scales) @@ -382,8 +384,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit): w1_weight_scale=w1_weight_scale_list, w2_weight_scale=w2_weight_scale_list, w3_weight_scale=w3_weight_scale_list, - mlp_style="mlp", - act_fn="relu2", + is_gated_mlp=False, + act_fn=ActivationType.Relu2, ) torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2) diff --git a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py index 7defc0dae7..28b590d88e 100644 --- a/tests/unittest/_torch/multi_gpu/test_moe_a2a.py +++ b/tests/unittest/_torch/multi_gpu/test_moe_a2a.py @@ -566,6 +566,7 @@ class TestMoEAlltoAll: (4, [32, 32, 32, 32], 4), (4, [1, 1, 1, 1], 2), (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + (4, [32, 0, 16, 0], 2), ], indirect=["mpi_pool_executor"]) def test_combine(self, mpi_pool_executor, all_num_tokens, top_k): diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index 4dc0564711..99154dd074 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -19,49 +19,23 @@ example_images = [ str(test_data_root / "61.jpg"), ] - -@pytest.fixture(scope="function") -def multimodal_model_config(): - """Get multimodal model configuration similar to integration tests""" - # You can extend this to support multiple models or get from environment - model_configs = { - 'llava-v1.6-mistral-7b-hf': { - 'model_name': - 'llava-v1.6-mistral-7b-hf', - 'hf_model_dir': - 'llava-hf/llava-v1.6-mistral-7b-hf', - 'model_dir': - llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf", - } - } - - return model_configs['llava-v1.6-mistral-7b-hf'] +_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf" +_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct" # TODO: Add multi-image in single chat test -@pytest.mark.parametrize("model_key", [ - "llava-v1.6-mistral-7b-hf", -]) +@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR]) @pytest.mark.parametrize("pd_disagg", [False, True]) -def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): +def test_single_image_chat(model_dir, pd_disagg): """Test processing single image using encoder (pass mm_embeddings) + LLM API. This test verifies that encoder (pass mm_embeddings) + LLM API produces identical results to standard llm generation (pass raw image) by comparing outputs. """ - # Get model configuration - if model_key != "llava-v1.6-mistral-7b-hf": - #TODO: add more model tests progressively here - pytest.skip( - f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now" - ) - - # Extract model information from config - encoder_model_dir = multimodal_model_config['model_dir'] # Test configuration max_tokens = 64 - free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2 + free_gpu_memory_fraction = 0.2 max_batch_size = 1 # Test data - OpenAI chat completion format @@ -76,15 +50,14 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): ) # Process multimodal data using encoder (pass mm_embeddings) - encoder = MultimodalEncoder(model=encoder_model_dir, - max_batch_size=max_batch_size) + encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size) cache_transceiver_cfg = CacheTransceiverConfig( backend="DEFAULT") if pd_disagg else None disable_overlap_scheduler = pd_disagg - llm = LLM(model=encoder_model_dir, + llm = LLM(model=model_dir, backend='pytorch', kv_cache_config=kv_cache_config, trust_remote_code=True, @@ -93,7 +66,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): llm_decode = None if pd_disagg: - llm_decode = LLM(model=encoder_model_dir, + llm_decode = LLM(model=model_dir, backend='pytorch', kv_cache_config=kv_cache_config, trust_remote_code=True, @@ -141,6 +114,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None" ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only" + outputs = llm.generate(inputs, sampling_params=sampling_params, disaggregated_params=ep_disaggregated_params) @@ -151,10 +125,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): pd_disaggregated_params = outputs[0].disaggregated_params pd_disaggregated_params.request_type = "generation_only" sampling_params = SamplingParams(max_tokens=max_tokens) - inputs[0][ - 'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it - inputs[0]['prompt_token_ids'] = outputs[ - 0].prompt_token_ids # use prompt token ids from encoder output + # remove multimodal data from input as decoder worker doesn't need it + inputs[0]['multi_modal_data'] = None + # use prompt token ids from encoder output + inputs[0]['prompt_token_ids'] = outputs[0].prompt_token_ids outputs = llm_decode.generate( inputs, @@ -199,24 +173,23 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config): f"Log probabilities don't match for output {i}, generation {j}" -@pytest.mark.parametrize("model_key", [ - "llava-v1.6-mistral-7b-hf", -]) -def test_multi_request_batch_chat(model_key, multimodal_model_config): +@pytest.mark.parametrize( + "model_dir, encoder_max_batch_size", + [ + (_LLAVA_DIR, 3), + # Qwen2.5 VL's vision encoder seems to output different embeddings based on this value. + # The test only passes with this set to 1. + (_QWEN_2_5_VL_DIR, 1), + ], +) +def test_multi_request_batch_chat(model_dir, encoder_max_batch_size): """Test batching multiple multimodal requests and verify encoder path matches raw path. This mirrors test_single_image_chat but with a batch of size 3. """ - if model_key != "llava-v1.6-mistral-7b-hf": - pytest.skip( - f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now" - ) - - encoder_model_dir = multimodal_model_config['model_dir'] max_tokens = 64 free_gpu_memory_fraction = 0.6 - max_batch_size = 3 prompts = [ "Describe the natural environment in the image.", @@ -232,10 +205,10 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config): free_gpu_memory_fraction=free_gpu_memory_fraction, ) - encoder = MultimodalEncoder(model=encoder_model_dir, - max_batch_size=max_batch_size) + encoder = MultimodalEncoder(model=model_dir, + max_batch_size=encoder_max_batch_size) llm = LLM( - model=encoder_model_dir, + model=model_dir, backend='pytorch', kv_cache_config=kv_cache_config, max_batch_size=1, # fix batch size to reduce non-determinism in tests @@ -305,8 +278,7 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config): "Describe the weather in the image.", ], 2), ]) -def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates, - multimodal_model_config): +def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates): """Test mm_keys in KV cache events with cache reuse scenarios. This test verifies: @@ -316,7 +288,7 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates, - Same media + same prompts: full reuse (0 duplicate offsets) - Same media + different prompts: partial reuse (prefix blocks reused) """ - encoder_model_dir = multimodal_model_config['model_dir'] + encoder_model_dir = _LLAVA_DIR max_tokens = 16 free_gpu_memory_fraction = 0.6 diff --git a/tests/unittest/llmapi/apps/_test_openai_responses.py b/tests/unittest/llmapi/apps/_test_openai_responses.py index 18271f6b76..e6902127cb 100644 --- a/tests/unittest/llmapi/apps/_test_openai_responses.py +++ b/tests/unittest/llmapi/apps/_test_openai_responses.py @@ -21,11 +21,18 @@ def model(request): return request.param +@pytest.fixture(scope="module", + params=[0, 2], + ids=["disable_processpool", "enable_processpool"]) +def num_postprocess_workers(request): + return request.param + + @pytest.fixture(scope="module") -def server(model: str): +def server(model: str, num_postprocess_workers: int): model_path = get_model_path(model) - args = [] + args = ["--num_postprocess_workers", f"{num_postprocess_workers}"] if model.startswith("Qwen3"): args.extend(["--reasoning_parser", "qwen3"]) elif model.startswith("DeepSeek-R1"): diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py index d287e5e35e..dc95ecf292 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_top_logprobs.py @@ -1,6 +1,3 @@ -import os -import tempfile - import openai import pytest import yaml @@ -22,34 +19,28 @@ def backend(request): @pytest.fixture(scope="module") -def temp_extra_llm_api_options_file(): - temp_dir = tempfile.gettempdir() - temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml") - try: - extra_llm_api_options_dict = { - "enable_chunked_prefill": False, - "gather_generation_logits": True, - "kv_cache_config": { - "enable_block_reuse": False, - } +def temp_extra_llm_api_options_file(tmp_path_factory): + extra_llm_api_options_dict = { + "enable_chunked_prefill": False, + "gather_generation_logits": True, + "kv_cache_config": { + "enable_block_reuse": False, } + } - with open(temp_file_path, 'w') as f: - yaml.dump(extra_llm_api_options_dict, f) - - yield temp_file_path - finally: - if os.path.exists(temp_file_path): - os.remove(temp_file_path) + temp_file_path = tmp_path_factory.mktemp( + "config") / "extra_llm_api_options.yaml" + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + return temp_file_path @pytest.fixture(scope="module") def server(model_name: str, backend: str, temp_extra_llm_api_options_file: str): model_path = get_model_path(model_name) - args = [ - "--backend", f"{backend}", "--extra_llm_api_options", - temp_extra_llm_api_options_file - ] + args = ["--backend", f"{backend}"] + if backend == "trt": + args += ["--extra_llm_api_options", temp_extra_llm_api_options_file] with RemoteOpenAIServer(model_path, args) as remote_server: yield remote_server @@ -61,11 +52,7 @@ def async_client(server: RemoteOpenAIServer): @pytest.mark.asyncio(loop_scope="module") async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI, - model_name: str, backend: str): - # Skip if backend is PyTorch as it does not support topk logprobs when k > 1 - if backend == "pytorch": - pytest.skip("Topk logprobs is not supported") - + model_name: str): messages = [{ "role": "system", "content": "You are a helpful assistant." @@ -94,42 +81,3 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI, assert logprob_content.bytes is not None assert logprob_content.top_logprobs is not None assert len(logprob_content.top_logprobs) == 5 - - -@pytest.mark.asyncio(loop_scope="module") -async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI, - model_name: str, backend: str): - # Skip if backend is TRT because it is tested in test_chat_completion_top5_logprobs - if backend == "trt": - pytest.skip( - "TRT top logprobs is already tested in test_chat_completion_top5_logprobs" - ) - - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" - }] - # Test top_logprobs=1 - chat_completion = await async_client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0, - logprobs=True, - top_logprobs=1, - extra_body={ - "ignore_eos": True, - }) - logprobs = chat_completion.choices[0].logprobs - assert logprobs is not None and logprobs.content is not None - assert len(logprobs.content) == 10 - for logprob_content in logprobs.content: - assert logprob_content.token is not None - assert logprob_content.logprob is not None - assert logprob_content.bytes is not None - assert logprob_content.top_logprobs is not None - # Check that the top_logprobs contains only one entry - assert len(logprob_content.top_logprobs) == 1 diff --git a/tests/unittest/llmapi/test_executor.py b/tests/unittest/llmapi/test_executor.py index 2e0ef5f65f..338f6903b7 100644 --- a/tests/unittest/llmapi/test_executor.py +++ b/tests/unittest/llmapi/test_executor.py @@ -213,21 +213,33 @@ def _test_sync_generation_tp_inner(llama_7b_tp2_path: Path): result.outputs[0].token_ids) == ", neural network," try: - stats = await executor.aget_stats() - stats = json.loads(stats) - assert stats["iter"] == 0 - assert stats["cpuMemUsage"] > 0 - assert stats["gpuMemUsage"] > 0 - assert stats["inflightBatchingStats"]["numCtxTokens"] == 3 - assert stats["inflightBatchingStats"]["numGenRequests"] == 0 - assert stats["kvCacheStats"]["usedNumBlocks"] == 1 + stats_result = executor.aget_stats(timeout=2) + # aget_stats now returns IterationResult, iterate to get stats + async for stats_str in stats_result: + stats = json.loads(stats_str) if isinstance(stats_str, + str) else stats_str + assert stats["iter"] >= 0 + assert stats["cpuMemUsage"] > 0 + assert stats["gpuMemUsage"] > 0 + break # Just check first result except AsyncQueue.EventLoopShutdownError: pass asyncio.run(async_stats_task()) - stats = executor.get_stats() - assert json.loads(stats)["iter"] == 1 + # Poll for stats since RPC calls return immediately + import time + stats_list = [] + for _ in range(10): + stats_list = executor.get_stats(timeout=0.5) + if stats_list: + break + time.sleep(0.1) + + assert len(stats_list) > 0 + stats = json.loads(stats_list[0]) if isinstance(stats_list[0], + str) else stats_list[0] + assert stats["iter"] == 1 executor.shutdown() diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index fb6e24b81a..f8ffe8fc7b 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -4,6 +4,7 @@ import gc import json import os import sys +import time # Required for test_generate_with_seed to pass. # See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891 @@ -2193,6 +2194,7 @@ def llm_get_stats_test_harness(tp_size: int = 1, sampling_params=sampling_params): print(output) + time.sleep(2) results = llm.get_stats(2) validate_stats(results=results, @@ -2203,7 +2205,7 @@ def llm_get_stats_test_harness(tp_size: int = 1, enable_chunked_prefill=enable_chunked_prefill, enable_iter_req_stats=enable_iter_req_stats) - assert not llm.get_stats(2) + assert not llm.get_stats(0.5) # test that IterationResult()._done is properly set _ = llm.generate(prompts, sampling_params=sampling_params) @@ -2340,8 +2342,9 @@ def llm_get_stats_async_test_harness(tp_size: int = 1, async def task1(repetition_index: int): results = [] await asyncio.sleep( - 3) # ensure there's stats to collect for the assertion - async for stats in llm.get_stats_async(timeout=2): + 4) # ensure there's stats to collect for the assertion + async for stats in llm.get_stats_async( + 10): # it will return immediately results.append(stats) assert results diff --git a/tests/unittest/llmapi/test_llm_multi_gpu.py b/tests/unittest/llmapi/test_llm_multi_gpu.py index dd175a4809..971f25f11e 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu.py @@ -487,11 +487,12 @@ def test_llm_get_kv_cache_events_tp2(): # created + stored events assert events and len(events) >= 2 for event in events: + print(f"event: {event}") if event: - if event[0]["event_id"] == 0: - assert event[0]["data"]["type"] == "created" - elif event[0]["event_id"] == 1: - assert event[0]["data"]["type"] == "stored" + if event["event_id"] == 0: + assert event["data"]["type"] == "created" + elif event["event_id"] == 1: + assert event["data"]["type"] == "stored" @pytest.fixture(scope="module") diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 04d653b842..d90d51cd49 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -1,3 +1,4 @@ +import json import random import time from contextlib import contextmanager, nullcontext @@ -976,6 +977,62 @@ async def test_llm_rpc_streaming(): print(f"get result: {outputs}") +@skip_ray +def test_llm_rpc_get_stats(): + """Test that get_stats works with RPC orchestrator.""" + + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + enable_iter_perf_stats=True, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + # Generate some output to produce stats + for output in llm.generate( + prompts, sampling_params=SamplingParams(max_tokens=5)): + print(output) + + stats = llm.get_stats(timeout=5) + + assert len(stats) > 0, "Should have at least one stats entry" + # Stats should be JSON strings that can be parsed + parsed = json.loads(stats[0]) if isinstance(stats[0], str) else stats[0] + assert "iter" in parsed, "Stats should contain 'iter' field" + assert "cpuMemUsage" in parsed, "Stats should contain 'cpuMemUsage' field" + + +@skip_ray +@pytest.mark.asyncio +async def test_llm_rpc_get_stats_async(): + """Test that get_stats_async works with RPC orchestrator.""" + import json + + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + enable_iter_perf_stats=True, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + # Generate some output to produce stats + async for output in llm.generate_async( + prompts[0], sampling_params=SamplingParams(max_tokens=5)): + print(output) + + # Get stats via async API + stats_result = llm.get_stats_async(timeout=2) + + # Should be able to iterate over results + stats_count = 0 + async for stat in stats_result: + parsed = json.loads(stat) if isinstance(stat, str) else stat + assert "iter" in parsed, "Stats should contain 'iter' field" + stats_count += 1 + if stats_count >= 1: + break # Just verify we can get at least one + + assert stats_count > 0, "Should have received at least one stat" + + @pytest.mark.threadleak(enabled=False) @pytest.mark.part0 @skip_ray