Merge branch 'main' into fix_spec_gate

This commit is contained in:
Zheyu Fu 2025-12-22 12:22:51 -08:00 committed by GitHub
commit 9b33ea751b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1436 additions and 1051 deletions

View File

@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int thread_idx = ThreadingPolicy::offset();
int local_token_idx = ThreadingPolicy::token_idx();
if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
}
// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
if (already_copied & (1ULL << target_rank))
// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
if (already_copied & (1ULL << target_rank))
{
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
}
continue;
}
// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
continue;
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();
// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
#pragma unroll
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
}
ThreadingPolicy::sync();
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
ptrs, topk_target_ranks, topk_send_indices);
}
ThreadingPolicy::sync();
bool is_first_warp = threadIdx.x / warpSize == 0;
if (is_first_warp)
{
@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
bool is_last_token = false;
if (lane_id == 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
if (local_num_tokens != 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
}
else
{
is_last_token = true;
}
}
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
// Prepare kernel pointers struct
@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
if (params.one_block_per_token)
{
int grid_size = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
else
{
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@ -897,9 +924,19 @@ __global__ void moeA2ACombineKernel(
int local_token_idx = ThreadingPolicy::token_idx();
int const size_per_token = elements_per_token * sizeof(T);
if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
}
#if !DISABLE_SYNC_FOR_PROFILING
@ -951,6 +988,9 @@ __global__ void moeA2ACombineKernel(
__syncthreads();
#endif
if (local_num_tokens == 0)
return;
// Get output location for this token (using src_data_ptrs[0] as output)
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;
@ -1003,7 +1043,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.elements_per_token > 0);
// Configure kernel launch
@ -1011,6 +1051,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
int grid_size_block = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size_warp == 0)
{
grid_size_warp = 1;
}
if (grid_size_block == 0)
{
grid_size_block = 1;
}
// Prepare kernel pointers struct for combine
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize

View File

@ -186,7 +186,6 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
MoeA2ADataOffsets const& offsets = *reinterpret_cast<MoeA2ADataOffsets const*>(metainfo.data_ptr<int64_t>());
int64_t localNumTokens = tokenSelectedExperts.size(0);
TORCH_CHECK(localNumTokens > 0, "localNumTokens must be positive");
TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive");
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]");

View File

@ -1,29 +1,29 @@
from typing import Callable, List, Optional
from typing import Callable, List
import torch
import torch.nn.functional as F
from tensorrt_llm._torch.utils import ActivationType
def _resolve_activation(name: Optional[str]) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Returns an elementwise activation callable matching the given name.
Supported: "silu", "relu2".
Defaults to SiLU when name is None or empty.
"""
if not name:
name = "silu"
key = name.lower()
if key == "silu":
return F.silu
elif key == "relu2":
def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Returns an elementwise activation callable matching the given activation function.
Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2
"""
assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], (
f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'."
)
torch_fn = None
if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu:
torch_fn = F.silu
elif act_fn == ActivationType.Relu2:
def relu2(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))
return relu2
else:
raise ValueError(f"Unsupported activation '{name}'. Use one of: silu, relu2.")
torch_fn = relu2
return torch_fn
def _template_moe(
@ -94,8 +94,8 @@ def torch_moe(
w1_weight: List[torch.Tensor],
w2_weight: List[torch.Tensor],
w3_weight: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
apply_routing_on_input: bool = False,
) -> torch.Tensor:
"""
@ -117,8 +117,8 @@ def torch_moe(
- Llama4 MoE: sigmoid activated weights
w1_weight:
For per-expert lists:
mlp_style=="gated_mlp": List of W1 with shape (I, H) "gate" projection.
mlp_style=="mlp": List of W_up with shape (I, H) up projection.
is_gated_mlp==True: List of W1 with shape (I, H) "gate" projection.
is_gated_mlp==False: List of W_up with shape (I, H) up projection.
For stacked tensors (Llama4):
Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
w2_weight:
@ -129,17 +129,17 @@ def torch_moe(
w3_weight:
For per-expert lists with gated_mlp:
List of W3 with shape (I, H) "up" (second) projection in gated MLP.
For mlp style or stacked tensors:
For is_gated_mlp==False or stacked tensors:
pass an empty list []; ignored.
mlp_style:
is_gated_mlp:
Selects the per-expert MLP computation:
"gated_mlp" (default, Mixtral/DeepSeek/Llama4-style):
is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style):
y = W2( act(W1 x) * (W3 x) )
"mlp" (NemotronH-style 2-layer MLP):
is_gated_mlp==False (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
apply_routing_on_input:
If True (Llama4 pattern): multiply routing weights with INPUT before MLP
Result: act(input * routing_weight) - routing affects activation
@ -148,55 +148,63 @@ def torch_moe(
Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()
torch_act_fn = _resolve_torch_fn(act_fn)
# Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3
# Todo: either change torch_moe to use a single condition, or refactor this code.
# it should be :
# is_gated_mlp:
# stacked:
# ...
# not stacked:
# .
# else:
# assert (not stacked)
# ...
# .
if is_stacked:
# Llama4 stacked tensor format - only supports gated_mlp
if style != "gated_mlp":
raise ValueError("Stacked tensor format only supports 'gated_mlp' style")
if not is_gated_mlp:
raise ValueError("Stacked tensor format only supports gated MLP style")
w3_w1_stacked = w1_weight[0] # (E, 2*I, H)
intermediate_size = w3_w1_stacked.shape[1] // 2
w2_stacked = w2_weight[0] # (E, H, I)
def make_mlp(i: int):
gate_up = w3_w1_stacked[i] # (2*I, H)
intermediate_size = gate_up.shape[0] // 2
def make_mlp(idx: int):
gate_up = w3_w1_stacked[idx] # (2*I, H)
W3 = gate_up[:intermediate_size, :] # (I, H)
W1 = gate_up[intermediate_size:, :] # (I, H)
W2 = w2_stacked[i] # (H, I)
W2 = w2_stacked[idx] # (H, I)
weight_dtype = W1.dtype
return lambda inp: F.linear(
act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3),
torch_act_fn(F.linear(inp.to(weight_dtype), W1))
* F.linear(inp.to(weight_dtype), W3),
W2,
)
mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])]
mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])]
elif style == "gated_mlp":
elif is_gated_mlp:
# Standard per-expert list format with gated MLP
def make_mlp(i: int):
W1 = w1_weight[i] # (I, H)
W2 = w2_weight[i] # (H, I)
W3 = w3_weight[i] # (I, H)
return lambda inp: F.linear(act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
mlps = [make_mlp(i) for i in range(len(w1_weight))]
elif style == "mlp":
# Standard per-expert list format with simple MLP
def make_mlp(i: int):
W_up = w1_weight[i] # (I, H)
W_down = w2_weight[i] # (H, I)
return lambda inp: F.linear(act_fn(F.linear(inp, W_up)), W_down)
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
mlps = [make_mlp(i) for i in range(len(w1_weight))]
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
# Standard per-expert list format with simple MLP
def make_mlp(i: int):
W_up = w1_weight[i] # (I, H)
W_down = w2_weight[i] # (H, I)
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W_up)), W_down)
mlps = [make_mlp(i) for i in range(len(w1_weight))]
return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input)
@ -209,8 +217,8 @@ def torch_moe_fake(
w1_weight: List[torch.Tensor],
w2_weight: List[torch.Tensor],
w3_weight: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
apply_routing_on_input: bool = False,
) -> torch.Tensor:
return torch.empty_like(x)
@ -296,23 +304,20 @@ def torch_quant_fp8_moe(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
"""
FP8 MoE op using quantized linear operations.
Computes a Mixture-of-Experts layer similar to the reference auto_deploy::torch_moe op, but uses the
quantized FP8 linear op for expert computations.
FP8 MoE op using quantized linear operations. Computes a Mixture-of-Experts layer similar to the reference
auto_deploy::torch_moe op, but uses the quantized FP8 linear op for expert computations.
Args:
x: Input tensor of shape (B, H) or (B, S, H).
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
routing_weights: Tensor of normalized routing weights.
w1_weight:
List of per-expert weight tensors:
mlp_style=="gated_mlp": W1 with shape (I, H) "gate" projection.
mlp_style=="mlp": W_up with shape (I, H) up projection.
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K)
containing expert indices.routing_weights: Tensor of normalized routing weights.
w1_weight: List of per-expert weight tensors:
is_gated_mlp==True: W1 with shape (I, H) "gate" projection.
is_gated_mlp==False: W_up with shape (I, H) up projection.
w2_weight:
List of per-expert weight tensors:
gated_mlp: W2 with shape (H, I) down projection.
@ -323,21 +328,20 @@ def torch_quant_fp8_moe(
mlp: pass an empty list []; ignored.
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
mlp_style:
is_gated_mlp:
Selects the per-expert MLP computation:
"gated_mlp" (default, Mixtral/DeepSeek-style):
is_gated_mlp==True (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
"mlp" (NemotronH-style 2-layer MLP):
is_gated_mlp==False (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
"""
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()
torch_act_fn = _resolve_torch_fn(act_fn)
if style == "gated_mlp":
if is_gated_mlp:
def make_fp8_mlp(i):
def mlp(inp):
@ -355,7 +359,7 @@ def torch_quant_fp8_moe(
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = act_fn(gate_out) * up_out
prod = torch_act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
@ -368,7 +372,7 @@ def torch_quant_fp8_moe(
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
elif style == "mlp":
else:
def make_fp8_mlp(i):
def mlp(inp):
@ -380,7 +384,7 @@ def torch_quant_fp8_moe(
weight_scale=w1_weight_scale[i],
)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
act_fn(up_out),
torch_act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
@ -391,9 +395,6 @@ def torch_quant_fp8_moe(
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
return _template_moe(x, selected_experts, routing_weights, mlps)
@ -411,8 +412,8 @@ def torch_quant_fp8_moe_fake(
w1_weight_scale: List[torch.Tensor],
w2_weight_scale: List[torch.Tensor],
w3_weight_scale: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)
@ -434,8 +435,8 @@ def torch_quant_nvfp4_moe(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
act_fn: str = "silu", # silu or relu2
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
"""
FP4 MoE op using quantized linear operations.
@ -449,8 +450,8 @@ def torch_quant_nvfp4_moe(
routing_weights: Tensor of normalized routing weights.
w1_weight:
List of per-expert weight tensors:
mlp_style=="gated_mlp": W1 with shape (I, H) "gate" projection.
mlp_style=="mlp": W_up with shape (I, H) up projection.
is_gated_mlp==True: W1 with shape (I, H) "gate" projection.
is_gated_mlp==False: W_up with shape (I, H) up projection.
w2_weight:
List of per-expert weight tensors:
gated_mlp: W2 with shape (H, I) down projection.
@ -462,21 +463,20 @@ def torch_quant_nvfp4_moe(
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
mlp_style:
is_gated_mlp:
Selects the per-expert MLP computation:
"gated_mlp" (default, Mixtral/DeepSeek-style):
is_gated_mlp==True (default, Mixtral/DeepSeek-style):
y = W2( act(W1 x) * (W3 x) )
"mlp" (NemotronH-style 2-layer MLP):
is_gated_mlp==False (NemotronH-style 2-layer MLP):
y = W_down( act(W_up x) )
act_fn:
Elementwise activation applied inside the expert MLP.
Supported: "silu" (default), "relu2" (ReLU then square).
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
"""
act_fn = _resolve_activation(act_fn)
style = mlp_style.lower()
torch_act_fn = _resolve_torch_fn(act_fn)
if style == "gated_mlp":
if is_gated_mlp:
def make_fp4_mlp(i):
def mlp(inp):
@ -498,7 +498,7 @@ def torch_quant_nvfp4_moe(
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = act_fn(gate_out) * up_out
prod = torch_act_fn(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
@ -512,7 +512,7 @@ def torch_quant_nvfp4_moe(
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
elif style == "mlp":
else:
def make_fp4_mlp(i):
def mlp(inp):
@ -527,7 +527,7 @@ def torch_quant_nvfp4_moe(
alpha=w1_alpha[i],
)
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
act_fn(up_out),
torch_act_fn(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
@ -539,9 +539,6 @@ def torch_quant_nvfp4_moe(
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
return _template_moe(x, selected_experts, routing_weights, mlps)
@ -562,8 +559,8 @@ def torch_quant_nvfp4_moe_fake(
w1_alpha: List[torch.Tensor],
w2_alpha: List[torch.Tensor],
w3_alpha: List[torch.Tensor],
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -14,6 +14,8 @@ import torch.nn.functional as F
import triton
import triton.language as tl
from tensorrt_llm._torch.utils import ActivationType # noqa: F401
from ...utils.logger import ad_logger
@ -601,15 +603,13 @@ def triton_fused_moe(
routing_weights: torch.Tensor,
w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "mlp",
act_fn: str = "relu2",
is_gated_mlp: bool = False,
act_fn: int = int(ActivationType.Relu2),
) -> torch.Tensor:
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()
assert mlp_style == "mlp", "Triton backend only supports mlp style."
assert act_fn == "relu2", "Triton backend only supports relu2 activation."
assert not is_gated_mlp, "Triton backend only supports non gated MLP style."
assert act_fn == ActivationType.Relu2, "Triton backend only supports relu2 activation."
x_shape = x.shape
x2d = x.view(-1, x_shape[-1])
@ -661,12 +661,12 @@ def triton_quant_fp8_moe(
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
w3_weight_scale: torch.Tensor, # unused
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = False,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
"""Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
if mlp_style != "mlp":
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")
if is_gated_mlp:
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp only")
x_shape = x.shape
x2d = x.view(-1, x_shape[-1])
@ -760,7 +760,7 @@ def triton_quant_fp8_moe(
w1_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = False,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -26,8 +26,8 @@ def trtllm_moe_fused(
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
x_shape = x.shape
x = x.view(-1, x_shape[-1])
@ -37,24 +37,24 @@ def trtllm_moe_fused(
quant_scales = []
# Determine activation type
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()
activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if is_gated_mlp:
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
if act_fn == "silu":
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
)
else:
# For non-gated MLP with ReLU^2
if act_fn == "relu2":
if act_fn == ActivationType.Relu2:
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
)
return torch.ops.trtllm.fused_moe(
x,
@ -77,8 +77,8 @@ def trtllm_moe_fused_fake(
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)
@ -93,21 +93,12 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None:
supported_combinations = {
"gated_mlp": ["silu"],
"mlp": ["relu2"],
}
supported_act_fns = [
act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list
]
assert mlp_style in supported_combinations.keys(), (
f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}."
)
assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}."
assert act_fn in supported_combinations[mlp_style], (
f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. "
f"Supported combinations: {supported_combinations}"
def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None:
assert (is_gated_mlp and act_fn == ActivationType.Silu) or (
not is_gated_mlp and act_fn == ActivationType.Relu2
), (
f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. "
f"Supported combinations: gated mlp with silu or mlp with relu2."
)
@ -128,8 +119,8 @@ def trtllm_quant_fp8_moe_fused(
gemm1_dequant: torch.Tensor, # [E]
gemm2_act_quant: torch.Tensor, # [E]
gemm2_dequant: torch.Tensor, # [E]
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
"""
TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
@ -149,8 +140,8 @@ def trtllm_quant_fp8_moe_fused(
gemm1_dequant: Precomputed gemm1 dequant scale [E]
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
gemm2_dequant: Precomputed gemm2 dequant scale [E]
mlp_style: "gated_mlp" or "mlp"
act_fn: "silu" for gated_mlp, "relu2" for mlp
is_gated_mlp: True for gated_mlp, False for mlp
act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
Non-Gated MLP:
activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
@ -159,7 +150,7 @@ def trtllm_quant_fp8_moe_fused(
activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
"""
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
# Store original shape and flatten to 2D
x_shape = x.shape
@ -190,28 +181,27 @@ def trtllm_quant_fp8_moe_fused(
# Todo: refactor this repeating code block
# Determine activation type
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()
activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if is_gated_mlp:
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
# For gated MLP, concatenate w1 and w3 as [w3, w1]
w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H]
fc1_expert_weights = w3_w1_stacked
if act_fn == "silu":
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
)
else:
# For non-gated MLP with ReLU^2
fc1_expert_weights = w1_weight.contiguous()
if act_fn == "relu2":
if act_fn == ActivationType.Relu2:
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
)
# Note! Outputting Float8_e4m3fn directly is not currently supported
output = torch.ops.trtllm.fused_moe(
@ -248,10 +238,10 @@ def trtllm_quant_fp8_moe_fused_fake(
gemm1_dequant: torch.Tensor,
gemm2_act_quant: torch.Tensor,
gemm2_dequant: torch.Tensor,
mlp_style: str,
act_fn: str,
is_gated_mlp: bool,
act_fn: int,
) -> torch.Tensor:
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
return torch.empty_like(x)
@ -268,8 +258,8 @@ def trtllm_quant_nvfp4_moe_fused(
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
"""TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP.
@ -285,22 +275,22 @@ def trtllm_quant_nvfp4_moe_fused(
"""
NVFP4_BLOCK_SIZE = 16
mlp_style = mlp_style.lower()
act_fn = act_fn.lower()
activation_type = ActivationType.Swiglu
if mlp_style == "gated_mlp":
if act_fn == "silu":
if is_gated_mlp:
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
activation_type = ActivationType.Swiglu
else:
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
elif mlp_style == "mlp":
if act_fn == "relu2":
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
)
else:
if act_fn == ActivationType.Relu2:
activation_type = ActivationType.Relu2
else:
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
raise ValueError(
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
)
# quant_scales is described by this code:
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
@ -353,7 +343,7 @@ def trtllm_quant_nvfp4_moe_fused_fake(
fc2_act_global_scale: torch.Tensor,
fc1_alpha: torch.Tensor,
fc2_alpha: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -13,6 +13,7 @@ from transformers.modeling_outputs import CausalLMOutput, MoeModelOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import BatchEncoding
from tensorrt_llm._torch.utils import ActivationType
from tensorrt_llm.inputs.utils import HF_CHAT_TEMPLATE_EXCEPTIONS
from ..nemotron_flash import NemotronFlashForCausalLMFactory
@ -182,6 +183,8 @@ class DeltaNet(nn.Module):
self.qk_activation = qk_activation
self.qk_norm = qk_norm
# can't use ActivationType enum here,
# because there is no Elu defined in cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
assert self.qk_activation in ["silu", "relu", "elu", "identity"]
assert self.qk_norm in ["l2", "sum"]
@ -331,7 +334,7 @@ class NemotronFlashMamba2(nn.Module):
self.num_heads = self.d_inner // self.headdim
self.rmsnorm = rmsnorm
self.dt_limit = dt_limit
self.activation = "silu"
self.activation = ActivationType.Silu
self.chunk_size = chunk_size
self.layer_idx = layer_idx

View File

@ -33,6 +33,7 @@ from transformers.utils import ModelOutput
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
from tensorrt_llm._torch.utils import ActivationType
class MambaRMSNormGated(torch.nn.Module):
@ -308,8 +309,8 @@ class NemotronHMOE(nn.Module):
w1_weight=[e.up_proj.weight for e in self.experts],
w2_weight=[e.down_proj.weight for e in self.experts],
w3_weight=[],
act_fn="relu2",
mlp_style="mlp",
act_fn=ActivationType.Relu2,
is_gated_mlp=False,
)
if has_latent_proj:

View File

@ -5,6 +5,8 @@ import torch
from pydantic import Field
from torch.fx import GraphModule, Node
from tensorrt_llm._torch.utils import ActivationType
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.cuda_mem_tracker import cuda_memory_tracker
@ -70,20 +72,20 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
except (AttributeError, KeyError):
pass
(is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn")
if is_stacked_moe:
# Stacked MoE (Llama4 pattern): only supports gated MLP
(act_fn_val,) = extract_op_args(node, "act_fn")
_process_llama4_stacked_moe_node(
gm, graph, node, replacement_op, act_fn_val, fused_key_counter
gm, graph, node, replacement_op, act_fn, fused_key_counter
)
else:
# Standard MoE with per-expert weight lists
(mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn")
assert backend != "triton" or mlp_style_val == "mlp", (
assert backend != "triton" or not is_gated_mlp, (
"Triton backend only supports mlp style."
)
_process_regular_moe_node(
gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter
gm, graph, node, replacement_op, is_gated_mlp, act_fn, fused_key_counter
)
fused_key_counter += 1
@ -102,8 +104,8 @@ def _process_regular_moe_node(
graph: torch.fx.Graph,
node: Node,
replacement_op,
mlp_style_val: str,
act_fn_val: str,
is_gated_mlp: bool,
act_fn: ActivationType,
fused_key_counter: int,
) -> None:
"""Process a single torch_moe node with per-expert weight lists.
@ -122,7 +124,7 @@ def _process_regular_moe_node(
)
# Stack weights based on MLP style
if mlp_style_val == "gated_mlp":
if is_gated_mlp:
# For gated MLP, concatenate w3 and w1 then stack across experts
fused_w_up_experts = torch.stack(
[
@ -135,12 +137,10 @@ def _process_regular_moe_node(
dim=0,
)
new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
elif mlp_style_val == "mlp":
else:
# For regular MLP, just stack w1
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
else:
raise ValueError(f"Unknown mlp_style: {mlp_style_val}")
# Stack w2/down weights
fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
@ -162,8 +162,8 @@ def _process_regular_moe_node(
replacement_op,
args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg),
kwargs={
"mlp_style": mlp_style_val,
"act_fn": act_fn_val,
"is_gated_mlp": is_gated_mlp,
"act_fn": act_fn,
},
)
@ -176,7 +176,7 @@ def _process_llama4_stacked_moe_node(
graph: torch.fx.Graph,
node: Node,
replacement_op,
act_fn_val: str,
act_fn: ActivationType,
fused_key_counter: int,
) -> None:
"""Process a single Llama4 MoE node with pre-stacked weight tensors.
@ -301,7 +301,8 @@ def _process_llama4_stacked_moe_node(
replacement_op,
args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg),
kwargs={
"act_fn": act_fn_val,
"act_fn": act_fn,
"is_gated_mlp": True,
},
)
@ -1240,7 +1241,7 @@ class MatchBmmMoePattern(BaseTransform):
w3_list_node,
),
kwargs={
"mlp_style": "gated_mlp",
"is_gated_mlp": True,
"apply_routing_on_input": apply_routing_on_input,
},
)

View File

@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Tuple
import torch
from torch.fx import GraphModule
from tensorrt_llm._torch.utils import ActivationType
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
@ -123,8 +125,8 @@ def trtllm_moe_fused_aux(
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
device = torch.cuda.current_device()
with torch.cuda.stream(
@ -137,7 +139,7 @@ def trtllm_moe_fused_aux(
routing_weights,
w3_w1_stacked_weight,
w2_stacked_weight,
mlp_style,
is_gated_mlp,
act_fn,
)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
@ -152,8 +154,8 @@ def trtllm_moe_fused_aux_fake(
routing_weights: torch.Tensor,
w3_w1_stacked_weight: torch.Tensor,
w2_stacked_weight: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)
@ -213,8 +215,8 @@ def trtllm_quant_fp8_moe_fused_aux(
gemm1_dequant: torch.Tensor, # [E]
gemm2_act_quant: torch.Tensor, # [E]
gemm2_dequant: torch.Tensor, # [E]
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
device = torch.cuda.current_device()
with torch.cuda.stream(
@ -237,7 +239,7 @@ def trtllm_quant_fp8_moe_fused_aux(
gemm1_dequant,
gemm2_act_quant,
gemm2_dequant,
mlp_style,
is_gated_mlp,
act_fn,
)
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
@ -262,8 +264,8 @@ def trtllm_quant_fp8_moe_fused_aux_fake(
gemm1_dequant: torch.Tensor,
gemm2_act_quant: torch.Tensor,
gemm2_dequant: torch.Tensor,
mlp_style: str = "gated_mlp",
act_fn: str = "silu",
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -5,6 +5,8 @@ import torch
import torch.nn as nn
from torch.fx import GraphModule, Node
from tensorrt_llm._torch.utils import ActivationType
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import is_op
@ -87,15 +89,15 @@ def _quantize_moe_node(
s1, s2, s3 = collect_scales(idx)
args.extend([s1, s2, s3])
# Extract mlp_style and act_fn from the original node
# Extract is_gated_mlp and act_fn from the original node
# These can be in args[6:] or in kwargs
mlp_style = "gated_mlp" # default
act_fn = "silu" # default
is_gated_mlp = True # default
act_fn = ActivationType.Silu # default
if len(node.args) > 6:
mlp_style = node.args[6]
elif "mlp_style" in node.kwargs:
mlp_style = node.kwargs["mlp_style"]
is_gated_mlp = node.args[6]
elif "is_gated_mlp" in node.kwargs:
is_gated_mlp = node.kwargs["is_gated_mlp"]
if len(node.args) > 7:
act_fn = node.args[7]
@ -104,7 +106,7 @@ def _quantize_moe_node(
# Prepare kwargs for the quantized op
kwargs = {
"mlp_style": mlp_style,
"is_gated_mlp": is_gated_mlp,
"act_fn": act_fn,
}

View File

@ -527,6 +527,8 @@ class LlavaNextModel(PreTrainedModel):
return
if not DISAGG:
self.mm_encoder = LlavaNextVisionModel(model_config)
else:
self.mm_encoder = None
llm_model_config = copy.deepcopy(model_config)
llm_model_config.pretrained_config = model_config.pretrained_config.text_config
@ -545,7 +547,8 @@ class LlavaNextModel(PreTrainedModel):
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
weights = weight_mapper.preprocess_weights(weights)
self.mm_encoder.load_weights(weights)
if self.mm_encoder is not None:
self.mm_encoder.load_weights(weights)
def filter_weights(weights: Dict):
transformed_weights = {}

View File

@ -32,7 +32,8 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor, ExtraProcessedInputs,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
register_input_processor,
support_multimodal_disaggregated)
from ...logger import logger
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
@ -865,6 +866,8 @@ class Qwen2VLModelBase(PreTrainedModel):
mm_encoder_config = copy.deepcopy(model_config)
self.mm_encoder = Qwen2VisionModelBase(
mm_encoder_config, kwargs.get('vision_model_class', None))
else:
self.mm_encoder = None
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
config = model_config.pretrained_config
@ -953,24 +956,21 @@ class Qwen2VLModelBase(PreTrainedModel):
"""
VLM forward logic with inflight batching support.
"""
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
num_context_requests = attn_metadata.num_contexts
multimodal_params = kwargs.get("multimodal_params", [])
mm_embeds = []
mrope_config = {}
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts.
mm_multimodal_params = [
multimodal_param for multimodal_param in multimodal_params
if multimodal_param.multimodal_data.get("image", {}).get(
"pixel_values") is not None or multimodal_param.multimodal_data.
get("video", {}).get("pixel_values_videos") is not None
]
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate
# the entries that do have multimodal data from those that correspond to text-only prompts.
mm_multimodal_params = self._get_requests_with_mm_data(
multimodal_params)
if len(mm_multimodal_params) > 0:
if not _is_disagg():
mm_embeds = get_multimodal_embeddings(
encoder_forward_fn=self.mm_encoder.forward,
multimodal_params=mm_multimodal_params)
else:
elif not getattr(self, "support_mm_disagg", False):
raise NotImplementedError(
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
@ -995,6 +995,21 @@ class Qwen2VLModelBase(PreTrainedModel):
logger.debug(f'output shape: {output_prob.shape}')
return output_prob
def _get_requests_with_mm_data(self, multimodal_params):
mm_multimodal_params = []
for multimodal_param in multimodal_params:
data = multimodal_param.multimodal_data
if (
# The first 2 conditions check whether there is input on which inference should be run.
data.get("image", {}).get("pixel_values") is not None or
data.get("video", {}).get("pixel_values_videos") is not None
# This condition corresponds to when the embeddings are already populated, as is e.g.
# the case in EPD disagg in the prefill worker.
or data.get("multimodal_embedding")):
mm_multimodal_params.append(multimodal_param)
return mm_multimodal_params
@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=Qwen2VisionTransformerPretrainedModel)
@ -1032,11 +1047,89 @@ class Qwen2VLModel(Qwen2VLModelBase):
self.llm.load_weights(weights, weight_mapper)
class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
def get_prompt_token_ids(
self, inputs: TextPrompt,
mm_handles: List[Dict[str,
Any]]) -> Tuple[List[int], List[int], List[int]]:
"""
Build input token ids with multimodal placeholders expanded to the number of MM tokens.
Args:
inputs: Text prompt input container. Must contain a non-empty prompt string.
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
Returns:
Tuple[List[int], List[int], List[int]]:
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
- mm_token_length: per-image MM token lengths
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
"""
# TODO: Move this function to the base input processor class when extending for more models
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")
if not isinstance(mm_handles, list):
raise TypeError("mm_handles must be a list")
if len(mm_handles) != 1:
# TODO: only support single multimodal item within a request for now
raise NotImplementedError(
"Only one mm_handle is supported for Qwen2.5 VL for now")
hidden_size = mm_handles[0]['tensor_size'][1]
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
input_ids = self.tokenizer(text_prompt,
return_tensors="pt").input_ids[0]
image_token_index = self.config.image_token_id
image_mask = input_ids == image_token_index
image_positions = torch.where(image_mask)[0]
num_images = len(image_positions)
assert num_images == len(
mm_handles), "Number of images must match number of mm_handles"
total_mm_tokens = sum(mm_handle["tensor_size"][0]
for mm_handle in mm_handles)
final_length = len(input_ids) - num_images + total_mm_tokens
# Create output tensor
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
placeholder_id = self.tllm_multimodal_token_id
# Fill the expanded sequence
write_pos = 0
image_cnt = 0
mm_token_length = []
mm_token_offsets = []
for read_pos in range(len(input_ids)):
if input_ids[read_pos] == image_token_index:
# Replace with placeholder id
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
expanded_ids[write_pos:write_pos + mm_token_num] = \
placeholder_id
mm_token_offsets.append(write_pos)
mm_token_length.append(mm_token_num)
write_pos += mm_token_num
image_cnt += 1
else:
# Copy text token as-is
expanded_ids[write_pos] = input_ids[read_pos]
write_pos += 1
assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
assert mm_token_length[-1] + mm_token_offsets[
-1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})"
return expanded_ids.to(
torch.int32).tolist(), mm_token_length, mm_token_offsets
@support_multimodal_disaggregated
@register_vision_encoder(Qwen2VisionModelBase,
vlm_base_model=Qwen2_5_VisionModel)
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
@register_input_processor(
Qwen2VLInputProcessorBase,
Qwen2_5VLInputProcessorBase,
model_type="qwen2_5_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={

View File

@ -262,6 +262,8 @@ class PyResult:
chunk_size=self._chunk_size) if return_generation_logits else None
self._log_probs = LogProbStorage() if return_log_probs else None
self._mm_embeddings = None
self._mrope_position_ids = None
self._mrope_position_deltas = None
self._additional_context_outputs = {
name: []
for name in additional_outputs
@ -293,6 +295,16 @@ class PyResult:
self._mm_embeddings = SharedTensorContainer.from_tensor(
mm_embeddings).dump_to_dict()
def set_mrope_position(
self,
mrope_position_ids: torch.Tensor,
mrope_position_deltas: torch.Tensor,
):
self._mrope_position_ids = (SharedTensorContainer.from_tensor(
mrope_position_ids).dump_to_dict())
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
mrope_position_deltas).dump_to_dict())
def transfer_remaining_device_logits(self):
"""Finalize any remaining generation logits transfers (for chunked mode)"""
if self._generation_logits:
@ -352,6 +364,18 @@ class PyResult:
def mm_embedding_handle(self) -> Dict[str, Any] | None:
return self._mm_embeddings
@property
def mrope_position_ids_handle(self) -> Dict[str, Any] | None:
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
# the `SharedTensorContainer` using the `from_dict` class method.
return self._mrope_position_ids
@property
def mrope_position_deltas_handle(self) -> Dict[str, Any] | None:
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
# the `SharedTensorContainer` using the `from_dict` class method.
return self._mrope_position_deltas
@property
def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None:
if self._additional_context_outputs is None:
@ -382,7 +406,8 @@ class LlmResult:
py_result_properties = frozenset(
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
'mm_embedding_handle', 'additional_context_outputs',
'additional_generation_outputs'))
'additional_generation_outputs', 'mrope_position_ids_handle',
'mrope_position_deltas_handle'))
def __init__(self,
result: Union[bytes, tensorrt_llm.bindings.executor.Result],

View File

@ -2213,13 +2213,14 @@ class PyTorchModelEngine(ModelEngine):
mrope_position_deltas).expand(
3, 1, 1)
mrope_position_ids.append(gen_mrope_position_ids)
multimodal_params.to_device(
"multimodal_data",
"cuda",
pin_memory=True,
target_keywords=[
"mrope_config.mrope_position_deltas"
])
if mrope_position_deltas.device.type == "cpu":
multimodal_params.to_device(
"multimodal_data",
"cuda",
pin_memory=True,
target_keywords=[
"mrope_config.mrope_position_deltas"
])
multimodal_params_list.append(multimodal_params)
request.py_batch_idx = request.py_seq_slot
@ -2448,8 +2449,9 @@ class PyTorchModelEngine(ModelEngine):
# NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but
# `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request
# so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope.
mrope_position_ids = torch.cat(mrope_position_ids,
dim=-1).pin_memory()
mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
if mrope_position_ids.device.type == "cpu":
mrope_position_ids = mrope_position_ids.pin_memory()
self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_(
mrope_position_ids[:, :, :total_num_tokens], non_blocking=True)
final_position_ids = self.mrope_position_ids_cuda[:, :, :
@ -3362,7 +3364,26 @@ class PyTorchModelEngine(ModelEngine):
mm_embeddings = list(
torch.split(mm_embeddings[0], multimodal_chunks, dim=0))
return {'mm_embeddings': mm_embeddings, 'logits': None}
# Extract mrope position data from multimodal_params if available
mrope_position_ids_list = []
mrope_position_deltas_list = []
for multimodal_param in multimodal_params:
mrope_config = multimodal_param.multimodal_data.get(
'mrope_config', {})
mrope_position_ids = mrope_config.get('mrope_position_ids')
mrope_position_deltas = mrope_config.get('mrope_position_deltas')
if mrope_position_ids is not None:
mrope_position_ids_list.append(mrope_position_ids)
if mrope_position_deltas is not None:
mrope_position_deltas_list.append(mrope_position_deltas)
result = {'mm_embeddings': mm_embeddings, 'logits': None}
if mrope_position_ids_list:
result['mrope_position_ids'] = mrope_position_ids_list
if mrope_position_deltas_list:
result['mrope_position_deltas'] = mrope_position_deltas_list
return result
def _init_userbuffers(self, hidden_size):
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:

View File

@ -21,7 +21,7 @@ from concurrent import futures
from dataclasses import dataclass
from functools import cached_property
from itertools import repeat
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
import numpy as np
import torch
@ -199,6 +199,8 @@ class EarlyStopSampler(Sampler):
@dataclass(kw_only=True)
class MultimodalResult:
mm_embeddings: List[torch.Tensor]
# Can be used to include e.g. `mrope_position_ids`, etc.
extra_data: Optional[Dict[str, Any]] = None
def values(self):
return vars(self).values()
@ -262,7 +264,10 @@ class EarlyStopWithMMResult(Sampler):
resource_manager: Optional[ResourceManager] = None,
) -> SampleStateWithMMResult:
# from model_outputs to MultimodalResult
data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"])
data = MultimodalResult(
mm_embeddings=model_outputs.pop("mm_embeddings"),
extra_data={**model_outputs},
)
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
@override
@ -276,7 +281,12 @@ class EarlyStopWithMMResult(Sampler):
scheduled_requests = state.scheduled_requests
assert not scheduled_requests.generation_requests
mm_embeddings = state.data.mm_embeddings
for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings):
extra_data = state.data.extra_data or {}
mrope_position_ids = extra_data.get("mrope_position_ids", None)
mrope_position_deltas = extra_data.get("mrope_position_deltas", None)
for i, (request, mm_embedding) in enumerate(
zip(scheduled_requests.context_requests, mm_embeddings)
):
request.state = LlmRequestState.GENERATION_COMPLETE
# NOTE: This is a hack: set finish reason manually and set the beam 0
request.set_finished_reason(FinishReason.LENGTH, 0)
@ -287,6 +297,12 @@ class EarlyStopWithMMResult(Sampler):
request.py_result.append_mm_embeddings(mm_embedding)
# Store mrope data if available
if mrope_position_ids is not None and mrope_position_deltas is not None:
request.py_result.set_mrope_position(
mrope_position_ids[i], mrope_position_deltas[i]
)
@override
def is_generation_model(self) -> bool:
return False

View File

@ -40,6 +40,8 @@ class DisaggregatedParams:
multimodal_hashes: Optional[List[List[int]]] = (
None # user provided mm hashes should be a list of 8 integers
)
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
return tllme.ContextPhaseParams(

View File

@ -1,9 +1,10 @@
import atexit
import concurrent.futures
import json
import os
import threading
import time
import weakref
from typing import Dict, Optional, Union
from typing import Dict, List, Optional
import torch
import zmq
@ -22,9 +23,11 @@ from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import PostprocWorker, PostprocWorkerConfig
from .request import CancellingRequest, GenerationRequest
from .result import GenerationResult, IterationResult
from .utils import (ErrorResponse, IntraProcessQueue, WorkerCommIpcAddrs,
create_mpi_comm_session, get_spawn_proxy_process_env,
is_llm_response, print_alive_threads)
from .rpc import RPCClient
from .rpc.rpc_common import get_unique_ipc_addr
from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session,
get_spawn_proxy_process_env, is_llm_response,
print_alive_threads)
from .worker import GenerationExecutorWorker, worker_main
__all__ = [
@ -89,19 +92,27 @@ class GenerationExecutorProxy(GenerationExecutor):
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
"llm_args", None) is not None else None
# Generate RPC address and key for stats RPC
self.rpc_addr = get_unique_ipc_addr()
self.hmac_key = os.urandom(32)
worker_kwargs = dict(**worker_kwargs,
worker_queues=self._setup_queues(),
postproc_worker_config=postproc_worker_config,
is_llm_executor=False)
is_llm_executor=False,
rpc_addr=self.rpc_addr,
hmac_key=self.hmac_key)
if "log_level" not in worker_kwargs:
worker_kwargs["log_level"] = logger.level
self.dispatch_result_thread: Optional[ManagedThread] = None
self.dispatch_stats_thread: Optional[ManagedThread] = None
self.dispatch_kv_cache_events_thread: Optional[ManagedThread] = None
self.rpc_client: Optional[RPCClient] = None
self._start_executor_workers(worker_kwargs)
# Create RPC client after workers are started (worker starts RPC server)
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)
# MPI registers its joiner using threading._register_atexit if possible.
# These functions run before atexit.register, so to avoid deadlock,
# we have to notify workers to exit before MPI starts to wait them.
@ -128,19 +139,11 @@ class GenerationExecutorProxy(GenerationExecutor):
socket_type=zmq.PULL
if self.enable_postprocess_parallel else zmq.PAIR,
name="proxy_result_queue")
self.mp_stats_queue = FusedIpcQueue(is_server=True,
fuse_message=False,
name="proxy_stats_queue")
self.kv_cache_events_queue = FusedIpcQueue(
is_server=True,
fuse_message=False,
name="proxy_kv_cache_events_queue")
# Stats and KV events are now fetched via RPC, not IPC queues.
return WorkerCommIpcAddrs(
request_queue_addr=self.request_queue.address,
worker_init_status_queue_addr=self.worker_init_status_queue.address,
result_queue_addr=self.result_queue.address,
stats_queue_addr=self.mp_stats_queue.address,
kv_cache_events_queue_addr=self.kv_cache_events_queue.address,
)
def abort_request(self, request_id: int) -> None:
@ -204,71 +207,8 @@ class GenerationExecutorProxy(GenerationExecutor):
return True # success
def _iteration_result_task(self,
queue: Union[FusedIpcQueue, IntraProcessQueue],
result_singleton: IterationResult,
urgent: bool = False) -> bool:
if not urgent:
time.sleep(0.2)
try:
data = queue.get()
except:
logger.debug(
"proxy.py: Error in _iteration_result_task: queue.get()")
return False
if data is None:
logger.debug("proxy.py: _iteration_result_task: data is None")
return False # shutdown the thread
data = data if isinstance(data, list) else [data]
queue = result_singleton.queue
async_queues = []
while queue.full():
queue.get()
try:
for d in data:
if d is None:
logger.debug("proxy.py: _iteration_result_task: d is None")
return False
if isinstance(queue, _SyncQueue):
queue.put_nowait(d)
async_queues.append(queue)
else:
queue.put(d)
if async_queues:
_SyncQueue.notify_many(queue.loop, async_queues)
except AsyncQueue.EventLoopShutdownError:
# This happens in the last loop while the generate workflow is
# stopped, or when get_stats() or aget_stats() are not called by users
# and therefore event loop can already be closed.
logger.debug("proxy.py: EventLoopShutdownError")
except Exception as e:
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
raise e
return True # success
def dispatch_stats_task(self) -> bool:
if not self._iter_stats_result:
# This can happen temporarily because the WAR in tensorrt_llm/bench/benchmark/throughput.py
# is not synchronized with self.dispatch_stats_thread.
logger.debug(
f"Skipping stats dispatch while self._iter_stats_result=None")
return True # Intended behavior, not an error
return self._iteration_result_task(self.mp_stats_queue,
self._iter_stats_result)
def dispatch_kv_cache_events_task(self) -> bool:
return self._iteration_result_task(self.kv_cache_events_queue,
self._iter_kv_events_result,
urgent=True)
# NOTE: _iteration_result_task, dispatch_stats_task, and dispatch_kv_cache_events_task
# have been removed as stats and kv_events are now fetched via RPC directly.
def _start_dispatch_threads(self):
if self.dispatch_result_thread is None:
@ -277,25 +217,9 @@ class GenerationExecutorProxy(GenerationExecutor):
weakref.WeakMethod(self.dispatch_result_task),
error_queue=self._error_queue,
name="proxy_dispatch_result_thread")
self.dispatch_stats_thread = ManagedThread(
weakref.WeakMethod(self.dispatch_stats_task),
error_queue=self._error_queue,
name="proxy_dispatch_stats_thread")
self.dispatch_kv_cache_events_thread = ManagedThread(
weakref.WeakMethod(self.dispatch_kv_cache_events_task),
error_queue=self._error_queue,
name="proxy_dispatch_kv_cache_events_thread")
self.dispatch_result_thread.start()
# Only collect stats when submission
# is via LLM API
if self._iter_stats_result:
self.dispatch_stats_thread.start()
if self._iter_kv_events_result:
self.dispatch_kv_cache_events_thread.start()
self._handle_background_error()
def _start_executor_workers(self, worker_kwargs):
@ -387,23 +311,18 @@ class GenerationExecutorProxy(GenerationExecutor):
):
self.dispatch_result_thread.stop()
self.dispatch_result_thread.join()
if self.dispatch_stats_thread is not None and self.dispatch_stats_thread.is_alive(
):
self.dispatch_stats_thread.stop()
self.dispatch_stats_thread.join()
if self.dispatch_kv_cache_events_thread is not None and self.dispatch_kv_cache_events_thread.is_alive(
):
self.dispatch_kv_cache_events_thread.stop()
self.dispatch_kv_cache_events_thread.join()
# step3: finish all remaining work
# close the RPC client
if self.rpc_client is not None:
self.rpc_client.close()
self.rpc_client = None
# close all the sockets
self.request_queue.close()
self.worker_init_status_queue.close()
self.result_queue.close()
self.mp_stats_queue.close()
self.kv_cache_events_queue.close()
self.workers_started = False
self.mpi_session.shutdown()
@ -441,6 +360,104 @@ class GenerationExecutorProxy(GenerationExecutor):
return result
def get_stats(self, timeout: float) -> List[dict]:
"""Get iteration statistics from the runtime via RPC.
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
List[dict]: A list of runtime stats as dict.
"""
if self.rpc_client is None:
logger.warning("RPC client not initialized, cannot get stats")
return []
stats = self.rpc_client.fetch_stats_wait_async(timeout=timeout).remote()
return [json.loads(s) if isinstance(s, str) else s for s in stats]
def aget_stats(self, timeout: float) -> IterationResult:
"""Get iteration statistics from the runtime via RPC (async).
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
IterationResult: An async iterable object containing runtime stats.
"""
# Initialize iteration result if needed
self._maybe_initialize_iteration_results()
if self._iter_stats_result is None:
logger.warning("Iteration statistics are not available yet.")
from .executor import empty_async_iterable
return empty_async_iterable()
# Fetch stats via RPC and populate the result
try:
stats = self.rpc_client.fetch_stats_wait_async(
timeout=timeout).remote()
except Exception as e:
logger.debug(f"Error fetching stats via RPC: {e}")
stats = []
for stat in stats:
self._iter_stats_result.queue.put(stat)
self._iter_stats_result.set_timeout(timeout)
return self._iter_stats_result
def get_kv_events(self, timeout: float) -> List[dict]:
"""Get iteration KV events from the runtime via RPC.
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
List[dict]: A list of runtime events as dict.
"""
if self.rpc_client is None:
logger.warning("RPC client not initialized, cannot get kv events")
return []
try:
events = self.rpc_client.fetch_kv_cache_events_wait_async(
timeout=timeout).remote()
return [json.loads(e) if isinstance(e, str) else e for e in events]
except Exception as e:
logger.error(f"Error fetching kv events via RPC: {e}")
return []
def aget_kv_events(self, timeout: float) -> IterationResult:
"""Get iteration KV events from the runtime via RPC (async).
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
IterationResult: An async iterable object containing runtime events.
"""
# Initialize iteration result if needed
self._maybe_initialize_iteration_results()
if self._iter_kv_events_result is None:
from .executor import empty_async_iterable
return empty_async_iterable()
# Fetch kv events via RPC and populate the result
try:
events = self.rpc_client.fetch_kv_cache_events_wait_async(
timeout=timeout).remote()
except Exception as e:
logger.debug(f"Error fetching kv events via RPC: {e}")
events = []
for event in events:
self._iter_kv_events_result.queue.put(event)
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result
def __del__(self):
self.shutdown()

View File

@ -1,4 +1,5 @@
import asyncio
import dataclasses
import json
import time
import weakref
@ -415,12 +416,19 @@ class GenerationResultBase:
self.cached_tokens = getattr(response_result, 'cached_tokens', 0)
self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter
if context_phase_params is not None:
self.disaggregated_params = DisaggregatedParams(
existing_disagg_params = self.disaggregated_params
# Use `replace` to preserve things like `mrope_position_ids_handle` and
# `mrope_position_deltas_handle`. However, explicitly set
# `multimodal_embedding_handles=None` since they should no longer be needed.
self.disaggregated_params = dataclasses.replace(
existing_disagg_params or DisaggregatedParams(),
request_type="context_only",
first_gen_tokens=context_phase_params.first_gen_tokens,
ctx_request_id=context_phase_params.req_id,
opaque_state=context_phase_params.opaque_state,
draft_tokens=context_phase_params.draft_tokens)
draft_tokens=context_phase_params.draft_tokens,
multimodal_embedding_handles=None,
)
finish_reasons = response_result.finish_reasons
# output_token_ids = (beams, tokens)
@ -440,6 +448,8 @@ class GenerationResultBase:
if hasattr(response_result, 'mm_embedding_handle'
) and response_result.mm_embedding_handle is not None:
self._mm_embedding_handle = response_result.mm_embedding_handle
mrope_position_ids_handle = response_result.mrope_position_ids_handle
mrope_position_deltas_handle = response_result.mrope_position_deltas_handle
if self.disaggregated_params is not None:
self.disaggregated_params.multimodal_embedding_handles = [
response_result.mm_embedding_handle
@ -451,6 +461,8 @@ class GenerationResultBase:
response_result.mm_embedding_handle
],
multimodal_hashes=self._multimodal_hashes)
self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle
self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle
# Processing background errors here ASAF during generation.
if self._background_error_handler and (
@ -811,8 +823,12 @@ class GenerationResult(GenerationResultBase):
def _repr_fields(self):
return [
'request_id', 'prompt_token_ids', 'outputs', 'finished',
"context_logits", "mm_embedding_handle"
'request_id',
'prompt_token_ids',
'outputs',
'finished',
"context_logits",
"mm_embedding_handle",
]
def __repr__(self) -> str:

View File

@ -1,11 +1,13 @@
import json
import threading
from typing import Optional
from typing import List, Optional
from ..llmapi.mpi_session import MpiPoolSession, MpiSession
from ..llmapi.utils import logger_debug
from ..llmapi.utils import logger_debug, print_colored
from ..logger import logger
from .executor import GenerationExecutor
from .postproc_worker import PostprocWorkerConfig
from .result import IterationResult
from .rpc_proxy_mixin import RpcExecutorMixin
from .rpc_worker import RpcWorker
from .utils import create_mpi_comm_session, get_spawn_proxy_process_env
@ -69,20 +71,110 @@ class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor):
**self.worker_kwargs)
def _setup_mainloop_with_tasks(self):
"""Setup mainloop with all tasks needed for RpcProxy."""
"""Setup mainloop with tasks needed for RpcProxy.
Note: Stats and kv_events are now fetched on-demand via direct RPC calls
(get_stats, aget_stats, get_kv_events, aget_kv_events), not via streaming loops.
"""
tasks = [
self._fetch_responses_loop_async,
self._fetch_stats_loop_async,
]
# Only add kv_cache_events loop if it's enabled
if self._iter_kv_events_result:
tasks.append(self._fetch_kv_cache_events_loop_async)
# Call mixin's setup_mainloop with custom tasks
self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop")
def fetch_stats_remote(self):
return self.rpc_client.fetch_stats().remote()
def get_stats(self, timeout: float) -> List[dict]:
"""Get iteration statistics from the runtime via RPC.
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
List[dict]: A list of runtime stats as dict.
"""
try:
stats = self.rpc_client.fetch_stats_wait_async(
timeout=timeout).remote()
return [json.loads(s) if isinstance(s, str) else s for s in stats]
except Exception as e:
logger.debug(f"Error fetching stats via RPC: {e}")
return []
def aget_stats(self, timeout: float) -> IterationResult:
"""Get iteration statistics from the runtime via RPC (async).
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
IterationResult: An async iterable object containing runtime stats.
"""
self._maybe_initialize_iteration_results()
if self._iter_stats_result is None:
print_colored("Iteration statistics are not available yet.\n",
"yellow")
from .executor import empty_async_iterable
return empty_async_iterable()
# Fetch stats via RPC and populate the result
try:
stats = self.rpc_client.fetch_stats_wait_async(
timeout=timeout).remote()
except Exception:
stats = []
for stat in stats:
self._iter_stats_result.queue.put(stat)
self._iter_stats_result.set_timeout(timeout)
return self._iter_stats_result
def get_kv_events(self, timeout: float) -> List[dict]:
"""Get iteration KV events from the runtime via RPC.
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
List[dict]: A list of runtime events as dict.
"""
try:
# Events are already serialized by the worker's fetch_kv_cache_events_wait_async()
events = self.rpc_client.fetch_kv_cache_events_wait_async(
timeout=timeout).remote()
return [json.loads(e) if isinstance(e, str) else e for e in events]
except Exception as e:
logger.debug(f"Error fetching kv events via RPC: {e}")
return []
def aget_kv_events(self, timeout: float) -> IterationResult:
"""Get iteration KV events from the runtime via RPC (async).
Args:
timeout (float): Max wait time in seconds for the RPC call.
Returns:
IterationResult: An async iterable object containing runtime events.
"""
# Initialize iteration result if needed
self._maybe_initialize_iteration_results()
if self._iter_kv_events_result is None:
from .executor import empty_async_iterable
return empty_async_iterable()
# Fetch kv events via RPC and populate the result
try:
events = self.rpc_client.fetch_kv_cache_events_wait_async(
timeout=timeout).remote()
except Exception:
events = []
for event in events:
self._iter_kv_events_result.queue.put(event)
self._iter_kv_events_result.set_timeout(timeout)
return self._iter_kv_events_result
def setup_engine_remote(self):
return self.rpc_client.setup_engine().remote(need_response=True)

