mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-11 21:43:24 +08:00
Merge branch 'main' into fix_spec_gate
This commit is contained in:
commit
9b33ea751b
@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
int thread_idx = ThreadingPolicy::offset();
|
||||
int local_token_idx = ThreadingPolicy::token_idx();
|
||||
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
if (local_num_tokens == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare per-policy shared-memory tiles for this token
|
||||
extern __shared__ int smem[];
|
||||
int* smem_topk_target_ranks;
|
||||
int* smem_topk_send_indices;
|
||||
int warps_per_block = blockDim.x / warpSize;
|
||||
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
|
||||
{
|
||||
int lane_id = threadIdx.x / warpSize;
|
||||
smem_topk_target_ranks = smem + lane_id * TOP_K;
|
||||
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
|
||||
// Special case: If local_num_tokens == 0,
|
||||
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
|
||||
// Other threads should return.
|
||||
if (local_token_idx > 0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_topk_target_ranks = smem;
|
||||
smem_topk_send_indices = smem + TOP_K;
|
||||
}
|
||||
// Threads that do not have a token to process should return.
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
return;
|
||||
|
||||
uint64_t already_copied = 0;
|
||||
for (int k = 0; k < TOP_K; k++)
|
||||
{
|
||||
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
|
||||
// Use contiguous partitioning to determine target rank
|
||||
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
|
||||
|
||||
if (already_copied & (1ULL << target_rank))
|
||||
// Prepare per-policy shared-memory tiles for this token
|
||||
extern __shared__ int smem[];
|
||||
int* smem_topk_target_ranks;
|
||||
int* smem_topk_send_indices;
|
||||
int warps_per_block = blockDim.x / warpSize;
|
||||
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
|
||||
{
|
||||
int lane_id = threadIdx.x / warpSize;
|
||||
smem_topk_target_ranks = smem + lane_id * TOP_K;
|
||||
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_topk_target_ranks = smem;
|
||||
smem_topk_send_indices = smem + TOP_K;
|
||||
}
|
||||
|
||||
uint64_t already_copied = 0;
|
||||
for (int k = 0; k < TOP_K; k++)
|
||||
{
|
||||
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
|
||||
// Use contiguous partitioning to determine target rank
|
||||
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
|
||||
|
||||
if (already_copied & (1ULL << target_rank))
|
||||
{
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
|
||||
// Mirror to shared memory immediately
|
||||
smem_topk_target_ranks[k] = -1;
|
||||
smem_topk_send_indices[k] = -1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only one thread per warp should increment the counter
|
||||
int dst_token_idx;
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
|
||||
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
|
||||
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
|
||||
// Mirror to shared memory immediately
|
||||
smem_topk_target_ranks[k] = -1;
|
||||
smem_topk_send_indices[k] = -1;
|
||||
smem_topk_target_ranks[k] = target_rank;
|
||||
smem_topk_send_indices[k] = dst_token_idx;
|
||||
}
|
||||
continue;
|
||||
already_copied |= 1ULL << target_rank;
|
||||
}
|
||||
// Sync before dispatching data
|
||||
ThreadingPolicy::sync();
|
||||
|
||||
// Only one thread per warp should increment the counter
|
||||
int dst_token_idx;
|
||||
if (thread_idx == 0)
|
||||
{
|
||||
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
|
||||
|
||||
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
|
||||
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
|
||||
// Mirror to shared memory immediately
|
||||
smem_topk_target_ranks[k] = target_rank;
|
||||
smem_topk_send_indices[k] = dst_token_idx;
|
||||
}
|
||||
already_copied |= 1ULL << target_rank;
|
||||
}
|
||||
// Sync before dispatching data
|
||||
ThreadingPolicy::sync();
|
||||
|
||||
// Read staged routing once into registers per thread
|
||||
int topk_target_ranks[TOP_K];
|
||||
int topk_send_indices[TOP_K];
|
||||
// Read staged routing once into registers per thread
|
||||
int topk_target_ranks[TOP_K];
|
||||
int topk_send_indices[TOP_K];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOP_K; ++k)
|
||||
{
|
||||
topk_target_ranks[k] = smem_topk_target_ranks[k];
|
||||
topk_send_indices[k] = smem_topk_send_indices[k];
|
||||
for (int k = 0; k < TOP_K; ++k)
|
||||
{
|
||||
topk_target_ranks[k] = smem_topk_target_ranks[k];
|
||||
topk_send_indices[k] = smem_topk_send_indices[k];
|
||||
}
|
||||
|
||||
// Perform a single source load and TOP_K fanout per payload
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
|
||||
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
|
||||
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
|
||||
|
||||
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
|
||||
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
|
||||
}
|
||||
|
||||
ThreadingPolicy::sync();
|
||||
}
|
||||
|
||||
// Perform a single source load and TOP_K fanout per payload
|
||||
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
|
||||
{
|
||||
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
|
||||
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
|
||||
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
|
||||
|
||||
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
|
||||
ptrs, topk_target_ranks, topk_send_indices);
|
||||
}
|
||||
|
||||
ThreadingPolicy::sync();
|
||||
|
||||
bool is_first_warp = threadIdx.x / warpSize == 0;
|
||||
if (is_first_warp)
|
||||
{
|
||||
@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
|
||||
bool is_last_token = false;
|
||||
if (lane_id == 0)
|
||||
{
|
||||
int cnt = atomicAdd(ptrs.local_token_counter, 1);
|
||||
is_last_token = cnt + 1 == local_num_tokens;
|
||||
if (local_num_tokens != 0)
|
||||
{
|
||||
int cnt = atomicAdd(ptrs.local_token_counter, 1);
|
||||
is_last_token = cnt + 1 == local_num_tokens;
|
||||
}
|
||||
else
|
||||
{
|
||||
is_last_token = true;
|
||||
}
|
||||
}
|
||||
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
|
||||
|
||||
@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
// Validate parameters
|
||||
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
|
||||
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
|
||||
TLLM_CHECK(params.local_num_tokens > 0);
|
||||
TLLM_CHECK(params.local_num_tokens >= 0);
|
||||
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
|
||||
|
||||
// Prepare kernel pointers struct
|
||||
@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
if (params.one_block_per_token)
|
||||
{
|
||||
int grid_size = params.local_num_tokens;
|
||||
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
|
||||
if (grid_size == 0)
|
||||
{
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
|
||||
@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
|
||||
else
|
||||
{
|
||||
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
|
||||
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
|
||||
if (grid_size == 0)
|
||||
{
|
||||
grid_size = 1;
|
||||
}
|
||||
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
|
||||
SWITCH_TOP_K(params.top_k, TOP_K,
|
||||
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
|
||||
@ -897,9 +924,19 @@ __global__ void moeA2ACombineKernel(
|
||||
int local_token_idx = ThreadingPolicy::token_idx();
|
||||
int const size_per_token = elements_per_token * sizeof(T);
|
||||
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
if (local_num_tokens == 0)
|
||||
{
|
||||
return;
|
||||
// Special case: If local_num_tokens == 0,
|
||||
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
|
||||
// Other threads should return.
|
||||
if (local_token_idx > 0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Threads that do not have a token to process should return.
|
||||
if (local_token_idx >= local_num_tokens)
|
||||
return;
|
||||
}
|
||||
|
||||
#if !DISABLE_SYNC_FOR_PROFILING
|
||||
@ -951,6 +988,9 @@ __global__ void moeA2ACombineKernel(
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
if (local_num_tokens == 0)
|
||||
return;
|
||||
|
||||
// Get output location for this token (using src_data_ptrs[0] as output)
|
||||
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;
|
||||
|
||||
@ -1003,7 +1043,7 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
|
||||
// Validate parameters
|
||||
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
|
||||
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
|
||||
TLLM_CHECK(params.local_num_tokens > 0);
|
||||
TLLM_CHECK(params.local_num_tokens >= 0);
|
||||
TLLM_CHECK(params.elements_per_token > 0);
|
||||
|
||||
// Configure kernel launch
|
||||
@ -1011,6 +1051,15 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
|
||||
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
|
||||
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
|
||||
int grid_size_block = params.local_num_tokens;
|
||||
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
|
||||
if (grid_size_warp == 0)
|
||||
{
|
||||
grid_size_warp = 1;
|
||||
}
|
||||
if (grid_size_block == 0)
|
||||
{
|
||||
grid_size_block = 1;
|
||||
}
|
||||
|
||||
// Prepare kernel pointers struct for combine
|
||||
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize
|
||||
|
||||
@ -186,7 +186,6 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
|
||||
MoeA2ADataOffsets const& offsets = *reinterpret_cast<MoeA2ADataOffsets const*>(metainfo.data_ptr<int64_t>());
|
||||
|
||||
int64_t localNumTokens = tokenSelectedExperts.size(0);
|
||||
TORCH_CHECK(localNumTokens > 0, "localNumTokens must be positive");
|
||||
TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive");
|
||||
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
|
||||
TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]");
|
||||
|
||||
@ -1,29 +1,29 @@
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
def _resolve_activation(name: Optional[str]) -> Callable[[torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Returns an elementwise activation callable matching the given name.
|
||||
Supported: "silu", "relu2".
|
||||
Defaults to SiLU when name is None or empty.
|
||||
"""
|
||||
if not name:
|
||||
name = "silu"
|
||||
key = name.lower()
|
||||
|
||||
if key == "silu":
|
||||
return F.silu
|
||||
elif key == "relu2":
|
||||
def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Returns an elementwise activation callable matching the given activation function.
|
||||
Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2
|
||||
"""
|
||||
assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], (
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'."
|
||||
)
|
||||
torch_fn = None
|
||||
if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu:
|
||||
torch_fn = F.silu
|
||||
elif act_fn == ActivationType.Relu2:
|
||||
|
||||
def relu2(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
return relu2
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{name}'. Use one of: silu, relu2.")
|
||||
torch_fn = relu2
|
||||
return torch_fn
|
||||
|
||||
|
||||
def _template_moe(
|
||||
@ -94,8 +94,8 @@ def torch_moe(
|
||||
w1_weight: List[torch.Tensor],
|
||||
w2_weight: List[torch.Tensor],
|
||||
w3_weight: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
apply_routing_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -117,8 +117,8 @@ def torch_moe(
|
||||
- Llama4 MoE: sigmoid activated weights
|
||||
w1_weight:
|
||||
For per-expert lists:
|
||||
• mlp_style=="gated_mlp": List of W1 with shape (I, H) — "gate" projection.
|
||||
• mlp_style=="mlp": List of W_up with shape (I, H) — up projection.
|
||||
• is_gated_mlp==True: List of W1 with shape (I, H) — "gate" projection.
|
||||
• is_gated_mlp==False: List of W_up with shape (I, H) — up projection.
|
||||
For stacked tensors (Llama4):
|
||||
• Single-element list containing stacked w3_w1 tensor with shape (E, 2*I, H) in TRT-LLM format
|
||||
w2_weight:
|
||||
@ -129,17 +129,17 @@ def torch_moe(
|
||||
w3_weight:
|
||||
For per-expert lists with gated_mlp:
|
||||
• List of W3 with shape (I, H) — "up" (second) projection in gated MLP.
|
||||
For mlp style or stacked tensors:
|
||||
For is_gated_mlp==False or stacked tensors:
|
||||
• pass an empty list []; ignored.
|
||||
mlp_style:
|
||||
is_gated_mlp:
|
||||
Selects the per-expert MLP computation:
|
||||
• "gated_mlp" (default, Mixtral/DeepSeek/Llama4-style):
|
||||
• is_gated_mlp==True (default, Mixtral/DeepSeek/Llama4-style):
|
||||
y = W2( act(W1 x) * (W3 x) )
|
||||
• "mlp" (NemotronH-style 2-layer MLP):
|
||||
• is_gated_mlp==False (NemotronH-style 2-layer MLP):
|
||||
y = W_down( act(W_up x) )
|
||||
act_fn:
|
||||
Elementwise activation applied inside the expert MLP.
|
||||
Supported: "silu" (default), "relu2" (ReLU then square).
|
||||
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
|
||||
apply_routing_on_input:
|
||||
If True (Llama4 pattern): multiply routing weights with INPUT before MLP
|
||||
Result: act(input * routing_weight) - routing affects activation
|
||||
@ -148,55 +148,63 @@ def torch_moe(
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as the input x.
|
||||
"""
|
||||
act_fn = _resolve_activation(act_fn)
|
||||
style = mlp_style.lower()
|
||||
torch_act_fn = _resolve_torch_fn(act_fn)
|
||||
|
||||
# Detect if using stacked tensor format (Llama4) vs per-expert lists (standard)
|
||||
is_stacked = len(w1_weight) == 1 and w1_weight[0].ndim == 3
|
||||
|
||||
# Todo: either change torch_moe to use a single condition, or refactor this code.
|
||||
# it should be :
|
||||
# is_gated_mlp:
|
||||
# stacked:
|
||||
# ...
|
||||
# not stacked:
|
||||
# .
|
||||
# else:
|
||||
# assert (not stacked)
|
||||
# ...
|
||||
# .
|
||||
if is_stacked:
|
||||
# Llama4 stacked tensor format - only supports gated_mlp
|
||||
if style != "gated_mlp":
|
||||
raise ValueError("Stacked tensor format only supports 'gated_mlp' style")
|
||||
if not is_gated_mlp:
|
||||
raise ValueError("Stacked tensor format only supports gated MLP style")
|
||||
|
||||
w3_w1_stacked = w1_weight[0] # (E, 2*I, H)
|
||||
intermediate_size = w3_w1_stacked.shape[1] // 2
|
||||
w2_stacked = w2_weight[0] # (E, H, I)
|
||||
|
||||
def make_mlp(i: int):
|
||||
gate_up = w3_w1_stacked[i] # (2*I, H)
|
||||
intermediate_size = gate_up.shape[0] // 2
|
||||
def make_mlp(idx: int):
|
||||
gate_up = w3_w1_stacked[idx] # (2*I, H)
|
||||
W3 = gate_up[:intermediate_size, :] # (I, H)
|
||||
W1 = gate_up[intermediate_size:, :] # (I, H)
|
||||
W2 = w2_stacked[i] # (H, I)
|
||||
W2 = w2_stacked[idx] # (H, I)
|
||||
weight_dtype = W1.dtype
|
||||
return lambda inp: F.linear(
|
||||
act_fn(F.linear(inp.to(weight_dtype), W1)) * F.linear(inp.to(weight_dtype), W3),
|
||||
torch_act_fn(F.linear(inp.to(weight_dtype), W1))
|
||||
* F.linear(inp.to(weight_dtype), W3),
|
||||
W2,
|
||||
)
|
||||
|
||||
mlps = [make_mlp(i) for i in range(w3_w1_stacked.shape[0])]
|
||||
mlps = [make_mlp(idx) for idx in range(w3_w1_stacked.shape[0])]
|
||||
|
||||
elif style == "gated_mlp":
|
||||
elif is_gated_mlp:
|
||||
# Standard per-expert list format with gated MLP
|
||||
def make_mlp(i: int):
|
||||
W1 = w1_weight[i] # (I, H)
|
||||
W2 = w2_weight[i] # (H, I)
|
||||
W3 = w3_weight[i] # (I, H)
|
||||
return lambda inp: F.linear(act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
|
||||
|
||||
mlps = [make_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
elif style == "mlp":
|
||||
# Standard per-expert list format with simple MLP
|
||||
def make_mlp(i: int):
|
||||
W_up = w1_weight[i] # (I, H)
|
||||
W_down = w2_weight[i] # (H, I)
|
||||
return lambda inp: F.linear(act_fn(F.linear(inp, W_up)), W_down)
|
||||
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W1)) * F.linear(inp, W3), W2)
|
||||
|
||||
mlps = [make_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
# Standard per-expert list format with simple MLP
|
||||
def make_mlp(i: int):
|
||||
W_up = w1_weight[i] # (I, H)
|
||||
W_down = w2_weight[i] # (H, I)
|
||||
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W_up)), W_down)
|
||||
|
||||
mlps = [make_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps, apply_routing_on_input)
|
||||
|
||||
@ -209,8 +217,8 @@ def torch_moe_fake(
|
||||
w1_weight: List[torch.Tensor],
|
||||
w2_weight: List[torch.Tensor],
|
||||
w3_weight: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
apply_routing_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
@ -296,23 +304,20 @@ def torch_quant_fp8_moe(
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
|
||||
act_fn: str = "silu", # silu or relu2
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
FP8 MoE op using quantized linear operations.
|
||||
|
||||
Computes a Mixture-of-Experts layer similar to the reference auto_deploy::torch_moe op, but uses the
|
||||
quantized FP8 linear op for expert computations.
|
||||
FP8 MoE op using quantized linear operations. Computes a Mixture-of-Experts layer similar to the reference
|
||||
auto_deploy::torch_moe op, but uses the quantized FP8 linear op for expert computations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, H) or (B, S, H).
|
||||
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
|
||||
routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight:
|
||||
List of per-expert weight tensors:
|
||||
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
|
||||
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
|
||||
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K)
|
||||
containing expert indices.routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight: List of per-expert weight tensors:
|
||||
• is_gated_mlp==True: W1 with shape (I, H) — "gate" projection.
|
||||
• is_gated_mlp==False: W_up with shape (I, H) — up projection.
|
||||
w2_weight:
|
||||
List of per-expert weight tensors:
|
||||
• gated_mlp: W2 with shape (H, I) — down projection.
|
||||
@ -323,21 +328,20 @@ def torch_quant_fp8_moe(
|
||||
• mlp: pass an empty list []; ignored.
|
||||
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
|
||||
mlp_style:
|
||||
is_gated_mlp:
|
||||
Selects the per-expert MLP computation:
|
||||
• "gated_mlp" (default, Mixtral/DeepSeek-style):
|
||||
• is_gated_mlp==True (default, Mixtral/DeepSeek-style):
|
||||
y = W2( act(W1 x) * (W3 x) )
|
||||
• "mlp" (NemotronH-style 2-layer MLP):
|
||||
• is_gated_mlp==False (NemotronH-style 2-layer MLP):
|
||||
y = W_down( act(W_up x) )
|
||||
act_fn:
|
||||
Elementwise activation applied inside the expert MLP.
|
||||
Supported: "silu" (default), "relu2" (ReLU then square).
|
||||
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
|
||||
"""
|
||||
|
||||
act_fn = _resolve_activation(act_fn)
|
||||
style = mlp_style.lower()
|
||||
torch_act_fn = _resolve_torch_fn(act_fn)
|
||||
|
||||
if style == "gated_mlp":
|
||||
if is_gated_mlp:
|
||||
|
||||
def make_fp8_mlp(i):
|
||||
def mlp(inp):
|
||||
@ -355,7 +359,7 @@ def torch_quant_fp8_moe(
|
||||
input_scale=w3_input_scale[i],
|
||||
weight_scale=w3_weight_scale[i],
|
||||
)
|
||||
prod = act_fn(gate_out) * up_out
|
||||
prod = torch_act_fn(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
@ -368,7 +372,7 @@ def torch_quant_fp8_moe(
|
||||
|
||||
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
elif style == "mlp":
|
||||
else:
|
||||
|
||||
def make_fp8_mlp(i):
|
||||
def mlp(inp):
|
||||
@ -380,7 +384,7 @@ def torch_quant_fp8_moe(
|
||||
weight_scale=w1_weight_scale[i],
|
||||
)
|
||||
return torch.ops.auto_deploy.torch_quant_fp8_linear(
|
||||
act_fn(up_out),
|
||||
torch_act_fn(up_out),
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
@ -391,9 +395,6 @@ def torch_quant_fp8_moe(
|
||||
|
||||
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@ -411,8 +412,8 @@ def torch_quant_fp8_moe_fake(
|
||||
w1_weight_scale: List[torch.Tensor],
|
||||
w2_weight_scale: List[torch.Tensor],
|
||||
w3_weight_scale: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -434,8 +435,8 @@ def torch_quant_nvfp4_moe(
|
||||
w1_alpha: List[torch.Tensor],
|
||||
w2_alpha: List[torch.Tensor],
|
||||
w3_alpha: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
|
||||
act_fn: str = "silu", # silu or relu2
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
FP4 MoE op using quantized linear operations.
|
||||
@ -449,8 +450,8 @@ def torch_quant_nvfp4_moe(
|
||||
routing_weights: Tensor of normalized routing weights.
|
||||
w1_weight:
|
||||
List of per-expert weight tensors:
|
||||
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
|
||||
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
|
||||
• is_gated_mlp==True: W1 with shape (I, H) — "gate" projection.
|
||||
• is_gated_mlp==False: W_up with shape (I, H) — up projection.
|
||||
w2_weight:
|
||||
List of per-expert weight tensors:
|
||||
• gated_mlp: W2 with shape (H, I) — down projection.
|
||||
@ -462,21 +463,20 @@ def torch_quant_nvfp4_moe(
|
||||
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
|
||||
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
|
||||
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
|
||||
mlp_style:
|
||||
is_gated_mlp:
|
||||
Selects the per-expert MLP computation:
|
||||
• "gated_mlp" (default, Mixtral/DeepSeek-style):
|
||||
• is_gated_mlp==True (default, Mixtral/DeepSeek-style):
|
||||
y = W2( act(W1 x) * (W3 x) )
|
||||
• "mlp" (NemotronH-style 2-layer MLP):
|
||||
• is_gated_mlp==False (NemotronH-style 2-layer MLP):
|
||||
y = W_down( act(W_up x) )
|
||||
act_fn:
|
||||
Elementwise activation applied inside the expert MLP.
|
||||
Supported: "silu" (default), "relu2" (ReLU then square).
|
||||
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
|
||||
"""
|
||||
|
||||
act_fn = _resolve_activation(act_fn)
|
||||
style = mlp_style.lower()
|
||||
torch_act_fn = _resolve_torch_fn(act_fn)
|
||||
|
||||
if style == "gated_mlp":
|
||||
if is_gated_mlp:
|
||||
|
||||
def make_fp4_mlp(i):
|
||||
def mlp(inp):
|
||||
@ -498,7 +498,7 @@ def torch_quant_nvfp4_moe(
|
||||
weight_scale=w3_weight_scale[i],
|
||||
alpha=w3_alpha[i],
|
||||
)
|
||||
prod = act_fn(gate_out) * up_out
|
||||
prod = torch_act_fn(gate_out) * up_out
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
prod,
|
||||
w2_weight[i],
|
||||
@ -512,7 +512,7 @@ def torch_quant_nvfp4_moe(
|
||||
|
||||
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
elif style == "mlp":
|
||||
else:
|
||||
|
||||
def make_fp4_mlp(i):
|
||||
def mlp(inp):
|
||||
@ -527,7 +527,7 @@ def torch_quant_nvfp4_moe(
|
||||
alpha=w1_alpha[i],
|
||||
)
|
||||
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
|
||||
act_fn(up_out),
|
||||
torch_act_fn(up_out),
|
||||
w2_weight[i],
|
||||
bias=None,
|
||||
input_scale=w2_input_scale[i],
|
||||
@ -539,9 +539,6 @@ def torch_quant_nvfp4_moe(
|
||||
|
||||
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
|
||||
return _template_moe(x, selected_experts, routing_weights, mlps)
|
||||
|
||||
|
||||
@ -562,8 +559,8 @@ def torch_quant_nvfp4_moe_fake(
|
||||
w1_alpha: List[torch.Tensor],
|
||||
w2_alpha: List[torch.Tensor],
|
||||
w3_alpha: List[torch.Tensor],
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@ -14,6 +14,8 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType # noqa: F401
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
|
||||
@ -601,15 +603,13 @@ def triton_fused_moe(
|
||||
routing_weights: torch.Tensor,
|
||||
w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
mlp_style: str = "mlp",
|
||||
act_fn: str = "relu2",
|
||||
is_gated_mlp: bool = False,
|
||||
act_fn: int = int(ActivationType.Relu2),
|
||||
) -> torch.Tensor:
|
||||
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
|
||||
|
||||
mlp_style = mlp_style.lower()
|
||||
act_fn = act_fn.lower()
|
||||
assert mlp_style == "mlp", "Triton backend only supports mlp style."
|
||||
assert act_fn == "relu2", "Triton backend only supports relu2 activation."
|
||||
assert not is_gated_mlp, "Triton backend only supports non gated MLP style."
|
||||
assert act_fn == ActivationType.Relu2, "Triton backend only supports relu2 activation."
|
||||
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
@ -661,12 +661,12 @@ def triton_quant_fp8_moe(
|
||||
w1_weight_scale: torch.Tensor, # [E] stacked weight scales
|
||||
w2_weight_scale: torch.Tensor, # [E] stacked weight scales
|
||||
w3_weight_scale: torch.Tensor, # unused
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = False,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
"""Triton FP8 W8A8 MoE with 2-layer MLP and ReLU^2 activation."""
|
||||
if mlp_style != "mlp":
|
||||
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp_style=='mlp' only")
|
||||
if is_gated_mlp:
|
||||
raise NotImplementedError("triton_quant_fp8_moe currently supports mlp only")
|
||||
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
@ -760,7 +760,7 @@ def triton_quant_fp8_moe(
|
||||
w1_weight_scale: torch.Tensor,
|
||||
w2_weight_scale: torch.Tensor,
|
||||
w3_weight_scale: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = False,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -26,8 +26,8 @@ def trtllm_moe_fused(
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
@ -37,24 +37,24 @@ def trtllm_moe_fused(
|
||||
quant_scales = []
|
||||
|
||||
# Determine activation type
|
||||
mlp_style = mlp_style.lower()
|
||||
act_fn = act_fn.lower()
|
||||
|
||||
activation_type = ActivationType.Swiglu
|
||||
if mlp_style == "gated_mlp":
|
||||
if is_gated_mlp:
|
||||
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
|
||||
if act_fn == "silu":
|
||||
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
|
||||
activation_type = ActivationType.Swiglu
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
|
||||
elif mlp_style == "mlp":
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
|
||||
)
|
||||
else:
|
||||
# For non-gated MLP with ReLU^2
|
||||
if act_fn == "relu2":
|
||||
if act_fn == ActivationType.Relu2:
|
||||
activation_type = ActivationType.Relu2
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
|
||||
)
|
||||
|
||||
return torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
@ -77,8 +77,8 @@ def trtllm_moe_fused_fake(
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -93,21 +93,12 @@ def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None:
|
||||
supported_combinations = {
|
||||
"gated_mlp": ["silu"],
|
||||
"mlp": ["relu2"],
|
||||
}
|
||||
supported_act_fns = [
|
||||
act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list
|
||||
]
|
||||
assert mlp_style in supported_combinations.keys(), (
|
||||
f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}."
|
||||
)
|
||||
assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}."
|
||||
assert act_fn in supported_combinations[mlp_style], (
|
||||
f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. "
|
||||
f"Supported combinations: {supported_combinations}"
|
||||
def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None:
|
||||
assert (is_gated_mlp and act_fn == ActivationType.Silu) or (
|
||||
not is_gated_mlp and act_fn == ActivationType.Relu2
|
||||
), (
|
||||
f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'. "
|
||||
f"Supported combinations: gated mlp with silu or mlp with relu2."
|
||||
)
|
||||
|
||||
|
||||
@ -128,8 +119,8 @@ def trtllm_quant_fp8_moe_fused(
|
||||
gemm1_dequant: torch.Tensor, # [E]
|
||||
gemm2_act_quant: torch.Tensor, # [E]
|
||||
gemm2_dequant: torch.Tensor, # [E]
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP.
|
||||
@ -149,8 +140,8 @@ def trtllm_quant_fp8_moe_fused(
|
||||
gemm1_dequant: Precomputed gemm1 dequant scale [E]
|
||||
gemm2_act_quant: Precomputed gemm2 act quant scale [1]
|
||||
gemm2_dequant: Precomputed gemm2 dequant scale [E]
|
||||
mlp_style: "gated_mlp" or "mlp"
|
||||
act_fn: "silu" for gated_mlp, "relu2" for mlp
|
||||
is_gated_mlp: True for gated_mlp, False for mlp
|
||||
act_fn: ActivationType.Silu for gated_mlp, ActivationType.Relu2 for mlp
|
||||
|
||||
Non-Gated MLP:
|
||||
activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t()
|
||||
@ -159,7 +150,7 @@ def trtllm_quant_fp8_moe_fused(
|
||||
activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t()
|
||||
"""
|
||||
|
||||
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
|
||||
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
|
||||
|
||||
# Store original shape and flatten to 2D
|
||||
x_shape = x.shape
|
||||
@ -190,28 +181,27 @@ def trtllm_quant_fp8_moe_fused(
|
||||
# Todo: refactor this repeating code block
|
||||
|
||||
# Determine activation type
|
||||
mlp_style = mlp_style.lower()
|
||||
act_fn = act_fn.lower()
|
||||
|
||||
activation_type = ActivationType.Swiglu
|
||||
if mlp_style == "gated_mlp":
|
||||
if is_gated_mlp:
|
||||
# Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T)
|
||||
# For gated MLP, concatenate w1 and w3 as [w3, w1]
|
||||
w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H]
|
||||
fc1_expert_weights = w3_w1_stacked
|
||||
if act_fn == "silu":
|
||||
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
|
||||
activation_type = ActivationType.Swiglu
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
|
||||
elif mlp_style == "mlp":
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
|
||||
)
|
||||
else:
|
||||
# For non-gated MLP with ReLU^2
|
||||
fc1_expert_weights = w1_weight.contiguous()
|
||||
if act_fn == "relu2":
|
||||
if act_fn == ActivationType.Relu2:
|
||||
activation_type = ActivationType.Relu2
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
|
||||
)
|
||||
|
||||
# Note! Outputting Float8_e4m3fn directly is not currently supported
|
||||
output = torch.ops.trtllm.fused_moe(
|
||||
@ -248,10 +238,10 @@ def trtllm_quant_fp8_moe_fused_fake(
|
||||
gemm1_dequant: torch.Tensor,
|
||||
gemm2_act_quant: torch.Tensor,
|
||||
gemm2_dequant: torch.Tensor,
|
||||
mlp_style: str,
|
||||
act_fn: str,
|
||||
is_gated_mlp: bool,
|
||||
act_fn: int,
|
||||
) -> torch.Tensor:
|
||||
_validate_mlp_style_and_act_fn(mlp_style, act_fn)
|
||||
_validate_mlp_style_and_act_fn(is_gated_mlp, act_fn)
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@ -268,8 +258,8 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
fc2_act_global_scale: torch.Tensor, # Global scale for FC2 activations
|
||||
fc1_alpha: torch.Tensor, # Precomputed FC1 alpha (1.0 / (fc1_act_global_scale * fc1_weight_blockscale_fp8))
|
||||
fc2_alpha: torch.Tensor, # Precomputed FC2 alpha (1.0 / (fc2_act_global_scale * fc2_weight_blockscale_fp8))
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
"""TensorRT-LLM Cutlass NVFP4 W8A8 MoE for gated and non-gated MLP.
|
||||
|
||||
@ -285,22 +275,22 @@ def trtllm_quant_nvfp4_moe_fused(
|
||||
|
||||
"""
|
||||
NVFP4_BLOCK_SIZE = 16
|
||||
mlp_style = mlp_style.lower()
|
||||
act_fn = act_fn.lower()
|
||||
|
||||
activation_type = ActivationType.Swiglu
|
||||
if mlp_style == "gated_mlp":
|
||||
if act_fn == "silu":
|
||||
if is_gated_mlp:
|
||||
if act_fn in [ActivationType.Silu, ActivationType.Swiglu]:
|
||||
activation_type = ActivationType.Swiglu
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.")
|
||||
elif mlp_style == "mlp":
|
||||
if act_fn == "relu2":
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}' for gated_mlp. Use 'silu'."
|
||||
)
|
||||
else:
|
||||
if act_fn == ActivationType.Relu2:
|
||||
activation_type = ActivationType.Relu2
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.")
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{ActivationType(act_fn).name}' for mlp. Use 'relu2'."
|
||||
)
|
||||
|
||||
# quant_scales is described by this code:
|
||||
# https://github.com/NVIDIA/TensorRT-LLM/blob/c9771ebb997683c08b26bbba796a7fc6aff09d93/cpp/tensorrt_llm/thop/moeOp.cpp#L1015
|
||||
@ -353,7 +343,7 @@ def trtllm_quant_nvfp4_moe_fused_fake(
|
||||
fc2_act_global_scale: torch.Tensor,
|
||||
fc1_alpha: torch.Tensor,
|
||||
fc2_alpha: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -13,6 +13,7 @@ from transformers.modeling_outputs import CausalLMOutput, MoeModelOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
from tensorrt_llm.inputs.utils import HF_CHAT_TEMPLATE_EXCEPTIONS
|
||||
|
||||
from ..nemotron_flash import NemotronFlashForCausalLMFactory
|
||||
@ -182,6 +183,8 @@ class DeltaNet(nn.Module):
|
||||
self.qk_activation = qk_activation
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
# can't use ActivationType enum here,
|
||||
# because there is no Elu defined in cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
|
||||
assert self.qk_activation in ["silu", "relu", "elu", "identity"]
|
||||
assert self.qk_norm in ["l2", "sum"]
|
||||
|
||||
@ -331,7 +334,7 @@ class NemotronFlashMamba2(nn.Module):
|
||||
self.num_heads = self.d_inner // self.headdim
|
||||
self.rmsnorm = rmsnorm
|
||||
self.dt_limit = dt_limit
|
||||
self.activation = "silu"
|
||||
self.activation = ActivationType.Silu
|
||||
self.chunk_size = chunk_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from transformers.utils import ModelOutput
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import gated_rms_norm_ref
|
||||
from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
class MambaRMSNormGated(torch.nn.Module):
|
||||
@ -308,8 +309,8 @@ class NemotronHMOE(nn.Module):
|
||||
w1_weight=[e.up_proj.weight for e in self.experts],
|
||||
w2_weight=[e.down_proj.weight for e in self.experts],
|
||||
w3_weight=[],
|
||||
act_fn="relu2",
|
||||
mlp_style="mlp",
|
||||
act_fn=ActivationType.Relu2,
|
||||
is_gated_mlp=False,
|
||||
)
|
||||
|
||||
if has_latent_proj:
|
||||
|
||||
@ -5,6 +5,8 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
@ -70,20 +72,20 @@ def _insert_fused_moe_ops(gm: GraphModule, backend: Literal["auto", "trtllm", "t
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
|
||||
(is_gated_mlp, act_fn) = extract_op_args(node, "is_gated_mlp", "act_fn")
|
||||
|
||||
if is_stacked_moe:
|
||||
# Stacked MoE (Llama4 pattern): only supports gated MLP
|
||||
(act_fn_val,) = extract_op_args(node, "act_fn")
|
||||
_process_llama4_stacked_moe_node(
|
||||
gm, graph, node, replacement_op, act_fn_val, fused_key_counter
|
||||
gm, graph, node, replacement_op, act_fn, fused_key_counter
|
||||
)
|
||||
else:
|
||||
# Standard MoE with per-expert weight lists
|
||||
(mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn")
|
||||
assert backend != "triton" or mlp_style_val == "mlp", (
|
||||
assert backend != "triton" or not is_gated_mlp, (
|
||||
"Triton backend only supports mlp style."
|
||||
)
|
||||
_process_regular_moe_node(
|
||||
gm, graph, node, replacement_op, mlp_style_val, act_fn_val, fused_key_counter
|
||||
gm, graph, node, replacement_op, is_gated_mlp, act_fn, fused_key_counter
|
||||
)
|
||||
|
||||
fused_key_counter += 1
|
||||
@ -102,8 +104,8 @@ def _process_regular_moe_node(
|
||||
graph: torch.fx.Graph,
|
||||
node: Node,
|
||||
replacement_op,
|
||||
mlp_style_val: str,
|
||||
act_fn_val: str,
|
||||
is_gated_mlp: bool,
|
||||
act_fn: ActivationType,
|
||||
fused_key_counter: int,
|
||||
) -> None:
|
||||
"""Process a single torch_moe node with per-expert weight lists.
|
||||
@ -122,7 +124,7 @@ def _process_regular_moe_node(
|
||||
)
|
||||
|
||||
# Stack weights based on MLP style
|
||||
if mlp_style_val == "gated_mlp":
|
||||
if is_gated_mlp:
|
||||
# For gated MLP, concatenate w3 and w1 then stack across experts
|
||||
fused_w_up_experts = torch.stack(
|
||||
[
|
||||
@ -135,12 +137,10 @@ def _process_regular_moe_node(
|
||||
dim=0,
|
||||
)
|
||||
new_key_w_up = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
elif mlp_style_val == "mlp":
|
||||
else:
|
||||
# For regular MLP, just stack w1
|
||||
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
|
||||
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style: {mlp_style_val}")
|
||||
|
||||
# Stack w2/down weights
|
||||
fused_w_down_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
|
||||
@ -162,8 +162,8 @@ def _process_regular_moe_node(
|
||||
replacement_op,
|
||||
args=(hidden_states, selected_experts, routing_weights, w_up_arg, w_down_arg),
|
||||
kwargs={
|
||||
"mlp_style": mlp_style_val,
|
||||
"act_fn": act_fn_val,
|
||||
"is_gated_mlp": is_gated_mlp,
|
||||
"act_fn": act_fn,
|
||||
},
|
||||
)
|
||||
|
||||
@ -176,7 +176,7 @@ def _process_llama4_stacked_moe_node(
|
||||
graph: torch.fx.Graph,
|
||||
node: Node,
|
||||
replacement_op,
|
||||
act_fn_val: str,
|
||||
act_fn: ActivationType,
|
||||
fused_key_counter: int,
|
||||
) -> None:
|
||||
"""Process a single Llama4 MoE node with pre-stacked weight tensors.
|
||||
@ -301,7 +301,8 @@ def _process_llama4_stacked_moe_node(
|
||||
replacement_op,
|
||||
args=(scaled_input, selected_experts, ones_node, w_up_arg, w_down_arg),
|
||||
kwargs={
|
||||
"act_fn": act_fn_val,
|
||||
"act_fn": act_fn,
|
||||
"is_gated_mlp": True,
|
||||
},
|
||||
)
|
||||
|
||||
@ -1240,7 +1241,7 @@ class MatchBmmMoePattern(BaseTransform):
|
||||
w3_list_node,
|
||||
),
|
||||
kwargs={
|
||||
"mlp_style": "gated_mlp",
|
||||
"is_gated_mlp": True,
|
||||
"apply_routing_on_input": apply_routing_on_input,
|
||||
},
|
||||
)
|
||||
|
||||
@ -6,6 +6,8 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
@ -123,8 +125,8 @@ def trtllm_moe_fused_aux(
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
device = torch.cuda.current_device()
|
||||
with torch.cuda.stream(
|
||||
@ -137,7 +139,7 @@ def trtllm_moe_fused_aux(
|
||||
routing_weights,
|
||||
w3_w1_stacked_weight,
|
||||
w2_stacked_weight,
|
||||
mlp_style,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
@ -152,8 +154,8 @@ def trtllm_moe_fused_aux_fake(
|
||||
routing_weights: torch.Tensor,
|
||||
w3_w1_stacked_weight: torch.Tensor,
|
||||
w2_stacked_weight: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -213,8 +215,8 @@ def trtllm_quant_fp8_moe_fused_aux(
|
||||
gemm1_dequant: torch.Tensor, # [E]
|
||||
gemm2_act_quant: torch.Tensor, # [E]
|
||||
gemm2_dequant: torch.Tensor, # [E]
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
device = torch.cuda.current_device()
|
||||
with torch.cuda.stream(
|
||||
@ -237,7 +239,7 @@ def trtllm_quant_fp8_moe_fused_aux(
|
||||
gemm1_dequant,
|
||||
gemm2_act_quant,
|
||||
gemm2_dequant,
|
||||
mlp_style,
|
||||
is_gated_mlp,
|
||||
act_fn,
|
||||
)
|
||||
torch.ops.auto_deploy.record_event(device, cuda_stream_manager.AUX_STREAM_NAME)
|
||||
@ -262,8 +264,8 @@ def trtllm_quant_fp8_moe_fused_aux_fake(
|
||||
gemm1_dequant: torch.Tensor,
|
||||
gemm2_act_quant: torch.Tensor,
|
||||
gemm2_dequant: torch.Tensor,
|
||||
mlp_style: str = "gated_mlp",
|
||||
act_fn: str = "silu",
|
||||
is_gated_mlp: bool = True,
|
||||
act_fn: int = int(ActivationType.Silu),
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import is_op
|
||||
@ -87,15 +89,15 @@ def _quantize_moe_node(
|
||||
s1, s2, s3 = collect_scales(idx)
|
||||
args.extend([s1, s2, s3])
|
||||
|
||||
# Extract mlp_style and act_fn from the original node
|
||||
# Extract is_gated_mlp and act_fn from the original node
|
||||
# These can be in args[6:] or in kwargs
|
||||
mlp_style = "gated_mlp" # default
|
||||
act_fn = "silu" # default
|
||||
is_gated_mlp = True # default
|
||||
act_fn = ActivationType.Silu # default
|
||||
|
||||
if len(node.args) > 6:
|
||||
mlp_style = node.args[6]
|
||||
elif "mlp_style" in node.kwargs:
|
||||
mlp_style = node.kwargs["mlp_style"]
|
||||
is_gated_mlp = node.args[6]
|
||||
elif "is_gated_mlp" in node.kwargs:
|
||||
is_gated_mlp = node.kwargs["is_gated_mlp"]
|
||||
|
||||
if len(node.args) > 7:
|
||||
act_fn = node.args[7]
|
||||
@ -104,7 +106,7 @@ def _quantize_moe_node(
|
||||
|
||||
# Prepare kwargs for the quantized op
|
||||
kwargs = {
|
||||
"mlp_style": mlp_style,
|
||||
"is_gated_mlp": is_gated_mlp,
|
||||
"act_fn": act_fn,
|
||||
}
|
||||
|
||||
|
||||
@ -527,6 +527,8 @@ class LlavaNextModel(PreTrainedModel):
|
||||
return
|
||||
if not DISAGG:
|
||||
self.mm_encoder = LlavaNextVisionModel(model_config)
|
||||
else:
|
||||
self.mm_encoder = None
|
||||
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config = model_config.pretrained_config.text_config
|
||||
@ -545,7 +547,8 @@ class LlavaNextModel(PreTrainedModel):
|
||||
if isinstance(weight_mapper, LlavaNextHfWeightMapper):
|
||||
weights = weight_mapper.preprocess_weights(weights)
|
||||
|
||||
self.mm_encoder.load_weights(weights)
|
||||
if self.mm_encoder is not None:
|
||||
self.mm_encoder.load_weights(weights)
|
||||
|
||||
def filter_weights(weights: Dict):
|
||||
transformed_weights = {}
|
||||
|
||||
@ -32,7 +32,8 @@ from ...inputs import (BaseMultimodalDummyInputsBuilder,
|
||||
BaseMultimodalInputProcessor, ExtraProcessedInputs,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
register_input_processor,
|
||||
support_multimodal_disaggregated)
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -865,6 +866,8 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
mm_encoder_config = copy.deepcopy(model_config)
|
||||
self.mm_encoder = Qwen2VisionModelBase(
|
||||
mm_encoder_config, kwargs.get('vision_model_class', None))
|
||||
else:
|
||||
self.mm_encoder = None
|
||||
|
||||
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
config = model_config.pretrained_config
|
||||
@ -953,24 +956,21 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
"""
|
||||
VLM forward logic with inflight batching support.
|
||||
"""
|
||||
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
|
||||
num_context_requests = attn_metadata.num_contexts
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embeds = []
|
||||
mrope_config = {}
|
||||
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts.
|
||||
mm_multimodal_params = [
|
||||
multimodal_param for multimodal_param in multimodal_params
|
||||
if multimodal_param.multimodal_data.get("image", {}).get(
|
||||
"pixel_values") is not None or multimodal_param.multimodal_data.
|
||||
get("video", {}).get("pixel_values_videos") is not None
|
||||
]
|
||||
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate
|
||||
# the entries that do have multimodal data from those that correspond to text-only prompts.
|
||||
mm_multimodal_params = self._get_requests_with_mm_data(
|
||||
multimodal_params)
|
||||
if len(mm_multimodal_params) > 0:
|
||||
if not _is_disagg():
|
||||
mm_embeds = get_multimodal_embeddings(
|
||||
encoder_forward_fn=self.mm_encoder.forward,
|
||||
multimodal_params=mm_multimodal_params)
|
||||
else:
|
||||
elif not getattr(self, "support_mm_disagg", False):
|
||||
raise NotImplementedError(
|
||||
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
|
||||
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
|
||||
@ -995,6 +995,21 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
logger.debug(f'output shape: {output_prob.shape}')
|
||||
return output_prob
|
||||
|
||||
def _get_requests_with_mm_data(self, multimodal_params):
|
||||
mm_multimodal_params = []
|
||||
for multimodal_param in multimodal_params:
|
||||
data = multimodal_param.multimodal_data
|
||||
if (
|
||||
# The first 2 conditions check whether there is input on which inference should be run.
|
||||
data.get("image", {}).get("pixel_values") is not None or
|
||||
data.get("video", {}).get("pixel_values_videos") is not None
|
||||
# This condition corresponds to when the embeddings are already populated, as is e.g.
|
||||
# the case in EPD disagg in the prefill worker.
|
||||
or data.get("multimodal_embedding")):
|
||||
mm_multimodal_params.append(multimodal_param)
|
||||
|
||||
return mm_multimodal_params
|
||||
|
||||
|
||||
@register_vision_encoder(Qwen2VisionModelBase,
|
||||
vlm_base_model=Qwen2VisionTransformerPretrainedModel)
|
||||
@ -1032,11 +1047,89 @@ class Qwen2VLModel(Qwen2VLModelBase):
|
||||
self.llm.load_weights(weights, weight_mapper)
|
||||
|
||||
|
||||
class Qwen2_5VLInputProcessorBase(Qwen2VLInputProcessorBase):
|
||||
|
||||
def get_prompt_token_ids(
|
||||
self, inputs: TextPrompt,
|
||||
mm_handles: List[Dict[str,
|
||||
Any]]) -> Tuple[List[int], List[int], List[int]]:
|
||||
"""
|
||||
Build input token ids with multimodal placeholders expanded to the number of MM tokens.
|
||||
|
||||
Args:
|
||||
inputs: Text prompt input container. Must contain a non-empty prompt string.
|
||||
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int], List[int]]:
|
||||
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
|
||||
- mm_token_length: per-image MM token lengths
|
||||
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
|
||||
"""
|
||||
# TODO: Move this function to the base input processor class when extending for more models
|
||||
text_prompt = inputs.get("prompt")
|
||||
if not text_prompt:
|
||||
raise ValueError("Text prompt is required but not provided")
|
||||
|
||||
if not isinstance(mm_handles, list):
|
||||
raise TypeError("mm_handles must be a list")
|
||||
|
||||
if len(mm_handles) != 1:
|
||||
# TODO: only support single multimodal item within a request for now
|
||||
raise NotImplementedError(
|
||||
"Only one mm_handle is supported for Qwen2.5 VL for now")
|
||||
hidden_size = mm_handles[0]['tensor_size'][1]
|
||||
assert hidden_size == self.config.text_config.hidden_size, "Multimodal embedding hidden size must match model hidden size"
|
||||
input_ids = self.tokenizer(text_prompt,
|
||||
return_tensors="pt").input_ids[0]
|
||||
|
||||
image_token_index = self.config.image_token_id
|
||||
|
||||
image_mask = input_ids == image_token_index
|
||||
image_positions = torch.where(image_mask)[0]
|
||||
num_images = len(image_positions)
|
||||
assert num_images == len(
|
||||
mm_handles), "Number of images must match number of mm_handles"
|
||||
total_mm_tokens = sum(mm_handle["tensor_size"][0]
|
||||
for mm_handle in mm_handles)
|
||||
final_length = len(input_ids) - num_images + total_mm_tokens
|
||||
# Create output tensor
|
||||
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
|
||||
placeholder_id = self.tllm_multimodal_token_id
|
||||
|
||||
# Fill the expanded sequence
|
||||
write_pos = 0
|
||||
image_cnt = 0
|
||||
mm_token_length = []
|
||||
mm_token_offsets = []
|
||||
for read_pos in range(len(input_ids)):
|
||||
if input_ids[read_pos] == image_token_index:
|
||||
# Replace with placeholder id
|
||||
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
|
||||
expanded_ids[write_pos:write_pos + mm_token_num] = \
|
||||
placeholder_id
|
||||
mm_token_offsets.append(write_pos)
|
||||
mm_token_length.append(mm_token_num)
|
||||
write_pos += mm_token_num
|
||||
image_cnt += 1
|
||||
else:
|
||||
# Copy text token as-is
|
||||
expanded_ids[write_pos] = input_ids[read_pos]
|
||||
write_pos += 1
|
||||
|
||||
assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
|
||||
assert mm_token_length[-1] + mm_token_offsets[
|
||||
-1] <= final_length, f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less than or equal to final_length ({final_length})"
|
||||
return expanded_ids.to(
|
||||
torch.int32).tolist(), mm_token_length, mm_token_offsets
|
||||
|
||||
|
||||
@support_multimodal_disaggregated
|
||||
@register_vision_encoder(Qwen2VisionModelBase,
|
||||
vlm_base_model=Qwen2_5_VisionModel)
|
||||
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
|
||||
@register_input_processor(
|
||||
Qwen2VLInputProcessorBase,
|
||||
Qwen2_5VLInputProcessorBase,
|
||||
model_type="qwen2_5_vl",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
|
||||
@ -262,6 +262,8 @@ class PyResult:
|
||||
chunk_size=self._chunk_size) if return_generation_logits else None
|
||||
self._log_probs = LogProbStorage() if return_log_probs else None
|
||||
self._mm_embeddings = None
|
||||
self._mrope_position_ids = None
|
||||
self._mrope_position_deltas = None
|
||||
self._additional_context_outputs = {
|
||||
name: []
|
||||
for name in additional_outputs
|
||||
@ -293,6 +295,16 @@ class PyResult:
|
||||
self._mm_embeddings = SharedTensorContainer.from_tensor(
|
||||
mm_embeddings).dump_to_dict()
|
||||
|
||||
def set_mrope_position(
|
||||
self,
|
||||
mrope_position_ids: torch.Tensor,
|
||||
mrope_position_deltas: torch.Tensor,
|
||||
):
|
||||
self._mrope_position_ids = (SharedTensorContainer.from_tensor(
|
||||
mrope_position_ids).dump_to_dict())
|
||||
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
|
||||
mrope_position_deltas).dump_to_dict())
|
||||
|
||||
def transfer_remaining_device_logits(self):
|
||||
"""Finalize any remaining generation logits transfers (for chunked mode)"""
|
||||
if self._generation_logits:
|
||||
@ -352,6 +364,18 @@ class PyResult:
|
||||
def mm_embedding_handle(self) -> Dict[str, Any] | None:
|
||||
return self._mm_embeddings
|
||||
|
||||
@property
|
||||
def mrope_position_ids_handle(self) -> Dict[str, Any] | None:
|
||||
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
|
||||
# the `SharedTensorContainer` using the `from_dict` class method.
|
||||
return self._mrope_position_ids
|
||||
|
||||
@property
|
||||
def mrope_position_deltas_handle(self) -> Dict[str, Any] | None:
|
||||
# NOTE: when populated, the returned `dict` contains the information necessary to rebuild
|
||||
# the `SharedTensorContainer` using the `from_dict` class method.
|
||||
return self._mrope_position_deltas
|
||||
|
||||
@property
|
||||
def additional_context_outputs(self) -> Dict[str, torch.Tensor] | None:
|
||||
if self._additional_context_outputs is None:
|
||||
@ -382,7 +406,8 @@ class LlmResult:
|
||||
py_result_properties = frozenset(
|
||||
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs',
|
||||
'mm_embedding_handle', 'additional_context_outputs',
|
||||
'additional_generation_outputs'))
|
||||
'additional_generation_outputs', 'mrope_position_ids_handle',
|
||||
'mrope_position_deltas_handle'))
|
||||
|
||||
def __init__(self,
|
||||
result: Union[bytes, tensorrt_llm.bindings.executor.Result],
|
||||
|
||||
@ -2213,13 +2213,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mrope_position_deltas).expand(
|
||||
3, 1, 1)
|
||||
mrope_position_ids.append(gen_mrope_position_ids)
|
||||
multimodal_params.to_device(
|
||||
"multimodal_data",
|
||||
"cuda",
|
||||
pin_memory=True,
|
||||
target_keywords=[
|
||||
"mrope_config.mrope_position_deltas"
|
||||
])
|
||||
if mrope_position_deltas.device.type == "cpu":
|
||||
multimodal_params.to_device(
|
||||
"multimodal_data",
|
||||
"cuda",
|
||||
pin_memory=True,
|
||||
target_keywords=[
|
||||
"mrope_config.mrope_position_deltas"
|
||||
])
|
||||
multimodal_params_list.append(multimodal_params)
|
||||
|
||||
request.py_batch_idx = request.py_seq_slot
|
||||
@ -2448,8 +2449,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# NOTE: self.use_mrope is enough for differentiating whether to use mrope_position_ids but
|
||||
# `_create_dummy_context_requests` from `kv_cache_creater` makes an exception that I can not add multimodal_data to the dummy_request
|
||||
# so that we only replace position_ids with mrope_position_ids when it is not a dummy request and for models who is using mrope.
|
||||
mrope_position_ids = torch.cat(mrope_position_ids,
|
||||
dim=-1).pin_memory()
|
||||
mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
|
||||
if mrope_position_ids.device.type == "cpu":
|
||||
mrope_position_ids = mrope_position_ids.pin_memory()
|
||||
self.mrope_position_ids_cuda[:, :, :total_num_tokens].copy_(
|
||||
mrope_position_ids[:, :, :total_num_tokens], non_blocking=True)
|
||||
final_position_ids = self.mrope_position_ids_cuda[:, :, :
|
||||
@ -3362,7 +3364,26 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mm_embeddings = list(
|
||||
torch.split(mm_embeddings[0], multimodal_chunks, dim=0))
|
||||
|
||||
return {'mm_embeddings': mm_embeddings, 'logits': None}
|
||||
# Extract mrope position data from multimodal_params if available
|
||||
mrope_position_ids_list = []
|
||||
mrope_position_deltas_list = []
|
||||
for multimodal_param in multimodal_params:
|
||||
mrope_config = multimodal_param.multimodal_data.get(
|
||||
'mrope_config', {})
|
||||
mrope_position_ids = mrope_config.get('mrope_position_ids')
|
||||
mrope_position_deltas = mrope_config.get('mrope_position_deltas')
|
||||
if mrope_position_ids is not None:
|
||||
mrope_position_ids_list.append(mrope_position_ids)
|
||||
if mrope_position_deltas is not None:
|
||||
mrope_position_deltas_list.append(mrope_position_deltas)
|
||||
|
||||
result = {'mm_embeddings': mm_embeddings, 'logits': None}
|
||||
if mrope_position_ids_list:
|
||||
result['mrope_position_ids'] = mrope_position_ids_list
|
||||
if mrope_position_deltas_list:
|
||||
result['mrope_position_deltas'] = mrope_position_deltas_list
|
||||
|
||||
return result
|
||||
|
||||
def _init_userbuffers(self, hidden_size):
|
||||
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:
|
||||
|
||||
@ -21,7 +21,7 @@ from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, Generic, List, Optional, Type, TypeVar, cast
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -199,6 +199,8 @@ class EarlyStopSampler(Sampler):
|
||||
@dataclass(kw_only=True)
|
||||
class MultimodalResult:
|
||||
mm_embeddings: List[torch.Tensor]
|
||||
# Can be used to include e.g. `mrope_position_ids`, etc.
|
||||
extra_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
@ -262,7 +264,10 @@ class EarlyStopWithMMResult(Sampler):
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> SampleStateWithMMResult:
|
||||
# from model_outputs to MultimodalResult
|
||||
data = MultimodalResult(mm_embeddings=model_outputs["mm_embeddings"])
|
||||
data = MultimodalResult(
|
||||
mm_embeddings=model_outputs.pop("mm_embeddings"),
|
||||
extra_data={**model_outputs},
|
||||
)
|
||||
return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data)
|
||||
|
||||
@override
|
||||
@ -276,7 +281,12 @@ class EarlyStopWithMMResult(Sampler):
|
||||
scheduled_requests = state.scheduled_requests
|
||||
assert not scheduled_requests.generation_requests
|
||||
mm_embeddings = state.data.mm_embeddings
|
||||
for request, mm_embedding in zip(scheduled_requests.context_requests, mm_embeddings):
|
||||
extra_data = state.data.extra_data or {}
|
||||
mrope_position_ids = extra_data.get("mrope_position_ids", None)
|
||||
mrope_position_deltas = extra_data.get("mrope_position_deltas", None)
|
||||
for i, (request, mm_embedding) in enumerate(
|
||||
zip(scheduled_requests.context_requests, mm_embeddings)
|
||||
):
|
||||
request.state = LlmRequestState.GENERATION_COMPLETE
|
||||
# NOTE: This is a hack: set finish reason manually and set the beam 0
|
||||
request.set_finished_reason(FinishReason.LENGTH, 0)
|
||||
@ -287,6 +297,12 @@ class EarlyStopWithMMResult(Sampler):
|
||||
|
||||
request.py_result.append_mm_embeddings(mm_embedding)
|
||||
|
||||
# Store mrope data if available
|
||||
if mrope_position_ids is not None and mrope_position_deltas is not None:
|
||||
request.py_result.set_mrope_position(
|
||||
mrope_position_ids[i], mrope_position_deltas[i]
|
||||
)
|
||||
|
||||
@override
|
||||
def is_generation_model(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -40,6 +40,8 @@ class DisaggregatedParams:
|
||||
multimodal_hashes: Optional[List[List[int]]] = (
|
||||
None # user provided mm hashes should be a list of 8 integers
|
||||
)
|
||||
mrope_position_ids_handle: Optional[Dict[str, Any]] = None
|
||||
mrope_position_deltas_handle: Optional[Dict[str, Any]] = None
|
||||
|
||||
def get_context_phase_params(self) -> tllme.ContextPhaseParams:
|
||||
return tllme.ContextPhaseParams(
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
@ -22,9 +23,11 @@ from .ipc import FusedIpcQueue, IpcQueue
|
||||
from .postproc_worker import PostprocWorker, PostprocWorkerConfig
|
||||
from .request import CancellingRequest, GenerationRequest
|
||||
from .result import GenerationResult, IterationResult
|
||||
from .utils import (ErrorResponse, IntraProcessQueue, WorkerCommIpcAddrs,
|
||||
create_mpi_comm_session, get_spawn_proxy_process_env,
|
||||
is_llm_response, print_alive_threads)
|
||||
from .rpc import RPCClient
|
||||
from .rpc.rpc_common import get_unique_ipc_addr
|
||||
from .utils import (ErrorResponse, WorkerCommIpcAddrs, create_mpi_comm_session,
|
||||
get_spawn_proxy_process_env, is_llm_response,
|
||||
print_alive_threads)
|
||||
from .worker import GenerationExecutorWorker, worker_main
|
||||
|
||||
__all__ = [
|
||||
@ -89,19 +92,27 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
|
||||
"llm_args", None) is not None else None
|
||||
|
||||
# Generate RPC address and key for stats RPC
|
||||
self.rpc_addr = get_unique_ipc_addr()
|
||||
self.hmac_key = os.urandom(32)
|
||||
|
||||
worker_kwargs = dict(**worker_kwargs,
|
||||
worker_queues=self._setup_queues(),
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=False)
|
||||
is_llm_executor=False,
|
||||
rpc_addr=self.rpc_addr,
|
||||
hmac_key=self.hmac_key)
|
||||
|
||||
if "log_level" not in worker_kwargs:
|
||||
worker_kwargs["log_level"] = logger.level
|
||||
|
||||
self.dispatch_result_thread: Optional[ManagedThread] = None
|
||||
self.dispatch_stats_thread: Optional[ManagedThread] = None
|
||||
self.dispatch_kv_cache_events_thread: Optional[ManagedThread] = None
|
||||
self.rpc_client: Optional[RPCClient] = None
|
||||
self._start_executor_workers(worker_kwargs)
|
||||
|
||||
# Create RPC client after workers are started (worker starts RPC server)
|
||||
self.rpc_client = RPCClient(self.rpc_addr, hmac_key=self.hmac_key)
|
||||
|
||||
# MPI registers its joiner using threading._register_atexit if possible.
|
||||
# These functions run before atexit.register, so to avoid deadlock,
|
||||
# we have to notify workers to exit before MPI starts to wait them.
|
||||
@ -128,19 +139,11 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
socket_type=zmq.PULL
|
||||
if self.enable_postprocess_parallel else zmq.PAIR,
|
||||
name="proxy_result_queue")
|
||||
self.mp_stats_queue = FusedIpcQueue(is_server=True,
|
||||
fuse_message=False,
|
||||
name="proxy_stats_queue")
|
||||
self.kv_cache_events_queue = FusedIpcQueue(
|
||||
is_server=True,
|
||||
fuse_message=False,
|
||||
name="proxy_kv_cache_events_queue")
|
||||
# Stats and KV events are now fetched via RPC, not IPC queues.
|
||||
return WorkerCommIpcAddrs(
|
||||
request_queue_addr=self.request_queue.address,
|
||||
worker_init_status_queue_addr=self.worker_init_status_queue.address,
|
||||
result_queue_addr=self.result_queue.address,
|
||||
stats_queue_addr=self.mp_stats_queue.address,
|
||||
kv_cache_events_queue_addr=self.kv_cache_events_queue.address,
|
||||
)
|
||||
|
||||
def abort_request(self, request_id: int) -> None:
|
||||
@ -204,71 +207,8 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
|
||||
return True # success
|
||||
|
||||
def _iteration_result_task(self,
|
||||
queue: Union[FusedIpcQueue, IntraProcessQueue],
|
||||
result_singleton: IterationResult,
|
||||
urgent: bool = False) -> bool:
|
||||
if not urgent:
|
||||
time.sleep(0.2)
|
||||
|
||||
try:
|
||||
data = queue.get()
|
||||
except:
|
||||
logger.debug(
|
||||
"proxy.py: Error in _iteration_result_task: queue.get()")
|
||||
return False
|
||||
|
||||
if data is None:
|
||||
logger.debug("proxy.py: _iteration_result_task: data is None")
|
||||
return False # shutdown the thread
|
||||
|
||||
data = data if isinstance(data, list) else [data]
|
||||
queue = result_singleton.queue
|
||||
async_queues = []
|
||||
|
||||
while queue.full():
|
||||
queue.get()
|
||||
|
||||
try:
|
||||
for d in data:
|
||||
if d is None:
|
||||
logger.debug("proxy.py: _iteration_result_task: d is None")
|
||||
return False
|
||||
|
||||
if isinstance(queue, _SyncQueue):
|
||||
queue.put_nowait(d)
|
||||
async_queues.append(queue)
|
||||
else:
|
||||
queue.put(d)
|
||||
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(queue.loop, async_queues)
|
||||
|
||||
except AsyncQueue.EventLoopShutdownError:
|
||||
# This happens in the last loop while the generate workflow is
|
||||
# stopped, or when get_stats() or aget_stats() are not called by users
|
||||
# and therefore event loop can already be closed.
|
||||
logger.debug("proxy.py: EventLoopShutdownError")
|
||||
except Exception as e:
|
||||
logger.debug(f"proxy.py: Error in _iteration_result_task: {e}")
|
||||
raise e
|
||||
|
||||
return True # success
|
||||
|
||||
def dispatch_stats_task(self) -> bool:
|
||||
if not self._iter_stats_result:
|
||||
# This can happen temporarily because the WAR in tensorrt_llm/bench/benchmark/throughput.py
|
||||
# is not synchronized with self.dispatch_stats_thread.
|
||||
logger.debug(
|
||||
f"Skipping stats dispatch while self._iter_stats_result=None")
|
||||
return True # Intended behavior, not an error
|
||||
return self._iteration_result_task(self.mp_stats_queue,
|
||||
self._iter_stats_result)
|
||||
|
||||
def dispatch_kv_cache_events_task(self) -> bool:
|
||||
return self._iteration_result_task(self.kv_cache_events_queue,
|
||||
self._iter_kv_events_result,
|
||||
urgent=True)
|
||||
# NOTE: _iteration_result_task, dispatch_stats_task, and dispatch_kv_cache_events_task
|
||||
# have been removed as stats and kv_events are now fetched via RPC directly.
|
||||
|
||||
def _start_dispatch_threads(self):
|
||||
if self.dispatch_result_thread is None:
|
||||
@ -277,25 +217,9 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
weakref.WeakMethod(self.dispatch_result_task),
|
||||
error_queue=self._error_queue,
|
||||
name="proxy_dispatch_result_thread")
|
||||
self.dispatch_stats_thread = ManagedThread(
|
||||
weakref.WeakMethod(self.dispatch_stats_task),
|
||||
error_queue=self._error_queue,
|
||||
name="proxy_dispatch_stats_thread")
|
||||
self.dispatch_kv_cache_events_thread = ManagedThread(
|
||||
weakref.WeakMethod(self.dispatch_kv_cache_events_task),
|
||||
error_queue=self._error_queue,
|
||||
name="proxy_dispatch_kv_cache_events_thread")
|
||||
|
||||
self.dispatch_result_thread.start()
|
||||
|
||||
# Only collect stats when submission
|
||||
# is via LLM API
|
||||
if self._iter_stats_result:
|
||||
self.dispatch_stats_thread.start()
|
||||
|
||||
if self._iter_kv_events_result:
|
||||
self.dispatch_kv_cache_events_thread.start()
|
||||
|
||||
self._handle_background_error()
|
||||
|
||||
def _start_executor_workers(self, worker_kwargs):
|
||||
@ -387,23 +311,18 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
):
|
||||
self.dispatch_result_thread.stop()
|
||||
self.dispatch_result_thread.join()
|
||||
if self.dispatch_stats_thread is not None and self.dispatch_stats_thread.is_alive(
|
||||
):
|
||||
self.dispatch_stats_thread.stop()
|
||||
self.dispatch_stats_thread.join()
|
||||
if self.dispatch_kv_cache_events_thread is not None and self.dispatch_kv_cache_events_thread.is_alive(
|
||||
):
|
||||
self.dispatch_kv_cache_events_thread.stop()
|
||||
self.dispatch_kv_cache_events_thread.join()
|
||||
|
||||
# step3: finish all remaining work
|
||||
|
||||
# close the RPC client
|
||||
if self.rpc_client is not None:
|
||||
self.rpc_client.close()
|
||||
self.rpc_client = None
|
||||
|
||||
# close all the sockets
|
||||
self.request_queue.close()
|
||||
self.worker_init_status_queue.close()
|
||||
self.result_queue.close()
|
||||
self.mp_stats_queue.close()
|
||||
self.kv_cache_events_queue.close()
|
||||
|
||||
self.workers_started = False
|
||||
self.mpi_session.shutdown()
|
||||
@ -441,6 +360,104 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
|
||||
return result
|
||||
|
||||
def get_stats(self, timeout: float) -> List[dict]:
|
||||
"""Get iteration statistics from the runtime via RPC.
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of runtime stats as dict.
|
||||
"""
|
||||
if self.rpc_client is None:
|
||||
logger.warning("RPC client not initialized, cannot get stats")
|
||||
return []
|
||||
|
||||
stats = self.rpc_client.fetch_stats_wait_async(timeout=timeout).remote()
|
||||
return [json.loads(s) if isinstance(s, str) else s for s in stats]
|
||||
|
||||
def aget_stats(self, timeout: float) -> IterationResult:
|
||||
"""Get iteration statistics from the runtime via RPC (async).
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
IterationResult: An async iterable object containing runtime stats.
|
||||
"""
|
||||
# Initialize iteration result if needed
|
||||
self._maybe_initialize_iteration_results()
|
||||
|
||||
if self._iter_stats_result is None:
|
||||
logger.warning("Iteration statistics are not available yet.")
|
||||
from .executor import empty_async_iterable
|
||||
return empty_async_iterable()
|
||||
|
||||
# Fetch stats via RPC and populate the result
|
||||
try:
|
||||
stats = self.rpc_client.fetch_stats_wait_async(
|
||||
timeout=timeout).remote()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching stats via RPC: {e}")
|
||||
stats = []
|
||||
|
||||
for stat in stats:
|
||||
self._iter_stats_result.queue.put(stat)
|
||||
|
||||
self._iter_stats_result.set_timeout(timeout)
|
||||
return self._iter_stats_result
|
||||
|
||||
def get_kv_events(self, timeout: float) -> List[dict]:
|
||||
"""Get iteration KV events from the runtime via RPC.
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of runtime events as dict.
|
||||
"""
|
||||
if self.rpc_client is None:
|
||||
logger.warning("RPC client not initialized, cannot get kv events")
|
||||
return []
|
||||
|
||||
try:
|
||||
events = self.rpc_client.fetch_kv_cache_events_wait_async(
|
||||
timeout=timeout).remote()
|
||||
return [json.loads(e) if isinstance(e, str) else e for e in events]
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching kv events via RPC: {e}")
|
||||
return []
|
||||
|
||||
def aget_kv_events(self, timeout: float) -> IterationResult:
|
||||
"""Get iteration KV events from the runtime via RPC (async).
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
IterationResult: An async iterable object containing runtime events.
|
||||
"""
|
||||
# Initialize iteration result if needed
|
||||
self._maybe_initialize_iteration_results()
|
||||
|
||||
if self._iter_kv_events_result is None:
|
||||
from .executor import empty_async_iterable
|
||||
return empty_async_iterable()
|
||||
|
||||
# Fetch kv events via RPC and populate the result
|
||||
try:
|
||||
events = self.rpc_client.fetch_kv_cache_events_wait_async(
|
||||
timeout=timeout).remote()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching kv events via RPC: {e}")
|
||||
events = []
|
||||
|
||||
for event in events:
|
||||
self._iter_kv_events_result.queue.put(event)
|
||||
|
||||
self._iter_kv_events_result.set_timeout(timeout)
|
||||
return self._iter_kv_events_result
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
import weakref
|
||||
@ -415,12 +416,19 @@ class GenerationResultBase:
|
||||
self.cached_tokens = getattr(response_result, 'cached_tokens', 0)
|
||||
self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter
|
||||
if context_phase_params is not None:
|
||||
self.disaggregated_params = DisaggregatedParams(
|
||||
existing_disagg_params = self.disaggregated_params
|
||||
# Use `replace` to preserve things like `mrope_position_ids_handle` and
|
||||
# `mrope_position_deltas_handle`. However, explicitly set
|
||||
# `multimodal_embedding_handles=None` since they should no longer be needed.
|
||||
self.disaggregated_params = dataclasses.replace(
|
||||
existing_disagg_params or DisaggregatedParams(),
|
||||
request_type="context_only",
|
||||
first_gen_tokens=context_phase_params.first_gen_tokens,
|
||||
ctx_request_id=context_phase_params.req_id,
|
||||
opaque_state=context_phase_params.opaque_state,
|
||||
draft_tokens=context_phase_params.draft_tokens)
|
||||
draft_tokens=context_phase_params.draft_tokens,
|
||||
multimodal_embedding_handles=None,
|
||||
)
|
||||
|
||||
finish_reasons = response_result.finish_reasons
|
||||
# output_token_ids = (beams, tokens)
|
||||
@ -440,6 +448,8 @@ class GenerationResultBase:
|
||||
if hasattr(response_result, 'mm_embedding_handle'
|
||||
) and response_result.mm_embedding_handle is not None:
|
||||
self._mm_embedding_handle = response_result.mm_embedding_handle
|
||||
mrope_position_ids_handle = response_result.mrope_position_ids_handle
|
||||
mrope_position_deltas_handle = response_result.mrope_position_deltas_handle
|
||||
if self.disaggregated_params is not None:
|
||||
self.disaggregated_params.multimodal_embedding_handles = [
|
||||
response_result.mm_embedding_handle
|
||||
@ -451,6 +461,8 @@ class GenerationResultBase:
|
||||
response_result.mm_embedding_handle
|
||||
],
|
||||
multimodal_hashes=self._multimodal_hashes)
|
||||
self.disaggregated_params.mrope_position_ids_handle = mrope_position_ids_handle
|
||||
self.disaggregated_params.mrope_position_deltas_handle = mrope_position_deltas_handle
|
||||
|
||||
# Processing background errors here ASAF during generation.
|
||||
if self._background_error_handler and (
|
||||
@ -811,8 +823,12 @@ class GenerationResult(GenerationResultBase):
|
||||
|
||||
def _repr_fields(self):
|
||||
return [
|
||||
'request_id', 'prompt_token_ids', 'outputs', 'finished',
|
||||
"context_logits", "mm_embedding_handle"
|
||||
'request_id',
|
||||
'prompt_token_ids',
|
||||
'outputs',
|
||||
'finished',
|
||||
"context_logits",
|
||||
"mm_embedding_handle",
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import json
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from ..llmapi.mpi_session import MpiPoolSession, MpiSession
|
||||
from ..llmapi.utils import logger_debug
|
||||
from ..llmapi.utils import logger_debug, print_colored
|
||||
from ..logger import logger
|
||||
from .executor import GenerationExecutor
|
||||
from .postproc_worker import PostprocWorkerConfig
|
||||
from .result import IterationResult
|
||||
from .rpc_proxy_mixin import RpcExecutorMixin
|
||||
from .rpc_worker import RpcWorker
|
||||
from .utils import create_mpi_comm_session, get_spawn_proxy_process_env
|
||||
@ -69,20 +71,110 @@ class GenerationExecutorRpcProxy(RpcExecutorMixin, GenerationExecutor):
|
||||
**self.worker_kwargs)
|
||||
|
||||
def _setup_mainloop_with_tasks(self):
|
||||
"""Setup mainloop with all tasks needed for RpcProxy."""
|
||||
"""Setup mainloop with tasks needed for RpcProxy.
|
||||
|
||||
Note: Stats and kv_events are now fetched on-demand via direct RPC calls
|
||||
(get_stats, aget_stats, get_kv_events, aget_kv_events), not via streaming loops.
|
||||
"""
|
||||
tasks = [
|
||||
self._fetch_responses_loop_async,
|
||||
self._fetch_stats_loop_async,
|
||||
]
|
||||
# Only add kv_cache_events loop if it's enabled
|
||||
if self._iter_kv_events_result:
|
||||
tasks.append(self._fetch_kv_cache_events_loop_async)
|
||||
|
||||
# Call mixin's setup_mainloop with custom tasks
|
||||
self.setup_mainloop(tasks=tasks, thread_name="rpc_proxy_main_loop")
|
||||
|
||||
def fetch_stats_remote(self):
|
||||
return self.rpc_client.fetch_stats().remote()
|
||||
def get_stats(self, timeout: float) -> List[dict]:
|
||||
"""Get iteration statistics from the runtime via RPC.
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of runtime stats as dict.
|
||||
"""
|
||||
try:
|
||||
stats = self.rpc_client.fetch_stats_wait_async(
|
||||
timeout=timeout).remote()
|
||||
return [json.loads(s) if isinstance(s, str) else s for s in stats]
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching stats via RPC: {e}")
|
||||
return []
|
||||
|
||||
def aget_stats(self, timeout: float) -> IterationResult:
|
||||
"""Get iteration statistics from the runtime via RPC (async).
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
IterationResult: An async iterable object containing runtime stats.
|
||||
"""
|
||||
self._maybe_initialize_iteration_results()
|
||||
|
||||
if self._iter_stats_result is None:
|
||||
print_colored("Iteration statistics are not available yet.\n",
|
||||
"yellow")
|
||||
from .executor import empty_async_iterable
|
||||
return empty_async_iterable()
|
||||
|
||||
# Fetch stats via RPC and populate the result
|
||||
try:
|
||||
stats = self.rpc_client.fetch_stats_wait_async(
|
||||
timeout=timeout).remote()
|
||||
except Exception:
|
||||
stats = []
|
||||
|
||||
for stat in stats:
|
||||
self._iter_stats_result.queue.put(stat)
|
||||
|
||||
self._iter_stats_result.set_timeout(timeout)
|
||||
return self._iter_stats_result
|
||||
|
||||
def get_kv_events(self, timeout: float) -> List[dict]:
|
||||
"""Get iteration KV events from the runtime via RPC.
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
List[dict]: A list of runtime events as dict.
|
||||
"""
|
||||
try:
|
||||
# Events are already serialized by the worker's fetch_kv_cache_events_wait_async()
|
||||
events = self.rpc_client.fetch_kv_cache_events_wait_async(
|
||||
timeout=timeout).remote()
|
||||
return [json.loads(e) if isinstance(e, str) else e for e in events]
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching kv events via RPC: {e}")
|
||||
return []
|
||||
|
||||
def aget_kv_events(self, timeout: float) -> IterationResult:
|
||||
"""Get iteration KV events from the runtime via RPC (async).
|
||||
|
||||
Args:
|
||||
timeout (float): Max wait time in seconds for the RPC call.
|
||||
|
||||
Returns:
|
||||
IterationResult: An async iterable object containing runtime events.
|
||||
"""
|
||||
# Initialize iteration result if needed
|
||||
self._maybe_initialize_iteration_results()
|
||||
|
||||
if self._iter_kv_events_result is None:
|
||||
from .executor import empty_async_iterable
|
||||
return empty_async_iterable()
|
||||
|
||||
# Fetch kv events via RPC and populate the result
|
||||
try:
|
||||
events = self.rpc_client.fetch_kv_cache_events_wait_async(
|
||||
timeout=timeout).remote()
|
||||
except Exception:
|
||||
events = []
|
||||
|
||||
for event in events:
|
||||
self._iter_kv_events_result.queue.put(event)
|
||||
|
||||
self._iter_kv_events_result.set_timeout(timeout)
|
||||
return self._iter_kv_events_result
|
||||
|
||||
def setup_engine_remote(self):
|
||||
return self.rpc_client.setup_engine().remote(need_response=True)
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..llmapi.tracer import global_tracer
|
||||
from ..llmapi.utils import AsyncQueue, _SyncQueue
|
||||
from ..llmapi.utils import _SyncQueue
|
||||
from ..logger import logger
|
||||
from .request import GenerationRequest
|
||||
from .result import GenerationResult
|
||||
@ -47,15 +46,16 @@ class RpcExecutorMixin:
|
||||
Args:
|
||||
tasks: List of async coroutine functions to run.
|
||||
thread_name: Name for the main loop thread
|
||||
|
||||
Note: Stats and kv_events are now fetched on-demand via direct RPC calls
|
||||
(get_stats, aget_stats, get_kv_events, aget_kv_events), so the default
|
||||
tasks only include the responses loop. Callers can still provide custom
|
||||
tasks including stats/kv_events loops if needed for streaming use cases.
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [
|
||||
self._fetch_responses_loop_async,
|
||||
self._fetch_stats_loop_async,
|
||||
]
|
||||
# Only add kv_cache_events loop if it's enabled
|
||||
if self._iter_kv_events_result:
|
||||
tasks.append(self._fetch_kv_cache_events_loop_async)
|
||||
|
||||
async def main_loop_task():
|
||||
await asyncio.gather(*[task() for task in tasks])
|
||||
@ -136,22 +136,6 @@ class RpcExecutorMixin:
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(event_loop, async_queues)
|
||||
|
||||
def handle_stats(self, stats):
|
||||
"""Handle stats received from RPC worker and put them into the stats result queue.
|
||||
|
||||
Args:
|
||||
stats: Statistics data from the RPC worker (can be dict, str, or list)
|
||||
"""
|
||||
self._handle_iteration_data(stats, self._iter_stats_result, "stats")
|
||||
|
||||
def handle_kv_cache_events(self, events):
|
||||
"""Handle KV cache events received from RPC worker and put them into the events result queue.
|
||||
|
||||
Args:
|
||||
events: KV cache events data from the RPC worker (can be dict, str, or list)
|
||||
"""
|
||||
self._handle_iteration_data(events, self._iter_kv_events_result, "kv_cache_events")
|
||||
|
||||
async def _generic_fetch_loop_async(
|
||||
self, fetch_method_name: str, handler_method: Callable, method_name: str
|
||||
):
|
||||
@ -181,86 +165,6 @@ class RpcExecutorMixin:
|
||||
method_name="_fetch_responses_loop_async",
|
||||
)
|
||||
|
||||
async def _fetch_stats_loop_async(self):
|
||||
await self._generic_fetch_loop_async(
|
||||
fetch_method_name="fetch_stats_loop_async",
|
||||
handler_method=self.handle_stats,
|
||||
method_name="_fetch_stats_loop_async",
|
||||
)
|
||||
|
||||
async def _fetch_kv_cache_events_loop_async(self):
|
||||
await self._generic_fetch_loop_async(
|
||||
fetch_method_name="fetch_kv_cache_events_loop_async",
|
||||
handler_method=self.handle_kv_cache_events,
|
||||
method_name="_fetch_kv_cache_events_loop_async",
|
||||
)
|
||||
|
||||
def _handle_iteration_data(self, data, result_singleton, data_type: str):
|
||||
"""Generic method to handle iteration data received from RPC worker.
|
||||
|
||||
Args:
|
||||
data: Data from the RPC worker (can be dict, str, or list)
|
||||
result_singleton: The iteration result singleton to put data into
|
||||
data_type: Type of data for logging (e.g., "stats", "kv_cache_events")
|
||||
"""
|
||||
# Make sure we have initialized the iteration results
|
||||
self._maybe_initialize_iteration_results()
|
||||
|
||||
if not result_singleton:
|
||||
logger.debug(f"Skipping {data_type} handling while result_singleton=None")
|
||||
return
|
||||
|
||||
# Get the queue from the result singleton
|
||||
queue = result_singleton.queue
|
||||
async_queues = []
|
||||
|
||||
# Clear old data if queue is full (similar to _iteration_result_task)
|
||||
while queue.full():
|
||||
queue.get()
|
||||
|
||||
try:
|
||||
# Handle different types of data
|
||||
if isinstance(data, str):
|
||||
# Already JSON serialized
|
||||
data_json = data
|
||||
elif isinstance(data, list):
|
||||
# Skip empty lists to avoid putting nothing in the queue
|
||||
if not data:
|
||||
logger.debug(f"rpc_proxy.py: Skipping empty {data_type} list")
|
||||
return
|
||||
|
||||
# Handle list of data (multiple iterations)
|
||||
for item in data:
|
||||
if isinstance(item, str):
|
||||
item_json = item
|
||||
else:
|
||||
item_json = json.dumps(item)
|
||||
|
||||
if isinstance(queue, _SyncQueue):
|
||||
queue.put_nowait(item_json)
|
||||
async_queues.append(queue)
|
||||
else:
|
||||
queue.put(item_json)
|
||||
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(queue.loop, async_queues)
|
||||
return
|
||||
else:
|
||||
# Convert dict/other to JSON string as expected by IterationResult
|
||||
data_json = json.dumps(data)
|
||||
|
||||
if isinstance(queue, _SyncQueue):
|
||||
queue.put_nowait(data_json)
|
||||
async_queues.append(queue)
|
||||
else:
|
||||
queue.put(data_json)
|
||||
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(queue.loop, async_queues)
|
||||
|
||||
except AsyncQueue.EventLoopShutdownError:
|
||||
# This happens when the event loop is already closed
|
||||
logger.debug(f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}")
|
||||
raise e
|
||||
# NOTE: _fetch_stats_loop_async and _fetch_kv_cache_events_loop_async have been removed.
|
||||
# Stats and kv_events are now fetched on-demand via direct RPC calls
|
||||
# (get_stats, aget_stats, get_kv_events, aget_kv_events) instead of streaming loops.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Event
|
||||
from typing import AsyncGenerator, Optional
|
||||
@ -50,8 +51,9 @@ class RpcWorkerMixin:
|
||||
"""Submits a request to the worker."""
|
||||
with nvtx_range_debug("RpcWorker.submit", color="blue", category="Worker"):
|
||||
logger_debug(f"[worker] Submitting request {request.id}", color="green")
|
||||
super().submit(request)
|
||||
result = super().submit(request)
|
||||
logger_debug(f"[worker] Submitted request {request.id}", color="green")
|
||||
return result
|
||||
|
||||
def fetch_responses(self, timeout: Optional[float] = None) -> list:
|
||||
"""Fetch responses from the response queue (blocking)."""
|
||||
@ -99,54 +101,58 @@ class RpcWorkerMixin:
|
||||
f"[worker] RpcWorker {self.rank} quitting fetch_responses_loop_async", color="yellow"
|
||||
)
|
||||
|
||||
async def fetch_stats_async(self, timeout: Optional[float] = None) -> list:
|
||||
"""Async version of fetch_stats using asyncio.to_thread."""
|
||||
return await asyncio.to_thread(self.fetch_stats)
|
||||
|
||||
async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list:
|
||||
"""Async version of fetch_kv_cache_events using asyncio.to_thread."""
|
||||
return await asyncio.to_thread(self.fetch_kv_cache_events)
|
||||
|
||||
async def fetch_stats_loop_async(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> AsyncGenerator[list, None]:
|
||||
"""Stream stats in a loop until shutdown."""
|
||||
async for data in self._generic_fetch_loop_async(
|
||||
fetch_method=self.fetch_stats_async,
|
||||
serializer=self._stats_serializer,
|
||||
method_name="fetch_stats_loop_async",
|
||||
timeout=timeout,
|
||||
):
|
||||
yield data
|
||||
|
||||
async def fetch_kv_cache_events_loop_async(
|
||||
self, timeout: Optional[float] = None
|
||||
) -> AsyncGenerator[list, None]:
|
||||
"""Stream KV cache events in a loop until shutdown."""
|
||||
async for data in self._generic_fetch_loop_async(
|
||||
fetch_method=self.fetch_kv_cache_events_async,
|
||||
serializer=self._kv_cache_events_serializer,
|
||||
method_name="fetch_kv_cache_events_loop_async",
|
||||
timeout=timeout,
|
||||
):
|
||||
yield data
|
||||
|
||||
async def _generic_fetch_loop_async(
|
||||
self, fetch_method, serializer, method_name: str, timeout: Optional[float] = None
|
||||
) -> AsyncGenerator[list, None]:
|
||||
"""Generic method for fetching data in a loop.
|
||||
async def fetch_stats_wait_async(self, timeout: Optional[float] = None) -> list:
|
||||
"""Poll for stats until available or timeout.
|
||||
|
||||
Args:
|
||||
fetch_method: The async method to call for fetching data
|
||||
serializer: The serializer function to apply to each item
|
||||
method_name: Name of the method for logging
|
||||
timeout: Optional timeout between fetches
|
||||
timeout: Max wait time in seconds. If None, fetch once without waiting.
|
||||
"""
|
||||
while not self.shutdown_event.is_set():
|
||||
timeout = timeout or 0.1
|
||||
await asyncio.sleep(timeout)
|
||||
data = await fetch_method()
|
||||
# Always yield data, even if empty, to prevent the client looks like hanging
|
||||
# TODO: Remove the empty data to reduce the IPC overhead
|
||||
yield [serializer(item) for item in data]
|
||||
logger_debug(f"[worker] RpcWorker {self.rank} quitting {method_name}", color="yellow")
|
||||
logger_debug(
|
||||
f"[worker] RpcWorker {self.rank} is fetching stats with timeout {timeout}",
|
||||
color="yellow",
|
||||
)
|
||||
start = time.time()
|
||||
while True:
|
||||
stats = await asyncio.to_thread(self.fetch_stats)
|
||||
if stats or timeout is None:
|
||||
break
|
||||
if (time.time() - start) >= timeout:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
return [self._stats_serializer(s) for s in stats]
|
||||
|
||||
async def fetch_kv_cache_events_wait_async(self, timeout: Optional[float] = None) -> list:
|
||||
"""Poll for KV cache events until available or timeout.
|
||||
|
||||
Args:
|
||||
timeout: Max wait time in seconds. If None, fetch once without waiting.
|
||||
"""
|
||||
start = time.time()
|
||||
while True:
|
||||
events = await asyncio.to_thread(self.fetch_kv_cache_events)
|
||||
if events or timeout is None:
|
||||
break
|
||||
if (time.time() - start) >= timeout:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
return [self._kv_cache_events_serializer(e) for e in events]
|
||||
|
||||
async def fetch_stats_async(self, timeout: Optional[float] = None) -> list:
|
||||
"""Async version of fetch_stats using asyncio.to_thread.
|
||||
|
||||
This method is exposed via RPC and can be called directly by the proxy.
|
||||
Returns serialized stats (JSON strings) that can be sent over RPC.
|
||||
"""
|
||||
stats = await asyncio.to_thread(self.fetch_stats)
|
||||
# Serialize stats before sending over RPC (IterationStats objects are not picklable)
|
||||
return [self._stats_serializer(s) for s in stats]
|
||||
|
||||
async def fetch_kv_cache_events_async(self, timeout: Optional[float] = None) -> list:
|
||||
"""Async version of fetch_kv_cache_events using asyncio.to_thread.
|
||||
|
||||
This method is exposed via RPC and can be called directly by the proxy.
|
||||
Returns serialized events (JSON strings) that can be sent over RPC.
|
||||
"""
|
||||
events = await asyncio.to_thread(self.fetch_kv_cache_events)
|
||||
# Serialize events before sending over RPC
|
||||
return [self._kv_cache_events_serializer(e) for e in events]
|
||||
|
||||
@ -142,8 +142,6 @@ class WorkerCommIpcAddrs(NamedTuple):
|
||||
request_queue_addr: tuple[str, Optional[bytes]]
|
||||
worker_init_status_queue_addr: tuple[str, Optional[bytes]]
|
||||
result_queue_addr: tuple[str, Optional[bytes]]
|
||||
stats_queue_addr: tuple[str, Optional[bytes]]
|
||||
kv_cache_events_queue_addr: tuple[str, Optional[bytes]]
|
||||
|
||||
|
||||
def is_llm_response(instance):
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import zmq
|
||||
|
||||
@ -18,25 +16,22 @@ from ..llmapi.llm_args import BaseLlmArgs
|
||||
from ..llmapi.mpi_session import set_mpi_session_cpp
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..llmapi.tracer import VizTracer, set_global_tracer
|
||||
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug,
|
||||
print_traceback_on_error)
|
||||
from ..llmapi.utils import ManagedThread, logger_debug, print_traceback_on_error
|
||||
from ..sampling_params import BatchedLogitsProcessor
|
||||
from .base_worker import BaseWorker, _init_hf_modules
|
||||
from .executor import IterationResultQueue
|
||||
from .ipc import FusedIpcQueue, IpcQueue
|
||||
from .postproc_worker import (PostprocWorker, PostprocWorkerConfig,
|
||||
postproc_worker_main)
|
||||
from .request import CancellingRequest, GenerationRequest
|
||||
from .result import IterationResult
|
||||
from .utils import (ErrorResponse, RequestError, WorkerCommIpcAddrs,
|
||||
has_event_loop)
|
||||
from .rpc_worker_mixin import RpcWorkerMixin
|
||||
from .utils import ErrorResponse, RequestError, WorkerCommIpcAddrs
|
||||
|
||||
__all__ = [
|
||||
"GenerationExecutorWorker",
|
||||
]
|
||||
|
||||
|
||||
class GenerationExecutorWorker(BaseWorker):
|
||||
class GenerationExecutorWorker(RpcWorkerMixin, BaseWorker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -48,6 +43,8 @@ class GenerationExecutorWorker(BaseWorker):
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
rpc_addr: Optional[str] = None,
|
||||
hmac_key: Optional[bytes] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine=engine,
|
||||
@ -62,35 +59,18 @@ class GenerationExecutorWorker(BaseWorker):
|
||||
|
||||
self.setup_engine()
|
||||
|
||||
# Setup RPC server for stats (skip init_rpc_worker to keep IPC response queue)
|
||||
# Only set up if rpc_addr is provided (for stats RPC support)
|
||||
if rpc_addr is not None:
|
||||
self.rpc_addr = rpc_addr
|
||||
self.hmac_key = hmac_key
|
||||
self.start_rpc_server() # Reuse from RpcWorkerMixin
|
||||
|
||||
self.await_response_thread = ManagedThread(
|
||||
self.await_response_task,
|
||||
error_queue=self._error_queue,
|
||||
name="await_response_thread")
|
||||
|
||||
self.dispatch_stats_thread = ManagedThread(
|
||||
self.dispatch_stats_task,
|
||||
error_queue=self._error_queue,
|
||||
name="dispatch_stats_thread")
|
||||
|
||||
self.dispatch_kv_cache_events_thread = ManagedThread(
|
||||
self.dispatch_kv_cache_events_task,
|
||||
error_queue=self._error_queue,
|
||||
name="dispatch_kv_cache_events_thread")
|
||||
|
||||
def _create_iteration_result_queue(self,
|
||||
it_result_queue: IterationResultQueue):
|
||||
if not it_result_queue.is_initialized:
|
||||
# not yet initialized
|
||||
it_result_queue.is_initialized = True
|
||||
if has_event_loop():
|
||||
_queue = AsyncQueue()
|
||||
it_result_queue.queue = _queue.sync_q
|
||||
it_result_queue.aqueue = _queue
|
||||
else:
|
||||
_queue = Queue()
|
||||
it_result_queue.queue = _queue
|
||||
it_result_queue.aqueue = None
|
||||
|
||||
def start_thread(self, thread: ManagedThread):
|
||||
if self.engine.can_enqueue_requests() and not thread.is_alive():
|
||||
thread.start()
|
||||
@ -98,74 +78,10 @@ class GenerationExecutorWorker(BaseWorker):
|
||||
def await_response_task(self) -> bool:
|
||||
return self._await_response_helper()
|
||||
|
||||
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
|
||||
engine_get_result_api: Callable,
|
||||
result_singleton: IterationResult,
|
||||
result_serializer: Callable) -> bool:
|
||||
time.sleep(0.2)
|
||||
async_queues = []
|
||||
queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue
|
||||
try:
|
||||
for results in engine_get_result_api():
|
||||
res = result_serializer(results)
|
||||
if self._is_llm_executor and result_singleton:
|
||||
# In this case, there's no ExecutorBindingProxy.
|
||||
# Worker needs to take care of putting to result queue.
|
||||
while queue.full():
|
||||
queue.get()
|
||||
if isinstance(queue, _SyncQueue):
|
||||
queue.put_nowait(res)
|
||||
async_queues.append(queue)
|
||||
else:
|
||||
queue.put(res)
|
||||
else:
|
||||
# Send to ExecutorBindingProxy via IPC
|
||||
queue.put(res)
|
||||
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(queue.loop, async_queues)
|
||||
except AsyncQueue.EventLoopShutdownError:
|
||||
# This happens in the last results loop while the generate workflow is stopped.
|
||||
logger.debug("worker.py: EventLoopShutdownError")
|
||||
except Exception as e:
|
||||
logger.error(f"worker.py: Error in _iteration_result_task: {e}")
|
||||
raise e
|
||||
|
||||
return True # success
|
||||
|
||||
def dispatch_stats_task(self) -> bool:
|
||||
return self._iteration_result_task(self.stats_queues, self.fetch_stats,
|
||||
self._iter_stats_result,
|
||||
self._stats_serializer)
|
||||
|
||||
def dispatch_kv_cache_events_task(self) -> bool:
|
||||
if isinstance(self.engine, tllm.Executor):
|
||||
# Check if the engine has a kv cache event manager
|
||||
# If not, return an empty list for the events which will cause the thread to exit early.
|
||||
event_manager = self.engine.get_kv_cache_event_manager()
|
||||
if event_manager is None:
|
||||
events_api = lambda: [None]
|
||||
else:
|
||||
events_api = event_manager.get_latest_events
|
||||
return self._iteration_result_task(self.kv_events_queues,
|
||||
events_api,
|
||||
self._iter_kv_events_result,
|
||||
self._kv_cache_events_serializer)
|
||||
else:
|
||||
return self._iteration_result_task(
|
||||
self.kv_events_queues, self.engine.get_latest_kv_cache_events,
|
||||
self._iter_kv_events_result, self._kv_cache_events_serializer)
|
||||
|
||||
def start(self):
|
||||
# create iteration result queues
|
||||
self._create_iteration_result_queue(self.stats_queues)
|
||||
self._create_iteration_result_queue(self.kv_events_queues)
|
||||
|
||||
# start threads
|
||||
# Stats and KV events are now fetched on-demand via RPC,
|
||||
# so we only need to start the response thread
|
||||
self.start_thread(self.await_response_thread)
|
||||
self.start_thread(self.dispatch_kv_cache_events_thread)
|
||||
if mpi_rank() == 0:
|
||||
self.start_thread(self.dispatch_stats_thread)
|
||||
|
||||
def shutdown(self):
|
||||
|
||||
@ -178,16 +94,9 @@ class GenerationExecutorWorker(BaseWorker):
|
||||
|
||||
if self.engine is not None:
|
||||
if self.engine.can_enqueue_requests():
|
||||
|
||||
if self.await_response_thread.is_alive():
|
||||
self.await_response_thread.stop()
|
||||
self.await_response_thread.join()
|
||||
if self.dispatch_stats_thread.is_alive():
|
||||
self.dispatch_stats_thread.stop()
|
||||
self.dispatch_stats_thread.join()
|
||||
if self.dispatch_kv_cache_events_thread.is_alive():
|
||||
self.dispatch_kv_cache_events_thread.stop()
|
||||
self.dispatch_kv_cache_events_thread.join()
|
||||
|
||||
self.engine.shutdown()
|
||||
self.engine = None
|
||||
@ -240,6 +149,8 @@ def worker_main(
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
rpc_addr: Optional[str] = None,
|
||||
hmac_key: Optional[bytes] = None,
|
||||
) -> None:
|
||||
|
||||
mpi_comm().barrier()
|
||||
@ -287,15 +198,6 @@ def worker_main(
|
||||
is_server=False,
|
||||
socket_type=zmq.DEALER,
|
||||
name="worker_init_status_queue")
|
||||
mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr,
|
||||
is_server=False,
|
||||
fuse_message=True,
|
||||
name="worker_stats_queue")
|
||||
kv_cache_events_queue = FusedIpcQueue(
|
||||
worker_queues.kv_cache_events_queue_addr,
|
||||
is_server=False,
|
||||
fuse_message=False,
|
||||
name="worker_kv_cache_events_queue")
|
||||
|
||||
if postproc_worker_config.enabled:
|
||||
# IPC queues for sending inputs to the postprocess parallel
|
||||
@ -322,9 +224,6 @@ def worker_main(
|
||||
assert result_queues is not None
|
||||
for q in result_queues:
|
||||
q.put(None)
|
||||
# Signal the stats thread in the proxy to quit
|
||||
mp_stats_queue.put(None)
|
||||
kv_cache_events_queue.put(None)
|
||||
|
||||
postprocess_worker_futures = []
|
||||
if is_leader and postproc_worker_config.enabled:
|
||||
@ -370,7 +269,9 @@ def worker_main(
|
||||
is_llm_executor=is_llm_executor,
|
||||
hf_model_dir=hf_model_dir,
|
||||
tokenizer=tokenizer,
|
||||
llm_args=llm_args)
|
||||
llm_args=llm_args,
|
||||
rpc_addr=rpc_addr,
|
||||
hmac_key=hmac_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
@ -396,11 +297,6 @@ def worker_main(
|
||||
else:
|
||||
worker.set_result_queue(result_queue)
|
||||
|
||||
# initialize the iteration result queues
|
||||
worker._set_iteration_result_queue(worker.stats_queues,
|
||||
mp_stats_queue)
|
||||
worker._set_iteration_result_queue(worker.kv_events_queues,
|
||||
kv_cache_events_queue)
|
||||
# Send ready signal with confirmation
|
||||
ready_msg = (ready_signal, None)
|
||||
if not worker_init_status_queue.notify_with_retry(ready_msg):
|
||||
|
||||
@ -89,8 +89,12 @@ class RequestOutput(DetokenizedGenerationResultBase, GenerationResult):
|
||||
|
||||
def _repr_fields(self):
|
||||
return [
|
||||
"request_id", "prompt", "prompt_token_ids", "outputs", "finished",
|
||||
"mm_embedding_handle"
|
||||
"request_id",
|
||||
"prompt",
|
||||
"prompt_token_ids",
|
||||
"outputs",
|
||||
"finished",
|
||||
"mm_embedding_handle",
|
||||
]
|
||||
|
||||
|
||||
@ -419,7 +423,7 @@ class BaseLLM:
|
||||
multimodal_params = None
|
||||
|
||||
if is_mm_disagg:
|
||||
if not self.input_processor.support_mm_disagg:
|
||||
if not getattr(self.input_processor, "support_mm_disagg", False):
|
||||
raise ValueError(
|
||||
"Multimodal disaggregated inference is not supported for this model"
|
||||
)
|
||||
@ -436,14 +440,42 @@ class BaseLLM:
|
||||
mm_hashes = disaggregated_params.multimodal_hashes
|
||||
multimodal_input = MultimodalInput.from_components(
|
||||
mm_hashes, mm_token_positions, mm_token_length)
|
||||
multimodal_data = {"multimodal_embedding": mm_handles}
|
||||
if disaggregated_params.mrope_position_ids_handle is not None:
|
||||
# NOTE: `PyTorchModelEngine` assumes both are present when using mrope.
|
||||
assert disaggregated_params.mrope_position_deltas_handle is not None
|
||||
mrope_config = {}
|
||||
mrope_config[
|
||||
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
|
||||
mrope_config[
|
||||
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
|
||||
multimodal_data["mrope_config"] = mrope_config
|
||||
multimodal_params = MultimodalParams(
|
||||
multimodal_input=multimodal_input,
|
||||
multimodal_data={"multimodal_embedding": mm_handles})
|
||||
multimodal_data=multimodal_data,
|
||||
)
|
||||
|
||||
elif "prompt_token_ids" in inputs:
|
||||
prompt_token_ids = inputs['prompt_token_ids']
|
||||
prompt = None
|
||||
query_token_ids = inputs.get("query_token_ids", None)
|
||||
multimodal_data = {}
|
||||
# NOTE: when running in `generation_only` for disagg, this is the code path we expect to hit.
|
||||
if disaggregated_params is not None and disaggregated_params.mrope_position_ids_handle is not None:
|
||||
# It looks like `PyTorchModelEngine` assumes both are present when using mrope?
|
||||
if disaggregated_params.mrope_position_deltas_handle is None:
|
||||
raise RuntimeError(
|
||||
"`mrope_position_ids_handle` and `mrope_position_deltas_handle` must both "
|
||||
"be provided, or both `None`.")
|
||||
mrope_config = {}
|
||||
mrope_config[
|
||||
"mrope_position_ids"] = disaggregated_params.mrope_position_ids_handle
|
||||
mrope_config[
|
||||
"mrope_position_deltas"] = disaggregated_params.mrope_position_deltas_handle
|
||||
multimodal_data["mrope_config"] = mrope_config
|
||||
if multimodal_data:
|
||||
multimodal_params = MultimodalParams(
|
||||
multimodal_data=multimodal_data)
|
||||
elif "prompt" in inputs:
|
||||
if 'multi_modal_data' in inputs:
|
||||
# TODO: The current design uses a wrapper for existing input processor (input_processor_with_hash)
|
||||
|
||||
@ -101,14 +101,8 @@ class MultimodalEncoder(_TorchLLM):
|
||||
|
||||
inputs = [prompt_inputs(i) for i in inputs]
|
||||
|
||||
def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any:
|
||||
if isinstance(maybe_batched, list):
|
||||
return maybe_batched[pos]
|
||||
else:
|
||||
return maybe_batched
|
||||
|
||||
futures = []
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
for request_inputs in inputs:
|
||||
future = self.generate_async(request_inputs)
|
||||
futures.append(future)
|
||||
|
||||
|
||||
@ -51,20 +51,22 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest,
|
||||
MemoryUpdateRequest, ModelCard,
|
||||
ModelList, PromptTokensDetails,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
UpdateWeightsRequest, UsageInfo,
|
||||
to_llm_disaggregated_params)
|
||||
from tensorrt_llm.serve.postprocess_handlers import (
|
||||
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
|
||||
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
|
||||
chat_response_post_processor, chat_stream_post_processor,
|
||||
completion_response_post_processor, completion_stream_post_processor)
|
||||
ResponsesAPIPostprocArgs, chat_harmony_post_processor,
|
||||
chat_harmony_streaming_post_processor, chat_response_post_processor,
|
||||
chat_stream_post_processor, completion_response_post_processor,
|
||||
completion_stream_post_processor, responses_api_post_processor,
|
||||
responses_api_streaming_post_processor)
|
||||
from tensorrt_llm.serve.responses_utils import (ConversationHistoryStore,
|
||||
ResponsesStreamingProcessor,
|
||||
ServerArrivalTimeMiddleware)
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
create_response as responses_api_create_response
|
||||
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
process_streaming_events as responses_api_process_streaming_events
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
request_preprocess as responses_api_request_preprocess
|
||||
from tensorrt_llm.version import __version__ as VERSION
|
||||
@ -119,9 +121,8 @@ class OpenAIServer:
|
||||
self.model_config = None
|
||||
|
||||
# Enable response storage for Responses API
|
||||
self.enable_store = True
|
||||
if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0:
|
||||
self.enable_store = False
|
||||
self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled
|
||||
|
||||
self.conversation_store = ConversationHistoryStore()
|
||||
|
||||
model_dir = Path(model)
|
||||
@ -942,19 +943,39 @@ class OpenAIServer:
|
||||
return self.create_error_response(message=str(e), err_type="internal_error")
|
||||
|
||||
async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response:
|
||||
async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]:
|
||||
async for event_data in responses_api_process_streaming_events(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
model_name=self.model,
|
||||
conversation_store=self.conversation_store,
|
||||
use_harmony=self.use_harmony,
|
||||
reasoning_parser=self.llm.args.reasoning_parser,
|
||||
tool_parser=self.tool_parser,
|
||||
enable_store=self.enable_store
|
||||
):
|
||||
yield event_data
|
||||
async def create_response(
|
||||
promise: RequestOutput, postproc_params: PostprocParams) -> ResponsesResponse:
|
||||
await promise.aresult()
|
||||
if self.postproc_worker_enabled:
|
||||
response = promise.outputs[0]._postprocess_result
|
||||
else:
|
||||
args = postproc_params.postproc_args
|
||||
response = await responses_api_create_response(
|
||||
generator=promise,
|
||||
request=request,
|
||||
sampling_params=args.sampling_params,
|
||||
model_name=self.model,
|
||||
conversation_store=self.conversation_store,
|
||||
generation_result=None,
|
||||
enable_store=self.enable_store and request.store,
|
||||
use_harmony=self.use_harmony,
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
tool_parser=args.tool_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
|
||||
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
|
||||
streaming_processor = args.streaming_processor
|
||||
initial_responses = streaming_processor.get_initial_responses()
|
||||
for initial_response in initial_responses:
|
||||
yield initial_response
|
||||
|
||||
async for res in promise:
|
||||
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
|
||||
for pp_res in pp_results:
|
||||
yield pp_res
|
||||
|
||||
try:
|
||||
if request.background:
|
||||
@ -977,38 +998,61 @@ class OpenAIServer:
|
||||
request=request,
|
||||
prev_response=prev_response,
|
||||
conversation_store=self.conversation_store,
|
||||
enable_store=self.enable_store,
|
||||
enable_store=self.enable_store and request.store,
|
||||
use_harmony=self.use_harmony,
|
||||
tokenizer=self.tokenizer if not self.use_harmony else None,
|
||||
model_config=self.model_config if not self.use_harmony else None,
|
||||
processor=self.processor if not self.use_harmony else None,
|
||||
)
|
||||
|
||||
streaming_processor = None
|
||||
if request.stream:
|
||||
# Per-request streaming processor
|
||||
streaming_processor = ResponsesStreamingProcessor(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=self.model,
|
||||
conversation_store=self.conversation_store,
|
||||
enable_store=self.enable_store and request.store,
|
||||
use_harmony=self.use_harmony,
|
||||
reasoning_parser=self.llm.args.reasoning_parser,
|
||||
tool_parser=self.tool_parser,
|
||||
)
|
||||
|
||||
postproc_args = ResponsesAPIPostprocArgs(
|
||||
model=self.model,
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
use_harmony=self.use_harmony,
|
||||
reasoning_parser=self.llm.args.reasoning_parser,
|
||||
tool_parser=self.tool_parser,
|
||||
streaming_processor=streaming_processor,
|
||||
)
|
||||
postproc_params = PostprocParams(
|
||||
post_processor=responses_api_streaming_post_processor
|
||||
if request.stream else responses_api_post_processor,
|
||||
postproc_args=postproc_args,
|
||||
)
|
||||
promise = self.llm.generate_async(
|
||||
inputs=input_tokens,
|
||||
sampling_params=sampling_params,
|
||||
streaming=request.stream,
|
||||
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
|
||||
)
|
||||
|
||||
if self.postproc_worker_enabled and request.store:
|
||||
logger.warning("Postproc workers are enabled, request will not be stored!")
|
||||
|
||||
asyncio.create_task(self.await_disconnected(raw_request, promise))
|
||||
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
create_stream_response(promise, request, sampling_params),
|
||||
content=create_streaming_generator(promise, postproc_params),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return await responses_api_create_response(
|
||||
generator=promise,
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=self.model,
|
||||
conversation_store=self.conversation_store,
|
||||
generation_result=None,
|
||||
enable_store=self.enable_store,
|
||||
use_harmony=self.use_harmony,
|
||||
reasoning_parser=self.llm.args.reasoning_parser,
|
||||
tool_parser=self.tool_parser)
|
||||
response = await create_response(promise, postproc_params)
|
||||
return JSONResponse(content=response.model_dump())
|
||||
except CppExecutorError:
|
||||
logger.error(traceback.format_exc())
|
||||
# If internal executor error is raised, shutdown the server
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from tensorrt_llm.serve.responses_utils import ResponsesStreamingProcessor
|
||||
from tensorrt_llm.serve.responses_utils import \
|
||||
create_response_non_store as responses_api_create_response_non_store
|
||||
|
||||
from .._utils import nvtx_range_debug
|
||||
from ..executor import (DetokenizedGenerationResultBase, GenerationResult,
|
||||
GenerationResultBase)
|
||||
from ..executor.postproc_worker import PostprocArgs
|
||||
from ..executor.result import Logprob, TokenLogprobs
|
||||
from ..llmapi import SamplingParams
|
||||
from ..llmapi.reasoning_parser import (BaseReasoningParser,
|
||||
ReasoningParserFactory)
|
||||
from ..llmapi.tokenizer import TransformersTokenizer
|
||||
@ -26,7 +31,8 @@ from .openai_protocol import (ChatCompletionLogProbs,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse, DeltaFunctionCall,
|
||||
DeltaMessage, DeltaToolCall, FunctionCall,
|
||||
PromptTokensDetails, StreamOptions, ToolCall,
|
||||
PromptTokensDetails, ResponsesRequest,
|
||||
ResponsesResponse, StreamOptions, ToolCall,
|
||||
UsageInfo, to_disaggregated_params)
|
||||
from .tool_parser.base_tool_parser import BaseToolParser
|
||||
from .tool_parser.core_types import ToolCallItem
|
||||
@ -543,3 +549,42 @@ def chat_harmony_streaming_post_processor(
|
||||
num_prompt_tokens=args.num_prompt_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ResponsesAPIPostprocArgs(PostprocArgs):
|
||||
model: str
|
||||
request: ResponsesRequest
|
||||
sampling_params: SamplingParams
|
||||
use_harmony: bool
|
||||
reasoning_parser: Optional[str] = None
|
||||
tool_parser: Optional[str] = None
|
||||
streaming_processor: Optional[ResponsesStreamingProcessor] = None
|
||||
|
||||
|
||||
@nvtx_range_debug("responses_api_post_processor")
|
||||
def responses_api_post_processor(
|
||||
rsp: GenerationResult,
|
||||
args: ResponsesAPIPostprocArgs) -> ResponsesResponse:
|
||||
return responses_api_create_response_non_store(
|
||||
generation_result=rsp,
|
||||
request=args.request,
|
||||
sampling_params=args.sampling_params,
|
||||
model_name=args.model,
|
||||
use_harmony=args.use_harmony,
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
tool_parser=args.tool_parser,
|
||||
)
|
||||
|
||||
|
||||
@nvtx_range_debug("responses_api_streaming_post_processor")
|
||||
def responses_api_streaming_post_processor(
|
||||
rsp: GenerationResult, args: ResponsesAPIPostprocArgs) -> List[str]:
|
||||
if args.streaming_processor is None:
|
||||
raise ValueError(
|
||||
"streaming_processor is required for streaming post-processing")
|
||||
outputs = args.streaming_processor.process_single_output(rsp)
|
||||
if rsp._done:
|
||||
outputs.append(
|
||||
args.streaming_processor.get_final_response_non_store(rsp))
|
||||
return outputs
|
||||
|
||||
@ -10,7 +10,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator
|
||||
from copy import copy
|
||||
from typing import Any, Literal, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Any, List, Literal, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
from openai.types.responses import (ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
@ -41,6 +41,7 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
from transformers import AutoProcessor, PretrainedConfig
|
||||
|
||||
from tensorrt_llm.bindings import steady_clock_now
|
||||
from tensorrt_llm.executor import GenerationResult
|
||||
from tensorrt_llm.inputs.utils import apply_chat_template
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
from tensorrt_llm.llmapi.llm import RequestOutput
|
||||
@ -962,7 +963,7 @@ def _apply_tool_parser(
|
||||
return normal_text, calls
|
||||
|
||||
|
||||
async def _create_output_content(
|
||||
def _create_output_content(
|
||||
final_res: RequestOutput,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
tool_parser: Optional[str] = None,
|
||||
@ -1040,7 +1041,7 @@ async def _create_output_content(
|
||||
return output_items, output_messages
|
||||
|
||||
|
||||
async def _create_output_content_harmony(
|
||||
def _create_output_content_harmony(
|
||||
final_res: RequestOutput
|
||||
) -> Tuple[list[ResponseOutputItem], list[Message]]:
|
||||
output_messages = _parse_output_tokens(final_res.outputs[0].token_ids)
|
||||
@ -1057,12 +1058,53 @@ async def _create_output_content_harmony(
|
||||
return output_content, output_messages
|
||||
|
||||
|
||||
def _create_response(
|
||||
final_res: GenerationResult,
|
||||
use_harmony: bool,
|
||||
request: ResponsesRequest,
|
||||
model_name: str,
|
||||
response_creation_time: int,
|
||||
sampling_params: SamplingParams,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
tool_parser: Optional[str] = None,
|
||||
) -> tuple[ResponsesResponse, list[Message | ChatCompletionMessageParam]]:
|
||||
_responses_debug_log("================================================")
|
||||
_responses_debug_log("RAW MODEL OUTPUT:")
|
||||
_responses_debug_log(final_res.outputs)
|
||||
_responses_debug_log("================================================")
|
||||
|
||||
# prepare responses output
|
||||
output_content = []
|
||||
if use_harmony:
|
||||
output_content, output_messages = _create_output_content_harmony(
|
||||
final_res)
|
||||
else:
|
||||
output_content, output_messages = _create_output_content(
|
||||
final_res, reasoning_parser, tool_parser, request.tools)
|
||||
|
||||
response = ResponsesResponse.from_request(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=response_creation_time,
|
||||
output=output_content,
|
||||
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
|
||||
)
|
||||
|
||||
_responses_debug_log("========== Response ===========")
|
||||
_responses_debug_log(response)
|
||||
_responses_debug_log("===============================")
|
||||
|
||||
# return output_messages for store_response
|
||||
return response, output_messages
|
||||
|
||||
|
||||
async def create_response(
|
||||
generator,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
model_name: str,
|
||||
conversation_store: ConversationHistoryStore,
|
||||
generator: Optional[AsyncGenerator[RequestOutput, None]] = None,
|
||||
generation_result: Optional[RequestOutput] = None,
|
||||
enable_store: bool = False,
|
||||
use_harmony: bool = True,
|
||||
@ -1078,33 +1120,22 @@ async def create_response(
|
||||
|
||||
if generation_result is not None:
|
||||
final_res = generation_result
|
||||
else:
|
||||
elif generator is not None:
|
||||
final_res = await generator
|
||||
|
||||
if final_res is None:
|
||||
raise RuntimeError("No output generated or provided")
|
||||
|
||||
_responses_debug_log("================================================")
|
||||
_responses_debug_log("RAW MODEL OUTPUT:")
|
||||
_responses_debug_log(final_res.outputs)
|
||||
_responses_debug_log("================================================")
|
||||
|
||||
# prepare responses output
|
||||
output_content = []
|
||||
if use_harmony:
|
||||
output_content, output_messages = await _create_output_content_harmony(
|
||||
final_res)
|
||||
else:
|
||||
output_content, output_messages = await _create_output_content(
|
||||
final_res, reasoning_parser, tool_parser, request.tools)
|
||||
|
||||
response = ResponsesResponse.from_request(
|
||||
response, output_messages = _create_response(
|
||||
final_res=final_res,
|
||||
use_harmony=use_harmony,
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=response_creation_time,
|
||||
output=output_content,
|
||||
status=finish_reason_mapping(final_res.outputs[0].finish_reason),
|
||||
response_creation_time=response_creation_time,
|
||||
sampling_params=sampling_params,
|
||||
reasoning_parser=reasoning_parser,
|
||||
tool_parser=tool_parser,
|
||||
)
|
||||
|
||||
if enable_store and request.store:
|
||||
@ -1112,9 +1143,34 @@ async def create_response(
|
||||
resp_msgs=output_messages,
|
||||
prev_resp_id=prev_response_id)
|
||||
|
||||
_responses_debug_log("========== Response ===========")
|
||||
_responses_debug_log(response)
|
||||
_responses_debug_log("===============================")
|
||||
return response
|
||||
|
||||
|
||||
def create_response_non_store(
|
||||
generation_result: RequestOutput,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
model_name: str,
|
||||
use_harmony: bool = True,
|
||||
create_time: Optional[int] = None,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
tool_parser: Optional[str] = None,
|
||||
) -> ResponsesResponse:
|
||||
response_creation_time = create_time if create_time is not None else int(
|
||||
time.time())
|
||||
|
||||
# prepare responses output
|
||||
response, _ = _create_response(
|
||||
final_res=generation_result,
|
||||
use_harmony=use_harmony,
|
||||
request=request,
|
||||
model_name=model_name,
|
||||
response_creation_time=response_creation_time,
|
||||
sampling_params=sampling_params,
|
||||
reasoning_parser=reasoning_parser,
|
||||
tool_parser=tool_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -1649,6 +1705,143 @@ def _generate_streaming_event_harmony(
|
||||
parser.last_content_delta)
|
||||
|
||||
|
||||
class ResponsesStreamingProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
sampling_params: SamplingParams,
|
||||
model_name: str,
|
||||
create_time: Optional[int] = None,
|
||||
conversation_store: Optional[ConversationHistoryStore] = None,
|
||||
enable_store: bool = False,
|
||||
use_harmony: bool = True,
|
||||
reasoning_parser: Optional[str] = None,
|
||||
tool_parser: Optional[str] = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.request = request
|
||||
self.sampling_params = sampling_params
|
||||
self.sequence_number = 0
|
||||
self.streaming_events_helper = ResponsesStreamingEventsHelper()
|
||||
self.response_creation_time = create_time if create_time is not None else int(
|
||||
time.time())
|
||||
self.final_res: Optional[RequestOutput] = None
|
||||
self.reasoning_parser_dict: dict[int, BaseReasoningParser] = {}
|
||||
self.tool_parser_dict: dict[int, BaseToolParser] = {}
|
||||
self.stream_request_id = f"responses-api-{request.request_id}"
|
||||
self.conversation_store = conversation_store
|
||||
self.enable_store = enable_store
|
||||
self.use_harmony = use_harmony
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.tool_parser = tool_parser
|
||||
|
||||
def _send_event(self, event: OpenAIBaseModel):
|
||||
# Set sequence_number if the event has this attribute
|
||||
if hasattr(event, 'sequence_number'):
|
||||
event.sequence_number = self.sequence_number
|
||||
self.sequence_number += 1
|
||||
# Get event type from the event's type field if it exists
|
||||
event_type = getattr(event, 'type', 'unknown')
|
||||
return (f"event: {event_type}\n"
|
||||
f"data: {event.model_dump_json(indent=None)}\n\n")
|
||||
|
||||
def get_initial_responses(self) -> List[str]:
|
||||
initial_response = ResponsesResponse.from_request(
|
||||
request=self.request,
|
||||
sampling_params=self.sampling_params,
|
||||
model_name=self.model_name,
|
||||
created_time=self.response_creation_time,
|
||||
output=[],
|
||||
status="in_progress",
|
||||
usage=None,
|
||||
).model_dump()
|
||||
|
||||
resp_created = self._send_event(
|
||||
self.streaming_events_helper.get_response_created_event(
|
||||
initial_response))
|
||||
resp_in_progress = self._send_event(
|
||||
self.streaming_events_helper.get_response_in_progress_event(
|
||||
initial_response))
|
||||
return [resp_created, resp_in_progress]
|
||||
|
||||
async def get_final_response(
|
||||
self,
|
||||
final_res: RequestOutput,
|
||||
) -> str:
|
||||
final_response = await create_response(
|
||||
generator=None,
|
||||
request=self.request,
|
||||
sampling_params=self.sampling_params,
|
||||
model_name=self.model_name,
|
||||
conversation_store=self.conversation_store,
|
||||
generation_result=final_res,
|
||||
enable_store=self.enable_store,
|
||||
use_harmony=self.use_harmony,
|
||||
create_time=self.response_creation_time,
|
||||
reasoning_parser=self.reasoning_parser,
|
||||
tool_parser=self.tool_parser,
|
||||
)
|
||||
|
||||
return self._send_event(
|
||||
ResponseCompletedEvent(
|
||||
type="response.completed",
|
||||
sequence_number=-1,
|
||||
response=final_response.model_dump(),
|
||||
))
|
||||
|
||||
def get_final_response_non_store(
|
||||
self,
|
||||
final_res: RequestOutput,
|
||||
) -> str:
|
||||
final_response = create_response_non_store(
|
||||
generation_result=final_res,
|
||||
request=self.request,
|
||||
sampling_params=self.sampling_params,
|
||||
model_name=self.model_name,
|
||||
use_harmony=self.use_harmony,
|
||||
create_time=self.response_creation_time,
|
||||
reasoning_parser=self.reasoning_parser,
|
||||
tool_parser=self.tool_parser,
|
||||
)
|
||||
|
||||
return self._send_event(
|
||||
ResponseCompletedEvent(
|
||||
type="response.completed",
|
||||
sequence_number=-1,
|
||||
response=final_response.model_dump(),
|
||||
))
|
||||
|
||||
def process_single_output(self, res: GenerationResult) -> list[str]:
|
||||
event_generator = None
|
||||
output = res.outputs[0]
|
||||
if self.use_harmony:
|
||||
event_generator = _generate_streaming_event_harmony(
|
||||
harmony_adapter=get_harmony_adapter(),
|
||||
stream_request_id=self.stream_request_id,
|
||||
output=output,
|
||||
request=self.request,
|
||||
streaming_events_helper=self.streaming_events_helper,
|
||||
)
|
||||
|
||||
else:
|
||||
event_generator = _generate_streaming_event(
|
||||
output=output,
|
||||
request=self.request,
|
||||
finished_generation=res._done,
|
||||
streaming_events_helper=self.streaming_events_helper,
|
||||
reasoning_parser_id=self.reasoning_parser,
|
||||
tool_parser_id=self.tool_parser,
|
||||
reasoning_parser_dict=self.reasoning_parser_dict,
|
||||
tool_parser_dict=self.tool_parser_dict,
|
||||
)
|
||||
|
||||
if event_generator is None:
|
||||
raise RuntimeError("Failed to generate streaming events")
|
||||
|
||||
return [self._send_event(event) for event in event_generator]
|
||||
|
||||
|
||||
async def process_streaming_events(
|
||||
generator,
|
||||
request: ResponsesRequest,
|
||||
@ -1661,97 +1854,31 @@ async def process_streaming_events(
|
||||
reasoning_parser: Optional[str] = None,
|
||||
tool_parser: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
sequence_number = 0
|
||||
response_creation_time = create_time if create_time is not None else int(
|
||||
time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
reasoning_parser_dict: dict[int, BaseReasoningParser] = {}
|
||||
tool_parser_dict: dict[int, BaseToolParser] = {}
|
||||
|
||||
def _send_event(event: OpenAIBaseModel):
|
||||
nonlocal sequence_number
|
||||
# Set sequence_number if the event has this attribute
|
||||
if hasattr(event, 'sequence_number'):
|
||||
event.sequence_number = sequence_number
|
||||
sequence_number += 1
|
||||
# Get event type from the event's type field if it exists
|
||||
event_type = getattr(event, 'type', 'unknown')
|
||||
return (f"event: {event_type}\n"
|
||||
f"data: {event.model_dump_json(indent=None)}\n\n")
|
||||
|
||||
streaming_events_helper = ResponsesStreamingEventsHelper()
|
||||
|
||||
initial_response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=response_creation_time,
|
||||
output=[],
|
||||
status="in_progress",
|
||||
usage=None,
|
||||
).model_dump()
|
||||
|
||||
yield _send_event(
|
||||
streaming_events_helper.get_response_created_event(initial_response))
|
||||
yield _send_event(
|
||||
streaming_events_helper.get_response_in_progress_event(
|
||||
initial_response))
|
||||
|
||||
stream_request_id = f"responses-api-{request.request_id}"
|
||||
harmony_adapter = get_harmony_adapter()
|
||||
async for res in generator:
|
||||
final_res = res
|
||||
# TODO(JunyiXu-nv): handle multiple outputs
|
||||
output = res.outputs[0]
|
||||
|
||||
event_generator = None
|
||||
if use_harmony:
|
||||
event_generator = _generate_streaming_event_harmony(
|
||||
harmony_adapter=harmony_adapter,
|
||||
stream_request_id=stream_request_id,
|
||||
output=output,
|
||||
request=request,
|
||||
streaming_events_helper=streaming_events_helper,
|
||||
)
|
||||
|
||||
else:
|
||||
event_generator = _generate_streaming_event(
|
||||
output=output,
|
||||
request=request,
|
||||
finished_generation=res.finished,
|
||||
streaming_events_helper=streaming_events_helper,
|
||||
reasoning_parser_id=reasoning_parser,
|
||||
tool_parser_id=tool_parser,
|
||||
reasoning_parser_dict=reasoning_parser_dict,
|
||||
tool_parser_dict=tool_parser_dict,
|
||||
)
|
||||
|
||||
if event_generator is None:
|
||||
raise RuntimeError("Failed to generate streaming events")
|
||||
|
||||
for event in event_generator:
|
||||
yield _send_event(event)
|
||||
|
||||
final_response = await create_response(
|
||||
generator=generator,
|
||||
streaming_processor = ResponsesStreamingProcessor(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
model_name=model_name,
|
||||
create_time=create_time,
|
||||
conversation_store=conversation_store,
|
||||
generation_result=final_res,
|
||||
enable_store=enable_store,
|
||||
use_harmony=use_harmony,
|
||||
create_time=response_creation_time,
|
||||
reasoning_parser=reasoning_parser,
|
||||
tool_parser=tool_parser,
|
||||
)
|
||||
|
||||
yield _send_event(
|
||||
ResponseCompletedEvent(
|
||||
type="response.completed",
|
||||
sequence_number=-1,
|
||||
response=final_response.model_dump(),
|
||||
))
|
||||
initial_responses = streaming_processor.get_initial_responses()
|
||||
for initial_response in initial_responses:
|
||||
yield initial_response
|
||||
|
||||
async for res in generator:
|
||||
final_res = res
|
||||
events = streaming_processor.process_single_output(res)
|
||||
for event in events:
|
||||
yield event
|
||||
|
||||
final_response = await streaming_processor.get_final_response(final_res)
|
||||
|
||||
yield final_response
|
||||
|
||||
|
||||
class ServerArrivalTimeMiddleware:
|
||||
|
||||
@ -479,3 +479,8 @@ triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregat
|
||||
cpp/test_multi_gpu.py::test_cache_transceiver[8proc-mooncake_kvcache-90] SKIP (https://nvbugs/5760737)
|
||||
unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py::test_allreduce_pg_op[seqlen:16-hidden:1024] SKIP (https://nvbugs/5760740)
|
||||
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3_4gpus[trtllm-two_model-no_overlap_scheduler] SKIP (https://nvbugs/5760747)
|
||||
unittest/_torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-strategy:8-dtype:bfloat16-hidden:8192-seqlen:[15]] SKIP (https://nvbugs/5761364)
|
||||
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
|
||||
accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm] SKIP (https://nvbugs/5762852)
|
||||
accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] SKIP (https://nvbugs/5762852)
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.torch_moe # noqa: F401
|
||||
import tensorrt_llm._torch.custom_ops.torch_custom_ops as trt_ops # noqa: F401
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
def test_flashinfer_fused_moe_matches_torch_moe():
|
||||
@ -75,8 +76,8 @@ def test_flashinfer_fused_moe_matches_torch_moe():
|
||||
w1_weight=w1_list, # gate projection
|
||||
w2_weight=w2_list, # down projection
|
||||
w3_weight=w3_list, # up projection
|
||||
mlp_style="gated_mlp",
|
||||
act_fn="silu",
|
||||
is_gated_mlp=True,
|
||||
act_fn=int(ActivationType.Silu),
|
||||
)
|
||||
|
||||
# Compare outputs
|
||||
|
||||
@ -186,7 +186,7 @@ def test_llama4_stacked_moe_pattern_detection():
|
||||
moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(x, selected_experts, routing_weights, w1_list, w2_list, w3_list),
|
||||
kwargs={"mlp_style": "gated_mlp", "apply_routing_on_input": True},
|
||||
kwargs={"is_gated_mlp": True, "apply_routing_on_input": True},
|
||||
)
|
||||
graph.output(moe_node)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_availab
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale
|
||||
from tensorrt_llm._torch.modules.fused_moe import MoE # noqa: F401
|
||||
from tensorrt_llm._torch.utils import ActivationType
|
||||
|
||||
|
||||
def setup_moe_test(dtype, num_experts):
|
||||
@ -173,8 +174,8 @@ def test_bmm_based_moe_op_run(dtype):
|
||||
[fused_w3_w1_stacked_weight], # Wrap in list for unified interface
|
||||
[fused_w2_weight], # Wrap in list for unified interface
|
||||
[], # Empty w3_weight list for stacked gated MLP
|
||||
mlp_style="gated_mlp",
|
||||
act_fn="silu",
|
||||
is_gated_mlp=True,
|
||||
act_fn=ActivationType.Silu,
|
||||
apply_routing_on_input=True,
|
||||
)
|
||||
output_torch_fused_moe = torch.ops.auto_deploy.torch_moe_fused(
|
||||
|
||||
@ -82,7 +82,7 @@ def compute_with_experts(
|
||||
alpha=None,
|
||||
beta=None,
|
||||
limit=None,
|
||||
activation_func="silu",
|
||||
activation_func: ActivationType = ActivationType.Silu,
|
||||
):
|
||||
def relu2(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.square(F.relu(x))
|
||||
@ -110,7 +110,7 @@ def compute_with_experts(
|
||||
|
||||
inter = x1_scaled * x2
|
||||
else:
|
||||
if activation_func == "swiglu" or activation_func == "silu":
|
||||
if activation_func == ActivationType.Swiglu or activation_func == ActivationType.Silu:
|
||||
inter = F.silu(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t())
|
||||
else:
|
||||
inter = relu2(expert_inputs @ w1_expert.t())
|
||||
@ -136,10 +136,6 @@ def _get_test_data(
|
||||
return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales
|
||||
|
||||
|
||||
def _activation_type_from_str(activation_func: str) -> ActivationType:
|
||||
return ActivationType.Swiglu if activation_func in ["swiglu", "silu"] else ActivationType.Relu2
|
||||
|
||||
|
||||
def _print_diff_if(
|
||||
condition: Callable[[torch.Tensor], bool],
|
||||
diff: torch.Tensor,
|
||||
@ -183,7 +179,7 @@ F16_TEST_DTYPES = [
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
|
||||
@pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES)
|
||||
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
|
||||
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
|
||||
@skip_pre_hopper
|
||||
def test_trtllm_fused_moe(
|
||||
batch_size,
|
||||
@ -201,13 +197,13 @@ def test_trtllm_fused_moe(
|
||||
pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})")
|
||||
|
||||
torch.manual_seed(42)
|
||||
if activation_func in ["swiglu", "silu"]:
|
||||
if activation_func in [ActivationType.Swiglu, ActivationType.Silu]:
|
||||
X_GEN_SCALE = 1.0
|
||||
else:
|
||||
X_GEN_SCALE = 0.5
|
||||
W_GEN_SCALE = 0.1
|
||||
|
||||
x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data(
|
||||
x, router_logits, w31_weight, w2_weight, _, _ = _get_test_data(
|
||||
otype,
|
||||
wtype,
|
||||
batch_size,
|
||||
@ -239,19 +235,17 @@ def test_trtllm_fused_moe(
|
||||
"F16 test only supports bfloat16 or float16"
|
||||
)
|
||||
|
||||
activation_type = _activation_type_from_str(activation_func)
|
||||
|
||||
def get_fc1_expert_weights(
|
||||
activation_func: str, w31_weight: torch.Tensor, w1_weight: torch.Tensor
|
||||
activation_func: ActivationType, w31_weight: torch.Tensor, w1_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
if activation_func == "relu2":
|
||||
if activation_func == ActivationType.Relu2:
|
||||
return w1_weight.contiguous()
|
||||
else:
|
||||
return w31_weight
|
||||
|
||||
# (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size)
|
||||
_, w1_weight = torch.chunk(w31_weight, 2, dim=1)
|
||||
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
|
||||
is_gated_mlp = False if activation_func == ActivationType.Relu2 else True
|
||||
|
||||
torch.cuda.synchronize()
|
||||
ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused(
|
||||
@ -260,9 +254,13 @@ def test_trtllm_fused_moe(
|
||||
routing_weights,
|
||||
w3_w1_stacked_weight=get_fc1_expert_weights(activation_func, w31_weight, w1_weight),
|
||||
w2_stacked_weight=w2_weight,
|
||||
mlp_style=mlp_style,
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
act_fn=activation_func,
|
||||
)
|
||||
# Convert ActivationType.Silu to ActivationType.Swiglu for C++ op compatibility
|
||||
cpp_activation_type = (
|
||||
ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func
|
||||
)
|
||||
trtllm_test_output = torch.ops.trtllm.fused_moe(
|
||||
x,
|
||||
selected_experts.to(torch.int),
|
||||
@ -273,11 +271,11 @@ def test_trtllm_fused_moe(
|
||||
fc2_expert_biases=None,
|
||||
output_dtype=otype,
|
||||
quant_scales=[],
|
||||
activation_type=activation_type,
|
||||
activation_type=cpp_activation_type,
|
||||
)[0].view(x.shape)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if mlp_style == "mlp":
|
||||
if not is_gated_mlp:
|
||||
with torch.inference_mode():
|
||||
output_triton_moe = torch.ops.auto_deploy.triton_moe_fused(
|
||||
x,
|
||||
@ -285,6 +283,7 @@ def test_trtllm_fused_moe(
|
||||
routing_weights,
|
||||
w1_weight.contiguous(),
|
||||
w2_weight.contiguous(),
|
||||
is_gated_mlp=False,
|
||||
)[0].view(x.shape)
|
||||
torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@ -308,7 +307,7 @@ FP8_TEST_DTYPES = [
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
|
||||
@pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES)
|
||||
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
|
||||
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
|
||||
@pytest.mark.skipif(
|
||||
not fp8_compatible() or not trtllm_ops_available(),
|
||||
reason="Requires fp8 and trtllm support",
|
||||
@ -336,7 +335,7 @@ def test_trtllm_fused_moe_fp8(
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
if activation_func in ["swiglu", "silu"]:
|
||||
if activation_func in [ActivationType.Swiglu, ActivationType.Silu]:
|
||||
X_GEN_SCALE = 1.0
|
||||
else:
|
||||
X_GEN_SCALE = 0.5
|
||||
@ -399,7 +398,7 @@ def test_trtllm_fused_moe_fp8(
|
||||
|
||||
# (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size)
|
||||
w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1)
|
||||
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
|
||||
is_gated_mlp = False if activation_func == ActivationType.Relu2 else True
|
||||
|
||||
# compute quant_scales
|
||||
gemm1_dequant = (w1_scales * hidden_states_scale).contiguous().squeeze().to(torch.float32)
|
||||
@ -424,13 +423,13 @@ def test_trtllm_fused_moe_fp8(
|
||||
gemm1_dequant=gemm1_dequant,
|
||||
gemm2_act_quant=gemm2_act_quant,
|
||||
gemm2_dequant=gemm2_dequant,
|
||||
mlp_style=mlp_style,
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
act_fn=activation_func,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if mlp_style == "mlp":
|
||||
if not is_gated_mlp:
|
||||
with torch.inference_mode():
|
||||
output_triton_fp8_moe = torch.ops.auto_deploy.triton_quant_fp8_moe(
|
||||
x,
|
||||
@ -445,7 +444,7 @@ def test_trtllm_fused_moe_fp8(
|
||||
w1_scales,
|
||||
w2_scales,
|
||||
w3_scales,
|
||||
mlp_style=mlp_style,
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
act_fn=activation_func,
|
||||
)
|
||||
torch.testing.assert_close(output_triton_fp8_moe, ref_output, rtol=1e-1, atol=1e-1)
|
||||
@ -569,7 +568,7 @@ NVFP4_TEST_DTYPES = [
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
|
||||
@pytest.mark.parametrize("otype, wtype", NVFP4_TEST_DTYPES)
|
||||
@pytest.mark.parametrize("activation_func", ["silu", "relu2"])
|
||||
@pytest.mark.parametrize("activation_func", [ActivationType.Silu, ActivationType.Relu2])
|
||||
@pytest.mark.skipif(
|
||||
not fp4_compatible() or not trtllm_ops_available(),
|
||||
reason="Requires fp4 and trtllm support",
|
||||
@ -693,25 +692,23 @@ def test_trtllm_fused_moe_nvfp4(
|
||||
fc1_alpha = 1.0 / (fc1_activation_gs * fc1_weight_gs)
|
||||
fc2_alpha = 1.0 / (fc2_activation_gs * w2_gs)
|
||||
|
||||
mlp_style = "mlp" if activation_func == "relu2" else "gated_mlp"
|
||||
if mlp_style == "gated_mlp":
|
||||
is_gated_mlp = False if activation_func == ActivationType.Relu2 else True
|
||||
if is_gated_mlp:
|
||||
# For gated MLP, concatenate w1 and w3 as [w3, w1]
|
||||
fc1_expert_weights_fp4 = torch.cat([w3_q_fp4, w1_q_fp4], dim=1).contiguous()
|
||||
fc1_weight_blockscale_fp8 = torch.cat([w3_blockscale, w1_blockscale], dim=1)
|
||||
fc1_weight_gs = torch.max(w3_gs, w1_gs)
|
||||
if activation_func != "silu":
|
||||
if activation_func != ActivationType.Silu:
|
||||
raise ValueError(
|
||||
f"Unsupported activation '{activation_func}' for gated_mlp. Use 'silu'."
|
||||
)
|
||||
elif mlp_style == "mlp":
|
||||
else:
|
||||
# For non-gated MLP with ReLU^2
|
||||
fc1_expert_weights_fp4 = w1_q_fp4
|
||||
fc1_weight_blockscale_fp8 = w1_blockscale.view(torch.long)
|
||||
fc1_weight_gs = w1_gs
|
||||
if activation_func != "relu2":
|
||||
if activation_func != ActivationType.Relu2:
|
||||
raise ValueError(f"Unsupported activation '{activation_func}' for mlp. Use 'relu2'.")
|
||||
else:
|
||||
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
|
||||
|
||||
fc2_expert_weights_fp4 = w2_q_fp4.view(torch.long)
|
||||
fc2_weight_blockscale_fp8 = w2_blockscale.view(torch.long)
|
||||
@ -729,7 +726,7 @@ def test_trtllm_fused_moe_nvfp4(
|
||||
fc2_activation_gs,
|
||||
fc1_alpha,
|
||||
fc2_alpha,
|
||||
mlp_style=mlp_style,
|
||||
is_gated_mlp=is_gated_mlp,
|
||||
act_fn=activation_func,
|
||||
)
|
||||
|
||||
@ -747,8 +744,7 @@ def test_trtllm_fused_moe_nvfp4(
|
||||
block_size=NVFP4_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
concat_w3_w1 = mlp_style == "gated_mlp"
|
||||
if concat_w3_w1:
|
||||
if is_gated_mlp:
|
||||
w1_gs = w3_gs = torch.max(w1_gs, w3_gs)
|
||||
|
||||
w1_dq = torch.empty(w1.shape, device="cuda", dtype=otype)
|
||||
@ -782,14 +778,18 @@ def test_trtllm_fused_moe_nvfp4(
|
||||
block_size=NVFP4_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Convert ActivationType.Silu to ActivationType.Swiglu for reference op compatibility
|
||||
resolved_activation_type = (
|
||||
ActivationType.Swiglu if activation_func == ActivationType.Silu else activation_func
|
||||
)
|
||||
ref_output = torch_moe_nvfp4(
|
||||
x_dq,
|
||||
torch.cat([w3_dq, w1_dq], dim=1) if concat_w3_w1 else w1_dq,
|
||||
torch.cat([w3_dq, w1_dq], dim=1) if is_gated_mlp else w1_dq,
|
||||
w2_dq,
|
||||
top_k,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
_activation_type_from_str(activation_func),
|
||||
resolved_activation_type,
|
||||
)
|
||||
return ref_output
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from utils.util import skip_pre_hopper
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.load_moe_align import moe_align_block_size
|
||||
from tensorrt_llm._torch.utils import ActivationType # noqa: F401
|
||||
|
||||
|
||||
def _pack_routed_tokens_reference(
|
||||
@ -131,6 +132,7 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit):
|
||||
routing_weights,
|
||||
w_up_stacked,
|
||||
w_down_stacked,
|
||||
is_gated_mlp=False,
|
||||
)
|
||||
|
||||
# Reference Torch MoE in mlp mode with relu2 activation
|
||||
@ -141,8 +143,8 @@ def test_triton_moe_matches_torch_moe_mlp_relu2(early_exit):
|
||||
w1_weight=w_up_list,
|
||||
w2_weight=w_down_list,
|
||||
w3_weight=[],
|
||||
mlp_style="mlp",
|
||||
act_fn="relu2",
|
||||
is_gated_mlp=False,
|
||||
act_fn=ActivationType.Relu2,
|
||||
)
|
||||
torch.testing.assert_close(out_triton, out_torch, rtol=5e-2, atol=5e-2)
|
||||
|
||||
@ -364,8 +366,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
w1_weight_scale,
|
||||
w2_weight_scale,
|
||||
w3_weight_scale_tensor,
|
||||
mlp_style="mlp",
|
||||
act_fn="relu2",
|
||||
is_gated_mlp=False,
|
||||
act_fn=ActivationType.Relu2,
|
||||
)
|
||||
|
||||
# Reference: Torch quantized FP8 MoE (uses lists of tensors and scales)
|
||||
@ -382,8 +384,8 @@ def test_triton_quant_fp8_moe_matches_torch_quant_fp8_moe(early_exit):
|
||||
w1_weight_scale=w1_weight_scale_list,
|
||||
w2_weight_scale=w2_weight_scale_list,
|
||||
w3_weight_scale=w3_weight_scale_list,
|
||||
mlp_style="mlp",
|
||||
act_fn="relu2",
|
||||
is_gated_mlp=False,
|
||||
act_fn=ActivationType.Relu2,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@ -566,6 +566,7 @@ class TestMoEAlltoAll:
|
||||
(4, [32, 32, 32, 32], 4),
|
||||
(4, [1, 1, 1, 1], 2),
|
||||
(8, [640, 640, 640, 640, 640, 640, 640, 640], 4),
|
||||
(4, [32, 0, 16, 0], 2),
|
||||
],
|
||||
indirect=["mpi_pool_executor"])
|
||||
def test_combine(self, mpi_pool_executor, all_num_tokens, top_k):
|
||||
|
||||
@ -19,49 +19,23 @@ example_images = [
|
||||
str(test_data_root / "61.jpg"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multimodal_model_config():
|
||||
"""Get multimodal model configuration similar to integration tests"""
|
||||
# You can extend this to support multiple models or get from environment
|
||||
model_configs = {
|
||||
'llava-v1.6-mistral-7b-hf': {
|
||||
'model_name':
|
||||
'llava-v1.6-mistral-7b-hf',
|
||||
'hf_model_dir':
|
||||
'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||
'model_dir':
|
||||
llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf",
|
||||
}
|
||||
}
|
||||
|
||||
return model_configs['llava-v1.6-mistral-7b-hf']
|
||||
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
|
||||
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
|
||||
|
||||
|
||||
# TODO: Add multi-image in single chat test
|
||||
@pytest.mark.parametrize("model_key", [
|
||||
"llava-v1.6-mistral-7b-hf",
|
||||
])
|
||||
@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR])
|
||||
@pytest.mark.parametrize("pd_disagg", [False, True])
|
||||
def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
def test_single_image_chat(model_dir, pd_disagg):
|
||||
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
|
||||
|
||||
This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
|
||||
results to standard llm generation (pass raw image) by comparing outputs.
|
||||
"""
|
||||
# Get model configuration
|
||||
if model_key != "llava-v1.6-mistral-7b-hf":
|
||||
#TODO: add more model tests progressively here
|
||||
pytest.skip(
|
||||
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
|
||||
)
|
||||
|
||||
# Extract model information from config
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
|
||||
# Test configuration
|
||||
max_tokens = 64
|
||||
free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2
|
||||
free_gpu_memory_fraction = 0.2
|
||||
max_batch_size = 1
|
||||
|
||||
# Test data - OpenAI chat completion format
|
||||
@ -76,15 +50,14 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
)
|
||||
|
||||
# Process multimodal data using encoder (pass mm_embeddings)
|
||||
encoder = MultimodalEncoder(model=encoder_model_dir,
|
||||
max_batch_size=max_batch_size)
|
||||
encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size)
|
||||
|
||||
cache_transceiver_cfg = CacheTransceiverConfig(
|
||||
backend="DEFAULT") if pd_disagg else None
|
||||
|
||||
disable_overlap_scheduler = pd_disagg
|
||||
|
||||
llm = LLM(model=encoder_model_dir,
|
||||
llm = LLM(model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
trust_remote_code=True,
|
||||
@ -93,7 +66,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
|
||||
llm_decode = None
|
||||
if pd_disagg:
|
||||
llm_decode = LLM(model=encoder_model_dir,
|
||||
llm_decode = LLM(model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
trust_remote_code=True,
|
||||
@ -141,6 +114,7 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
|
||||
assert ep_disaggregated_params is not None, "Encoder output disaggregated params is None"
|
||||
ep_disaggregated_params.request_type = "context_and_generation" if not pd_disagg else "context_only"
|
||||
|
||||
outputs = llm.generate(inputs,
|
||||
sampling_params=sampling_params,
|
||||
disaggregated_params=ep_disaggregated_params)
|
||||
@ -151,10 +125,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
pd_disaggregated_params = outputs[0].disaggregated_params
|
||||
pd_disaggregated_params.request_type = "generation_only"
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
inputs[0][
|
||||
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
|
||||
inputs[0]['prompt_token_ids'] = outputs[
|
||||
0].prompt_token_ids # use prompt token ids from encoder output
|
||||
# remove multimodal data from input as decoder worker doesn't need it
|
||||
inputs[0]['multi_modal_data'] = None
|
||||
# use prompt token ids from encoder output
|
||||
inputs[0]['prompt_token_ids'] = outputs[0].prompt_token_ids
|
||||
|
||||
outputs = llm_decode.generate(
|
||||
inputs,
|
||||
@ -199,24 +173,23 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
|
||||
f"Log probabilities don't match for output {i}, generation {j}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_key", [
|
||||
"llava-v1.6-mistral-7b-hf",
|
||||
])
|
||||
def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
@pytest.mark.parametrize(
|
||||
"model_dir, encoder_max_batch_size",
|
||||
[
|
||||
(_LLAVA_DIR, 3),
|
||||
# Qwen2.5 VL's vision encoder seems to output different embeddings based on this value.
|
||||
# The test only passes with this set to 1.
|
||||
(_QWEN_2_5_VL_DIR, 1),
|
||||
],
|
||||
)
|
||||
def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):
|
||||
"""Test batching multiple multimodal requests and verify encoder path matches raw path.
|
||||
|
||||
This mirrors test_single_image_chat but with a batch of size 3.
|
||||
"""
|
||||
if model_key != "llava-v1.6-mistral-7b-hf":
|
||||
pytest.skip(
|
||||
f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now"
|
||||
)
|
||||
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
|
||||
max_tokens = 64
|
||||
free_gpu_memory_fraction = 0.6
|
||||
max_batch_size = 3
|
||||
|
||||
prompts = [
|
||||
"Describe the natural environment in the image.",
|
||||
@ -232,10 +205,10 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
)
|
||||
|
||||
encoder = MultimodalEncoder(model=encoder_model_dir,
|
||||
max_batch_size=max_batch_size)
|
||||
encoder = MultimodalEncoder(model=model_dir,
|
||||
max_batch_size=encoder_max_batch_size)
|
||||
llm = LLM(
|
||||
model=encoder_model_dir,
|
||||
model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_batch_size=1, # fix batch size to reduce non-determinism in tests
|
||||
@ -305,8 +278,7 @@ def test_multi_request_batch_chat(model_key, multimodal_model_config):
|
||||
"Describe the weather in the image.",
|
||||
], 2),
|
||||
])
|
||||
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
|
||||
multimodal_model_config):
|
||||
def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
|
||||
"""Test mm_keys in KV cache events with cache reuse scenarios.
|
||||
|
||||
This test verifies:
|
||||
@ -316,7 +288,7 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates,
|
||||
- Same media + same prompts: full reuse (0 duplicate offsets)
|
||||
- Same media + different prompts: partial reuse (prefix blocks reused)
|
||||
"""
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
encoder_model_dir = _LLAVA_DIR
|
||||
|
||||
max_tokens = 16
|
||||
free_gpu_memory_fraction = 0.6
|
||||
|
||||
@ -21,11 +21,18 @@ def model(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[0, 2],
|
||||
ids=["disable_processpool", "enable_processpool"])
|
||||
def num_postprocess_workers(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model: str):
|
||||
def server(model: str, num_postprocess_workers: int):
|
||||
model_path = get_model_path(model)
|
||||
|
||||
args = []
|
||||
args = ["--num_postprocess_workers", f"{num_postprocess_workers}"]
|
||||
if model.startswith("Qwen3"):
|
||||
args.extend(["--reasoning_parser", "qwen3"])
|
||||
elif model.startswith("DeepSeek-R1"):
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import yaml
|
||||
@ -22,34 +19,28 @@ def backend(request):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def temp_extra_llm_api_options_file():
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_file_path = os.path.join(temp_dir, "extra_llm_api_options.yaml")
|
||||
try:
|
||||
extra_llm_api_options_dict = {
|
||||
"enable_chunked_prefill": False,
|
||||
"gather_generation_logits": True,
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False,
|
||||
}
|
||||
def temp_extra_llm_api_options_file(tmp_path_factory):
|
||||
extra_llm_api_options_dict = {
|
||||
"enable_chunked_prefill": False,
|
||||
"gather_generation_logits": True,
|
||||
"kv_cache_config": {
|
||||
"enable_block_reuse": False,
|
||||
}
|
||||
}
|
||||
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
|
||||
yield temp_file_path
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
temp_file_path = tmp_path_factory.mktemp(
|
||||
"config") / "extra_llm_api_options.yaml"
|
||||
with open(temp_file_path, 'w') as f:
|
||||
yaml.dump(extra_llm_api_options_dict, f)
|
||||
return temp_file_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_name: str, backend: str, temp_extra_llm_api_options_file: str):
|
||||
model_path = get_model_path(model_name)
|
||||
args = [
|
||||
"--backend", f"{backend}", "--extra_llm_api_options",
|
||||
temp_extra_llm_api_options_file
|
||||
]
|
||||
args = ["--backend", f"{backend}"]
|
||||
if backend == "trt":
|
||||
args += ["--extra_llm_api_options", temp_extra_llm_api_options_file]
|
||||
with RemoteOpenAIServer(model_path, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@ -61,11 +52,7 @@ def async_client(server: RemoteOpenAIServer):
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
|
||||
model_name: str, backend: str):
|
||||
# Skip if backend is PyTorch as it does not support topk logprobs when k > 1
|
||||
if backend == "pytorch":
|
||||
pytest.skip("Topk logprobs is not supported")
|
||||
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
@ -94,42 +81,3 @@ async def test_chat_completion_top5_logprobs(async_client: openai.AsyncOpenAI,
|
||||
assert logprob_content.bytes is not None
|
||||
assert logprob_content.top_logprobs is not None
|
||||
assert len(logprob_content.top_logprobs) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="module")
|
||||
async def test_chat_completion_top1_logprobs(async_client: openai.AsyncOpenAI,
|
||||
model_name: str, backend: str):
|
||||
# Skip if backend is TRT because it is tested in test_chat_completion_top5_logprobs
|
||||
if backend == "trt":
|
||||
pytest.skip(
|
||||
"TRT top logprobs is already tested in test_chat_completion_top5_logprobs"
|
||||
)
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What is the capital of France?"
|
||||
}]
|
||||
# Test top_logprobs=1
|
||||
chat_completion = await async_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
logprobs=True,
|
||||
top_logprobs=1,
|
||||
extra_body={
|
||||
"ignore_eos": True,
|
||||
})
|
||||
logprobs = chat_completion.choices[0].logprobs
|
||||
assert logprobs is not None and logprobs.content is not None
|
||||
assert len(logprobs.content) == 10
|
||||
for logprob_content in logprobs.content:
|
||||
assert logprob_content.token is not None
|
||||
assert logprob_content.logprob is not None
|
||||
assert logprob_content.bytes is not None
|
||||
assert logprob_content.top_logprobs is not None
|
||||
# Check that the top_logprobs contains only one entry
|
||||
assert len(logprob_content.top_logprobs) == 1
|
||||
|
||||
@ -213,21 +213,33 @@ def _test_sync_generation_tp_inner(llama_7b_tp2_path: Path):
|
||||
result.outputs[0].token_ids) == ", neural network,"
|
||||
|
||||
try:
|
||||
stats = await executor.aget_stats()
|
||||
stats = json.loads(stats)
|
||||
assert stats["iter"] == 0
|
||||
assert stats["cpuMemUsage"] > 0
|
||||
assert stats["gpuMemUsage"] > 0
|
||||
assert stats["inflightBatchingStats"]["numCtxTokens"] == 3
|
||||
assert stats["inflightBatchingStats"]["numGenRequests"] == 0
|
||||
assert stats["kvCacheStats"]["usedNumBlocks"] == 1
|
||||
stats_result = executor.aget_stats(timeout=2)
|
||||
# aget_stats now returns IterationResult, iterate to get stats
|
||||
async for stats_str in stats_result:
|
||||
stats = json.loads(stats_str) if isinstance(stats_str,
|
||||
str) else stats_str
|
||||
assert stats["iter"] >= 0
|
||||
assert stats["cpuMemUsage"] > 0
|
||||
assert stats["gpuMemUsage"] > 0
|
||||
break # Just check first result
|
||||
except AsyncQueue.EventLoopShutdownError:
|
||||
pass
|
||||
|
||||
asyncio.run(async_stats_task())
|
||||
|
||||
stats = executor.get_stats()
|
||||
assert json.loads(stats)["iter"] == 1
|
||||
# Poll for stats since RPC calls return immediately
|
||||
import time
|
||||
stats_list = []
|
||||
for _ in range(10):
|
||||
stats_list = executor.get_stats(timeout=0.5)
|
||||
if stats_list:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(stats_list) > 0
|
||||
stats = json.loads(stats_list[0]) if isinstance(stats_list[0],
|
||||
str) else stats_list[0]
|
||||
assert stats["iter"] == 1
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Required for test_generate_with_seed to pass.
|
||||
# See the discussion in https://github.com/NVIDIA/TensorRT-LLM/pull/4264#issuecomment-2943269891
|
||||
@ -2193,6 +2194,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
|
||||
sampling_params=sampling_params):
|
||||
print(output)
|
||||
|
||||
time.sleep(2)
|
||||
results = llm.get_stats(2)
|
||||
|
||||
validate_stats(results=results,
|
||||
@ -2203,7 +2205,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
enable_iter_req_stats=enable_iter_req_stats)
|
||||
|
||||
assert not llm.get_stats(2)
|
||||
assert not llm.get_stats(0.5)
|
||||
|
||||
# test that IterationResult()._done is properly set
|
||||
_ = llm.generate(prompts, sampling_params=sampling_params)
|
||||
@ -2340,8 +2342,9 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
|
||||
async def task1(repetition_index: int):
|
||||
results = []
|
||||
await asyncio.sleep(
|
||||
3) # ensure there's stats to collect for the assertion
|
||||
async for stats in llm.get_stats_async(timeout=2):
|
||||
4) # ensure there's stats to collect for the assertion
|
||||
async for stats in llm.get_stats_async(
|
||||
10): # it will return immediately
|
||||
results.append(stats)
|
||||
|
||||
assert results
|
||||
|
||||
@ -487,11 +487,12 @@ def test_llm_get_kv_cache_events_tp2():
|
||||
# created + stored events
|
||||
assert events and len(events) >= 2
|
||||
for event in events:
|
||||
print(f"event: {event}")
|
||||
if event:
|
||||
if event[0]["event_id"] == 0:
|
||||
assert event[0]["data"]["type"] == "created"
|
||||
elif event[0]["event_id"] == 1:
|
||||
assert event[0]["data"]["type"] == "stored"
|
||||
if event["event_id"] == 0:
|
||||
assert event["data"]["type"] == "created"
|
||||
elif event["event_id"] == 1:
|
||||
assert event["data"]["type"] == "stored"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
@ -976,6 +977,62 @@ async def test_llm_rpc_streaming():
|
||||
print(f"get result: {outputs}")
|
||||
|
||||
|
||||
@skip_ray
|
||||
def test_llm_rpc_get_stats():
|
||||
"""Test that get_stats works with RPC orchestrator."""
|
||||
|
||||
with LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
enable_iter_perf_stats=True,
|
||||
orchestrator_type="rpc") as llm:
|
||||
assert isinstance(llm._executor, GenerationExecutorRpcProxy)
|
||||
|
||||
# Generate some output to produce stats
|
||||
for output in llm.generate(
|
||||
prompts, sampling_params=SamplingParams(max_tokens=5)):
|
||||
print(output)
|
||||
|
||||
stats = llm.get_stats(timeout=5)
|
||||
|
||||
assert len(stats) > 0, "Should have at least one stats entry"
|
||||
# Stats should be JSON strings that can be parsed
|
||||
parsed = json.loads(stats[0]) if isinstance(stats[0], str) else stats[0]
|
||||
assert "iter" in parsed, "Stats should contain 'iter' field"
|
||||
assert "cpuMemUsage" in parsed, "Stats should contain 'cpuMemUsage' field"
|
||||
|
||||
|
||||
@skip_ray
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_rpc_get_stats_async():
|
||||
"""Test that get_stats_async works with RPC orchestrator."""
|
||||
import json
|
||||
|
||||
with LLM(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
enable_iter_perf_stats=True,
|
||||
orchestrator_type="rpc") as llm:
|
||||
assert isinstance(llm._executor, GenerationExecutorRpcProxy)
|
||||
|
||||
# Generate some output to produce stats
|
||||
async for output in llm.generate_async(
|
||||
prompts[0], sampling_params=SamplingParams(max_tokens=5)):
|
||||
print(output)
|
||||
|
||||
# Get stats via async API
|
||||
stats_result = llm.get_stats_async(timeout=2)
|
||||
|
||||
# Should be able to iterate over results
|
||||
stats_count = 0
|
||||
async for stat in stats_result:
|
||||
parsed = json.loads(stat) if isinstance(stat, str) else stat
|
||||
assert "iter" in parsed, "Stats should contain 'iter' field"
|
||||
stats_count += 1
|
||||
if stats_count >= 1:
|
||||
break # Just verify we can get at least one
|
||||
|
||||
assert stats_count > 0, "Should have received at least one stat"
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.part0
|
||||
@skip_ray
|
||||
|
||||
Loading…
Reference in New Issue
Block a user