View File

@ -1,13 +1,12 @@
import asyncio
import atexit
import json
import os
import threading
from typing import Callable, List, Optional
from .._utils import nvtx_range_debug
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue, _SyncQueue
from ..llmapi.utils import _SyncQueue
from ..logger import logger
from .request import GenerationRequest
from .result import GenerationResult
@ -47,15 +46,16 @@ class RpcExecutorMixin:
Args:
tasks: List of async coroutine functions to run.
thread_name: Name for the main loop thread
Note: Stats and kv_events are now fetched on-demand via direct RPC calls
(get_stats, aget_stats, get_kv_events, aget_kv_events), so the default
tasks only include the responses loop. Callers can still provide custom
tasks including stats/kv_events loops if needed for streaming use cases.
"""
if tasks is None:
tasks = [
self._fetch_responses_loop_async,
self._fetch_stats_loop_async,
]
# Only add kv_cache_events loop if it's enabled
if self._iter_kv_events_result:
tasks.append(self._fetch_kv_cache_events_loop_async)
async def main_loop_task():
await asyncio.gather(*[task() for task in tasks])
@ -136,22 +136,6 @@ class RpcExecutorMixin:
if async_queues:
_SyncQueue.notify_many(event_loop, async_queues)
def handle_stats(self, stats):
"""Handle stats received from RPC worker and put them into the stats result queue.
Args:
stats: Statistics data from the RPC worker (can be dict, str, or list)
"""
self._handle_iteration_data(stats, self._iter_stats_result, "stats")
def handle_kv_cache_events(self, events):
"""Handle KV cache events received from RPC worker and put them into the events result queue.
Args:
events: KV cache events data from the RPC worker (can be dict, str, or list)
"""
self._handle_iteration_data(events, self._iter_kv_events_result, "kv_cache_events")
async def _generic_fetch_loop_async(
self, fetch_method_name: str, handler_method: Callable, method_name: str
):
@ -181,86 +165,6 @@ class RpcExecutorMixin:
method_name="_fetch_responses_loop_async",
)
async def _fetch_stats_loop_async(self):
await self._generic_fetch_loop_async(
fetch_method_name="fetch_stats_loop_async",
handler_method=self.handle_stats,
method_name="_fetch_stats_loop_async",
)
async def _fetch_kv_cache_events_loop_async(self):
await self._generic_fetch_loop_async(
fetch_method_name="fetch_kv_cache_events_loop_async",
handler_method=self.handle_kv_cache_events,
method_name="_fetch_kv_cache_events_loop_async",
)
def _handle_iteration_data(self, data, result_singleton, data_type: str):
"""Generic method to handle iteration data received from RPC worker.
Args:
data: Data from the RPC worker (can be dict, str, or list)
result_singleton: The iteration result singleton to put data into
data_type: Type of data for logging (e.g., "stats", "kv_cache_events")
"""
# Make sure we have initialized the iteration results
self._maybe_initialize_iteration_results()
if not result_singleton:
logger.debug(f"Skipping {data_type} handling while result_singleton=None")
return
# Get the queue from the result singleton
queue = result_singleton.queue
async_queues = []
# Clear old data if queue is full (similar to _iteration_result_task)
while queue.full():
queue.get()
try:
# Handle different types of data
if isinstance(data, str):
# Already JSON serialized
data_json = data
elif isinstance(data, list):
# Skip empty lists to avoid putting nothing in the queue
if not data:
logger.debug(f"rpc_proxy.py: Skipping empty {data_type} list")
return
# Handle list of data (multiple iterations)
for item in data:
if isinstance(item, str):
item_json = item
else:
item_json = json.dumps(item)
if isinstance(queue, _SyncQueue):
queue.put_nowait(item_json)
async_queues.append(queue)
else:
queue.put(item_json)
if async_queues:
_SyncQueue.notify_many(queue.loop, async_queues)
return
else:
# Convert dict/other to JSON string as expected by IterationResult
data_json = json.dumps(data)
if isinstance(queue, _SyncQueue):
queue.put_nowait(data_json)
async_queues.append(queue)
else:
queue.put(data_json)
if async_queues:
_SyncQueue.notify_many(queue.loop, async_queues)
except AsyncQueue.EventLoopShutdownError:
# This happens when the event loop is already closed
logger.debug(f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}")
except Exception as e:
logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}")
raise e
# NOTE: _fetch_stats_loop_async and _fetch_kv_cache_events_loop_async have been removed.
# Stats and kv_events are now fetched on-demand via direct RPC calls
# (get_stats, aget_stats, get_kv_events, aget_kv_events) instead of streaming loops.

View File

@ -1,4 +1,5 @@
import asyncio
import time
from queue import Queue
from threading import Event
from typing import AsyncGenerator, Optional
@ -50,8 +51,9 @@ class RpcWorkerMixin:
"""Submits a request to the worker."""
with nvtx_range_debug("RpcWorker.submit", color="blue", category="Worker"):
logger_debug(f"[worker] Submitting request {request.id}", color="green")
super().submit(request)
result = super().submit(request)
logger_debug(f"[worker] Submitted request {request.id}", color="green")
return result
def fetch_responses(self, timeout: Optional[float] = None) -> list:
"""Fetch responses from the response queue (blocking)."""
@ -99,54 +101,58 @@ class RpcWorkerMixin:
f"[worker] RpcWorker {self.rank} quitting fetch_responses_loop_async", color="yellow"
)
async def fetch_stats_async(self, timeout: Optional[float] = None) -> list:
"""Async version of fetch_stats using asyncio.to_thread."""
return await asyncio.to_thread(self.fetch_stats)
async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list:
"""Async version of fetch_kv_cache_events using asyncio.to_thread."""
return await asyncio.to_thread(self.fetch_kv_cache_events)
async def fetch_stats_loop_async(
self, timeout: Optional[float] = None
) -> AsyncGenerator[list, None]:
"""Stream stats in a loop until shutdown."""
async for data in self._generic_fetch_loop_async(
fetch_method=self.fetch_stats_async,
serializer=self._stats_serializer,
method_name="fetch_stats_loop_async",
timeout=timeout,
):
yield data
async def fetch_kv_cache_events_loop_async(
self, timeout: Optional[float] = None
) -> AsyncGenerator[list, None]:
"""Stream KV cache events in a loop until shutdown."""
async for data in self._generic_fetch_loop_async(
fetch_method=self.fetch_kv_cache_events_async,
serializer=self._kv_cache_events_serializer,
method_name="fetch_kv_cache_events_loop_async",
timeout=timeout,
):
yield data
async def _generic_fetch_loop_async(
self, fetch_method, serializer, method_name: str, timeout: Optional[float] = None
) -> AsyncGenerator[list, None]:
"""Generic method for fetching data in a loop.
async def fetch_stats_wait_async(self, timeout: Optional[float] = None) -> list:
"""Poll for stats until available or timeout.
Args:
fetch_method: The async method to call for fetching data
serializer: The serializer function to apply to each item
method_name: Name of the method for logging
timeout: Optional timeout between fetches
timeout: Max wait time in seconds. If None, fetch once without waiting.
"""
while not self.shutdown_event.is_set():
timeout = timeout or 0.1
await asyncio.sleep(timeout)
data = await fetch_method()
# Always yield data, even if empty, to prevent the client looks like hanging
# TODO: Remove the empty data to reduce the IPC overhead
yield [serializer(item) for item in data]
logger_debug(f"[worker] RpcWorker {self.rank} quitting {method_name}", color="yellow")
logger_debug(
f"[worker] RpcWorker {self.rank} is fetching stats with timeout {timeout}",
color="yellow",
)
start = time.time()
while True:
stats = await asyncio.to_thread(self.fetch_stats)
if stats or timeout is None:
break
if (time.time() - start) >= timeout:
break
await asyncio.sleep(0.1)
return [self._stats_serializer(s) for s in stats]
async def fetch_kv_cache_events_wait_async(self, timeout: Optional[float] = None) -> list:
"""Poll for KV cache events until available or timeout.
Args:
timeout: Max wait time in seconds. If None, fetch once without waiting.
"""
start = time.time()
while True:
events = await asyncio.to_thread(self.fetch_kv_cache_events)
if events or timeout is None:
break
if (time.time() - start) >= timeout:
break
await asyncio.sleep(0.1)
return [self._kv_cache_events_serializer(e) for e in events]
async def fetch_stats_async(self, timeout: Optional[float] = None) -> list:
"""Async version of fetch_stats using asyncio.to_thread.
This method is exposed via RPC and can be called directly by the proxy.
Returns serialized stats (JSON strings) that can be sent over RPC.
"""
stats = await asyncio.to_thread(self.fetch_stats)
# Serialize stats before sending over RPC (IterationStats objects are not picklable)
return [self._stats_serializer(s) for s in stats]
async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list:
"""Async version of fetch_kv_cache_events using asyncio.to_thread.
This method is exposed via RPC and can be called directly by the proxy.
Returns serialized events (JSON strings) that can be sent over RPC.
"""
events = await asyncio.to_thread(self.fetch_kv_cache_events)
# Serialize events before sending over RPC
return [self._kv_cache_events_serializer(e) for e in events]

View File

@ -142,8 +142,6 @@ class WorkerCommIpcAddrs(NamedTuple):
request_queue_addr: tuple[str, Optional[bytes]]
worker_init_status_queue_addr: tuple[str, Optional[bytes]]
result_queue_addr: tuple[str, Optional[bytes]]
stats_queue_addr: tuple[str, Optional[bytes]]
kv_cache_events_queue_addr: tuple[str, Optional[bytes]]
def is_llm_response(instance):

View File

@ -1,11 +1,9 @@
import gc
import os
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from queue import Queue
from typing import Callable, List, Optional, Union
from typing import List, Optional, Union
import zmq
@ -18,25 +16,22 @@ from ..llmapi.llm_args import BaseLlmArgs
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug,
print_traceback_on_error)
from ..llmapi.utils import ManagedThread, logger_debug, print_traceback_on_error
from ..sampling_params import BatchedLogitsProcessor
from .base_worker import BaseWorker, _init_hf_modules
from .executor import IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocWorker, PostprocWorkerConfig,
postproc_worker_main)
from .request import CancellingRequest, GenerationRequest
from .result import IterationResult
from .utils import (ErrorResponse, RequestError, WorkerCommIpcAddrs,
has_event_loop)
from .rpc_worker_mixin import RpcWorkerMixin
from .utils import ErrorResponse, RequestError, WorkerCommIpcAddrs
__all__ = [
"GenerationExecutorWorker",
]
class GenerationExecutorWorker(BaseWorker):
class GenerationExecutorWorker(RpcWorkerMixin, BaseWorker):
def __init__(
self,
@ -48,6 +43,8 @@ class GenerationExecutorWorker(BaseWorker):
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
) -> None:
super().__init__(
engine=engine,
@ -62,35 +59,18 @@ class GenerationExecutorWorker(BaseWorker):
self.setup_engine()
# Setup RPC server for stats (skip init_rpc_worker to keep IPC response queue)
# Only set up if rpc_addr is provided (for stats RPC support)
if rpc_addr is not None:
self.rpc_addr = rpc_addr
self.hmac_key = hmac_key
self.start_rpc_server() # Reuse from RpcWorkerMixin
self.await_response_thread = ManagedThread(
self.await_response_task,
error_queue=self._error_queue,
name="await_response_thread")
self.dispatch_stats_thread = ManagedThread(
self.dispatch_stats_task,
error_queue=self._error_queue,
name="dispatch_stats_thread")
self.dispatch_kv_cache_events_thread = ManagedThread(
self.dispatch_kv_cache_events_task,
error_queue=self._error_queue,
name="dispatch_kv_cache_events_thread")
def _create_iteration_result_queue(self,
it_result_queue: IterationResultQueue):
if not it_result_queue.is_initialized:
# not yet initialized
it_result_queue.is_initialized = True
if has_event_loop():
_queue = AsyncQueue()
it_result_queue.queue = _queue.sync_q
it_result_queue.aqueue = _queue
else:
_queue = Queue()
it_result_queue.queue = _queue
it_result_queue.aqueue = None
def start_thread(self, thread: ManagedThread):
if self.engine.can_enqueue_requests() and not thread.is_alive():
thread.start()
@ -98,74 +78,10 @@ class GenerationExecutorWorker(BaseWorker):
def await_response_task(self) -> bool:
return self._await_response_helper()
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
engine_get_result_api: Callable,
result_singleton: IterationResult,
result_serializer: Callable) -> bool:
time.sleep(0.2)
async_queues = []
queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue
try:
for results in engine_get_result_api():
res = result_serializer(results)
if self._is_llm_executor and result_singleton:
# In this case, there's no ExecutorBindingProxy.
# Worker needs to take care of putting to result queue.
while queue.full():
queue.get()
if isinstance(queue, _SyncQueue):
queue.put_nowait(res)
async_queues.append(queue)
else:
queue.put(res)
else:
# Send to ExecutorBindingProxy via IPC
queue.put(res)
if async_queues:
_SyncQueue.notify_many(queue.loop, async_queues)
except AsyncQueue.EventLoopShutdownError:
# This happens in the last results loop while the generate workflow is stopped.
logger.debug("worker.py: EventLoopShutdownError")
except Exception as e:
logger.error(f"worker.py: Error in _iteration_result_task: {e}")
raise e
return True # success
def dispatch_stats_task(self) -> bool:
return self._iteration_result_task(self.stats_queues, self.fetch_stats,
self._iter_stats_result,
self._stats_serializer)
def dispatch_kv_cache_events_task(self) -> bool:
if isinstance(self.engine, tllm.Executor):
# Check if the engine has a kv cache event manager
# If not, return an empty list for the events which will cause the thread to exit early.
event_manager = self.engine.get_kv_cache_event_manager()
if event_manager is None:
events_api = lambda: [None]
else:
events_api = event_manager.get_latest_events
return self._iteration_result_task(self.kv_events_queues,
events_api,
self._iter_kv_events_result,
self._kv_cache_events_serializer)
else:
return self._iteration_result_task(
self.kv_events_queues, self.engine.get_latest_kv_cache_events,
self._iter_kv_events_result, self._kv_cache_events_serializer)
def start(self):
# create iteration result queues
self._create_iteration_result_queue(self.stats_queues)
self._create_iteration_result_queue(self.kv_events_queues)
# start threads
# Stats and KV events are now fetched on-demand via RPC,
# so we only need to start the response thread
self.start_thread(self.await_response_thread)
self.start_thread(self.dispatch_kv_cache_events_thread)
if mpi_rank() == 0:
self.start_thread(self.dispatch_stats_thread)
def shutdown(self):
@ -178,16 +94,9 @@ class GenerationExecutorWorker(BaseWorker):
if self.engine is not None:
if self.engine.can_enqueue_requests():
if self.await_response_thread.is_alive():
self.await_response_thread.stop()
self.await_response_thread.join()
if self.dispatch_stats_thread.is_alive():
self.dispatch_stats_thread.stop()
self.dispatch_stats_thread.join()
if self.dispatch_kv_cache_events_thread.is_alive():
self.dispatch_kv_cache_events_thread.stop()
self.dispatch_kv_cache_events_thread.join()
self.engine.shutdown()
self.engine = None
@ -240,6 +149,8 @@ def worker_main(
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
rpc_addr: Optional[str] = None,
hmac_key: Optional[bytes] = None,
) -> None:
mpi_comm().barrier()
@ -287,15 +198,6 @@ def worker_main(
is_server=False,
socket_type=zmq.DEALER,
name="worker_init_status_queue")
mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr,
is_server=False,
fuse_message=True,
name="worker_stats_queue")
kv_cache_events_queue = FusedIpcQueue(
worker_queues.kv_cache_events_queue_addr,
is_server=False,
fuse_message=False,
name="worker_kv_cache_events_queue")
if postproc_worker_config.enabled:
# IPC queues for sending inputs to the postprocess parallel
@ -322,9 +224,6 @@ def worker_main(
assert result_queues is not None
for q in result_queues:
q.put(None)
# Signal the stats thread in the proxy to quit
mp_stats_queue.put(None)
kv_cache_events_queue.put(None)
postprocess_worker_futures = []
if is_leader and postproc_worker_config.enabled:
@ -370,7 +269,9 @@ def worker_main(
is_llm_executor=is_llm_executor,
hf_model_dir=hf_model_dir,
tokenizer=tokenizer,
llm_args=llm_args)
llm_args=llm_args,
rpc_addr=rpc_addr,
hmac_key=hmac_key)
except Exception as e:
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
logger.error(traceback.format_exc())
@ -396,11 +297,6 @@ def worker_main(
else:
worker.set_result_queue(result_queue)
# initialize the iteration result queues
worker._set_iteration_result_queue(worker.stats_queues,
mp_stats_queue)
worker._set_iteration_result_queue(worker.kv_events_queues,
kv_cache_events_queue)
# Send ready signal with confirmation
ready_msg = (ready_signal, None)
if not worker_init_status_queue.notify_with_retry(ready_msg):

View File

@ -89,8 +89,12 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
def _repr_fields(self):
return [
"request_id", "prompt", "prompt_token_ids", "outputs", "finished",
"mm_embedding_handle"
"request_id",
"prompt",
"prompt_token_ids",
"outputs",
"finished",
"mm_embedding_handle",
]
@ -419,7 +423,7 @@ class BaseLLM:
multimodal_params = None
if is_mm_disagg:
if not self.input_processor.support_mm_disagg:
if not getattr(self.input_processor, "support_mm_disagg", False):
raise ValueError(
"Multimodal disaggregated inference is not supported for this model"
)
@ -436,14 +440,42 @@ class BaseLLM:
mm_hashes = disaggregated_params.multimodal_hashes
multimodal_input = MultimodalInput.from_components(
mm_hashes, mm_token_positions, mm_token_length)
multimodal_data = {"multimodal_embedding": mm_handles}
if disaggregated_params.mrope_position_ids_handle is not None:
# NOTE: `PyTorchModelEngine` assumes both are present when using mrope.
assert disaggregated_params.mrope_position_deltas_handle is not None
mrope_config = {}
mrope_config[
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
mrope_config[
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
multimodal_data["mrope_config"] = mrope_config
multimodal_params = MultimodalParams(
multimodal_input=multimodal_input,
multimodal_data={"multimodal_embedding": mm_handles})
multimodal_data=multimodal_data,
)
elif "prompt_token_ids" in inputs:
prompt_token_ids = inputs['prompt_token_ids']
prompt = None
query_token_ids = inputs.get("query_token_ids", None)
multimodal_data = {}
# NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit.
if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None:
# It looks like `PyTorchModelEngine` assumes both are present when using mrope?
if disaggregated_params.mrope_position_deltas_handle is None:
raise RuntimeError(
"`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both "
"be provided, or both `None`.")
mrope_config = {}
mrope_config[
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
mrope_config[
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
multimodal_data["mrope_config"] = mrope_config
if multimodal_data:
multimodal_params = MultimodalParams(
multimodal_data=multimodal_data)
elif "prompt" in inputs:
if 'multi_modal_data' in inputs:
# TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash)

View File

@ -101,14 +101,8 @@ class MultimodalEncoder(_TorchLLM):
inputs = [prompt_inputs(i) for i in inputs]
def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
if isinstance(maybe_batched, list):
return maybe_batched[pos]
else:
return maybe_batched
futures = []
for i, request_inputs in enumerate(inputs):
for request_inputs in inputs:
future = self.generate_async(request_inputs)
futures.append(future)

View File

@ -51,20 +51,22 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
MemoryUpdateRequest, ModelCard,
ModelList, PromptTokensDetails,
ResponsesRequest,
ResponsesResponse,
UpdateWeightsRequest, UsageInfo,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
chat_response_post_processor, chat_stream_post_processor,
completion_response_post_processor, completion_stream_post_processor)
ResponsesAPIPostprocArgs, chat_harmony_post_processor,
chat_harmony_streaming_post_processor, chat_response_post_processor,
chat_stream_post_processor, completion_response_post_processor,
completion_stream_post_processor, responses_api_post_processor,
responses_api_streaming_post_processor)
from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore,
ResponsesStreamingProcessor,
ServerArrivalTimeMiddleware)
from tensorrt_llm.serve.responses_utils import \
create_response as responses_api_create_response
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
from tensorrt_llm.serve.responses_utils import \
process_streaming_events as responses_api_process_streaming_events
from tensorrt_llm.serve.responses_utils import \
request_preprocess as responses_api_request_preprocess
from tensorrt_llm.version import __version__ as VERSION
@ -119,9 +121,8 @@ class OpenAIServer:
self.model_config = None
# Enable response storage for Responses API
self.enable_store = True
if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0:
self.enable_store = False
self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled
self.conversation_store = ConversationHistoryStore()
model_dir = Path(model)
@ -942,19 +943,39 @@ class OpenAIServer:
return self.create_error_response(message=str(e), err_type="internal_error")
async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response:
async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]:
async for event_data in responses_api_process_streaming_events(
request=request,
sampling_params=sampling_params,
generator=generator,
model_name=self.model,
conversation_store=self.conversation_store,
use_harmony=self.use_harmony,
reasoning_parser=self.llm.args.reasoning_parser,
tool_parser=self.tool_parser,
enable_store=self.enable_store
):
yield event_data
async def create_response(
promise: RequestOutput, postproc_params: PostprocParams) -> ResponsesResponse:
await promise.aresult()
if self.postproc_worker_enabled:
response = promise.outputs[0]._postprocess_result
else:
args = postproc_params.postproc_args
response = await responses_api_create_response(
generator=promise,
request=request,
sampling_params=args.sampling_params,
model_name=self.model,
conversation_store=self.conversation_store,
generation_result=None,
enable_store=self.enable_store and request.store,
use_harmony=self.use_harmony,
reasoning_parser=args.reasoning_parser,
tool_parser=args.tool_parser,
)
return response
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
streaming_processor = args.streaming_processor
initial_responses = streaming_processor.get_initial_responses()
for initial_response in initial_responses:
yield initial_response
async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
for pp_res in pp_results:
yield pp_res
try:
if request.background:
@ -977,38 +998,61 @@ class OpenAIServer:
request=request,
prev_response=prev_response,
conversation_store=self.conversation_store,
enable_store=self.enable_store,
enable_store=self.enable_store and request.store,
use_harmony=self.use_harmony,
tokenizer=self.tokenizer if not self.use_harmony else None,
model_config=self.model_config if not self.use_harmony else None,
processor=self.processor if not self.use_harmony else None,
)
streaming_processor = None
if request.stream:
# Per-request streaming processor
streaming_processor = ResponsesStreamingProcessor(
request=request,
sampling_params=sampling_params,
model_name=self.model,
conversation_store=self.conversation_store,
enable_store=self.enable_store and request.store,
use_harmony=self.use_harmony,
reasoning_parser=self.llm.args.reasoning_parser,
tool_parser=self.tool_parser,
)
postproc_args = ResponsesAPIPostprocArgs(
model=self.model,
request=request,
sampling_params=sampling_params,
use_harmony=self.use_harmony,
reasoning_parser=self.llm.args.reasoning_parser,
tool_parser=self.tool_parser,
streaming_processor=streaming_processor,
)
postproc_params = PostprocParams(
post_processor=responses_api_streaming_post_processor
if request.stream else responses_api_post_processor,
postproc_args=postproc_args,
)
promise = self.llm.generate_async(
inputs=input_tokens,
sampling_params=sampling_params,
streaming=request.stream,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
)
if self.postproc_worker_enabled and request.store:
logger.warning("Postproc workers are enabled, request will not be stored!")
asyncio.create_task(self.await_disconnected(raw_request, promise))
if request.stream:
return StreamingResponse(
create_stream_response(promise, request, sampling_params),
content=create_streaming_generator(promise, postproc_params),
media_type="text/event-stream"
)
else:
return await responses_api_create_response(
generator=promise,
request=request,
sampling_params=sampling_params,
model_name=self.model,
conversation_store=self.conversation_store,
generation_result=None,
enable_store=self.enable_store,
use_harmony=self.use_harmony,
reasoning_parser=self.llm.args.reasoning_parser,
tool_parser=self.tool_parser)
response = await create_response(promise, postproc_params)
return JSONResponse(content=response.model_dump())
except CppExecutorError:
logger.error(traceback.format_exc())
# If internal executor error is raised, shutdown the server

View File

@ -1,11 +1,16 @@
from dataclasses import dataclass, field
from typing import Any, List, Literal, Optional, Tuple, Union
from tensorrt_llm.serve.responses_utils import ResponsesStreamingProcessor
from tensorrt_llm.serve.responses_utils import \
create_response_non_store as responses_api_create_response_non_store
from .._utils import nvtx_range_debug
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase)
from ..executor.postproc_worker import PostprocArgs
from ..executor.result import Logprob, TokenLogprobs
from ..llmapi import SamplingParams
from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
from ..llmapi.tokenizer import TransformersTokenizer
@ -26,7 +31,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
CompletionResponseStreamChoice,
CompletionStreamResponse, DeltaFunctionCall,
DeltaMessage, DeltaToolCall, FunctionCall,
PromptTokensDetails, StreamOptions, ToolCall,
PromptTokensDetails, ResponsesRequest,
ResponsesResponse, StreamOptions, ToolCall,
UsageInfo, to_disaggregated_params)
from .tool_parser.base_tool_parser import BaseToolParser
from .tool_parser.core_types import ToolCallItem
@ -543,3 +549,42 @@ def chat_harmony_streaming_post_processor(
num_prompt_tokens=args.num_prompt_tokens,
)
return response
@dataclass(kw_only=True)
class ResponsesAPIPostprocArgs(PostprocArgs):
model: str
request: ResponsesRequest
sampling_params: SamplingParams
use_harmony: bool
reasoning_parser: Optional[str] = None
tool_parser: Optional[str] = None
streaming_processor: Optional[ResponsesStreamingProcessor] = None
@nvtx_range_debug("responses_api_post_processor")
def responses_api_post_processor(
rsp: GenerationResult,
args: ResponsesAPIPostprocArgs) -> ResponsesResponse:
return responses_api_create_response_non_store(
generation_result=rsp,
request=args.request,
sampling_params=args.sampling_params,
model_name=args.model,
use_harmony=args.use_harmony,
reasoning_parser=args.reasoning_parser,
tool_parser=args.tool_parser,
)
@nvtx_range_debug("responses_api_streaming_post_processor")
def responses_api_streaming_post_processor(
rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> List[str]:
if args.streaming_processor is None:
raise ValueError(
"streaming_processor is required for streaming post-processing")
outputs = args.streaming_processor.process_single_output(rsp)
if rsp._done:
outputs.append(
args.streaming_processor.get_final_response_non_store(rsp))
return outputs

View File

@ -10,7 +10,7 @@ import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from copy import copy
from typing import Any, Literal, Optional, OrderedDict, Tuple, Union
from typing import Any, List, Literal, Optional, OrderedDict, Tuple, Union
from openai.types.responses import (ResponseCompletedEvent,
ResponseContentPartAddedEvent,
@ -41,6 +41,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
from transformers import AutoProcessor, PretrainedConfig
from tensorrt_llm.bindings import steady_clock_now
from tensorrt_llm.executor import GenerationResult
from tensorrt_llm.inputs.utils import apply_chat_template
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.llm import RequestOutput
@ -962,7 +963,7 @@ def _apply_tool_parser(
return normal_text, calls
async def _create_output_content(
def _create_output_content(
final_res: RequestOutput,
reasoning_parser: Optional[str] = None,
tool_parser: Optional[str] = None,
@ -1040,7 +1041,7 @@ async def _create_output_content(
return output_items, output_messages
async def _create_output_content_harmony(
def _create_output_content_harmony(
final_res: RequestOutput
) -> Tuple[list[ResponseOutputItem], list[Message]]:
output_messages = _parse_output_tokens(final_res.outputs[0].token_ids)
@ -1057,12 +1058,53 @@ async def _create_output_content_harmony(
return output_content, output_messages
def _create_response(
final_res: GenerationResult,
use_harmony: bool,
request: ResponsesRequest,
model_name: str,
response_creation_time: int,
sampling_params: SamplingParams,
reasoning_parser: Optional[str] = None,
tool_parser: Optional[str] = None,
) -> tuple[ResponsesResponse, list[Message | ChatCompletionMessageParam]]:
_responses_debug_log("================================================")
_responses_debug_log("RAW MODEL OUTPUT:")
_responses_debug_log(final_res.outputs)
_responses_debug_log("================================================")
# prepare responses output
output_content = []
if use_harmony:
output_content, output_messages = _create_output_content_harmony(
final_res)
else:
output_content, output_messages = _create_output_content(
final_res, reasoning_parser, tool_parser, request.tools)
response = ResponsesResponse.from_request(
request=request,
sampling_params=sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=output_content,
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
)
_responses_debug_log("========== Response ===========")
_responses_debug_log(response)
_responses_debug_log("===============================")
# return output_messages for store_response
return response, output_messages
async def create_response(
generator,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
conversation_store: ConversationHistoryStore,
generator: Optional[AsyncGenerator[RequestOutput, None]] = None,
generation_result: Optional[RequestOutput] = None,
enable_store: bool = False,
use_harmony: bool = True,
@ -1078,33 +1120,22 @@ async def create_response(
if generation_result is not None:
final_res = generation_result
else:
elif generator is not None:
final_res = await generator
if final_res is None:
raise RuntimeError("No output generated or provided")
_responses_debug_log("================================================")
_responses_debug_log("RAW MODEL OUTPUT:")
_responses_debug_log(final_res.outputs)
_responses_debug_log("================================================")
# prepare responses output
output_content = []
if use_harmony:
output_content, output_messages = await _create_output_content_harmony(
final_res)
else:
output_content, output_messages = await _create_output_content(
final_res, reasoning_parser, tool_parser, request.tools)
response = ResponsesResponse.from_request(
response, output_messages = _create_response(
final_res=final_res,
use_harmony=use_harmony,
request=request,
sampling_params=sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=output_content,
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
response_creation_time=response_creation_time,
sampling_params=sampling_params,
reasoning_parser=reasoning_parser,
tool_parser=tool_parser,
)
if enable_store and request.store:
@ -1112,9 +1143,34 @@ async def create_response(
resp_msgs=output_messages,
prev_resp_id=prev_response_id)
_responses_debug_log("========== Response ===========")
_responses_debug_log(response)
_responses_debug_log("===============================")
return response
def create_response_non_store(
generation_result: RequestOutput,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
use_harmony: bool = True,
create_time: Optional[int] = None,
reasoning_parser: Optional[str] = None,
tool_parser: Optional[str] = None,
) -> ResponsesResponse:
response_creation_time = create_time if create_time is not None else int(
time.time())
# prepare responses output
response, _ = _create_response(
final_res=generation_result,
use_harmony=use_harmony,
request=request,
model_name=model_name,
response_creation_time=response_creation_time,
sampling_params=sampling_params,
reasoning_parser=reasoning_parser,
tool_parser=tool_parser,
)
return response
@ -1649,6 +1705,143 @@ def _generate_streaming_event_harmony(
parser.last_content_delta)
class ResponsesStreamingProcessor:
def __init__(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
create_time: Optional[int] = None,
conversation_store: Optional[ConversationHistoryStore] = None,
enable_store: bool = False,
use_harmony: bool = True,
reasoning_parser: Optional[str] = None,
tool_parser: Optional[str] = None,
):
self.model_name = model_name
self.request = request
self.sampling_params = sampling_params
self.sequence_number = 0
self.streaming_events_helper = ResponsesStreamingEventsHelper()
self.response_creation_time = create_time if create_time is not None else int(
time.time())
self.final_res: Optional[RequestOutput] = None
self.reasoning_parser_dict: dict[int, BaseReasoningParser] = {}
self.tool_parser_dict: dict[int, BaseToolParser] = {}
self.stream_request_id = f"responses-api-{request.request_id}"
self.conversation_store = conversation_store
self.enable_store = enable_store
self.use_harmony = use_harmony
self.reasoning_parser = reasoning_parser
self.tool_parser = tool_parser
def _send_event(self, event: OpenAIBaseModel):
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
event.sequence_number = self.sequence_number
self.sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
def get_initial_responses(self) -> List[str]:
initial_response = ResponsesResponse.from_request(
request=self.request,
sampling_params=self.sampling_params,
model_name=self.model_name,
created_time=self.response_creation_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
resp_created = self._send_event(
self.streaming_events_helper.get_response_created_event(
initial_response))
resp_in_progress = self._send_event(
self.streaming_events_helper.get_response_in_progress_event(
initial_response))
return [resp_created, resp_in_progress]
async def get_final_response(
self,
final_res: RequestOutput,
) -> str:
final_response = await create_response(
generator=None,
request=self.request,
sampling_params=self.sampling_params,
model_name=self.model_name,
conversation_store=self.conversation_store,
generation_result=final_res,
enable_store=self.enable_store,
use_harmony=self.use_harmony,
create_time=self.response_creation_time,
reasoning_parser=self.reasoning_parser,
tool_parser=self.tool_parser,
)
return self._send_event(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))
def get_final_response_non_store(
self,
final_res: RequestOutput,
) -> str:
final_response = create_response_non_store(
generation_result=final_res,
request=self.request,
sampling_params=self.sampling_params,
model_name=self.model_name,
use_harmony=self.use_harmony,
create_time=self.response_creation_time,
reasoning_parser=self.reasoning_parser,
tool_parser=self.tool_parser,
)
return self._send_event(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))
def process_single_output(self, res: GenerationResult) -> list[str]:
event_generator = None
output = res.outputs[0]
if self.use_harmony:
event_generator = _generate_streaming_event_harmony(
harmony_adapter=get_harmony_adapter(),
stream_request_id=self.stream_request_id,
output=output,
request=self.request,
streaming_events_helper=self.streaming_events_helper,
)
else:
event_generator = _generate_streaming_event(
output=output,
request=self.request,
finished_generation=res._done,
streaming_events_helper=self.streaming_events_helper,
reasoning_parser_id=self.reasoning_parser,
tool_parser_id=self.tool_parser,
reasoning_parser_dict=self.reasoning_parser_dict,
tool_parser_dict=self.tool_parser_dict,
)
if event_generator is None:
raise RuntimeError("Failed to generate streaming events")
return [self._send_event(event) for event in event_generator]
async def process_streaming_events(
generator,
request: ResponsesRequest,
@ -1661,97 +1854,31 @@ async def process_streaming_events(
reasoning_parser: Optional[str] = None,
tool_parser: Optional[str] = None,
) -> AsyncGenerator[str, None]:
sequence_number = 0
response_creation_time = create_time if create_time is not None else int(
time.time())
final_res: Optional[RequestOutput] = None
reasoning_parser_dict: dict[int, BaseReasoningParser] = {}
tool_parser_dict: dict[int, BaseToolParser] = {}
def _send_event(event: OpenAIBaseModel):
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
event.sequence_number = sequence_number
sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
streaming_events_helper = ResponsesStreamingEventsHelper()
initial_response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=response_creation_time,
output=[],
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
streaming_events_helper.get_response_created_event(initial_response))
yield _send_event(
streaming_events_helper.get_response_in_progress_event(
initial_response))
stream_request_id = f"responses-api-{request.request_id}"
harmony_adapter = get_harmony_adapter()
async for res in generator:
final_res = res
# TODO(JunyiXu-nv): handle multiple outputs
output = res.outputs[0]
event_generator = None
if use_harmony:
event_generator = _generate_streaming_event_harmony(
harmony_adapter=harmony_adapter,
stream_request_id=stream_request_id,
output=output,
request=request,
streaming_events_helper=streaming_events_helper,
)
else:
event_generator = _generate_streaming_event(
output=output,
request=request,
finished_generation=res.finished,
streaming_events_helper=streaming_events_helper,
reasoning_parser_id=reasoning_parser,
tool_parser_id=tool_parser,
reasoning_parser_dict=reasoning_parser_dict,
tool_parser_dict=tool_parser_dict,
)
if event_generator is None:
raise RuntimeError("Failed to generate streaming events")
for event in event_generator:
yield _send_event(event)
final_response = await create_response(
generator=generator,
streaming_processor = ResponsesStreamingProcessor(
request=request,
sampling_params=sampling_params,
model_name=model_name,
create_time=create_time,
conversation_store=conversation_store,
generation_result=final_res,
enable_store=enable_store,
use_harmony=use_harmony,
create_time=response_creation_time,
reasoning_parser=reasoning_parser,
tool_parser=tool_parser,
)
yield _send_event(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),
))
initial_responses = streaming_processor.get_initial_responses()
for initial_response in initial_responses:
yield initial_response
async for res in generator:
final_res = res
events = streaming_processor.process_single_output(res)
for event in events:
yield event
final_response = await streaming_processor.get_final_response(final_res)
yield final_response
class ServerArrivalTimeMiddleware:

View File

@ -479,3 +479,8 @@ triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregat
cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (https://nvbugs/5760737)
unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_allreduce_pg_op[seqlen:16-hidden:1024] SKIP (https://nvbugs/5760740)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler] SKIP (https://nvbugs/5760747)
unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-strategy:8-dtype:bfloat16-hidden:8192-seqlen:[15]] SKIP (https://nvbugs/5761364)
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (https://nvbugs/5762852)
accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugs/5762852)

View File

@ -6,6 +6,7 @@ import torch
import tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.torch_moe # noqa: F401
import tensorrt_llm._torch.custom_ops.torch_custom_ops as trt_ops # noqa: F401
from tensorrt_llm._torch.utils import ActivationType
def test_flashinfer_fused_moe_matches_torch_moe():
@ -75,8 +76,8 @@ def test_flashinfer_fused_moe_matches_torch_moe():
w1_weight=w1_list, # gate projection
w2_weight=w2_list, # down projection
w3_weight=w3_list, # up projection
mlp_style="gated_mlp",
act_fn="silu",
is_gated_mlp=True,
act_fn=int(ActivationType.Silu),
)
# Compare outputs

View File

@ -186,7 +186,7 @@ def test_llama4_stacked_moe_pattern_detection():
moe_node = graph.call_function(
torch.ops.auto_deploy.torch_moe,
args=(x, selected_experts, routing_weights, w1_list, w2_list, w3_list),
kwargs={"mlp_style": "gated_mlp", "apply_routing_on_input": True},
kwargs={"is_gated_mlp": True, "apply_routing_on_input": True},
)
graph.output(moe_node)

View File

@ -7,6 +7,7 @@ from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_availab
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401
from tensorrt_llm._torch.utils import ActivationType
def setup_moe_test(dtype, num_experts):
@ -173,8 +174,8 @@ def test_bmm_based_moe_op_run(dtype):
[fused_w3_w1_stacked_weight], # Wrap in list for unified interface
[fused_w2_weight], # Wrap in list for unified interface
[], # Empty w3_weight list for stacked gated MLP
mlp_style="gated_mlp",
act_fn="silu",
is_gated_mlp=True,
act_fn=ActivationType.Silu,
apply_routing_on_input=True,
)
output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused(

View File

@ -82,7 +82,7 @@ def compute_with_experts(
alpha=None,
beta=None,
limit=None,
activation_func="silu",
activation_func: ActivationType = ActivationType.Silu,
):
def relu2(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))
@ -110,7 +110,7 @@ def compute_with_experts(
inter = x1_scaled * x2
else:
if activation_func == "swiglu" or activation_func == "silu":
if activation_func == ActivationType.Swiglu or activation_func == ActivationType.Silu:
inter = F.silu(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t())
else:
inter = relu2(expert_inputs @ w1_expert.t())
@ -136,10 +136,6 @@ def _get_test_data(
return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales
def _activation_type_from_str(activation_func: str) -> ActivationType:
return ActivationType.Swiglu if activation_func in ["swiglu", "silu"] else ActivationType.Relu2
def _print_diff_if(
condition: Callable[[torch.Tensor], bool],
diff: torch.Tensor,
@ -183,7 +179,7 @@ F16_TEST_DTYPES = [
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES)
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
@skip_pre_hopper
def test_trtllm_fused_moe(
batch_size,
@ -201,13 +197,13 @@ def test_trtllm_fused_moe(
pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})")
torch.manual_seed(42)
if activation_func in ["swiglu", "silu"]:
if activation_func in [ActivationType.Swiglu, ActivationType.Silu]:
X_GEN_SCALE = 1.0
else:
X_GEN_SCALE = 0.5
W_GEN_SCALE = 0.1
x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data(
x, router_logits, w31_weight, w2_weight, _, _ = _get_test_data(
otype,
wtype,
batch_size,
@ -239,19 +235,17 @@ def test_trtllm_fused_moe(
"F16 test only supports bfloat16 or float16"
)
activation_type = _activation_type_from_str(activation_func)
def get_fc1_expert_weights(
activation_func: str, w31_weight: torch.Tensor, w1_weight: torch.Tensor
activation_func: ActivationType, w31_weight: torch.Tensor, w1_weight: torch.Tensor
) -> torch.Tensor:
if activation_func == "relu2":
if activation_func == ActivationType.Relu2:
return w1_weight.contiguous()
else:
return w31_weight
# (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size)
_, w1_weight = torch.chunk(w31_weight, 2, dim=1)
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
is_gated_mlp = False if activation_func == ActivationType.Relu2 else True
torch.cuda.synchronize()
ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused(
@ -260,9 +254,13 @@ def test_trtllm_fused_moe(
routing_weights,
w3_w1_stacked_weight=get_fc1_expert_weights(activation_func, w31_weight, w1_weight),
w2_stacked_weight=w2_weight,
mlp_style=mlp_style,
is_gated_mlp=is_gated_mlp,
act_fn=activation_func,
)
# Convert ActivationType.Silu to ActivationType.Swiglu for C++ op compatibility
cpp_activation_type = (
ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func
)
trtllm_test_output = torch.ops.trtllm.fused_moe(
x,
selected_experts.to(torch.int),
@ -273,11 +271,11 @@ def test_trtllm_fused_moe(
fc2_expert_biases=None,
output_dtype=otype,
quant_scales=[],
activation_type=activation_type,
activation_type=cpp_activation_type,
)[0].view(x.shape)
torch.cuda.synchronize()
if mlp_style == "mlp":
if not is_gated_mlp:
with torch.inference_mode():
output_triton_moe = torch.ops.auto_deploy.triton_moe_fused(
x,
@ -285,6 +283,7 @@ def test_trtllm_fused_moe(
routing_weights,
w1_weight.contiguous(),
w2_weight.contiguous(),
is_gated_mlp=False,
)[0].view(x.shape)
torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-2, atol=1e-2)
@ -308,7 +307,7 @@ FP8_TEST_DTYPES = [
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES)
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
@pytest.mark.skipif(
not fp8_compatible() or not trtllm_ops_available(),
reason="Requires fp8 and trtllm support",
@ -336,7 +335,7 @@ def test_trtllm_fused_moe_fp8(
)
torch.manual_seed(42)
if activation_func in ["swiglu", "silu"]:
if activation_func in [ActivationType.Swiglu, ActivationType.Silu]:
X_GEN_SCALE = 1.0
else:
X_GEN_SCALE = 0.5
@ -399,7 +398,7 @@ def test_trtllm_fused_moe_fp8(
# (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size)
w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1)
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
is_gated_mlp = False if activation_func == ActivationType.Relu2 else True
# compute quant_scales
gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32)
@ -424,13 +423,13 @@ def test_trtllm_fused_moe_fp8(
gemm1_dequant=gemm1_dequant,
gemm2_act_quant=gemm2_act_quant,
gemm2_dequant=gemm2_dequant,
mlp_style=mlp_style,
is_gated_mlp=is_gated_mlp,
act_fn=activation_func,
)
torch.cuda.synchronize()
if mlp_style == "mlp":
if not is_gated_mlp:
with torch.inference_mode():
output_triton_fp8_moe = torch.ops.auto_deploy.triton_quant_fp8_moe(
x,
@ -445,7 +444,7 @@ def test_trtllm_fused_moe_fp8(
w1_scales,
w2_scales,
w3_scales,
mlp_style=mlp_style,
is_gated_mlp=is_gated_mlp,
act_fn=activation_func,
)
torch.testing.assert_close(output_triton_fp8_moe, ref_output, rtol=1e-1, atol=1e-1)
@ -569,7 +568,7 @@ NVFP4_TEST_DTYPES = [
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES)
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
@pytest.mark.skipif(
not fp4_compatible() or not trtllm_ops_available(),
reason="Requires fp4 and trtllm support",
@ -693,25 +692,23 @@ def test_trtllm_fused_moe_nvfp4(
fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs)
fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs)
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
if mlp_style == "gated_mlp":
is_gated_mlp = False if activation_func == ActivationType.Relu2 else True
if is_gated_mlp:
# For gated MLP, concatenate w1 and w3 as [w3, w1]
fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous()
fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1)
fc1_weight_gs = torch.max(w3_gs, w1_gs)
if activation_func != "silu":
if activation_func != ActivationType.Silu:
raise ValueError(
f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'."
)
elif mlp_style == "mlp":
else:
# For non-gated MLP with ReLU^2
fc1_expert_weights_fp4 = w1_q_fp4
fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long)
fc1_weight_gs = w1_gs
if activation_func != "relu2":
if activation_func != ActivationType.Relu2:
raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.")
else:
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long)
fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long)
@ -729,7 +726,7 @@ def test_trtllm_fused_moe_nvfp4(
fc2_activation_gs,
fc1_alpha,
fc2_alpha,
mlp_style=mlp_style,
is_gated_mlp=is_gated_mlp,
act_fn=activation_func,
)
@ -747,8 +744,7 @@ def test_trtllm_fused_moe_nvfp4(
block_size=NVFP4_BLOCK_SIZE,
)
concat_w3_w1 = mlp_style == "gated_mlp"
if concat_w3_w1:
if is_gated_mlp:
w1_gs = w3_gs = torch.max(w1_gs, w3_gs)
w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype)
@ -782,14 +778,18 @@ def test_trtllm_fused_moe_nvfp4(
block_size=NVFP4_BLOCK_SIZE,
)
# Convert ActivationType.Silu to ActivationType.Swiglu for reference op compatibility
resolved_activation_type = (
ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func
)
ref_output = torch_moe_nvfp4(
x_dq,
torch.cat([w3_dq, w1_dq], dim=1) if concat_w3_w1 else w1_dq,
torch.cat([w3_dq, w1_dq], dim=1) if is_gated_mlp else w1_dq,
w2_dq,
top_k,
routing_weights,
selected_experts,
_activation_type_from_str(activation_func),
resolved_activation_type,
)
return ref_output

View File

@ -4,6 +4,7 @@ from utils.util import skip_pre_hopper
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size
from tensorrt_llm._torch.utils import ActivationType # noqa: F401
def _pack_routed_tokens_reference(
@ -131,6 +132,7 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit):
routing_weights,
w_up_stacked,
w_down_stacked,
is_gated_mlp=False,
)
# Reference Torch MoE in mlp mode with relu2 activation
@ -141,8 +143,8 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit):
w1_weight=w_up_list,
w2_weight=w_down_list,
w3_weight=[],
mlp_style="mlp",
act_fn="relu2",
is_gated_mlp=False,
act_fn=ActivationType.Relu2,
)
torch.testing.assert_close(out_triton, out_torch, rtol=5e-2, atol=5e-2)
@ -364,8 +366,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
w1_weight_scale,
w2_weight_scale,
w3_weight_scale_tensor,
mlp_style="mlp",
act_fn="relu2",
is_gated_mlp=False,
act_fn=ActivationType.Relu2,
)
# Reference: Torch quantized FP8 MoE (uses lists of tensors and scales)
@ -382,8 +384,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
w1_weight_scale=w1_weight_scale_list,
w2_weight_scale=w2_weight_scale_list,
w3_weight_scale=w3_weight_scale_list,
mlp_style="mlp",
act_fn="relu2",
is_gated_mlp=False,
act_fn=ActivationType.Relu2,
)
torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)

View File

@ -566,6 +566,7 @@ class TestMoEAlltoAll:
(4, [32, 32, 32, 32], 4),
(4, [1, 1, 1, 1], 2),
(8, [640, 640, 640, 640, 640, 640, 640, 640], 4),
(4, [32, 0, 16, 0], 2),
],
indirect=["mpi_pool_executor"])
def test_combine(self, mpi_pool_executor, all_num_tokens, top_k):

View File

@ -19,49 +19,23 @@ example_images = [
str(test_data_root / "61.jpg"),
]
@pytest.fixture(scope="function")
def multimodal_model_config():
"""Get multimodal model configuration similar to integration tests"""
# You can extend this to support multiple models or get from environment
model_configs = {
'llava-v1.6-mistral-7b-hf': {
'model_name':
'llava-v1.6-mistral-7b-hf',
'hf_model_dir':
'llava-hf/llava-v1.6-mistral-7b-hf',
'model_dir':
llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf",
}
}
return model_configs['llava-v1.6-mistral-7b-hf']
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
# TODO: Add multi-image in single chat test
@pytest.mark.parametrize("model_key", [
"llava-v1.6-mistral-7b-hf",
])
@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR])
@pytest.mark.parametrize("pd_disagg", [False, True])
def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
def test_single_image_chat(model_dir, pd_disagg):
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
results to standard llm generation (pass raw image) by comparing outputs.
"""
# Get model configuration
if model_key != "llava-v1.6-mistral-7b-hf":
#TODO: add more model tests progressively here
pytest.skip(
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
)
# Extract model information from config
encoder_model_dir = multimodal_model_config['model_dir']
# Test configuration
max_tokens = 64
free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2
free_gpu_memory_fraction = 0.2
max_batch_size = 1
# Test data - OpenAI chat completion format
@ -76,15 +50,14 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
)
# Process multimodal data using encoder (pass mm_embeddings)
encoder = MultimodalEncoder(model=encoder_model_dir,
max_batch_size=max_batch_size)
encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size)
cache_transceiver_cfg = CacheTransceiverConfig(
backend="DEFAULT") if pd_disagg else None
disable_overlap_scheduler = pd_disagg
llm = LLM(model=encoder_model_dir,
llm = LLM(model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
trust_remote_code=True,
@ -93,7 +66,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
llm_decode = None
if pd_disagg:
llm_decode = LLM(model=encoder_model_dir,
llm_decode = LLM(model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
trust_remote_code=True,
@ -141,6 +114,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None"
ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only"
outputs = llm.generate(inputs,
sampling_params=sampling_params,
disaggregated_params=ep_disaggregated_params)
@ -151,10 +125,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
pd_disaggregated_params = outputs[0].disaggregated_params
pd_disaggregated_params.request_type = "generation_only"
sampling_params = SamplingParams(max_tokens=max_tokens)
inputs[0][
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
inputs[0]['prompt_token_ids'] = outputs[
0].prompt_token_ids # use prompt token ids from encoder output
# remove multimodal data from input as decoder worker doesn't need it
inputs[0]['multi_modal_data'] = None
# use prompt token ids from encoder output
inputs[0]['prompt_token_ids'] = outputs[0].prompt_token_ids
outputs = llm_decode.generate(
inputs,
@ -199,24 +173,23 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
f"Log probabilities don't match for output {i}, generation {j}"
@pytest.mark.parametrize("model_key", [
"llava-v1.6-mistral-7b-hf",
])
def test_multi_request_batch_chat(model_key, multimodal_model_config):
@pytest.mark.parametrize(
"model_dir, encoder_max_batch_size",
[
(_LLAVA_DIR, 3),
# Qwen2.5 VL's vision encoder seems to output different embeddings based on this value.
# The test only passes with this set to 1.
(_QWEN_2_5_VL_DIR, 1),
],
)
def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):
"""Test batching multiple multimodal requests and verify encoder path matches raw path.
This mirrors test_single_image_chat but with a batch of size 3.
"""
if model_key != "llava-v1.6-mistral-7b-hf":
pytest.skip(
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
)
encoder_model_dir = multimodal_model_config['model_dir']
max_tokens = 64
free_gpu_memory_fraction = 0.6
max_batch_size = 3
prompts = [
"Describe the natural environment in the image.",
@ -232,10 +205,10 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
free_gpu_memory_fraction=free_gpu_memory_fraction,
)
encoder = MultimodalEncoder(model=encoder_model_dir,
max_batch_size=max_batch_size)
encoder = MultimodalEncoder(model=model_dir,
max_batch_size=encoder_max_batch_size)
llm = LLM(
model=encoder_model_dir,
model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
max_batch_size=1, # fix batch size to reduce non-determinism in tests
@ -305,8 +278,7 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
"Describe the weather in the image.",
], 2),
])
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
multimodal_model_config):
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
"""Test mm_keys in KV cache events with cache reuse scenarios.
This test verifies:
@ -316,7 +288,7 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
- Same media + same prompts: full reuse (0 duplicate offsets)
- Same media + different prompts: partial reuse (prefix blocks reused)
"""
encoder_model_dir = multimodal_model_config['model_dir']
encoder_model_dir = _LLAVA_DIR
max_tokens = 16
free_gpu_memory_fraction = 0.6

View File

@ -21,11 +21,18 @@ def model(request):
return request.param
@pytest.fixture(scope="module",
params=[0, 2],
ids=["disable_processpool", "enable_processpool"])
def num_postprocess_workers(request):
return request.param
@pytest.fixture(scope="module")
def server(model: str):
def server(model: str, num_postprocess_workers: int):
model_path = get_model_path(model)
args = []
args = ["--num_postprocess_workers", f"{num_postprocess_workers}"]
if model.startswith("Qwen3"):
args.extend(["--reasoning_parser", "qwen3"])
elif model.startswith("DeepSeek-R1"):

View File

@ -1,6 +1,3 @@
import os
import tempfile
import openai
import pytest
import yaml
@ -22,34 +19,28 @@ def backend(request):
@pytest.fixture(scope="module")
def temp_extra_llm_api_options_file():
temp_dir = tempfile.gettempdir()
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
try:
extra_llm_api_options_dict = {
"enable_chunked_prefill": False,
"gather_generation_logits": True,
"kv_cache_config": {
"enable_block_reuse": False,
}
def temp_extra_llm_api_options_file(tmp_path_factory):
extra_llm_api_options_dict = {
"enable_chunked_prefill": False,
"gather_generation_logits": True,
"kv_cache_config": {
"enable_block_reuse": False,
}
}
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
yield temp_file_path
finally:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
temp_file_path = tmp_path_factory.mktemp(
"config") / "extra_llm_api_options.yaml"
with open(temp_file_path, 'w') as f:
yaml.dump(extra_llm_api_options_dict, f)
return temp_file_path
@pytest.fixture(scope="module")
def server(model_name: str, backend: str, temp_extra_llm_api_options_file: str):
model_path = get_model_path(model_name)
args = [
"--backend", f"{backend}", "--extra_llm_api_options",
temp_extra_llm_api_options_file
]
args = ["--backend", f"{backend}"]
if backend == "trt":
args += ["--extra_llm_api_options", temp_extra_llm_api_options_file]
with RemoteOpenAIServer(model_path, args) as remote_server:
yield remote_server
@ -61,11 +52,7 @@ def async_client(server: RemoteOpenAIServer):
@pytest.mark.asyncio(loop_scope="module")
async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
model_name: str, backend: str):
# Skip if backend is PyTorch as it does not support topk logprobs when k > 1
if backend == "pytorch":
pytest.skip("Topk logprobs is not supported")
model_name: str):
messages = [{
"role": "system",
"content": "You are a helpful assistant."
@ -94,42 +81,3 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
assert logprob_content.bytes is not None
assert logprob_content.top_logprobs is not None
assert len(logprob_content.top_logprobs) == 5
@pytest.mark.asyncio(loop_scope="module")
async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI,
model_name: str, backend: str):
# Skip if backend is TRT because it is tested in test_chat_completion_top5_logprobs
if backend == "trt":
pytest.skip(
"TRT top logprobs is already tested in test_chat_completion_top5_logprobs"
)
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "What is the capital of France?"
}]
# Test top_logprobs=1
chat_completion = await async_client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.0,
logprobs=True,
top_logprobs=1,
extra_body={
"ignore_eos": True,
})
logprobs = chat_completion.choices[0].logprobs
assert logprobs is not None and logprobs.content is not None
assert len(logprobs.content) == 10
for logprob_content in logprobs.content:
assert logprob_content.token is not None
assert logprob_content.logprob is not None
assert logprob_content.bytes is not None
assert logprob_content.top_logprobs is not None
# Check that the top_logprobs contains only one entry
assert len(logprob_content.top_logprobs) == 1

View File

@ -213,21 +213,33 @@ def _test_sync_generation_tp_inner(llama_7b_tp2_path: Path):
result.outputs[0].token_ids) == ", neural network,"
try:
stats = await executor.aget_stats()
stats = json.loads(stats)
assert stats["iter"] == 0
assert stats["cpuMemUsage"] > 0
assert stats["gpuMemUsage"] > 0
assert stats["inflightBatchingStats"]["numCtxTokens"] == 3
assert stats["inflightBatchingStats"]["numGenRequests"] == 0
assert stats["kvCacheStats"]["usedNumBlocks"] == 1
stats_result = executor.aget_stats(timeout=2)
# aget_stats now returns IterationResult, iterate to get stats
async for stats_str in stats_result:
stats = json.loads(stats_str) if isinstance(stats_str,
str) else stats_str
assert stats["iter"] >= 0
assert stats["cpuMemUsage"] > 0
assert stats["gpuMemUsage"] > 0
break # Just check first result
except AsyncQueue.EventLoopShutdownError:
pass
asyncio.run(async_stats_task())
stats = executor.get_stats()
assert json.loads(stats)["iter"] == 1
# Poll for stats since RPC calls return immediately
import time
stats_list = []
for _ in range(10):
stats_list = executor.get_stats(timeout=0.5)
if stats_list:
break
time.sleep(0.1)
assert len(stats_list) > 0
stats = json.loads(stats_list[0]) if isinstance(stats_list[0],
str) else stats_list[0]
assert stats["iter"] == 1
executor.shutdown()

View File

@ -4,6 +4,7 @@ import gc
import json
import os
import sys
import time
# Required for test_generate_with_seed to pass.
# See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891
@ -2193,6 +2194,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
sampling_params=sampling_params):
print(output)
time.sleep(2)
results = llm.get_stats(2)
validate_stats(results=results,
@ -2203,7 +2205,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
enable_chunked_prefill=enable_chunked_prefill,
enable_iter_req_stats=enable_iter_req_stats)
assert not llm.get_stats(2)
assert not llm.get_stats(0.5)
# test that IterationResult()._done is properly set
_ = llm.generate(prompts, sampling_params=sampling_params)
@ -2340,8 +2342,9 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
async def task1(repetition_index: int):
results = []
await asyncio.sleep(
3) # ensure there's stats to collect for the assertion
async for stats in llm.get_stats_async(timeout=2):
4) # ensure there's stats to collect for the assertion
async for stats in llm.get_stats_async(
10): # it will return immediately
results.append(stats)
assert results

View File

@ -487,11 +487,12 @@ def test_llm_get_kv_cache_events_tp2():
# created + stored events
assert events and len(events) >= 2
for event in events:
print(f"event: {event}")
if event:
if event[0]["event_id"] == 0:
assert event[0]["data"]["type"] == "created"
elif event[0]["event_id"] == 1:
assert event[0]["data"]["type"] == "stored"
if event["event_id"] == 0:
assert event["data"]["type"] == "created"
elif event["event_id"] == 1:
assert event["data"]["type"] == "stored"
@pytest.fixture(scope="module")

View File

@ -1,3 +1,4 @@
import json
import random
import time
from contextlib import contextmanager, nullcontext
@ -976,6 +977,62 @@ async def test_llm_rpc_streaming():
print(f"get result: {outputs}")
@skip_ray
def test_llm_rpc_get_stats():
"""Test that get_stats works with RPC orchestrator."""
with LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config,
enable_iter_perf_stats=True,
orchestrator_type="rpc") as llm:
assert isinstance(llm._executor, GenerationExecutorRpcProxy)
# Generate some output to produce stats
for output in llm.generate(
prompts, sampling_params=SamplingParams(max_tokens=5)):
print(output)
stats = llm.get_stats(timeout=5)
assert len(stats) > 0, "Should have at least one stats entry"
# Stats should be JSON strings that can be parsed
parsed = json.loads(stats[0]) if isinstance(stats[0], str) else stats[0]
assert "iter" in parsed, "Stats should contain 'iter' field"
assert "cpuMemUsage" in parsed, "Stats should contain 'cpuMemUsage' field"
@skip_ray
@pytest.mark.asyncio
async def test_llm_rpc_get_stats_async():
"""Test that get_stats_async works with RPC orchestrator."""
import json
with LLM(model=llama_model_path,
kv_cache_config=global_kvcache_config,
enable_iter_perf_stats=True,
orchestrator_type="rpc") as llm:
assert isinstance(llm._executor, GenerationExecutorRpcProxy)
# Generate some output to produce stats
async for output in llm.generate_async(
prompts[0], sampling_params=SamplingParams(max_tokens=5)):
print(output)
# Get stats via async API
stats_result = llm.get_stats_async(timeout=2)
# Should be able to iterate over results
stats_count = 0
async for stat in stats_result:
parsed = json.loads(stat) if isinstance(stat, str) else stat
assert "iter" in parsed, "Stats should contain 'iter' field"
stats_count += 1
if stats_count >= 1:
break # Just verify we can get at least one
assert stats_count > 0, "Should have received at least one stat"
@pytest.mark.threadleak(enabled=False)
@pytest.mark.part0
@skip_ray