Merge branch 'main' into wentao-enable-all-dense-for-mrv2

This commit is contained in:
yewentao256
2026-06-04 18:33:39 +00:00
227 changed files with 9771 additions and 3209 deletions
+4 -5
View File
@@ -1299,12 +1299,11 @@ steps:
source_file_dependencies:
- vllm/
- tests/entrypoints/llm
- tests/entrypoints/offline_mode
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_generate.py
- pytest -v -s entrypoints/offline_mode
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
- label: Entrypoints Integration (Pooling) # TBD
timeout_in_minutes: 180
@@ -1346,7 +1345,7 @@ steps:
- vllm/platforms/rocm.py
commands:
- pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- label: OpenAI API correctness # TBD
timeout_in_minutes: 180
+3 -4
View File
@@ -11,7 +11,7 @@ steps:
- tests/entrypoints/
commands:
- pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- label: Entrypoints Integration (LLM)
key: entrypoints-integration-llm
@@ -20,12 +20,11 @@ steps:
source_file_dependencies:
- vllm/
- tests/entrypoints/llm
- tests/entrypoints/offline_mode
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
- pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
mirror:
amd:
device: mi325_1
-7
View File
@@ -33,10 +33,3 @@ share/python-wheels/
*.egg
MANIFEST
rust/target/
# Not needed in Docker builds
docs/
.github/
.pre-commit-config.yaml
.clang-format
.gitattributes
format.sh
+2 -1
View File
@@ -34,10 +34,11 @@
/vllm/entrypoints/speech_to_text/realtime @njhill
/vllm/entrypoints/speech_to_text @NickLucche
/vllm/entrypoints/pooling @noooop
/vllm/entrypoints/sagemaker @DarkLight1337
/vllm/entrypoints/serve/sagemaker @DarkLight1337
/vllm/entrypoints/serve @njhill
/vllm/entrypoints/*.py @njhill
/vllm/entrypoints/chat_utils.py @DarkLight1337
/vllm/entrypoints/offline_utils.py @DarkLight1337
/vllm/entrypoints/llm.py @DarkLight1337
# Rust Frontend
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
actions: write
runs-on: ubuntu-latest
steps:
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
- uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v10.3.0
with:
# Increasing this value ensures that changes to this workflow
# propagate to all issues and PRs in days rather than months
@@ -102,6 +102,35 @@ constexpr float NUM_TOKEN_CUTOFF = 1024;
constexpr int kNumLanes = 32;
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
// Pack this lane's 16 fp32 elements into per-tensor E4M3 FP8 (one uint4 = 16
// B), scaling by `scale` (a reciprocal scale) and saturating to ±448. Used by
// the FlashInfer full-cache path for both the Q and KV stores.
__device__ __forceinline__ uint4 packFp8E4M3x16(float const* values,
float const scale) {
#ifndef USE_ROCM
uint4 out;
auto* out2 = reinterpret_cast<__nv_fp8x2_storage_t*>(&out);
#pragma unroll
for (int i = 0; i < kElemsPerLane / 2; i++) {
float2 scaled =
make_float2(values[2 * i] * scale, values[2 * i + 1] * scale);
scaled.x = fminf(fmaxf(scaled.x, -kFp8Max), kFp8Max);
scaled.y = fminf(fmaxf(scaled.y, -kFp8Max), kFp8Max);
out2[i] = __nv_cvt_float2_to_fp8x2(scaled, __NV_SATFINITE, __NV_E4M3);
}
return out;
#else
uint8_t out_bytes[kElemsPerLane];
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
float scaled = values[i] * scale;
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled);
}
return *reinterpret_cast<uint4 const*>(out_bytes);
#endif
}
// ────────────────────────────────────────────────────────────────────────────
// Small inline helpers
// ────────────────────────────────────────────────────────────────────────────
@@ -649,6 +678,257 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
#undef DISPATCH
}
// ────────────────────────────────────────────────────────────────────────────
// FlashInfer full-cache kernel
// ────────────────────────────────────────────────────────────────────────────
//
// Sibling to the FlashMLA kernel above, used by the FlashInfer V4 sparse-MLA
// backend. Differences from the legacy path:
// * No Q head padding — output Q layout matches the input num_heads_q.
// * KV is written as a *contiguous* 512-wide row per token (token-strided),
// not the legacy UE8M0 paged layout with a separate scale tail.
// * Q/KV are stored either as bf16 or as per-tensor E4M3 FP8 (one global
// scale), selected by the STORE_Q_FP8 / STORE_KV_FP8 template flags.
//
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) / warps).
// Each warp handles one (token, slot): slot < num_heads_q → Q, slot ==
// num_heads_q → KV.
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
__global__ void fusedDeepseekV4FullCacheKernel(
scalar_t_in* __restrict__ q_inout, // [N, H, 512], in place (bf16)
uint8_t* __restrict__ q_fp8_out, // [N, H, 512] fp8, optional
int64_t const q_fp8_stride0, // elements (fp8 == bytes)
int64_t const q_fp8_stride1, // elements (fp8 == bytes)
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
uint8_t* __restrict__ k_cache, // contiguous bf16 or fp8 cache
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
int64_t const* __restrict__ position_ids, // [N] i64
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
float const* __restrict__ fp8_scale_ptr, // scalar, KV fp8 only
float const* __restrict__ q_fp8_scale_inv, // scalar, Q fp8 only
float const eps,
int const num_tokens_full, // = q.size(0) = kv.size(0)
int const num_tokens_insert, // = slot_mapping.size(0)
int const num_heads_q, // H (no padding)
int const cache_block_size, // tokens per cache block
int64_t const kv_block_stride, // bytes per cache block
int64_t const kv_token_stride) { // bytes per cache token
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
int const slotsPerToken = num_heads_q + 1;
int const tokenIdx = globalWarpIdx / slotsPerToken;
int const slotIdx = globalWarpIdx % slotsPerToken;
if (tokenIdx >= num_tokens_full) return;
bool const isKV = (slotIdx == num_heads_q);
// KV branch: skip DP-padded tokens (no slot reserved for them).
if (isKV && tokenIdx >= num_tokens_insert) return;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaGridDependencySynchronize();
#endif
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
scalar_t_in const* src_ptr;
if (isKV) {
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
} else {
src_ptr = q_inout +
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) *
kHeadDim +
dim_base;
}
uint4 const v0 = *reinterpret_cast<uint4 const*>(src_ptr);
uint4 const v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
// ── Decode bf16 → 16 fp32 registers ───────────────────────────────────
float elements[kElemsPerLane];
{
auto const* p0 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
auto const* p1 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p0[i]);
elements[2 * i] = f2.x;
elements[2 * i + 1] = f2.y;
}
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p1[i]);
elements[8 + 2 * i] = f2.x;
elements[8 + 2 * i + 1] = f2.y;
}
}
// ── Q branch: RMSNorm (no weight) ─────────────────────────────────────
if (!isKV) {
float sumOfSquares = 0.0f;
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
sumOfSquares += elements[i] * elements[i];
}
sumOfSquares = warpSum<float>(sumOfSquares);
float const rms_rcp =
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
elements[i] = elements[i] * rms_rcp;
}
}
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────
bool const is_rope_lane = dim_base >= kNopeDim;
if (is_rope_lane) {
int64_t const pos = position_ids[tokenIdx];
constexpr int kHalfRope = kRopeDim / 2;
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
float const* sin_ptr = cos_ptr + kHalfRope;
int const rope_local_base = dim_base - kNopeDim;
int const half_base = rope_local_base >> 1;
float4 const c0 = *reinterpret_cast<float4 const*>(cos_ptr + half_base);
float4 const c1 = *reinterpret_cast<float4 const*>(cos_ptr + half_base + 4);
float4 const s0 = *reinterpret_cast<float4 const*>(sin_ptr + half_base);
float4 const s1 = *reinterpret_cast<float4 const*>(sin_ptr + half_base + 4);
float const cos_arr[8] = {c0.x, c0.y, c0.z, c0.w, c1.x, c1.y, c1.z, c1.w};
float const sin_arr[8] = {s0.x, s0.y, s0.z, s0.w, s1.x, s1.y, s1.z, s1.w};
#pragma unroll
for (int p = 0; p < kElemsPerLane / 2; p++) {
float const x_even = elements[2 * p];
float const x_odd = elements[2 * p + 1];
elements[2 * p] = x_even * cos_arr[p] - x_odd * sin_arr[p];
elements[2 * p + 1] = x_even * sin_arr[p] + x_odd * cos_arr[p];
}
}
// ── Store ─────────────────────────────────────────────────────────────
if (!isKV) {
if constexpr (STORE_Q_FP8) {
float const scale_inv = VLLM_LDG(q_fp8_scale_inv);
uint4 const out = packFp8E4M3x16(elements, scale_inv);
uint8_t* dst = q_fp8_out +
static_cast<int64_t>(tokenIdx) * q_fp8_stride0 +
static_cast<int64_t>(slotIdx) * q_fp8_stride1 + dim_base;
*reinterpret_cast<uint4*>(dst) = out;
} else {
uint4 out0, out1;
auto* po0 = reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
auto* po1 = reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
scalar_t_in* dst =
q_inout +
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
dim_base;
*reinterpret_cast<uint4*>(dst) = out0;
*reinterpret_cast<uint4*>(dst + 8) = out1;
}
} else {
int64_t const slot_id = slot_mapping[tokenIdx];
if (slot_id >= 0) {
int64_t const block_idx = slot_id / cache_block_size;
int64_t const pos_in_block = slot_id % cache_block_size;
uint8_t* cache_row =
k_cache + block_idx * kv_block_stride + pos_in_block * kv_token_stride;
if constexpr (STORE_KV_FP8) {
float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr);
uint4 const out = packFp8E4M3x16(elements, inv_scale);
*reinterpret_cast<uint4*>(cache_row + dim_base) = out;
} else {
uint4 out0, out1;
auto* po0 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
auto* po1 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
scalar_t_in* dst = reinterpret_cast<scalar_t_in*>(cache_row) + dim_base;
*reinterpret_cast<uint4*>(dst) = out0;
*reinterpret_cast<uint4*>(dst + 8) = out1;
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Configure + launch helper shared by the bf16 and fp8 full-cache launchers.
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
static void launchFullCacheKernel(
scalar_t_in* q_inout, uint8_t* q_fp8_out, int64_t q_fp8_stride0,
int64_t q_fp8_stride1, scalar_t_in const* kv_in, uint8_t* k_cache,
int64_t const* slot_mapping, int64_t const* position_ids,
float const* cos_sin_cache, float const* fp8_scale,
float const* q_fp8_scale_inv, float const eps, int const num_tokens_full,
int const num_tokens_insert, int const num_heads_q,
int const cache_block_size, int64_t const kv_block_stride,
int64_t const kv_token_stride, char const* op_name, cudaStream_t stream) {
constexpr int kBlockSize = 256;
constexpr int kWarpsPerBlock = kBlockSize / 32;
int64_t const total_warps =
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
int const grid =
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
auto* kernel =
fusedDeepseekV4FullCacheKernel<scalar_t_in, STORE_Q_FP8, STORE_KV_FP8>;
#ifndef USE_ROCM
static int const sm_version = getSMVersion();
STD_TORCH_CHECK(sm_version >= 80, op_name,
" requires sm_80+ (Ampere or newer); got sm_", sm_version);
cudaLaunchConfig_t config;
config.gridDim = dim3(grid);
config.blockDim = dim3(kBlockSize);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attrs;
config.numAttrs = (sm_version >= 90) ? 1 : 0;
cudaLaunchKernelEx(&config, kernel, q_inout, q_fp8_out, q_fp8_stride0,
q_fp8_stride1, kv_in, k_cache, slot_mapping, position_ids,
cos_sin_cache, fp8_scale, q_fp8_scale_inv, eps,
num_tokens_full, num_tokens_insert, num_heads_q,
cache_block_size, kv_block_stride, kv_token_stride);
#else
kernel<<<grid, kBlockSize, 0, stream>>>(
q_inout, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache,
slot_mapping, position_ids, cos_sin_cache, fp8_scale, q_fp8_scale_inv,
eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride, kv_token_stride);
#endif
}
} // namespace deepseek_v4_fused_ops
} // namespace vllm
@@ -735,3 +1015,167 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
});
return q_out;
}
// ────────────────────────────────────────────────────────────────────────────
// FlashInfer full-cache torch ops
// ────────────────────────────────────────────────────────────────────────────
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
torch::stable::Tensor& q, // [N, H, 512] bf16, in place
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] bf16
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
double eps, int64_t cache_block_size) {
using torch::headeronly::ScalarType;
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
"q must be contiguous CUDA");
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() == ScalarType::Long,
"slot_mapping must be int64 CUDA");
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() == ScalarType::Long,
"position_ids must be int64 CUDA");
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
cos_sin_cache.scalar_type() == ScalarType::Float &&
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64] float32");
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q.scalar_type() == ScalarType::BFloat16 &&
kv.scalar_type() == ScalarType::BFloat16,
"q and kv must be bfloat16");
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::BFloat16,
"k_cache must be bfloat16");
int const num_tokens_full = static_cast<int>(q.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q.size(1));
const torch::stable::accelerator::DeviceGuard device_guard(
q.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
// bf16 cache: 2 bytes/element -> byte strides for the uint8-addressed kernel.
int64_t const kv_block_stride = k_cache.stride(0) * 2;
int64_t const kv_token_stride = k_cache.stride(1) * 2;
VLLM_STABLE_DISPATCH_HALF_TYPES(
q.scalar_type(),
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert", [&] {
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, false,
false>(
reinterpret_cast<scalar_t*>(q.mutable_data_ptr()), nullptr, 0, 0,
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(), nullptr, nullptr,
static_cast<float>(eps), num_tokens_full, num_tokens_insert,
num_heads_q, static_cast<int>(cache_block_size), kv_block_stride,
kv_token_stride,
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
stream);
});
}
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
torch::stable::Tensor const& q, // [N, H, 512] bf16, read-only
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
torch::stable::Tensor& q_fp8, // [N, H, 512] fp8 e4m3
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] fp8
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
torch::stable::Tensor const& fp8_scale, // scalar float32 (KV scale)
torch::stable::Tensor const& q_fp8_scale_inv, // scalar float32 (1 / Q scale)
double eps, int64_t cache_block_size) {
using torch::headeronly::ScalarType;
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
"q must be contiguous CUDA");
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(q_fp8.device().is_cuda() && q_fp8.is_contiguous() &&
q_fp8.scalar_type() == ScalarType::Float8_e4m3fn &&
q_fp8.dim() == 3 && q_fp8.size(0) == q.size(0) &&
q_fp8.size(1) == q.size(1) && q_fp8.size(2) == q.size(2),
"q_fp8 must be a contiguous float8_e4m3fn tensor matching q");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() == ScalarType::Long,
"slot_mapping must be int64 CUDA");
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() == ScalarType::Long,
"position_ids must be int64 CUDA");
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
cos_sin_cache.scalar_type() == ScalarType::Float &&
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64] float32");
STD_TORCH_CHECK(fp8_scale.device().is_cuda() &&
fp8_scale.scalar_type() == ScalarType::Float &&
fp8_scale.size(0) == 1,
"fp8_scale must be a scalar float32 CUDA tensor");
STD_TORCH_CHECK(q_fp8_scale_inv.device().is_cuda() &&
q_fp8_scale_inv.scalar_type() == ScalarType::Float &&
q_fp8_scale_inv.size(0) == 1,
"q_fp8_scale_inv must be a scalar float32 CUDA tensor");
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q.scalar_type() == kv.scalar_type(),
"q and kv dtype must match");
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::Float8_e4m3fn,
"k_cache must be float8_e4m3fn");
int const num_tokens_full = static_cast<int>(q.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q.size(1));
const torch::stable::accelerator::DeviceGuard device_guard(
q.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
VLLM_STABLE_DISPATCH_HALF_TYPES(
q.scalar_type(),
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert", [&] {
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, true,
true>(
// q is read-only in the fp8 path (the kernel writes q_fp8); the
// launcher signature is non-const, so cast away const on the ptr.
reinterpret_cast<scalar_t*>(
const_cast<void*>(q.const_data_ptr())),
reinterpret_cast<uint8_t*>(q_fp8.mutable_data_ptr()),
q_fp8.stride(0), q_fp8.stride(1),
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(),
fp8_scale.const_data_ptr<float>(),
q_fp8_scale_inv.const_data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
static_cast<int>(cache_block_size),
// fp8 cache: 1 byte/element -> stride already in bytes.
k_cache.stride(0), k_cache.stride(1),
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
stream);
});
}
+17
View File
@@ -238,6 +238,23 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
double eps, int64_t cache_block_size);
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
torch::stable::Tensor& q, torch::stable::Tensor const& kv,
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache, double eps,
int64_t cache_block_size);
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
torch::stable::Tensor const& q, torch::stable::Tensor const& kv,
torch::stable::Tensor& q_fp8, torch::stable::Tensor& k_cache,
torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache,
torch::stable::Tensor const& fp8_scale,
torch::stable::Tensor const& q_fp8_scale_inv, double eps,
int64_t cache_block_size);
#ifndef USE_ROCM
torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input,
+20
View File
@@ -343,6 +343,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
// FlashInfer V4 full-cache variants: write Q in place (bf16) or to a separate
// FP8 tensor, and KV into a contiguous 512-wide token-strided cache.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, "
"Tensor position_ids, Tensor cos_sin_cache, float eps, "
"int cache_block_size) -> ()");
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert("
"Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, "
"int cache_block_size) -> ()");
#ifndef USE_ROCM
ops.def(
"minimax_allreduce_rms("
@@ -591,6 +605,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
ops.impl(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert));
ops.impl(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert));
#ifndef USE_ROCM
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
+2 -1
View File
@@ -55,7 +55,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch. Registered in _C_stable_libtorch.
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
// full-cache bf16/fp8 variants).
// Quantization ops
#ifndef USE_ROCM
-1
View File
@@ -98,7 +98,6 @@ RUN if [ "$USE_SCCACHE" = "1" ]; then \
ARG USE_SCCACHE
ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}}
ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}}
ENV SCCACHE_ENDPOINT=${USE_SCCACHE:+${SCCACHE_ENDPOINT}}
ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}}
ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0}
+14
View File
@@ -228,3 +228,17 @@ MLA decode backends are selected using the standard
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
### DeepSeek V4 Decode Backends
DeepSeek V4 sparse MLA uses its own decode backends, selected via
`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,
`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index
pipeline (compressor + SWA + indexer, 256-token blocks, head 512);
default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Non-Causal | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------ | --------- | ----------- | ---------- | ---- | ---------- | ------ | --------- | --- | --------------- | ------------ |
| `FLASHINFER_MLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
| `FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | 256 | 512 | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
| `ROCM_FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
+1
View File
@@ -82,6 +82,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
| Architecture | Models | CG for Image | CG for Video |
| ------------ | ------ | ------------ | ------------ |
| `InternVLChatModel` | `InternVL3.5`, `InternVL3`, `InternVL2.5`, `InternVL2` | ✅︎ | ✅︎ |
| `Qwen2VLForConditionalGeneration` | `Qwen2-VL` | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
+17 -14
View File
@@ -3,7 +3,7 @@
Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices.
!!! tip
To get started with quantization, see [LLM Compressor](llm_compressor.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
To get started with quantization, see [LLM Compressor](llm_compressor/README.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
The following are the supported quantization formats for vLLM:
@@ -12,9 +12,11 @@ The following are the supported quantization formats for vLLM:
- [GGUF](gguf.md)
- [GPTQModel](gptqmodel.md)
- [Intel Neural Compressor](inc.md)
- [INT4 W4A16](int4.md)
- [INT8 W8A8](int8.md)
- [FP8 W8A8](fp8.md)
- [LLM Compressor](llm_compressor/README.md)
- [FP8 W8A8](llm_compressor/fp8.md)
- [INT4 W4A16](llm_compressor/int4.md)
- [INT8 W4A8](llm_compressor/int8_w4a8.md)
- [INT8 W8A8](llm_compressor/int8_w8a8.md)
- [NVIDIA Model Optimizer](modelopt.md)
- [Online Quantization](online.md)
- [AMD Quark](quark.md)
@@ -46,16 +48,17 @@ th:not(:first-child) {
}
</style>
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU |
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- |
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | Arm CPU |
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- | ------- |
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| llm-compressor INT8 (W8A8)| ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ |
| llm-compressor INT8 (W4A8)| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ✅︎ |
| llm-compressor FP8 (W8A8) | ❌ | ❌ | | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- ✅︎ indicates that the quantization method is supported on the specified hardware.
@@ -21,9 +21,17 @@ The FP8 types typically supported in hardware have two distinct representations,
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
```bash
pip install llmcompressor
(venv-llm-compressor) pip install llmcompressor
```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process
The quantization process involves three main steps:
@@ -57,36 +65,28 @@ For FP8 quantization, we can recover accuracy with simple RTN quantization. We r
Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
??? code
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
# Configure the simple PTQ quantization
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["lm_head"],
)
# Configure the simple PTQ quantization
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["lm_head"],
)
# Apply the quantization algorithm.
oneshot(model=model, recipe=recipe)
# Apply the quantization algorithm.
oneshot(model=model, recipe=recipe)
# Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
```
# Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
```
### 3. Evaluating Accuracy
Install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
pip install vllm "lm-eval[api]>=0.4.12"
```
Load and run the model in `vllm`:
```python
@@ -12,15 +12,17 @@ Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs re
To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
```bash
pip install llmcompressor
(venv-llm-compressor) pip install llmcompressor
```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
pip install vllm "lm-eval[api]>=0.4.12"
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process
The quantization process involves four main steps:
@@ -52,55 +54,51 @@ When quantizing weights to INT4, you need sample data to estimate the weight upd
It's best to use calibration data that closely matches your deployment data.
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
??? code
```python
from datasets import load_dataset
```python
from datasets import load_dataset
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
### 3. Applying Quantization
Now, apply the quantization algorithms:
??? code
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
# Configure the quantization algorithms
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
# Configure the quantization algorithms
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
This process creates a W4A16 model with weights quantized to 4-bit integers.
@@ -141,36 +139,34 @@ lm_eval --model vllm \
The following is an example of an expanded quantization recipe you can tune to your own use case:
??? code
```python
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)
recipe = GPTQModifier(
targets="Linear",
config_groups={
"config_group": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
symmetric=True,
dynamic=False,
actorder="weight",
),
```python
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)
recipe = GPTQModifier(
targets="Linear",
config_groups={
"config_group": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
symmetric=True,
dynamic=False,
actorder="weight",
),
},
ignore=["lm_head"],
update_size=NUM_CALIBRATION_SAMPLES,
dampening_frac=0.01,
)
```
),
},
ignore=["lm_head"],
update_size=NUM_CALIBRATION_SAMPLES,
dampening_frac=0.01,
)
```
## Troubleshooting and Support
@@ -0,0 +1,217 @@
# INT8 W4A8
vLLM supports quantizing weights to INT4 and activations to INT8 for memory savings and inference acceleration.
This quantization method is particularly useful for reducing model size while maintaining good performance.
## Prerequisites
To use INT8 W4A8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library.
```bash
(venv-llm-compressor) pip install llmcompressor
```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process
The quantization process involves four main steps:
1. Loading the model
2. Preparing calibration data
3. Applying quantization
4. Evaluating accuracy in vLLM
### 1. Loading the Model
Load your model and tokenizer using the standard `transformers` AutoModel classes:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```
### 2. Preparing Calibration Data
When quantizing activations to INT8 and weights to INT4, you need sample data to estimate the activation scales.
It's best to use calibration data that closely matches your deployment data.
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
```python
from datasets import load_dataset
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
### 3. Applying Quantization
Now, apply the quantization algorithms.
The following recipes create W4A8 models (int4 weights, int8 activations). On Arm® CPUs, this is accelerated through [KleidiAI](https://github.com/ARM-software/kleidiai).
Use groupwise for best accuracy, and channelwise for best inference performance.
=== "Groupwise"
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
# Configure the quantization algorithms
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A8",
ignore=["lm_head"],
dampening_frac=0.01
),
]
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-G128-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
=== "Channelwise"
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
scheme = {
"targets": ["Linear"],
"weights": {
"num_bits": 4,
"type": QuantizationType.INT,
"strategy": QuantizationStrategy.CHANNEL,
"symmetric": True,
"dynamic": False,
"group_size": None,
},
"input_activations": {
"num_bits": 8,
"type": QuantizationType.INT,
"strategy": QuantizationStrategy.TOKEN,
"dynamic": True,
"symmetric": False,
"observer": None,
},
"output_activations": None,
}
recipe = [
GPTQModifier(
targets="Linear",
config_groups={"group_0": scheme},
ignore=["lm_head"],
dampening_frac=0.01,
),
]
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-Channelwise-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
### 4. Evaluating Accuracy
=== "Groupwise"
After quantization, you can load and run the model in vLLM:
```python
from vllm import LLM
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token")
```
To evaluate accuracy, you can use `lm_eval`:
```bash
lm_eval --model vllm \
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
--limit 250 \
--batch_size 'auto'
```
=== "Channelwise"
After quantization, you can load and run the model in vLLM:
```python
from vllm import LLM
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token")
```
To evaluate accuracy, you can use `lm_eval`:
```bash
lm_eval --model vllm \
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
--limit 250 \
--batch_size 'auto'
```
!!! note
Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations.
## Best Practices
- Start with 512 samples for calibration data (increase if accuracy drops)
- Use a sequence length of 2048 as a starting point
- Employ the chat template or instruction template that the model was trained with
- If you've fine-tuned a model, consider using a sample of your training data for calibration
## Troubleshooting and Support
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository.
@@ -17,15 +17,17 @@ Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs re
To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
```bash
pip install llmcompressor
(venv-llm-compressor) pip install llmcompressor
```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
pip install vllm "lm-eval[api]>=0.4.12"
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process
The quantization process involves four main steps:
@@ -57,26 +59,24 @@ When quantizing activations to INT8, you need sample data to estimate the activa
It's best to use calibration data that closely matches your deployment data.
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
??? code
```python
from datasets import load_dataset
```python
from datasets import load_dataset
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
</details>
@@ -84,33 +84,31 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra
Now, apply the quantization algorithms:
??? code
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
# Configure the quantization algorithms
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]
# Configure the quantization algorithms
recipe = [
SmoothQuantModifier(smoothing_strength=0.8),
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
# Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
This process creates a W8A8 model with weights and activations quantized to 8-bit integers.
+2
View File
@@ -569,6 +569,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
| `Granite4VisionForConditionalGeneration` | Granite 4 Vision | T + I<sup>E+</sup> | `ibm-granite/granite-4.1-3b-vision`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | T + A | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
| `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ |
@@ -709,6 +710,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
+7 -1
View File
@@ -24,8 +24,14 @@ echo "Checking pre-commit/pre-run-check status..."
MAX_WAIT=300
INTERVAL=60
ELAPSED=0
# Use a GitHub token if provided to raise the API rate limit (60 -> 5000
# requests/hour). Set GITHUB_TOKEN in the Read the Docs environment variables.
CURL_AUTH=()
if [ -n "$GITHUB_TOKEN" ]; then
CURL_AUTH=(-H "Authorization: Bearer $GITHUB_TOKEN")
fi
while :; do
RAW=$(curl -sS -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
RAW=$(curl -sS "${CURL_AUTH[@]}" -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
HTTP_CODE=$(printf %s "$RAW" | tail -n1)
BODY=$(printf %s "$RAW" | sed '$d')
if [ "$HTTP_CODE" != "200" ]; then
@@ -2554,6 +2554,7 @@ MODELS_NEED_VIDEO_METADATA = [
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
"internvl_chat",
"qwen2_5_vl",
"qwen3_vl",
"qwen3_vl_moe",
+3
View File
@@ -110,6 +110,9 @@ plugins:
redirect_maps:
features/spec_decode/README.md: features/speculative_decoding/README.md
features/spec_decode/speculators.md: features/speculative_decoding/speculators.md
features/quantization/fp8.md: features/quantization/llm_compressor/fp8.md
features/quantization/int4.md: features/quantization/llm_compressor/int4.md
features/quantization/int8.md: features/quantization/llm_compressor/int8_w8a8.md
serving/openai_compatible_server.md: serving/online_serving/README.md
markdown_extensions:
+1 -1
View File
@@ -38,7 +38,7 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL.
compressed-tensors == 0.15.0.1 # required for compressed-tensors
compressed-tensors == 0.17.0 # required for compressed-tensors
depyf==0.20.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files
+2 -2
View File
@@ -18,7 +18,7 @@ tilelang==0.1.9
nvidia-cudnn-frontend>=1.13.0,<1.19.0
# Required for faster safetensors model loading
fastsafetensors >= 0.2.2
fastsafetensors >= 0.3.2
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl[cu13]==4.5.2
@@ -28,4 +28,4 @@ quack-kernels>=0.3.3
tokenspeed-mla==0.1.2
# Humming kernels for quantization gemm
humming-kernels[cu13]==0.1.2
humming-kernels[cu13]==0.1.4
+4 -1
View File
@@ -19,7 +19,10 @@ setuptools-rust>=1.9.0
runai-model-streamer[s3,gcs,azure]==0.15.7
conch-triton-kernels==1.2.1
timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
tilelang==0.1.10
# Required for faster safetensors model loading
fastsafetensors >= 0.3.2
+1 -1
View File
@@ -57,7 +57,7 @@ arctic-inference == 0.1.1; platform_machine == "x86_64" # Required for suffix de
numba == 0.65.0 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs,azure]==0.15.7
fastsafetensors>=0.2.2; platform_machine == "x86_64" # 0.2.2 contains important fixes for multi-GPU mem usage
fastsafetensors>=0.3.2
instanttensor>=0.1.5; platform_machine == "x86_64"
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0; platform_machine == "x86_64"
+1 -1
View File
@@ -191,7 +191,7 @@ fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
# via cupy-cuda12x
fastsafetensors==0.2.2
fastsafetensors==0.3.2
# via
# -c requirements/cuda.txt
# -r requirements/test/cuda.in
+1 -1
View File
@@ -43,6 +43,6 @@ tritonclient>=2.51.0
numba == 0.65.0 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs,azure]==0.15.7
fastsafetensors>=0.2.2
fastsafetensors>=0.3.2
instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13
+1 -1
View File
@@ -56,7 +56,7 @@ arctic-inference==0.1.1 # Required for suffix decoding test
numba==0.65.0 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs,azure]==0.15.7
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@0.2.2 # PyPI only ships CUDA wheels
fastsafetensors>=0.3.2
instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0
+5 -3
View File
@@ -143,7 +143,7 @@ colorful==0.5.8
# via ray
colorlog==6.10.1
# via optuna
compressed-tensors==0.15.0.1
compressed-tensors==0.17.0
# via
# -c requirements/common.txt
# -r requirements/test/../common.txt
@@ -240,8 +240,10 @@ fastar==0.10.0
# via fastapi-cloud-cli
fastparquet==2026.3.0
# via genai-perf
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@65d80088fca7a8f567fba30415fbcc80f7d2259c
# via -r requirements/test/rocm.in
fastsafetensors==0.3.2
# via
# -c requirements/rocm.txt
# -r requirements/test/rocm.in
filelock==3.25.2
# via
# -c requirements/common.txt
+1 -1
View File
@@ -1168,7 +1168,7 @@ setup(
"zen": ["zentorch==2.11.0.0"],
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
"fastsafetensors": ["fastsafetensors >= 0.3.2"],
"instanttensor": ["instanttensor >= 0.1.5"],
"runai": ["runai-model-streamer[s3,gcs,azure] >= 0.15.7"],
"audio": [
@@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
AllReduceFusionPass,
RocmAiterAllReduceFusionPass,
)
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.fix_functionalization import (
FixFunctionalizationPass,
)
@@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
@@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceGemmaRMSNormModel(torch.nn.Module):
def __init__(
self,
hidden_size=16,
token_num=16,
eps=1e-6,
dtype: torch.dtype = torch.float16,
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)]
# Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path.
for n in self.norm:
n.weight.data.normal_(mean=0.0, std=0.1)
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
@@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
"test_model, enable_quant_fp8_custom_op, use_aiter",
[
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
pytest.param(
TestAllReduceGemmaRMSNormModel,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model,
True,
@@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model(
)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
if test_model_cls is TestAllReduceGemmaRMSNormModel:
fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass))
assert fused_nodes
assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes)
del all_reduce_fusion_pass
@@ -0,0 +1,250 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Inductor FALLBACK_ALLOW_LIST patch in env_override.py.
The patch wraps ``torch._inductor.lowering.FALLBACK_ALLOW_LIST`` in a thin
proxy that auto-allows any custom op in the ``vllm::`` or ``vllm_aiter::``
namespaces. This routes those ops through Inductor's fast-path
``make_fallback(target, warn=False, override_decomp=True)`` and avoids the
expensive ``error.operator_str(target, args, kwargs)`` formatting that
recursively stringifies every input ``TensorBox``.
The slow path is what made ``torch.compile`` effectively hang on Kimi-K2.6
TP=8 (deep MoE/TP IR provenance trees). These tests cover both the proxy's
semantics in isolation and the membership-check fast-path that Inductor's
``GraphLowering.call_function`` actually performs, so we can validate the
optimization without needing a full GPU compile.
"""
import time
import pytest
from vllm.env_override import (
_patch_inductor_fallback_allow_list,
_VllmFallbackAllowList,
)
class TestVllmFallbackAllowListProxy:
"""Unit tests for the membership-proxy semantics."""
def test_vllm_namespace_auto_allowed(self):
proxy = _VllmFallbackAllowList(set())
assert "vllm::all_reduce" in proxy
assert "vllm::fused_add_rms_norm" in proxy
assert "vllm::all_reduce.default" in proxy
def test_vllm_aiter_namespace_auto_allowed(self):
proxy = _VllmFallbackAllowList(set())
assert "vllm_aiter::fused_add_rms_norm" in proxy
assert "vllm_aiter::rocm_aiter_fused_moe" in proxy
def test_unknown_namespace_falls_through(self):
proxy = _VllmFallbackAllowList({"torchvision::roi_align"})
assert "torchvision::roi_align" in proxy
assert "made_up_ns::nonexistent_op" not in proxy
def test_non_string_falls_through_to_inner(self):
sentinel = object()
inner = {sentinel}
proxy = _VllmFallbackAllowList(inner)
assert sentinel in proxy
assert object() not in proxy
def test_prefix_only_match_not_substring(self):
proxy = _VllmFallbackAllowList(set())
assert "not_vllm::something" not in proxy
assert " vllm::space_prefixed" not in proxy
def test_standard_entries_preserved(self):
base = {"torchvision::roi_align", "aten::index_add"}
proxy = _VllmFallbackAllowList(base)
assert "torchvision::roi_align" in proxy
assert "aten::index_add" in proxy
assert "aten::__not_present__" not in proxy
def test_add_and_discard_delegate_to_inner(self):
inner: set[str] = set()
proxy = _VllmFallbackAllowList(inner)
proxy.add("custom::op")
assert "custom::op" in inner
proxy.discard("custom::op")
assert "custom::op" not in inner
def test_iter_len_repr(self):
base = {"torchvision::roi_align", "aten::index_add"}
proxy = _VllmFallbackAllowList(base)
assert set(iter(proxy)) == base
assert len(proxy) == len(base)
assert "torchvision::roi_align" in repr(proxy)
def test_getattr_delegates_to_inner(self):
class _Inner:
sentinel = "i_am_inner"
def some_method(self):
return 42
inner = _Inner()
proxy = _VllmFallbackAllowList(inner)
assert proxy.sentinel == "i_am_inner"
assert proxy.some_method() == 42
def test_sentinel_attribute(self):
proxy = _VllmFallbackAllowList(set())
assert proxy._vllm_patched is True
class TestPatchApplication:
"""Integration tests verifying the patch reaches ``torch._inductor``."""
def test_patch_applied_to_lowering(self):
import torch._inductor.lowering as _lowering
assert getattr(_lowering.FALLBACK_ALLOW_LIST, "_vllm_patched", False), (
"env_override._patch_inductor_fallback_allow_list did not run"
)
def test_graph_module_local_binding_rebound(self):
# ``torch/_inductor/graph.py`` does:
# from torch._inductor.lowering import FALLBACK_ALLOW_LIST
# so the patch has to overwrite the graph module's local binding too,
# otherwise the fast-path check in GraphLowering.call_function still
# sees the original (unwrapped) OrderedSet.
import torch._inductor.graph as _graph
import torch._inductor.lowering as _lowering
if not hasattr(_graph, "FALLBACK_ALLOW_LIST"):
pytest.skip(
"torch._inductor.graph no longer imports FALLBACK_ALLOW_LIST "
"as a module-level symbol; nothing to rebind."
)
assert _graph.FALLBACK_ALLOW_LIST is _lowering.FALLBACK_ALLOW_LIST
def test_patch_is_idempotent(self):
import torch._inductor.lowering as _lowering
first = _lowering.FALLBACK_ALLOW_LIST
_patch_inductor_fallback_allow_list()
_patch_inductor_fallback_allow_list()
assert _lowering.FALLBACK_ALLOW_LIST is first
def test_real_vllm_ops_in_real_allow_list(self):
# End-to-end membership check using the live (already-patched) object.
import torch._inductor.lowering as _lowering
allow_list = _lowering.FALLBACK_ALLOW_LIST
assert "vllm::all_reduce" in allow_list
assert "vllm::fused_add_rms_norm" in allow_list
assert "vllm_aiter::fused_add_rms_norm" in allow_list
class TestInductorFallbackFastPath:
"""Emulates ``GraphLowering.call_function``'s FALLBACK_ALLOW_LIST check.
The relevant snippet in ``torch/_inductor/graph.py`` is roughly::
base_name = target.name()
if base_name not in FALLBACK_ALLOW_LIST:
log.info(
"Creating implicit fallback for:\\n%s",
error.operator_str(target, args, kwargs),
)
out = make_fallback(target, ...)
On a deep MoE/TP graph (Kimi-K2.6 at TP=4/8) ``operator_str`` recurses
through every input ``TensorBox.__str__`` and ends up taking many minutes
of CPU per encountered op. The patch ensures the membership test
short-circuits for ``vllm::*``/``vllm_aiter::*`` ops so the slow path is
never entered. These tests pin that behaviour without needing a real
GPU compile.
"""
def _simulate_graph_lowering(self, target_names: list[str]):
"""Returns the set of target names that would have hit the slow
operator_str() path under the patched FALLBACK_ALLOW_LIST.
"""
import torch._inductor.lowering as _lowering
allow_list = _lowering.FALLBACK_ALLOW_LIST
slow_path_hits: list[str] = []
for name in target_names:
if name not in allow_list:
slow_path_hits.append(name)
return slow_path_hits
def test_vllm_ops_skip_slow_path(self):
slow = self._simulate_graph_lowering(
[
"vllm::all_reduce",
"vllm::fused_add_rms_norm",
"vllm_aiter::rocm_aiter_fused_moe",
"vllm_aiter::asm_moe",
]
)
assert slow == [], (
"Patched FALLBACK_ALLOW_LIST must short-circuit for all "
f"vllm::*/vllm_aiter::* ops; got slow-path hits: {slow}"
)
def test_non_vllm_ops_still_hit_slow_path(self):
# Without the patch this is also what would happen; with the patch
# the behaviour for non-vllm namespaces must be unchanged.
slow = self._simulate_graph_lowering(
["my_user_ns::custom_op", "fancy_ns::something_else"]
)
assert "my_user_ns::custom_op" in slow
assert "fancy_ns::something_else" in slow
def test_kimi_k2_6_style_op_stream(self):
"""Emulates one decoder layer's worth of fallback hits.
Kimi-K2.6 at TP=4 lowers a stream of ``vllm::all_reduce`` +
``vllm_aiter::fused_add_rms_norm`` calls (one per residual block)
plus a handful of fused-MoE ops. Pre-patch every one of these would
invoke ``operator_str`` and stringify a hundreds-deep IR provenance
tree; post-patch they must all short-circuit.
"""
n_layers = 64 # Kimi-K2.6 has ~64 decoder layers per replica
op_stream: list[str] = []
for _ in range(n_layers):
op_stream.extend(
[
"vllm::all_reduce",
"vllm_aiter::fused_add_rms_norm",
"vllm_aiter::rocm_aiter_fused_moe",
]
)
start = time.perf_counter()
slow = self._simulate_graph_lowering(op_stream)
elapsed_s = time.perf_counter() - start
assert slow == [], (
f"Expected all {len(op_stream)} vllm/vllm_aiter ops to take "
f"the fast path; got {len(slow)} slow-path hits."
)
# ``__contains__`` is O(1) per call, so a Kimi-sized stream should
# complete in well under a second even on a slow runner. The
# pre-patch slow path took many minutes per op on Kimi-K2.6 TP=8.
assert elapsed_s < 1.0, (
f"FALLBACK_ALLOW_LIST membership check is unexpectedly slow: "
f"{elapsed_s:.3f}s for {len(op_stream)} ops"
)
def test_inner_set_membership_still_works_for_standard_ops(self):
"""The patch must not break Inductor's existing fallback decisions
for non-vllm ops such as ``torchvision::roi_align``."""
import torch._inductor.lowering as _lowering
allow_list = _lowering.FALLBACK_ALLOW_LIST
# ``torchvision::roi_align`` has been a member of the upstream
# FALLBACK_ALLOW_LIST since the original Inductor implementation.
# If the proxy ever broke pass-through, this would regress.
if "torchvision::roi_align" not in allow_list:
pytest.skip(
"Upstream FALLBACK_ALLOW_LIST no longer ships "
"torchvision::roi_align; nothing to verify."
)
+21 -9
View File
@@ -277,12 +277,15 @@ def assert_verification_synced(local_ok: bool, msg: str) -> None:
assert bool(ok_tensor.item()), msg
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
def create_eplb_communicator_or_raise(
*, group_coordinator, backend, expert_weights, expert_buffer
):
try:
return create_eplb_communicator(
group_coordinator=group_coordinator,
backend=backend,
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
except Exception as exc:
raise RuntimeError(
@@ -355,7 +358,8 @@ def _test_async_transfer_layer_without_mtp_worker(
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
communicator.set_stream(cuda_stream)
@@ -368,6 +372,7 @@ def _test_async_transfer_layer_without_mtp_worker(
ep_group=ep_group,
communicator=communicator,
cuda_stream=cuda_stream,
layer_idx=layer_idx,
)
cuda_stream.synchronize()
move_from_buffer(
@@ -460,10 +465,12 @@ def _test_rearrange_expert_weights_with_redundancy(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend=eplb_communicator,
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Execute weight rearrangement
@@ -471,9 +478,9 @@ def _test_rearrange_expert_weights_with_redundancy(
old_indices,
new_indices,
expert_weights,
expert_buffer,
ep_group,
is_profile=False,
communicator=communicator,
communicator,
)
# Verify the rearrangement result
@@ -593,10 +600,12 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Execute rearrangement (should be no change)
@@ -604,9 +613,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
indices,
indices, # Same indices
expert_weights,
expert_buffer,
ep_group,
communicator,
is_profile=False,
)
# Verify that the weights have not changed
@@ -726,10 +735,12 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
layer_copy.append(weight.clone())
original_weights.append(layer_copy)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator,
backend="torch_nccl",
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Execute profile mode rearrangement
@@ -737,9 +748,10 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
old_indices,
new_indices,
expert_weights,
expert_buffer,
ep_group,
communicator,
is_profile=True, # Profile mode
is_profile=True,
)
# In profile mode, the weights should remain unchanged
+11 -1
View File
@@ -9,9 +9,11 @@ import pytest
import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_eplb_group,
get_tp_group,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -213,12 +215,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend="torch_nccl",
expert_weights=rank_expert_weights,
expert_buffer=expert_buffer,
)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
expert_buffer,
ep_group,
is_profile=False,
communicator,
)
num_local_experts = test_config.num_local_experts
@@ -10,11 +10,13 @@ import torch
from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
get_eplb_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -171,12 +173,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend="torch_nccl",
expert_weights=rank_expert_weights,
expert_buffer=expert_buffer,
)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
expert_buffer,
ep_group,
is_profile=False,
communicator,
)
num_global_experts = test_config.num_experts
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
import pytest
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import (
@@ -13,10 +14,13 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.entrypoints.serve.lora.protocol import (
LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest,
)
from vllm.exceptions import VLLMNotFoundError
from vllm.lora.request import LoRARequest
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
@@ -130,3 +134,60 @@ async def test_unload_lora_adapter_not_found():
assert isinstance(response, ErrorResponse)
assert response.error.type == "NotFoundError"
assert response.error.code == HTTPStatus.NOT_FOUND
class _ConcretePoolingServing(PoolingServingBase):
"""Minimal concrete subclass used only in these unit tests."""
request_id_prefix = "test"
def get_io_processor(self, request):
raise NotImplementedError
def _build_response(self, ctx):
raise NotImplementedError
def _make_pooling_serving(lora_name: str) -> _ConcretePoolingServing:
lora_request = LoRARequest(
lora_name=lora_name, lora_int_id=1, lora_path="/path/to/lora"
)
mock_models = MagicMock()
mock_models.lora_requests = {lora_name: lora_request}
mock_models.is_base_model.side_effect = lambda name: name == MODEL_NAME
serving = object.__new__(_ConcretePoolingServing)
serving.models = mock_models
return serving
def _make_pooling_ctx(model_name: str) -> PoolingServeContext:
mock_request = MagicMock()
mock_request.model = model_name
return PoolingServeContext(
request=mock_request,
model_name=MODEL_NAME,
request_id="test-id",
pooling_params=PoolingParams(),
)
def test_pooling_maybe_get_adapters_lora_name_sets_lora_request():
"""LoRA adapter name must populate ctx.lora_request without raising."""
lora_name = "bot-embed-lora"
serving = _make_pooling_serving(lora_name)
ctx = _make_pooling_ctx(lora_name)
serving._maybe_get_adapters(ctx)
assert ctx.lora_request is not None
assert ctx.lora_request.lora_name == lora_name
def test_pooling_maybe_get_adapters_unknown_model_raises():
"""An unrecognised model name must still raise VLLMNotFoundError."""
serving = _make_pooling_serving("some-lora")
ctx = _make_pooling_ctx("unknown-model")
with pytest.raises(VLLMNotFoundError):
serving._maybe_get_adapters(ctx)
@@ -6,7 +6,7 @@
import pytest
import pytest_asyncio
from ...utils import RemoteOpenAIServer
from tests.utils import RemoteOpenAIServer
# Model name constants used across tests
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
@@ -22,7 +22,8 @@ import tempfile
import pytest
import requests
from ...utils import RemoteOpenAIServer
from tests.utils import RemoteOpenAIServer
from .conftest import (
MODEL_NAME_SMOLLM,
)
@@ -4,7 +4,8 @@ import openai # use the official async_client for correctness check
import pytest
import requests
from ...utils import RemoteOpenAIServer
from tests.utils import RemoteOpenAIServer
from .conftest import MODEL_NAME_SMOLLM
@@ -12,7 +12,8 @@ import tempfile
import pytest
import requests
from ...utils import RemoteOpenAIServer
from tests.utils import RemoteOpenAIServer
from .conftest import (
MODEL_NAME_SMOLLM,
)
@@ -6,7 +6,8 @@ import openai # use the official client for correctness check
import pytest
import requests
from ...utils import RemoteOpenAIServer
from tests.utils import RemoteOpenAIServer
from .conftest import (
HEADER_SAGEMAKER_CLOSED_SESSION_ID,
HEADER_SAGEMAKER_NEW_SESSION_ID,
@@ -4,7 +4,7 @@
import pytest
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.utils import (
from vllm.entrypoints.serve.utils.api_utils import (
get_max_tokens,
sanitize_message,
should_include_usage,
@@ -6,7 +6,7 @@ from types import SimpleNamespace
import pytest
from vllm.entrypoints.openai import fingerprint as fp
from vllm.entrypoints.serve.utils import fingerprint as fp
def _cfg(tp=1, pp=1, dp=1, ep=False, digest="a3b21f94deadbeef"):
@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
def test_request_logger_log_outputs():
"""Test the new log_outputs functionality."""
# Create a mock logger to capture log calls
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test basic output logging
request_logger.log_outputs(
request_id="test-123",
outputs="Hello, world!",
output_token_ids=[1, 2, 3, 4],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-123"
assert call_args[3] == "Hello, world!"
assert call_args[4] == [1, 2, 3, 4]
assert call_args[5] == "stop"
def test_request_logger_log_outputs_streaming_delta():
"""Test log_outputs with streaming delta mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming delta logging
request_logger.log_outputs(
request_id="test-456",
outputs="Hello",
output_token_ids=[1],
finish_reason=None,
is_streaming=True,
delta=True,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-456"
assert call_args[2] == " (streaming delta)"
assert call_args[3] == "Hello"
assert call_args[4] == [1]
assert call_args[5] is None
def test_request_logger_log_outputs_streaming_complete():
"""Test log_outputs with streaming complete mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming complete logging
request_logger.log_outputs(
request_id="test-789",
outputs="Complete response",
output_token_ids=[1, 2, 3],
finish_reason="length",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-789"
assert call_args[2] == " (streaming complete)"
assert call_args[3] == "Complete response"
assert call_args[4] == [1, 2, 3]
assert call_args[5] == "length"
def test_request_logger_log_outputs_with_truncation():
"""Test log_outputs respects max_log_len setting."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
# Set max_log_len to 10
request_logger = RequestLogger(max_log_len=10)
# Test output truncation
long_output = "This is a very long output that should be truncated"
long_token_ids = list(range(20)) # 20 tokens
request_logger.log_outputs(
request_id="test-truncate",
outputs=long_output,
output_token_ids=long_token_ids,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
# Check that output was truncated to first 10 characters
logged_output = call_args[0][3]
assert logged_output == "This is a "
assert len(logged_output) == 10
# Check that token IDs were truncated to first 10 tokens
logged_token_ids = call_args[0][4]
assert logged_token_ids == list(range(10))
assert len(logged_token_ids) == 10
def test_request_logger_log_outputs_none_values():
"""Test log_outputs handles None values correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with None output_token_ids
request_logger.log_outputs(
request_id="test-none",
outputs="Test output",
output_token_ids=None,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-none"
assert call_args[3] == "Test output"
assert call_args[4] is None
assert call_args[5] == "stop"
def test_request_logger_log_outputs_empty_output():
"""Test log_outputs handles empty output correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=5)
# Test with empty output
request_logger.log_outputs(
request_id="test-empty",
outputs="",
output_token_ids=[],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-empty"
assert call_args[3] == ""
assert call_args[4] == []
assert call_args[5] == "stop"
def test_request_logger_log_outputs_integration():
"""Test that log_outputs can be called alongside log_inputs."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test that both methods can be called without interference
request_logger.log_inputs(
request_id="test-integration",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_embeds=None,
params=None,
lora_request=None,
)
request_logger.log_outputs(
request_id="test-integration",
outputs="Test output",
output_token_ids=[4, 5, 6],
finish_reason="stop",
is_streaming=False,
delta=False,
)
# Should have been called twice - once for inputs, once for outputs
assert mock_logger.info.call_count == 2
# Check that the calls were made with correct patterns
input_call = mock_logger.info.call_args_list[0][0]
output_call = mock_logger.info.call_args_list[1][0]
assert "Received request %s" in input_call[0]
assert input_call[1] == "test-integration"
assert "Generated response %s%s" in output_call[0]
assert output_call[1] == "test-integration"
def test_streaming_complete_logs_full_text_content():
"""Test that streaming complete logging includes
full accumulated text, not just token count."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with actual content instead of token count format
full_response = "This is a complete response from streaming"
request_logger.log_outputs(
request_id="test-streaming-full-text",
outputs=full_response,
output_token_ids=None,
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
# Verify the logged output is the full text, not a token count format
logged_output = call_args[3]
assert logged_output == full_response
assert "tokens>" not in logged_output
assert "streaming_complete" not in logged_output
# Verify other parameters
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"
@@ -7,7 +7,7 @@ from ssl import SSLContext
import pytest
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.entrypoints.serve.utils.ssl import SSLCertRefresher
class MockSSLContext(SSLContext):
@@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
accuracy_threshold: 0.72
num_questions: 1319
num_fewshot: 5
rocm_request_timeout_seconds: 1800
server_args: "--enforce-eager --max-model-len 4096"
@@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
accuracy_threshold: 0.45
num_questions: 1319
num_fewshot: 5
rocm_request_timeout_seconds: 1800
server_args: "--enforce-eager --max-model-len 4096"
+4 -4
View File
@@ -106,7 +106,7 @@ async def call_vllm_api(
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
return text, completion_tokens
except Exception as e:
print(f"Error calling vLLM API: {e}")
print(f"Error calling vLLM API ({type(e).__name__}): {e}")
return "", 0
@@ -177,6 +177,7 @@ def evaluate_gsm8k(
port: int = 8000,
temperature: float = 0.0,
seed: int | None = 42,
request_timeout_seconds: float = 600,
) -> dict[str, float | int]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
@@ -205,9 +206,8 @@ def evaluate_gsm8k(
output_tokens[i] = tokens
return answer, tokens
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=600)
) as session:
timeout = aiohttp.ClientTimeout(total=request_timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
tasks = [get_answer(session, i) for i in range(num_questions)]
await tqdm.gather(*tasks, desc="Evaluating")
@@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
host = f"http://{host}"
# Run GSM8K evaluation
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
if current_platform.is_rocm():
request_timeout_seconds = eval_config.get(
"rocm_request_timeout_seconds", request_timeout_seconds
)
results = evaluate_gsm8k(
num_questions=eval_config["num_questions"],
num_shots=eval_config["num_fewshot"],
host=host,
port=port,
request_timeout_seconds=request_timeout_seconds,
)
return results
@@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename):
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
if current_platform.is_rocm():
request_timeout_seconds = eval_config.get(
"rocm_request_timeout_seconds", request_timeout_seconds
)
print(f"Request timeout: {request_timeout_seconds}s")
print(f"Server args: {' '.join(server_args)}")
print(f"Environment variables: {env_dict}")
@@ -0,0 +1,339 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""ROCm kernel correctness tests for AITER unified attention.
Compares ``aiter.ops.triton.unified_attention`` against ``ref_paged_attn`` under
decode, prefill, and mixed batches with varied shapes.
"""
from typing import Any, Literal
import pytest
import torch
from tests.kernels.attention.test_triton_unified_attention import ref_paged_attn
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
_SKIP_NON_MI3XX = True
if current_platform.is_rocm():
from vllm.platforms.rocm import on_mi3xx
_SKIP_NON_MI3XX = not on_mi3xx()
pytestmark = [
pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-specific tests"),
pytest.mark.skipif(_SKIP_NON_MI3XX, reason="MI300/MI350 ROCm only"),
]
NUM_Q_HEADS = 8
NUM_KV_HEADS = 8
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 64]
DTYPES = [torch.bfloat16, torch.float16]
FP8_DTYPE = current_platform.fp8_dtype()
# (query_len, kv_len) per sequence
MIXED_SEQ_LENS = [
[(1, 128), (5, 18), (129, 463)],
[(10, 256), (5, 64), (32, 128)],
[(1, 1024), (5, 18), (129, 1328)],
]
DECODE_SEQ_LENS = [
[(1, 128), (1, 256), (1, 384), (1, 512)],
[(1, 1024), (1, 1536), (1, 2048)],
]
PREFILL_SEQ_LENS = [
[(256, 256), (128, 512)],
[(64, 128), (32, 256), (16, 512)],
[(256, 1024), (128, 2048)],
]
DEFAULT_ATOL, DEFAULT_RTOL = 1.5e-2, 1e-2
FP8_ATOL, FP8_RTOL = 1.5e-1, 1.5e-1
# Non-unity scale so q_descale handling is exercised explicitly.
Q_SCALE = 0.75
K_SCALE, V_SCALE = 0.5, 0.25
Fp8Variant = Literal["fp8_kv", "fp8_query", "fp8_query_kv"]
FP8_VARIANTS = [
pytest.param("fp8_kv", id="fp8_kv"),
pytest.param("fp8_query", id="fp8_query"),
pytest.param("fp8_query_kv", id="fp8_query_kv"),
]
FP8_SEQ_LENS = [
MIXED_SEQ_LENS[0],
DECODE_SEQ_LENS[0],
DECODE_SEQ_LENS[1],
PREFILL_SEQ_LENS[0],
PREFILL_SEQ_LENS[2],
]
def _require_aiter() -> None:
from vllm._aiter_ops import is_aiter_found_and_supported
if not is_aiter_found_and_supported():
pytest.skip("aiter is required on supported ROCm hardware for this test")
def _make_case(
*,
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
dtype: torch.dtype,
num_blocks: int = 2048,
kv_cache_dtype: torch.dtype | None = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
q_dtype: torch.dtype | None = None,
q_scale: float = Q_SCALE,
) -> dict[str, Any]:
torch.set_default_device("cuda")
query_lens = [q for q, _ in seq_lens]
kv_lens = [k for _, k in seq_lens]
num_seqs = len(seq_lens)
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(sum(query_lens), NUM_Q_HEADS, head_size, dtype=dtype)
if kv_cache_dtype is None:
key_cache = torch.randn(
num_blocks, block_size, NUM_KV_HEADS, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
else:
key_cache = torch.clamp(
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
-1.0,
1.0,
).to(kv_cache_dtype)
value_cache = torch.clamp(
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
-1.0,
1.0,
).to(kv_cache_dtype)
cu_seqlens_q = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
seq_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks), dtype=torch.int32
)
descale_shape = (num_seqs, NUM_KV_HEADS)
k_descale = torch.full(descale_shape, k_scale, dtype=torch.float32, device="cuda")
v_descale = torch.full(descale_shape, v_scale, dtype=torch.float32, device="cuda")
kernel_query = query
q_descale = None
if q_dtype is not None:
q_descale = torch.tensor(q_scale, dtype=torch.float32, device="cuda")
kernel_query = (query / q_scale).to(q_dtype)
return {
"query": query,
"kernel_query": kernel_query,
"key_cache": key_cache,
"value_cache": value_cache,
"block_tables": block_tables,
"query_lens": query_lens,
"kv_lens": kv_lens,
"seq_lens_tensor": seq_lens_tensor,
"cu_seqlens_q": cu_seqlens_q,
"q_descale": q_descale,
"k_descale": k_descale,
"v_descale": v_descale,
"scale": scale,
"max_query_len": max_query_len,
"max_kv_len": max_kv_len,
"query_dtype": dtype,
"k_scale": k_scale,
"v_scale": v_scale,
}
def _make_fp8_case(
*,
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
variant: Fp8Variant,
) -> dict[str, Any]:
use_fp8_kv = variant in ("fp8_kv", "fp8_query_kv")
use_fp8_query = variant in ("fp8_query", "fp8_query_kv")
return _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=torch.bfloat16,
kv_cache_dtype=FP8_DTYPE if use_fp8_kv else None,
k_scale=K_SCALE if use_fp8_kv else 1.0,
v_scale=V_SCALE if use_fp8_kv else 1.0,
q_dtype=FP8_DTYPE if use_fp8_query else None,
)
def _run_aiter_unified_attention(case: dict[str, Any]) -> torch.Tensor:
from aiter.ops.triton.unified_attention import unified_attention
kernel_query = case["kernel_query"]
# Kernel writes high-precision output even when Q is FP8 (matches vLLM usage).
output = torch.empty_like(case["query"])
unified_attention(
q=kernel_query,
k=case["key_cache"],
v=case["value_cache"],
out=output,
cu_seqlens_q=case["cu_seqlens_q"],
max_seqlen_q=case["max_query_len"],
seqused_k=case["seq_lens_tensor"],
max_seqlen_k=case["max_kv_len"],
softmax_scale=case["scale"],
causal=True,
alibi_slopes=None,
window_size=(-1, -1),
block_table=case["block_tables"],
softcap=0,
q_descale=case["q_descale"],
k_descale=case["k_descale"],
v_descale=case["v_descale"],
sinks=None,
output_scale=None,
)
return output
def _ref_output(case: dict[str, Any]) -> torch.Tensor:
key_cache = case["key_cache"]
value_cache = case["value_cache"]
if key_cache.dtype != case["query_dtype"]:
key_cache = key_cache.to(case["query_dtype"]) * case["k_scale"]
value_cache = value_cache.to(case["query_dtype"]) * case["v_scale"]
return ref_paged_attn(
query=case["query"],
key_cache=key_cache,
value_cache=value_cache,
query_lens=case["query_lens"],
kv_lens=case["kv_lens"],
block_tables=case["block_tables"],
scale=case["scale"],
)
def _assert_matches_reference(
case: dict[str, Any],
*,
atol: float = DEFAULT_ATOL,
rtol: float = DEFAULT_RTOL,
) -> None:
output = _run_aiter_unified_attention(case)
output_ref = _ref_output(case)
torch.testing.assert_close(output, output_ref, atol=atol, rtol=rtol)
@pytest.mark.parametrize("seq_lens", MIXED_SEQ_LENS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_aiter_unified_attn_mixed_batch(
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
dtype: torch.dtype,
) -> None:
"""Decode + prefill sequences in one batch (native dtypes)."""
_require_aiter()
set_random_seed(0)
case = _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=dtype,
)
_assert_matches_reference(case)
@pytest.mark.parametrize("seq_lens", DECODE_SEQ_LENS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_aiter_unified_attn_decode(
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
dtype: torch.dtype,
) -> None:
"""Single-token decode (native dtypes)."""
_require_aiter()
set_random_seed(0)
case = _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=dtype,
)
_assert_matches_reference(case)
@pytest.mark.parametrize("seq_lens", PREFILL_SEQ_LENS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16])
@torch.inference_mode()
def test_aiter_unified_attn_prefill(
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
) -> None:
"""Prefill-only batches with query_len > 1 (native dtypes)."""
_require_aiter()
set_random_seed(0)
case = _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=torch.bfloat16,
)
_assert_matches_reference(case)
@pytest.mark.skipif(
not current_platform.supports_fp8(),
reason="FP8 not supported on this hardware",
)
@pytest.mark.parametrize("variant", FP8_VARIANTS)
@pytest.mark.parametrize("seq_lens", FP8_SEQ_LENS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16, 64])
@torch.inference_mode()
def test_aiter_unified_attn_fp8(
variant: Fp8Variant,
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
) -> None:
"""FP8 KV cache, FP8 query, or both; compared at bf16 reference precision."""
_require_aiter()
set_random_seed(0)
case = _make_fp8_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
variant=variant,
)
_assert_matches_reference(case, atol=FP8_ATOL, rtol=FP8_RTOL)
+13 -5
View File
@@ -205,7 +205,10 @@ def run_with_expert_maps(
w2 = kwargs["w2"]
a = kwargs["hidden_states"]
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
max_num_tokens=kwargs.get("hidden_states").shape[0],
experts_per_token=kwargs.get("topk_ids").shape[1],
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
@@ -258,23 +261,27 @@ def run_8_bit(
a1_scale=None,
)
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
kwargs = {
"hidden_states": moe_tensors.a,
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
"global_num_experts": num_experts,
"activation": MoEActivation.SILU,
"expert_map": None,
"apply_router_weight_on_input": False,
}
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
moe_config = make_dummy_moe_config(
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
max_num_tokens=moe_tensors.a.shape[0],
experts_per_token=topk_ids.shape[1],
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
per_out_channel,
False,
topk_weights,
None,
)
workspace13.random_()
+4 -1
View File
@@ -1287,10 +1287,12 @@ def _test_body_eplb(
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=vllm_config.parallel_config.eplb_config.communicator,
expert_weights=expert_weights[0],
expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
# Rearrange expert weights across EP ranks
@@ -1298,6 +1300,7 @@ def _test_body_eplb(
old_global_expert_indices=initial_indices.unsqueeze(0),
new_global_expert_indices=shuffled_indices.unsqueeze(0),
expert_weights=expert_weights,
expert_buffer=expert_buffer,
ep_group=cpu_group,
communicator=communicator,
)
+6 -2
View File
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
def make_dummy_moe_config(
num_experts: int = 1,
num_local_experts: int | None = None,
experts_per_token: int = 1,
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
max_num_tokens: int = 512,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
num_local_experts=num_local_experts
if num_local_experts is not None
else num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
max_num_tokens=512,
max_num_tokens=max_num_tokens,
)
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Triton dequant-gather kernel used by
``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup)."""
import pytest
import torch
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
_dequant_gather_triton,
)
from vllm.platforms import current_platform
def _dequant_gather_torch(
ids: torch.Tensor,
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
hidden: int,
num_bits: int,
) -> torch.Tensor:
"""Reference: gather packed rows by id, unpack int32-packed INT, dequant."""
n = ids.shape[0]
int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden]))
scale_rows = weight_scale[ids]
w = int8.to(scale_rows.dtype)
if scale_rows.shape[1] == 1:
return w * scale_rows
ng = scale_rows.shape[1]
return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA"
)
@pytest.mark.parametrize("num_bits", [2, 4, 8])
@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("num_ids", [1, 17, 4096])
def test_dequant_gather(num_bits, group_size, dtype, num_ids):
torch.manual_seed(0)
device = "cuda"
vocab, hidden = 1000, 2048
pack_factor = 32 // num_bits
# Random full-range int32 packed weights (covers the sign bit -> exercises the
# arithmetic-shift + mask unpack path).
weight_packed = torch.randint(
-(2**31),
2**31,
(vocab, hidden // pack_factor),
dtype=torch.int32,
device=device,
)
num_groups = 1 if group_size == 0 else hidden // group_size
weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01
ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device)
out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits)
ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits)
assert out.shape == (num_ids, hidden)
assert out.dtype == dtype
torch.testing.assert_close(out, ref)
+149
View File
@@ -468,6 +468,7 @@ def _reference_kv_compress_norm_rope(
use_fp4: bool = False,
rms_eps: float = 1e-6,
fp8_max: float = 448.0,
return_full_cache: bool = False,
):
"""Compress → RMSNorm → GPT-J RoPE → quantize.
@@ -521,6 +522,12 @@ def _reference_kv_compress_norm_rope(
results.append(torch.cat([nope, rope]).to(state_cache.dtype))
result = torch.stack(results)
if return_full_cache:
# Contiguous 512-wide bf16 row (nope unrotated + rope rotated), matching
# the FlashInfer full-cache layout before any per-tensor fp8 quant. The
# kernel rounds the fp32 result to bf16 once at the store.
return result.to(torch.bfloat16)
if use_fp4:
return quantize_to_mxfp4(result)
else:
@@ -667,3 +674,145 @@ def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: b
assert torch.equal(actual_scale, scale[i : i + 1]), (
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
)
@pytest.mark.parametrize("compress_ratio", [4, 128])
@pytest.mark.parametrize("store_fp8", [False, True])
def test_cutedsl_full_cache_store(compress_ratio: int, store_fp8: bool):
"""CuTeDSL compressor full-cache (FlashInfer) store parity for head=512.
Exercises the contiguous bf16 / per-tensor fp8 store branch of both the C4
fused kernel and the C128 split kernel against the PyTorch reference.
"""
cutedsl = pytest.importorskip("cutlass") # noqa: F841
from vllm.models.deepseek_v4.nvidia.ops.sparse_attn_compress_cutedsl import (
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
)
HEAD_DIM = 512
ROPE_DIM = 64
RMS_EPS = 1e-6
FP8_MAX = 448.0
# C128 compress (Block8 kernel) requires state-cache block_size=8; C4 uses 16.
BLOCK_SIZE = 8 if compress_ratio == 128 else 16
KV_BLOCK_SIZE = 64
device = "cuda"
torch.manual_seed(7)
overlap = 1 if compress_ratio == 4 else 0
coff = 1 + overlap
num_tokens = 8
num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2
# The production CompressorStateCache is fp32.
state_cache = torch.randn(
num_pages, BLOCK_SIZE, 2 * coff * HEAD_DIM, dtype=torch.float32, device=device
)
block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)
token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
positions = torch.arange(
compress_ratio - 1,
compress_ratio * num_tokens,
compress_ratio,
dtype=torch.int64,
device=device,
)
rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
cos_sin_cache = torch.randn(
compress_ratio * num_tokens, ROPE_DIM, dtype=torch.float32, device=device
)
dtype = torch.float8_e4m3fn if store_fp8 else torch.bfloat16
kv_n_blocks = (num_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + 1
k_cache = torch.zeros(
kv_n_blocks, KV_BLOCK_SIZE, HEAD_DIM, dtype=dtype, device=device
)
fp8_scale = torch.tensor(
[0.5 if store_fp8 else 1.0], dtype=torch.float32, device=device
)
if compress_ratio == 4:
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
state_cache,
token_to_req,
positions,
slot_mapping,
block_table,
BLOCK_SIZE,
rms_weight,
RMS_EPS,
cos_sin_cache,
k_cache,
slot_mapping,
KV_BLOCK_SIZE,
k_cache.stride(0),
head_size=HEAD_DIM,
state_width=coff * HEAD_DIM,
rope_head_dim=ROPE_DIM,
fp8_max=FP8_MAX,
quant_block=64,
token_stride=576,
scale_dim=8,
compress_ratio=compress_ratio,
overlap=True,
store_full_kv=True,
store_full_fp8=store_fp8,
fp8_scale=fp8_scale,
)
else:
compressed_kv = torch.empty(
(num_tokens, HEAD_DIM), dtype=torch.float32, device=device
)
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
state_cache,
token_to_req,
positions,
slot_mapping,
block_table,
BLOCK_SIZE,
compressed_kv,
rms_weight,
RMS_EPS,
cos_sin_cache,
k_cache,
slot_mapping,
KV_BLOCK_SIZE,
k_cache.stride(0),
head_size=HEAD_DIM,
state_width=coff * HEAD_DIM,
rope_head_dim=ROPE_DIM,
fp8_max=FP8_MAX,
quant_block=64,
token_stride=576,
scale_dim=8,
compress_ratio=compress_ratio,
overlap=bool(overlap),
store_full_kv=True,
store_full_fp8=store_fp8,
fp8_scale=fp8_scale,
)
ref = _reference_kv_compress_norm_rope(
state_cache,
block_table,
positions,
rms_weight,
cos_sin_cache,
compress_ratio,
overlap,
rms_eps=RMS_EPS,
return_full_cache=True,
) # [num_tokens, HEAD_DIM] bf16
actual = torch.stack(
[k_cache[i // KV_BLOCK_SIZE, i % KV_BLOCK_SIZE] for i in range(num_tokens)]
)
if store_fp8:
ref_fp8 = torch.clamp(ref.float() / fp8_scale, -FP8_MAX, FP8_MAX).to(
torch.float8_e4m3fn
)
torch.testing.assert_close(actual.float(), ref_fp8.float(), rtol=0.0, atol=0.3)
else:
torch.testing.assert_close(actual.float(), ref.float(), rtol=3e-2, atol=3e-2)
@@ -67,7 +67,7 @@ def apply_rope_gptj_last_k(
head_dim = x.shape[-1]
nope_dim = head_dim - rope_dim
cs = cos_sin_cache[positions].to(torch.float32)
cs = cos_sin_cache[positions.long()].to(torch.float32)
cos = cs[..., :half]
sin = cs[..., half:]
@@ -114,6 +114,18 @@ def _op_available() -> bool:
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
def _full_cache_fp8_op_available() -> bool:
return hasattr(
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert"
)
def _full_cache_bf16_op_available() -> bool:
return hasattr(
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert"
)
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or not _op_available(),
reason="CUDA not available or fused DeepseekV4 op not built in",
@@ -415,3 +427,238 @@ def test_combined_q_and_kv(
"padded head slots must be exact zero"
)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# ── Full-cache (FlashInfer) path parity ──────────────────────────────────────
def _call_full_cache_fp8_fused(
q,
kv,
q_fp8,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
bs,
):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
q,
kv,
q_fp8,
k_cache,
slot_mapping,
positions.long(),
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
bs,
)
def _call_full_cache_bf16_fused(
q,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
eps,
bs,
):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
q,
kv,
k_cache,
slot_mapping,
positions.long(),
cos_sin_cache,
eps,
bs,
)
def _fp8_full_cache_reference(
q,
kv,
k_cache,
q_fp8,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
fp8_scale,
q_fp8_scale_inv,
):
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
q_fp8.copy_(
torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to(
torch.float8_e4m3fn
)
)
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
valid = slot_mapping >= 0
slots = slot_mapping[valid]
block_idx = slots // block_size
pos_in_block = slots % block_size
k_cache[block_idx, pos_in_block] = torch.clamp(
kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX
).to(torch.float8_e4m3fn)
def _bf16_full_cache_reference(
q,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
):
q_ref = rmsnorm_no_weight(q, eps)
# Kernel keeps RMSNorm+RoPE in fp32 and rounds to bf16 once at the store.
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache).to(q.dtype)
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
valid = slot_mapping >= 0
slots = slot_mapping[valid]
block_idx = slots // block_size
pos_in_block = slots % block_size
k_cache[block_idx, pos_in_block] = kv_ref[valid]
return q_ref
@pytest.mark.skipif(
not _full_cache_fp8_op_available(),
reason="full-cache per-tensor FP8 DeepseekV4 op not built in",
)
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("n_heads", [8, 17])
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
def test_full_cache_per_tensor_fp8_matches_reference(
num_tokens: int,
n_heads: int,
positions_dtype: torch.dtype,
):
torch.manual_seed(4)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
block_size = 16
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device)
q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn)
q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn)
k_cache_ref = torch.zeros(
num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device
)
k_cache_fused = torch.zeros_like(k_cache_ref)
_fp8_full_cache_reference(
q,
kv,
k_cache_ref,
q_fp8_ref,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
fp8_scale,
q_fp8_scale_inv,
)
_call_full_cache_fp8_fused(
q.clone(),
kv,
q_fp8_fused,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
block_size,
)
torch.testing.assert_close(
q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25
)
torch.testing.assert_close(
k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25
)
@pytest.mark.skipif(
not _full_cache_bf16_op_available(),
reason="full-cache BF16 DeepseekV4 op not built in",
)
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("n_heads", [8, 17])
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
def test_full_cache_bf16_matches_reference(
num_tokens: int,
n_heads: int,
positions_dtype: torch.dtype,
):
torch.manual_seed(5)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
block_size = 16
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
q_fused = q.clone()
k_cache_ref = torch.zeros(
num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device
)
k_cache_fused = torch.zeros_like(k_cache_ref)
q_ref = _bf16_full_cache_reference(
q,
kv,
k_cache_ref,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
_call_full_cache_bf16_fused(
q_fused,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
@@ -0,0 +1,481 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for sequence and token pooler head classes."""
import torch
import torch.nn as nn
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise.heads import (
ClassifierPoolerHead,
EmbeddingPoolerHead,
)
from vllm.model_executor.layers.pooler.tokwise.heads import (
TokenClassifierPoolerHead,
TokenEmbeddingPoolerHead,
)
from vllm.pooling_params import PoolingParams
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
_HIDDEN = 16
_BATCH = 3
def _make_params(
n: int,
*,
task: str = "embed",
dimensions: int | None = None,
use_activation: bool | None = None,
) -> list[PoolingParams]:
return [
PoolingParams(task=task, dimensions=dimensions, use_activation=use_activation)
for _ in range(n)
]
def _make_metadata(pooling_params: list[PoolingParams]) -> PoolingMetadata:
n = len(pooling_params)
return PoolingMetadata(
prompt_lens=torch.ones(n, dtype=torch.long),
prompt_token_ids=None,
prompt_token_ids_cpu=None,
pooling_params=pooling_params,
pooling_states=[PoolingStates() for _ in range(n)],
)
def _linear(in_f: int, out_f: int) -> nn.Linear:
torch.manual_seed(42)
return nn.Linear(in_f, out_f, bias=False)
# ---------------------------------------------------------------------------
# EmbeddingPoolerHead
# ---------------------------------------------------------------------------
class TestEmbeddingPoolerHead:
def test_supported_tasks(self):
head = EmbeddingPoolerHead()
assert head.get_supported_tasks() == {"embed"}
def test_passthrough(self):
head = EmbeddingPoolerHead()
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH))
out = head(x, meta)
assert torch.equal(out, x)
def test_head_dtype(self):
head = EmbeddingPoolerHead(head_dtype=torch.float16)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH))
out = head(x, meta)
assert out.dtype == torch.float16
def test_projector(self):
proj = _linear(_HIDDEN, 8)
head = EmbeddingPoolerHead(projector=proj)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH))
out = head(x, meta)
assert out.shape == (_BATCH, 8)
assert torch.allclose(out, proj(x))
def test_matryoshka_uniform(self):
head = EmbeddingPoolerHead()
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, dimensions=4)
meta = _make_metadata(params)
out = head(x, meta)
assert out.shape == (_BATCH, 4)
assert torch.equal(out, x[..., :4])
def test_matryoshka_mixed(self):
head = EmbeddingPoolerHead()
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="embed", dimensions=4),
PoolingParams(task="embed", dimensions=8),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
assert len(out) == 2
assert out[0].shape[-1] == 4
assert out[1].shape[-1] == 8
def test_matryoshka_mixed_with_none(self):
head = EmbeddingPoolerHead()
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="embed", dimensions=4),
PoolingParams(task="embed", dimensions=None),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
assert out[0].shape[-1] == 4
assert torch.equal(out[1], x[1])
def test_activation_uniform_true(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, use_activation=True)
meta = _make_metadata(params)
out = head(x, meta)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
def test_activation_uniform_false(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, use_activation=False)
meta = _make_metadata(params)
out = head(x, meta)
assert torch.equal(out, x)
def test_activation_mixed_flags(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="embed", use_activation=True),
PoolingParams(task="embed", use_activation=False),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
norm_0 = torch.linalg.norm(out[0], dim=-1)
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
assert torch.equal(out[1], x[1])
def test_list_input_gets_stacked(self):
head = EmbeddingPoolerHead()
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
meta = _make_metadata(_make_params(_BATCH))
out = head(tensors, meta)
assert out.shape == (_BATCH, _HIDDEN)
expected = torch.stack(tensors)
assert torch.equal(out, expected)
def test_projector_then_matryoshka(self):
proj = _linear(_HIDDEN, 8)
head = EmbeddingPoolerHead(projector=proj)
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, dimensions=4)
meta = _make_metadata(params)
out = head(x, meta)
assert out.shape == (_BATCH, 4)
assert torch.equal(out, proj(x)[..., :4])
def test_matryoshka_then_activation(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, dimensions=4, use_activation=True)
meta = _make_metadata(params)
out = head(x, meta)
assert out.shape == (_BATCH, 4)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
def test_empty_batch(self):
head = EmbeddingPoolerHead()
x = torch.randn(0, _HIDDEN)
meta = _make_metadata([])
out = head(x, meta)
assert out.shape == (0, _HIDDEN)
# ---------------------------------------------------------------------------
# ClassifierPoolerHead
# ---------------------------------------------------------------------------
class TestClassifierPoolerHead:
def test_supported_tasks(self):
head = ClassifierPoolerHead()
assert head.get_supported_tasks() == {"classify"}
def test_passthrough(self):
head = ClassifierPoolerHead()
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.equal(out, x)
def test_head_dtype(self):
head = ClassifierPoolerHead(head_dtype=torch.float16)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert out.dtype == torch.float16
def test_classifier(self):
clf = _linear(_HIDDEN, 3)
head = ClassifierPoolerHead(classifier=clf)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert out.shape == (_BATCH, 3)
assert torch.allclose(out, clf(x))
def test_logit_mean(self):
head = ClassifierPoolerHead(logit_mean=2.0)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.allclose(out, x - 2.0)
def test_logit_sigma(self):
head = ClassifierPoolerHead(logit_sigma=0.5)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.allclose(out, x / 0.5)
def test_platt_scaling_combined(self):
head = ClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.allclose(out, (x - 1.0) / 2.0)
def test_activation_uniform_true(self):
head = ClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, task="classify", use_activation=True)
meta = _make_metadata(params)
out = head(x, meta)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
def test_activation_uniform_false(self):
head = ClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, task="classify", use_activation=False)
meta = _make_metadata(params)
out = head(x, meta)
assert torch.equal(out, x)
def test_activation_mixed_flags(self):
head = ClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="classify", use_activation=True),
PoolingParams(task="classify", use_activation=False),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
norm_0 = torch.linalg.norm(out[0], dim=-1)
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
assert torch.equal(out[1], x[1])
def test_list_input_gets_stacked(self):
head = ClassifierPoolerHead()
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(tensors, meta)
assert out.shape == (_BATCH, _HIDDEN)
expected = torch.stack(tensors)
assert torch.equal(out, expected)
def test_classifier_then_platt_scaling(self):
clf = _linear(_HIDDEN, 3)
head = ClassifierPoolerHead(classifier=clf, logit_mean=1.0, logit_sigma=2.0)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
expected = (clf(x) - 1.0) / 2.0
assert torch.allclose(out, expected)
def test_empty_batch(self):
head = ClassifierPoolerHead()
x = torch.randn(0, _HIDDEN)
meta = _make_metadata([])
out = head(x, meta)
assert out.shape == (0, _HIDDEN)
# ---------------------------------------------------------------------------
# TokenEmbeddingPoolerHead
# ---------------------------------------------------------------------------
class TestTokenEmbeddingPoolerHead:
def test_supported_tasks(self):
head = TokenEmbeddingPoolerHead()
assert head.get_supported_tasks() == {"token_embed"}
def test_passthrough(self):
head = TokenEmbeddingPoolerHead()
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed")
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_none_chunked_prefill(self):
head = TokenEmbeddingPoolerHead()
param = PoolingParams(task="token_embed")
out = head.forward_chunk(None, param)
assert out is None
def test_head_dtype(self):
head = TokenEmbeddingPoolerHead(head_dtype=torch.float16)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed")
out = head.forward_chunk(x, param)
assert out.dtype == torch.float16
def test_projector(self):
proj = _linear(_HIDDEN, 8)
head = TokenEmbeddingPoolerHead(projector=proj)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed")
out = head.forward_chunk(x, param)
assert out.shape == (5, 8)
assert torch.allclose(out, proj(x))
def test_matryoshka_truncation(self):
head = TokenEmbeddingPoolerHead()
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", dimensions=4)
out = head.forward_chunk(x, param)
assert out.shape == (5, 4)
assert torch.equal(out, x[..., :4])
def test_activation_true(self):
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", use_activation=True)
out = head.forward_chunk(x, param)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
def test_activation_false(self):
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", use_activation=False)
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_projector_then_matryoshka(self):
proj = _linear(_HIDDEN, 8)
head = TokenEmbeddingPoolerHead(projector=proj)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", dimensions=4)
out = head.forward_chunk(x, param)
assert out.shape == (5, 4)
assert torch.equal(out, proj(x)[..., :4])
def test_matryoshka_then_activation(self):
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", dimensions=4, use_activation=True)
out = head.forward_chunk(x, param)
assert out.shape == (5, 4)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
def test_forward_mixed_batch_chunked_prefill(self):
head = TokenEmbeddingPoolerHead()
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
params = _make_params(3, task="token_embed")
meta = _make_metadata(params)
out = head(pooled_data, meta)
assert len(out) == 3
assert torch.equal(out[0], pooled_data[0])
assert out[1] is None
assert torch.equal(out[2], pooled_data[2])
def test_forward_empty_batch(self):
head = TokenEmbeddingPoolerHead()
meta = _make_metadata([])
out = head([], meta)
assert out == []
# ---------------------------------------------------------------------------
# TokenClassifierPoolerHead
# ---------------------------------------------------------------------------
class TestTokenClassifierPoolerHead:
def test_supported_tasks(self):
head = TokenClassifierPoolerHead()
assert head.get_supported_tasks() == {"token_classify"}
def test_passthrough(self):
head = TokenClassifierPoolerHead()
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_none_chunked_prefill(self):
head = TokenClassifierPoolerHead()
param = PoolingParams(task="token_classify")
out = head.forward_chunk(None, param)
assert out is None
def test_head_dtype(self):
head = TokenClassifierPoolerHead(head_dtype=torch.float16)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert out.dtype == torch.float16
def test_classifier(self):
clf = _linear(_HIDDEN, 3)
head = TokenClassifierPoolerHead(classifier=clf)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert out.shape == (5, 3)
assert torch.allclose(out, clf(x))
def test_logit_mean(self):
head = TokenClassifierPoolerHead(logit_mean=2.0)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.allclose(out, x - 2.0)
def test_logit_sigma(self):
head = TokenClassifierPoolerHead(logit_sigma=0.5)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.allclose(out, x / 0.5)
def test_platt_scaling_combined(self):
head = TokenClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.allclose(out, (x - 1.0) / 2.0)
def test_activation_true(self):
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify", use_activation=True)
out = head.forward_chunk(x, param)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
def test_activation_false(self):
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify", use_activation=False)
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_forward_mixed_batch_chunked_prefill(self):
head = TokenClassifierPoolerHead()
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
params = _make_params(3, task="token_classify")
meta = _make_metadata(params)
out = head(pooled_data, meta)
assert len(out) == 3
assert torch.equal(out[0], pooled_data[0])
assert out[1] is None
assert torch.equal(out[2], pooled_data[2])
def test_forward_empty_batch(self):
head = TokenClassifierPoolerHead()
meta = _make_metadata([])
out = head([], meta)
assert out == []
+76 -41
View File
@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test
from tests.utils import multi_gpu_test, wait_for_gpu_memory_to_clear
from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
@@ -404,6 +405,30 @@ def _get_vllm_runner_params(
}
def _wait_for_rocm_memory_to_settle() -> None:
if not current_platform.is_rocm():
return
num_gpus = current_platform.device_count()
if num_gpus == 0:
return
wait_for_gpu_memory_to_clear(
devices=list(range(num_gpus)),
threshold_ratio=0.01,
timeout_s=120,
)
@contextmanager
def _owned_vLLM_runner(vllm_runner, kwargs):
try:
with vllm_runner(**kwargs) as runner:
yield runner
finally:
_wait_for_rocm_memory_to_settle()
def _get_vLLM_output(
vllm_runner,
kwargs,
@@ -413,17 +438,21 @@ def _get_vLLM_output(
num_repetitions=1,
vllm_model=None,
):
outs = []
if vllm_model is None:
vllm_model = vllm_runner(**kwargs)
for _ in range(num_repetitions):
if num_logprobs < 0:
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
else:
vllm_output = vllm_model.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
outs.append(vllm_output)
runner_context = (
_owned_vLLM_runner(vllm_runner, kwargs)
if vllm_model is None
else nullcontext(vllm_model)
)
with runner_context as runner:
outs = []
for _ in range(num_repetitions):
if num_logprobs < 0:
vllm_output = runner.generate_greedy(prompts, max_tokens)
else:
vllm_output = runner.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
outs.append(vllm_output)
return outs, vllm_model
@@ -772,38 +801,44 @@ def test_apc_multiple_prompts_partial_cached_outputs(
# Cache only part of all the prompts
vllm_runner_kwargs["enable_prefix_caching"] = True
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
)
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0][:3],
outputs_1_lst=vllm_outputs_partial_cache[0],
name_0="vllm_no_cache",
name_1="vllm_partial_cache",
)
vllm_outputs_cache_rep, _ = _get_vLLM_output(
vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens,
num_logprobs,
n_repetitions,
vllm_model=vllm_model,
)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
with _owned_vLLM_runner(vllm_runner, vllm_runner_kwargs) as vllm_model:
vllm_outputs_partial_cache, _ = _get_vLLM_output(
vllm_runner,
vllm_runner_kwargs,
generated_prompts[:3],
max_tokens,
num_logprobs,
vllm_model=vllm_model,
)
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
outputs_0_lst=vllm_outputs_no_cache[0][:3],
outputs_1_lst=vllm_outputs_partial_cache[0],
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
name_1="vllm_partial_cache",
)
vllm_outputs_cache_rep, _ = _get_vLLM_output(
vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens,
num_logprobs,
n_repetitions,
vllm_model=vllm_model,
)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
# Test that outputs match whether prefix caching is enabled or not for mamba.
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
@@ -826,7 +861,7 @@ def test_same_mamba_output_apc_on_vs_off(
# No prefix caching
kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
with vllm_runner(**kwargs_no_apc) as vllm_model:
with _owned_vLLM_runner(vllm_runner, kwargs_no_apc) as vllm_model:
outputs_no_apc, _ = _get_vLLM_output(
vllm_runner,
kwargs_no_apc,
@@ -841,7 +876,7 @@ def test_same_mamba_output_apc_on_vs_off(
"enable_prefix_caching": True,
"mamba_block_size": 16,
}
with vllm_runner(**kwargs_with_apc) as vllm_model:
with _owned_vLLM_runner(vllm_runner, kwargs_with_apc) as vllm_model:
outputs_with_apc, _ = _get_vLLM_output(
vllm_runner,
kwargs_with_apc,
@@ -30,11 +30,14 @@ def vllm_to_hf_output(
MODEL_NAME = "ibm-granite/granite-speech-3.3-2b"
MODEL_NAME_4_0 = "ibm-granite/granite-4.0-1b-speech"
# "plus" variant of granite speech (uses GraniteSpeechPlusForConditionalGeneration).
MODEL_NAME_4_1_PLUS = "ibm-granite/granite-speech-4.1-2b-plus"
# Audio lora co-exists directly in the 3.3 model directory,
# the 4.0 model has adapters merged into the weights.
# the 4.0 and 4.1-plus models have adapters merged into the weights.
models: dict[str, str | None] = {
MODEL_NAME: MODEL_NAME,
MODEL_NAME_4_0: None,
MODEL_NAME_4_1_PLUS: None,
}
@@ -43,6 +43,10 @@ def qwen_vl_chat_template(content: str) -> str:
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
def internvl_chat_template(content: str) -> str:
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
def step3_vl_chat_template(content: str) -> str:
return (
"<begin▁of▁sentence> You are a helpful assistant.<|BOT|>user\n "
@@ -51,6 +55,17 @@ def step3_vl_chat_template(content: str) -> str:
MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
"internvl": VitCudagraphTestConfig(
model="OpenGVLab/InternVL3-1B",
num_video_frames=8,
image_prompt=internvl_chat_template("<image>\nWhat is in this image?"),
video_prompt=internvl_chat_template(
"<video>\nDescribe this video in one sentence."
),
needs_video_metadata=False,
vllm_runner_kwargs={"trust_remote_code": True},
marks=[pytest.mark.core_model],
),
"qwen2_5_vl": VitCudagraphTestConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct",
image_prompt=qwen_vl_chat_template(
+4
View File
@@ -938,6 +938,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"ibm-granite/granite-speech-3.3-2b",
extras={"4.0-1b": "ibm-granite/granite-4.0-1b-speech"},
),
"GraniteSpeechPlusForConditionalGeneration": _HfExamplesInfo(
"ibm-granite/granite-speech-4.1-2b-plus",
min_transformers_version="5.8.0",
),
"GLM4VForCausalLM": _HfExamplesInfo(
"zai-org/glm-4v-9b",
trust_remote_code=True,
+37
View File
@@ -36,11 +36,24 @@ def tokenizer():
return get_tokenizer("Qwen/Qwen3-32B")
TOOLS = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}},
},
}
]
@pytest.fixture
def request_obj():
return ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "hi"}],
tools=TOOLS,
tool_choice="auto",
)
@@ -328,3 +341,27 @@ def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert tool_args.endswith(remainder)
def test_parse_delta_tool_choice_none(tokenizer, request_obj):
parser = make_parser(tokenizer, reasoning=False, tool=True)
request = request_obj.model_copy(update={"tool_choice": "none"})
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
reasoning, content, tool_calls = collect_fields(results)
assert reasoning == ""
assert len(tool_calls) == 0
assert "<tool_call>" in content
assert "get_weather" in content
def test_parse_delta_tool_choice_none_with_reasoning(tokenizer, request_obj):
parser = make_parser(tokenizer, reasoning=True, tool=True)
request = request_obj.model_copy(update={"tool_choice": "none"})
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
reasoning, content, tool_calls = collect_fields(results)
assert "let me think about this" in reasoning
assert len(tool_calls) == 0
assert "<tool_call>" in content
assert "get_weather" in content
+5 -17
View File
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
@pytest.mark.parametrize(
"args",
[
# TODO: Enable once model is available again
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
],
)
def test_compressed_tensors_nvfp4(vllm_runner, args):
model, scheme = args
model, use_a16 = args
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
if (
isinstance(qkv_proj.scheme, scheme)
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
and not cutlass_fp4_supported()
):
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
assert qkv_proj.scheme.use_a16 == use_a16
assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model)
+1 -244
View File
@@ -10,12 +10,11 @@ from dataclasses import dataclass
from json.decoder import JSONDecodeError
from tempfile import NamedTemporaryFile
from typing import Any
from unittest.mock import MagicMock, patch
from unittest.mock import patch
from uuid import uuid4
import pytest
from vllm.entrypoints.logger import RequestLogger
from vllm.logger import (
_DATE_FORMAT,
_FORMAT,
@@ -269,248 +268,6 @@ def test_prepare_object_to_dump():
assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')"
def test_request_logger_log_outputs():
"""Test the new log_outputs functionality."""
# Create a mock logger to capture log calls
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test basic output logging
request_logger.log_outputs(
request_id="test-123",
outputs="Hello, world!",
output_token_ids=[1, 2, 3, 4],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-123"
assert call_args[3] == "Hello, world!"
assert call_args[4] == [1, 2, 3, 4]
assert call_args[5] == "stop"
def test_request_logger_log_outputs_streaming_delta():
"""Test log_outputs with streaming delta mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming delta logging
request_logger.log_outputs(
request_id="test-456",
outputs="Hello",
output_token_ids=[1],
finish_reason=None,
is_streaming=True,
delta=True,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-456"
assert call_args[2] == " (streaming delta)"
assert call_args[3] == "Hello"
assert call_args[4] == [1]
assert call_args[5] is None
def test_request_logger_log_outputs_streaming_complete():
"""Test log_outputs with streaming complete mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming complete logging
request_logger.log_outputs(
request_id="test-789",
outputs="Complete response",
output_token_ids=[1, 2, 3],
finish_reason="length",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-789"
assert call_args[2] == " (streaming complete)"
assert call_args[3] == "Complete response"
assert call_args[4] == [1, 2, 3]
assert call_args[5] == "length"
def test_request_logger_log_outputs_with_truncation():
"""Test log_outputs respects max_log_len setting."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
# Set max_log_len to 10
request_logger = RequestLogger(max_log_len=10)
# Test output truncation
long_output = "This is a very long output that should be truncated"
long_token_ids = list(range(20)) # 20 tokens
request_logger.log_outputs(
request_id="test-truncate",
outputs=long_output,
output_token_ids=long_token_ids,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
# Check that output was truncated to first 10 characters
logged_output = call_args[0][3]
assert logged_output == "This is a "
assert len(logged_output) == 10
# Check that token IDs were truncated to first 10 tokens
logged_token_ids = call_args[0][4]
assert logged_token_ids == list(range(10))
assert len(logged_token_ids) == 10
def test_request_logger_log_outputs_none_values():
"""Test log_outputs handles None values correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with None output_token_ids
request_logger.log_outputs(
request_id="test-none",
outputs="Test output",
output_token_ids=None,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-none"
assert call_args[3] == "Test output"
assert call_args[4] is None
assert call_args[5] == "stop"
def test_request_logger_log_outputs_empty_output():
"""Test log_outputs handles empty output correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=5)
# Test with empty output
request_logger.log_outputs(
request_id="test-empty",
outputs="",
output_token_ids=[],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-empty"
assert call_args[3] == ""
assert call_args[4] == []
assert call_args[5] == "stop"
def test_request_logger_log_outputs_integration():
"""Test that log_outputs can be called alongside log_inputs."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test that both methods can be called without interference
request_logger.log_inputs(
request_id="test-integration",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_embeds=None,
params=None,
lora_request=None,
)
request_logger.log_outputs(
request_id="test-integration",
outputs="Test output",
output_token_ids=[4, 5, 6],
finish_reason="stop",
is_streaming=False,
delta=False,
)
# Should have been called twice - once for inputs, once for outputs
assert mock_logger.info.call_count == 2
# Check that the calls were made with correct patterns
input_call = mock_logger.info.call_args_list[0][0]
output_call = mock_logger.info.call_args_list[1][0]
assert "Received request %s" in input_call[0]
assert input_call[1] == "test-integration"
assert "Generated response %s%s" in output_call[0]
assert output_call[1] == "test-integration"
def test_streaming_complete_logs_full_text_content():
"""Test that streaming complete logging includes
full accumulated text, not just token count."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with actual content instead of token count format
full_response = "This is a complete response from streaming"
request_logger.log_outputs(
request_id="test-streaming-full-text",
outputs=full_response,
output_token_ids=None,
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
# Verify the logged output is the full text, not a token count format
logged_output = call_args[3]
assert logged_output == full_response
assert "tokens>" not in logged_output
assert "streaming_complete" not in logged_output
# Verify other parameters
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"
# Add vllm prefix to make sure logs go through the vllm logger
test_logger = init_logger("vllm.test_logger")
@@ -54,15 +54,14 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
(
MiniMaxText01LinearAttention,
dict(
hidden_size=128,
hidden_inner_size=256,
num_heads=8,
head_dim=32,
max_position=2048,
block_size=64,
num_hidden_layer=12,
layer_idx=0,
linear_layer_idx=0,
config=SimpleNamespace(
hidden_size=256,
num_attention_heads=8,
head_dim=32,
num_hidden_layers=12,
block=64,
),
prefix="layers.0.self_attn",
),
LinearAttentionBackend,
MambaAttentionBackendEnum.LINEAR,
@@ -88,6 +87,8 @@ def test_mamba_layers_get_attn_backend(
expected_mamba_type,
):
"""Test that Mamba-like layers return the correct attention backend."""
if layer_class is MiniMaxText01LinearAttention:
init_kwargs["vllm_config"] = default_vllm_config
layer = layer_class(**init_kwargs)
backend_class = layer.get_attn_backend()
+37
View File
@@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n():
)
def test_free_kv_cache_block_queue_prepend_n():
# Seed the queue with one block so prepend has an existing head to splice
# in front of (fake_head->b0->fake_tail).
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
queue = FreeKVCacheBlockQueue(blocks[0:1])
# Prepend 0 blocks is a no-op.
queue.prepend_n([])
assert queue.num_free_blocks == 1
assert queue.fake_free_list_head.next_free_block is blocks[0]
# Prepend 2 blocks; they land in front of the existing head, in order.
# fake_head->b4->b5->b0->fake_tail
queue.prepend_n(blocks[4:6])
assert queue.num_free_blocks == 3
assert queue.fake_free_list_head.next_free_block is blocks[4]
assert blocks[4].prev_free_block is queue.fake_free_list_head
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is blocks[0]
assert blocks[0].prev_free_block is blocks[5]
assert blocks[0].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
# A second prepend goes ahead of everything previously prepended.
# fake_head->b1->b2->b4->b5->b0->fake_tail
queue.prepend_n(blocks[1:3])
assert queue.num_free_blocks == 5
assert queue.fake_free_list_head.next_free_block is blocks[1]
assert blocks[1].next_free_block is blocks[2]
assert blocks[2].next_free_block is blocks[4]
# The popleft order reflects the front-to-back queue order.
assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0]
assert queue.num_free_blocks == 0
def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create an empty FreeKVCacheBlockQueue with these blocks
+557
View File
@@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
KVCacheSpecKind,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
@@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
)
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# The SWA manager uses the configured 64-token interval (a multiple of the
# 32-token lcm_block_size) as its retention segment. For this 128-token
# prompt, the retained SWA tails are the 64-token interval boundary, the
# 96-token replay boundary, and the 128-token interval boundary.
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {7, 11, 15}
for i in range(16):
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
@pytest.mark.parametrize(
"interval, expected_match",
[
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
("33", "multiple of scheduler_block_size"),
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
# otherwise it would pass the modulo check and silently degrade to dense.
("-32", "non-negative"),
],
)
def test_hybrid_local_kv_retention_interval_rejects_invalid(
monkeypatch, interval, expected_match
):
"""A retention interval that is negative or not a multiple of
scheduler_block_size errors out at construction time."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
with pytest.raises(ValueError, match=expected_match):
make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
"""Verify retained local checkpoints are reused after block recycling."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
hash_block_size = 4
kv_cache_config = KVCacheConfig(
num_blocks=800,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
MLAAttentionSpec(
block_size=64 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.uint8,
compress_ratio=4,
),
),
KVCacheGroupSpec(
["swa"],
SlidingWindowSpec(
block_size=16 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=512,
),
),
KVCacheGroupSpec(
["c128"],
SlidingWindowSpec(
block_size=2 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=128,
),
),
KVCacheGroupSpec(
["c4"],
SlidingWindowSpec(
block_size=hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=8,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=4096,
enable_caching=True,
hash_block_size=hash_block_size,
)
def fill_request(request_id: str, token_offset: int) -> list[int]:
token_ids = [
token_offset + i for i in range(1024) for _ in range(hash_block_size)
]
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
while fill_req.num_computed_tokens < len(token_ids):
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
blocks = manager.allocate_slots(fill_req, num_new_tokens)
assert blocks is not None
fill_req.num_computed_tokens += num_new_tokens
manager.free(fill_req)
return token_ids
token_ids = fill_request("fill_0", 0)
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
assert num_computed_tokens == 1024
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
fill_request("fill_1", 100_000)
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
assert num_computed_tokens == 1024
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
"""Verify latest-only retention reuses only the replayable prompt boundary."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
token_ids = [i for i in range(16) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {11}
for i in range(16):
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
manager.free(req0)
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
assert retained_swa_block is not None
assert retained_swa_block[0].ref_cnt == 0
req1 = make_request("1", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Full prompt hits intentionally recompute the final block for logits, so
# the longest usable hit is the previous LCM boundary: 96 tokens.
assert num_computed_tokens == 12 * block_size
assert len(computed_blocks.blocks[1]) == 12
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
assert num_computed_tokens == 0
assert len(computed_blocks.blocks[1]) == 0
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
EAGLE/MTP lookup matches one additional local block after the returned
prefix and then drops it. Sparse retention must therefore cache the normal
local tail at the latest replay boundary plus one extra SWA block.
"""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0,
len(token_ids),
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {11, 12}
for i in range(15):
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
manager.free(req0)
req1 = make_request("1", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 12 * block_size
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
@@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
"""Default path (no retention): freeing an SWA request must place its
uncached scratch blocks at the front of the free queue (recycled first)
and keep its cached checkpoint blocks at the back (retained for prefix
hits). This split is always-on, independent of the retention interval."""
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
swa_manager = manager.coordinator.single_type_managers[1]
null_block = manager.block_pool.null_block
cached_ids: set[int] = set()
uncached_ids: set[int] = set()
cached_hash_indices: list[int] = []
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
if block is null_block:
continue
if block.block_hash is None:
uncached_ids.add(block.block_id)
else:
cached_ids.add(block.block_id)
cached_hash_indices.append(i)
# The dense default mask caches only the per-segment tails, so a 16-block
# SWA prompt must produce a mix of retained and scratch blocks.
assert cached_ids, "expected some retained (cached) SWA tail blocks"
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
manager.free(req)
order = [
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]
pos = {bid: i for i, bid in enumerate(order)}
# Every scratch block is recycled before every retained block.
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
# The retained tails survive the free and still serve a prefix-cache hit.
for i in cached_hash_indices:
assert (
manager.block_pool.get_cached_block(
req.block_hashes[i], kv_cache_group_ids=[1]
)
is not None
)
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)
return make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
**kwargs,
)
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
"""Sparse retention must work for a pure-SWA single-group model, not just
hybrid models: only the per-interval tails plus the latest replay tail are
cached, and a replay still hits the latest replayable boundary."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
# block 14. Crucially this is a strict subset of all 16 blocks: retention
# is actually sparse for pure SWA (not silently dense).
assert cached == {3, 7, 11, 14, 15}
# A replay of the same prompt hits the latest replayable boundary (240).
replay = make_request("1", token_ids, block_size, sha256)
_, num_computed = manager.get_computed_blocks(replay)
assert num_computed == 240
def test_pure_swa_retention_latest_only(monkeypatch):
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
# No segment tails (interval 0); only the latest replay tail (block 14).
assert cached == {14}
replay = make_request("1", token_ids, block_size, sha256)
_, num_computed = manager.get_computed_blocks(replay)
assert num_computed == 240
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
"""With retention unset, a pure-SWA model must keep the dense behavior:
every block boundary is a potential hit, so all blocks are cached."""
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
assert cached == set(range(16))
@@ -11,7 +11,9 @@ import pytest
import torch
from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
@@ -51,7 +53,7 @@ def test_rms_norm_batch_invariant_vs_standard(
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation (Triton)
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
@@ -125,7 +127,7 @@ def test_fused_add_rms_norm_batch_invariant_residual_path(
)
merged_single = x_single + residual_single
ref_out = triton_rms_norm(merged_single, weight, eps=eps)
ref_out = rms_norm_batch_invariant(merged_single, weight, eps=eps)
torch.testing.assert_close(
residual_out_single,
@@ -193,7 +195,7 @@ def test_rms_norm_3d_input(
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
@@ -242,7 +244,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Check for NaN or Inf
assert not torch.isnan(standard_output).any(), (
@@ -289,7 +291,7 @@ def test_rms_norm_formula(default_vllm_config):
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Compare against formula
torch.testing.assert_close(
@@ -325,7 +327,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
@@ -360,7 +362,7 @@ def test_rms_norm_determinism(default_vllm_config):
# Run multiple times
outputs = []
for _ in range(5):
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
output = rms_norm_batch_invariant(input_tensor.clone(), weight, eps=eps)
outputs.append(output)
# All outputs should be identical
@@ -395,7 +397,7 @@ if __name__ == "__main__":
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Compare
max_diff = (triton_output - standard_output).abs().max().item()
@@ -98,7 +98,6 @@ def _make_connector_with_fake_worker(
)
worker = connector.connector_worker
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
worker.kv_cache_layout = "HND"
if do_handshake:
remote_agents = worker._nixl_handshake(
+41 -8
View File
@@ -6,25 +6,56 @@
import pytest
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
pytestmark = pytest.mark.cpu_test
class _StubLMCacheMPConnector:
"""Stand-in for LMCacheMPConnector used in config-translation tests.
The real connector module hard-imports the optional ``lmcache`` package
at module load time, which is not installed in the cpu_test image. This
test only asserts on the connector *name* and the ``extra_config`` dict
produced by ``VllmConfig``, never instantiates the connector, so a bare
placeholder class is sufficient. Not subclassing ``SupportsHMA`` mirrors
the real connector's HMA support (it does not support HMA either)."""
@pytest.fixture
def stub_lmcache_mp_connector(monkeypatch):
"""Replace the lazy loader so VllmConfig.__post_init__ does not import
``lmcache_mp_connector`` (and thus ``lmcache``) during config tests."""
monkeypatch.setitem(
KVConnectorFactory._registry,
"LMCacheMPConnector",
lambda: _StubLMCacheMPConnector,
)
@pytest.mark.parametrize(
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
[
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30)),
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
# ``lmcache`` backend now defaults to LMCacheMPConnector. The KV
# storage capacity is owned by the standalone LMCache server, so
# ``kv_offloading_size`` is intentionally not propagated.
("lmcache", 4.0, 1, 1, "LMCacheMPConnector", None),
("lmcache", 8.0, 2, 2, "LMCacheMPConnector", None),
# When kv_offloading_size is None, offloading is disabled (backend is ignored)
("native", None, 1, 1, None, None),
],
)
def test_kv_connector(
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
stub_lmcache_mp_connector,
kv_offloading_backend,
kv_offloading_size,
tp,
pp,
expected_backend,
expected_bytes,
):
kv_transfer_config = (
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
@@ -59,10 +90,12 @@ def test_kv_connector(
# Existing config should be preserved
assert kv_connector_extra_config["existing_key"] == "existing_value"
elif kv_offloading_backend == "lmcache":
assert kv_connector_extra_config["lmcache.local_cpu"] is True
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
# Existing config should be replaced
assert "existing_key" not in kv_connector_extra_config
# MP mode does not push lmcache.local_cpu / max_local_cpu_size into
# extra config (the LMCache server owns capacity). Pre-existing
# extra config entries are preserved as-is.
assert "lmcache.local_cpu" not in kv_connector_extra_config
assert "lmcache.max_local_cpu_size" not in kv_connector_extra_config
assert kv_connector_extra_config["existing_key"] == "existing_value"
def _build_config(
@@ -197,13 +197,6 @@ class FakeNixlWrapper:
def get_xfer_telemetry(self, handle: int) -> dict:
return get_default_xfer_telemetry()
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
@contextlib.contextmanager
def _make_fake_nixl_pkg():
@@ -578,10 +571,7 @@ class TestNixlHandshake:
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
`request_id`.
"""
vllm_config = create_vllm_config()
@@ -598,7 +588,6 @@ class TestNixlHandshake:
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
@@ -1304,7 +1293,6 @@ def test_scheduler_kv_connector_stats_aggregation():
# Worker stats with transfer metrics
worker_stats = NixlKVConnectorStats()
worker_stats.record_transfer(get_default_xfer_telemetry())
worker_stats.data["remote_tokens"] = []
# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
scheduler_stats = NixlKVConnectorStats()
@@ -1314,7 +1302,6 @@ def test_scheduler_kv_connector_stats_aggregation():
"post_duration": [0],
"bytes_transferred": [0],
"num_descriptors": [0],
"remote_tokens": [128],
}
)
@@ -1355,7 +1342,6 @@ def test_scheduler_kv_connector_stats_aggregation():
).scheduler_stats.kv_connector_stats
nixl_stats = final_stats["NixlConnector"]
assert nixl_stats.num_successful_transfers == 2
assert nixl_stats.data["remote_tokens"] == [128]
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid(
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"local_physical_per_logical,remote_physical_per_logical,"
"local_block_ids,remote_block_ids,"
"expected_local,expected_remote",
[
# SSM prefix caching: remote has 3 placeholder + 1 real block,
# local has only the 1 real block. FA blocks are equal (no trim).
pytest.param(
10,
10,
[list(range(10)), [42]],
[list(range(10)), [40, 41, 42, 43]],
[list(range(10)), [42]],
[list(range(10)), [43]],
id="ssm_prefix_trim_only",
),
# FA partial prefix cache hit with homogeneous TP: local has 4 FA
# blocks (prefix cached), remote has full 10. SSM equal (no trim).
pytest.param(
10,
10,
[list(range(6, 10)), [42]],
[list(range(10)), [42]],
[list(range(6, 10)), [42]],
[list(range(6, 10)), [42]],
id="fa_prefix_hit_homo_tp",
),
# Both: FA partial prefix hit + SSM placeholder trim.
# local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9]
# local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real)
pytest.param(
10,
10,
[[6, 7, 8, 9], [99]],
[list(range(10)), [10, 20, 99]],
[[6, 7, 8, 9], [99]],
[[6, 7, 8, 9], [99]],
id="fa_prefix_hit_and_ssm_trim",
),
],
)
def test_apply_prefix_caching_ssm_prefix_cache_hit(
local_physical_per_logical,
remote_physical_per_logical,
local_block_ids,
remote_block_ids,
expected_local,
expected_remote,
):
"""_apply_prefix_caching end-trims SSM remote blocks to match the single
local block (placeholders dropped) and end-trims FA remote blocks on
partial prefix cache hits when physical_per_logical matches.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
worker = object.__new__(NixlConnectorWorker)
worker._has_mamba = True
worker._physical_blocks_per_logical_kv_block = local_physical_per_logical
worker._group_spec_types = (FullAttentionSpec, MambaSpec)
worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True)
aligned_local, aligned_remote = worker._apply_prefix_caching(
local_block_ids, remote_block_ids, remote_physical_per_logical
)
assert aligned_local == expected_local, (
f"Expected local {expected_local}, got {aligned_local}"
)
assert aligned_remote == expected_remote, (
f"Expected remote {expected_remote}, got {aligned_remote}"
)
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"local_physical_per_logical,remote_physical_per_logical,"
+4 -4
View File
@@ -1072,8 +1072,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
@pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
not current_platform.is_cuda(),
reason="Attention backend FLASHINFER is only supported on CUDA.",
)
def test_hybrid_attention_mamba_tensor_shapes():
"""
@@ -1508,8 +1508,8 @@ def test_is_uniform_decode() -> None:
@pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
not current_platform.is_cuda(),
reason="Attention backend FLASHINFER is only supported on CUDA.",
)
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
"""Test that a ValueError is raised when max_num_seqs exceeds the
@@ -1562,7 +1562,9 @@ def generate_legend() -> str:
def generate_mla_section(
prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]]
prefill_backends: list[dict[str, Any]],
decode_backends: list[dict[str, Any]],
v4_decode_backends: list[dict[str, Any]] | None = None,
) -> str:
"""Generate the complete MLA section with prefill and decode tables."""
lines = [
@@ -1611,6 +1613,22 @@ def generate_mla_section(
columns = _build_columns(is_mla=True, has_versions=False)
lines.extend(_render_table(columns, decode_backends))
if v4_decode_backends:
lines.extend(
[
"",
"### DeepSeek V4 Decode Backends",
"",
"DeepSeek V4 sparse MLA uses its own decode backends, selected via",
"`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,",
"`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index",
"pipeline (compressor + SWA + indexer, 256-token blocks, head 512);",
"default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.",
"",
]
)
lines.extend(_render_table(columns, v4_decode_backends))
lines.append("")
return "\n".join(lines)
@@ -1651,9 +1669,15 @@ def generate_docs() -> str:
if fi_features:
all_backends = _expand_flashinfer_variants(all_backends, fi_features)
# Split into MLA and non-MLA
mla_backends = [b for b in all_backends if b["is_mla"]]
non_mla_backends = [b for b in all_backends if not b["is_mla"]]
# DeepSeek V4 (*_DSV4) decode backends get their own subsection rather than
# mixing into the main MLA / standard tables (the ROCm V4 backend isn't
# flagged is_mla by the AST heuristic, so filter purely on the name).
def _is_v4(b: dict[str, Any]) -> bool:
return b["name"].endswith("_DSV4")
v4_decode_backends = [b for b in all_backends if _is_v4(b)]
mla_backends = [b for b in all_backends if b["is_mla"] and not _is_v4(b)]
non_mla_backends = [b for b in all_backends if not b["is_mla"] and not _is_v4(b)]
# Generate documentation
script_path = "tools/pre_commit/generate_attention_backend_docs.py"
@@ -1703,7 +1727,9 @@ def generate_docs() -> str:
doc_lines.append("\n>\n".join(footnotes) + "\n")
# Add MLA section with prefill and decode backends
doc_lines.append(generate_mla_section(mla_prefill_backends, mla_backends))
doc_lines.append(
generate_mla_section(mla_prefill_backends, mla_backends, v4_decode_backends)
)
return "\n".join(doc_lines)
+1 -1
View File
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs
from .plot import main as plot_main
@@ -44,6 +44,24 @@ from .matcher_utils import MatcherQuantFP8
FP8_DTYPE = current_platform.fp8_dtype()
_IR_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default
_IR_FUSED_ADD_RMS_NORM_OP = torch.ops.vllm_ir.fused_add_rms_norm.default
def _norm_input_weight_dtype_match(match: pm.Match) -> bool:
"""Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma
fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm."""
for node in match.nodes:
if node.target == _IR_RMS_NORM_OP:
x, weight = node.args[0], node.args[1]
elif node.target == _IR_FUSED_ADD_RMS_NORM_OP:
x, weight = node.args[0], node.args[2]
else:
continue
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
return x.meta["val"].dtype == weight.meta["val"].dtype
return True
# The empirical value for small batch
PDL_ADVANCE_LAUNCH_TOKENS = 16
@@ -132,6 +150,7 @@ if flashinfer_comm is not None:
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
weight_bias: float = 0.0,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
@@ -208,6 +227,7 @@ if flashinfer_comm is not None:
layout_code=layout_code,
use_oneshot=use_oneshot,
fp32_acc=fp32_acc,
weight_bias=weight_bias,
trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS,
)
@@ -225,6 +245,7 @@ if flashinfer_comm is not None:
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None,
weight_bias: float = 0.0,
) -> None:
pass
@@ -399,14 +420,142 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
# allreduce_in, residual
return allreduce[1], allreduce[2]
# extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern.
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_norm_input_weight_dtype_match,
)
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_norm_input_weight_dtype_match,
)
class AllReduceGemmaRMSNormPattern(BasePattern):
"""Gemma-style variant of AllReduceRMSNormPattern (no residual)."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def get_inputs(self) -> list[torch.Tensor]:
return [self.empty(5, 16), self.empty(16)]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = vllm.ir.ops.rms_norm(
allreduce_output, weight.float() + 1.0, self.epsilon
)
return rms, allreduce_output
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
weight_bias=1.0,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[3], allreduce[1]
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedAddGemmaRMSNormPattern(BasePattern):
"""Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual)."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def get_inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
residual = self.empty(5, 16)
weight = self.empty(16)
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = vllm.ir.ops.fused_add_rms_norm(
allreduce_output, residual, weight.float() + 1.0, self.epsilon
)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
weight_bias=1.0,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[1], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
@@ -881,6 +1030,18 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceGemmaRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddGemmaRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
@@ -36,10 +36,12 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
}
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
@@ -84,9 +84,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
}
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
+20 -18
View File
@@ -109,22 +109,24 @@ class EPLBConfig:
class ParallelConfig:
"""Configuration for the distributed execution."""
pipeline_parallel_size: int = 1
pipeline_parallel_size: int = Field(default=1, ge=1)
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
tensor_parallel_size: int = Field(default=1, ge=1)
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
prefill_context_parallel_size: int = Field(default=1, ge=1)
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
data_parallel_size: int = Field(default=1, ge=1)
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_size_local: int = Field(default=1, ge=0)
"""Number of local data parallel groups. A value of 0 is a sentinel used by
the engine-args layer to signal that data parallelism was specified
externally (see `ParallelConfig.__post_init__`)."""
data_parallel_rank: int = Field(default=0, ge=0)
"""Rank of the data parallel group. The runtime check at
``__post_init__`` further bounds this by ``data_parallel_size``."""
data_parallel_rank_local: int | None = None
"""Local rank of the data parallel group,
set only in SPMD mode."""
"""Local rank of the data parallel group, set only in SPMD mode."""
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
@@ -184,7 +186,7 @@ class ParallelConfig:
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
max_parallel_loading_workers: int | None = None
max_parallel_loading_workers: int | None = Field(default=None, ge=1)
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models."""
@@ -197,15 +199,15 @@ class ParallelConfig:
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
ubatch_size: int = Field(default=0, ge=0)
"""Number of ubatch size."""
dbo_decode_token_threshold: int = 32
dbo_decode_token_threshold: int = Field(default=32, ge=0)
"""The threshold for dual batch overlap for batches only containing decodes.
If the number of tokens in the request is greater than this threshold,
microbatching will be used. Otherwise, the request will be processed in a
single batch."""
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
dbo_prefill_token_threshold: int = Field(default=512, ge=0) # TODO(lucas): tune
"""The threshold for dual batch overlap for batches that contain one or more
prefills. If the number of tokens in the request is greater than this
threshold, microbatching will be used. Otherwise, the request will be
@@ -260,10 +262,10 @@ class ParallelConfig:
master_port: int = 29501
"""distributed master port for multi-node distributed
inference when distributed_executor_backend is mp."""
node_rank: int = 0
"""distributed node rank for multi-node distributed
node_rank: int = Field(default=0, ge=0)
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
nnodes: int = Field(default=1, ge=1)
"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""
numa_bind: bool = False
@@ -318,7 +320,7 @@ class ParallelConfig:
"""Port of the coordination TCPStore. Can be set by the API server; workers
connect as clients to exchange self-picked group ports at runtime."""
decode_context_parallel_size: int = 1
decode_context_parallel_size: int = Field(default=1, ge=1)
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
+6 -10
View File
@@ -771,10 +771,6 @@ class VllmConfig:
# If no KVTransferConfig is provided, create a default one.
if self.kv_transfer_config is None:
self.kv_transfer_config = KVTransferConfig()
num_kv_ranks = (
self.parallel_config.tensor_parallel_size
* self.parallel_config.pipeline_parallel_size
)
if kv_offloading_backend == "native":
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
@@ -786,12 +782,12 @@ class VllmConfig:
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
)
elif kv_offloading_backend == "lmcache":
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
self.kv_transfer_config.kv_connector_extra_config = {
"lmcache.local_cpu": True,
"lmcache.max_local_cpu_size": kv_gb_per_rank,
}
# Default to LMCache multi-process (MP) mode. The actual KV
# storage capacity is managed by the standalone LMCache server
# process, so ``kv_offloading_size`` is not propagated here.
# ``LMCacheMPConnector`` falls back to ``tcp://localhost:5555``
# when host/port are not provided via extra_config.
self.kv_transfer_config.kv_connector = "LMCacheMPConnector"
# This is the same for all backends
self.kv_transfer_config.kv_role = "kv_both"
@@ -40,11 +40,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
from vllm.distributed.utils import is_weak_contiguous # noqa: E402
class CustomAllreduce:
@@ -24,11 +24,7 @@ except Exception:
quick_ar = False
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
from vllm.distributed.utils import is_weak_contiguous # noqa: E402, F401
class QuickReduceRegime(Enum):
@@ -470,10 +470,14 @@ class ElasticEPScalingExecutor:
module._replace_quant_method(module.quant_method.old_quant_method)
prepare_communication_buffer_for_model(self.worker.model_runner.model)
eplb_model_state.expert_buffer = [
torch.empty_like(w) for w in model.expert_weights[0]
]
eplb_model_state.communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0],
expert_weights=model.expert_weights,
expert_buffer=eplb_model_state.expert_buffer,
)
if (
+1
View File
@@ -120,6 +120,7 @@ def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
layer_idx=layer_idx,
)
# Wait until all writes to expert_buffer have finished before making the
+199 -186
View File
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (
is_local_first_rank,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import is_weak_contiguous
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -63,8 +64,22 @@ class EplbCommunicator(ABC):
pass
@abstractmethod
def execute(self, old_indices: np.ndarray | None = None) -> None:
pass
def execute(self) -> None:
"""Complete all enqueued transfers.
Some backends perform communication here; others (e.g. NIXL)
issue transfers eagerly in add_recv and only wait here.
On return, all data is available in the destination buffers.
"""
def set_transfer_context( # noqa: B027
self, old_indices: np.ndarray, layer_idx: int
) -> None:
"""Pre-set layer context before add_recv calls.
Default is a no-op; overridden by backends (e.g. NIXL) that need
layer-level context to issue transfers inside add_recv.
"""
@property
def needs_profile_buffer_reservation(self) -> bool:
@@ -125,7 +140,7 @@ class TorchDistNcclEplbCommunicator(EplbCommunicator):
)
)
def execute(self, old_indices: np.ndarray | None = None) -> None:
def execute(self) -> None:
if not self._p2p_ops:
return
try:
@@ -168,7 +183,7 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
for tensor in tensors:
self._ops.append(("recv", tensor, src_rank))
def execute(self, old_indices: np.ndarray | None = None) -> None:
def execute(self) -> None:
if not self._ops:
return
@@ -229,29 +244,47 @@ class NixlEplbCommunicator(EplbCommunicator):
def __init__(
self,
cpu_group: ProcessGroup,
expert_weights: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None = None,
all_expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
) -> None:
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
assert all_expert_weights, (
"NixlEplbCommunicator requires non-empty all_expert_weights."
)
assert expert_buffer, "NixlEplbCommunicator requires non-empty expert_buffer."
nixl_wrapper_cls = nixl_utils.NixlWrapper
if nixl_wrapper_cls is None:
raise RuntimeError("NIXL/ RIXL is unavailable.")
self._cpu_group = cpu_group
self._cuda_stream = cuda_stream
self._world_size = cpu_group.size()
self._rank = cpu_group.rank()
# expert_id -> weight tensors to pack into the send buffer.
self._expert_send_map: dict[int, list[torch.Tensor]] = {}
# src_rank -> expert_id -> weight tensors to unpack after transfer.
self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {}
self._num_local_experts: int = expert_weights[0].shape[0]
self._device = expert_weights[0].device
for tensor in expert_weights:
assert tensor.device == self._device, (
"All local EPLB tensors are expected to be on the same device: "
f"expected={self._device}, got={tensor.device}"
self._all_expert_weights = all_expert_weights
self._expert_buffer = expert_buffer
self._num_local_experts: int = all_expert_weights[0][0].shape[0]
self._device = all_expert_weights[0][0].device
for layer_tensors in all_expert_weights:
for tensor in layer_tensors:
assert is_weak_contiguous(tensor), (
"Expert weight tensors must be contiguous in memory"
)
assert tensor.device == self._device, (
"All local EPLB tensors are expected to be on the same "
f"device: expected={self._device}, got={tensor.device}"
)
for tensor in expert_buffer:
assert is_weak_contiguous(tensor), (
"expert_buffer tensors must be contiguous in memory"
)
# (local_dlist, remote_dlist, xfer_handle) for in-flight READs;
# accumulated by add_recv, drained by execute.
self._xfer_entries: list[tuple[int, int, int]] = []
# Per-rank expert_id -> physical row; set by set_transfer_context.
self._expert_to_src_row: list[dict[int, int]] | None = None
self._layer_idx: int | None = None
nixl_agent_config = nixl_utils.nixl_agent_config
config = (
nixl_agent_config(capture_telemetry=False)
@@ -260,15 +293,16 @@ class NixlEplbCommunicator(EplbCommunicator):
)
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
self._nixl_memory_type = "VRAM"
self._registered_desc: object | None = None
# NIXL registration handles; deregistered in __del__.
self._registered_descs: list[object] = []
self._remote_agents: dict[int, str] = {}
self._remote_send_meta: dict[int, tuple[int, int]] = {}
self._send_buffer: torch.Tensor = torch.empty(0)
self._recv_buffer: torch.Tensor = torch.empty(0)
self._expert_bytes: int = 0
# peer -> (layer, tensor) -> (base_ptr, bytes_per_expert, dev_id).
self._remote_send_meta: dict[
int, dict[tuple[int, int], tuple[int, int, int]]
] = {}
self._cuda_device_id = int(self._device.index or 0)
self._init_step("buffers", self._init_registered_buffers, expert_weights)
self._init_step("buffers", self._init_registered_buffers)
self._init_step("agents", self._init_remote_agents)
self._init_step("send meta", self._exchange_remote_send_meta)
self._log_initialized()
@@ -291,19 +325,34 @@ class NixlEplbCommunicator(EplbCommunicator):
uid = uuid.uuid4().hex[:8]
return f"eplb-{self._rank}{pp_suffix}-{uid}"
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
pass
def add_send(
self,
tensors: list[torch.Tensor],
dst_rank: int,
expert_id: int,
) -> None:
assert dst_rank != self._rank, (
"EPLB communicator should not enqueue same-rank sends: "
f"rank={self._rank}, dst_rank={dst_rank}"
# No-op: NIXL READ is receiver-initiated. The sender's expert
# weights are pre-registered and always readable in-place.
pass
def set_transfer_context(self, old_indices: np.ndarray, layer_idx: int) -> None:
# Pre-compute expert_id -> src_row mapping for every rank so that
# add_recv can immediately issue NIXL READs.
assert not self._xfer_entries, (
f"set_transfer_context() called with {len(self._xfer_entries)} "
f"pending transfers from layer {self._layer_idx}; "
f"execute() was not called after previous add_recv() calls"
)
# An expert sent to multiple peers is packed only once; skip duplicates.
if expert_id not in self._expert_send_map:
self._expert_send_map[expert_id] = tensors
self._layer_idx = layer_idx
n = self._num_local_experts
rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n)
self._expert_to_src_row = [
{int(eid): i for i, eid in enumerate(row) if eid != -1}
for row in rank_experts
]
def add_recv(
self,
@@ -311,13 +360,44 @@ class NixlEplbCommunicator(EplbCommunicator):
src_rank: int,
expert_id: int,
) -> None:
assert src_rank != self._rank, (
"EPLB communicator should not enqueue same-rank recvs: "
f"rank={self._rank}, src_rank={src_rank}"
# Build NIXL descriptors and issue the RDMA READ immediately,
# overlapping the transfer with the remaining Python loop in
# move_to_buffer.
assert self._expert_to_src_row is not None and self._layer_idx is not None, (
"set_transfer_context() must be called before add_recv()"
)
recv_experts = self._recv_map.setdefault(src_rank, {})
if expert_id not in recv_experts:
recv_experts[expert_id] = tensors
src_row = self._expert_to_src_row[src_rank][expert_id]
layer_idx = self._layer_idx
local_descs: list[tuple[int, int, int]] = []
remote_descs: list[tuple[int, int, int]] = []
for t_idx, t in enumerate(tensors):
send_base, send_stride, remote_dev = self._remote_send_meta[src_rank][
(layer_idx, t_idx)
]
assert t.nbytes == send_stride, (
f"tensor {t_idx} size {t.nbytes} != remote stride {send_stride}"
)
local_descs.append(
(
t.data_ptr(),
t.nbytes,
self._cuda_device_id,
)
)
remote_descs.append(
(
send_base + src_row * send_stride,
send_stride,
remote_dev,
)
)
local_h, remote_h, xfer_h = self._create_peer_xfer(
src_rank, local_descs, remote_descs
)
self._nixl_wrapper.transfer(xfer_h)
self._xfer_entries.append((local_h, remote_h, xfer_h))
def _init_remote_agents(self) -> None:
local_metadata = self._nixl_wrapper.get_agent_metadata()
@@ -334,73 +414,60 @@ class NixlEplbCommunicator(EplbCommunicator):
peer_metadata
)
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
total_bytes = max(sum(t.nbytes for t in expert_weights), 1)
assert total_bytes % self._num_local_experts == 0, (
f"Number of bytes in moe layer {total_bytes} is not divisible "
f"by number of local experts {self._num_local_experts}"
)
self._expert_bytes = total_bytes // self._num_local_experts
def _init_registered_buffers(self) -> None:
all_tensors: list[torch.Tensor] = []
for layer_tensors in self._all_expert_weights:
all_tensors.extend(layer_tensors)
all_tensors.extend(self._expert_buffer)
self._send_buffer = torch.empty(
total_bytes, device=self._device, dtype=torch.uint8
)
self._recv_buffer = torch.empty(
total_bytes, device=self._device, dtype=torch.uint8
)
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
descs = self._nixl_wrapper.get_reg_descs(all_tensors)
self._nixl_wrapper.register_memory(descs)
self._registered_desc = descs
self._registered_descs.append(descs)
def _exchange_remote_send_meta(self) -> None:
"""Exchange send-buffer metadata so each rank can build dynamic
descriptors at execute time."""
local_meta: tuple[int, int] = (
self._send_buffer.data_ptr(),
self._cuda_device_id,
)
gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size
"""Exchange per-layer per-tensor metadata so receivers can compute
remote RDMA addresses at transfer time."""
local_meta: dict[tuple[int, int], tuple[int, int, int]] = {}
for layer_idx, layer_tensors in enumerate(self._all_expert_weights):
for t_idx, t in enumerate(layer_tensors):
nbytes_per_expert = t.nbytes // self._num_local_experts
local_meta[(layer_idx, t_idx)] = (
t.data_ptr(),
nbytes_per_expert,
self._cuda_device_id,
)
# Per-rank map: (layer_idx, tensor_idx) -> (base_ptr, bytes_per_expert, dev_id).
# add_recv uses base_ptr + src_row * bytes_per_expert to compute
# the remote RDMA address for each expert.
gathered_meta: list[dict[tuple[int, int], tuple[int, int, int]] | None] = [
None
] * self._world_size
torch.distributed.all_gather_object(
gathered_meta, local_meta, group=self._cpu_group
)
local_keys = set(local_meta.keys())
for peer in self._remote_agents:
peer_meta = gathered_meta[peer]
assert peer_meta is not None
peer_keys = set(peer_meta.keys())
if peer_keys != local_keys:
raise RuntimeError(
f"NIXL EPLB metadata key mismatch with rank {peer}: "
f"local={sorted(local_keys)}, peer={sorted(peer_keys)}"
)
for key in local_keys:
_, local_stride, _ = local_meta[key]
_, peer_stride, _ = peer_meta[key]
if local_stride != peer_stride:
raise RuntimeError(
f"NIXL EPLB nbytes_per_expert mismatch for {key} "
f"with rank {peer}: "
f"local={local_stride}, peer={peer_stride}"
)
self._remote_send_meta[peer] = peer_meta
@staticmethod
def _pack_send_buffer(
in_tensors: list[torch.Tensor],
send_buffer: torch.Tensor,
byte_offset: int,
) -> None:
for tensor in in_tensors:
raw = tensor.reshape(-1).view(torch.uint8)
if raw.numel() == 0:
continue
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
raw, non_blocking=True
)
byte_offset += raw.numel()
@staticmethod
def _unpack_recv_buffer(
recv_buffer: torch.Tensor,
out_tensors: list[torch.Tensor],
byte_offset: int,
) -> None:
for tensor in out_tensors:
num_bytes = tensor.numel() * tensor.element_size()
if num_bytes == 0:
continue
tensor.reshape(-1).view(torch.uint8).copy_(
recv_buffer[byte_offset : byte_offset + num_bytes],
non_blocking=True,
)
byte_offset += num_bytes
def _wait_for_all_transfers(self, handles: list[int]) -> None:
pending = set(handles)
while pending:
@@ -456,110 +523,52 @@ class NixlEplbCommunicator(EplbCommunicator):
)
return (local_handle, remote_handle, xfer_handle)
def execute(self, old_indices: np.ndarray | None = None) -> None:
assert old_indices is not None, (
"NixlEplbCommunicator.execute requires old_indices"
def execute(self) -> None:
assert self._layer_idx is not None or not self._xfer_entries, (
"set_transfer_context() must be called before execute() "
"if any add_recv() calls were made"
)
xfer_entries: list[tuple[int, int, int]] = []
try:
n = self._num_local_experts
rank_experts = old_indices[: self._world_size * n].reshape(
self._world_size, n
)
# Build expert_id -> send slot mapping per rank.
expert_to_send_slot: list[dict[int, int]] = [
{int(eid): i for i, eid in enumerate(row) if eid != -1}
for row in rank_experts
]
self._wait_for_all_transfers([x[2] for x in self._xfer_entries])
# Phase 1: pack each expert at its slot offset in the send buffer.
with torch.cuda.stream(self._cuda_stream):
for expert_id, tensors in self._expert_send_map.items():
slot = expert_to_send_slot[self._rank][expert_id]
byte_offset = slot * self._expert_bytes
self._pack_send_buffer(tensors, self._send_buffer, byte_offset)
# Ensure all packed data is visible in device memory before pulls.
if self._cuda_stream is not None:
self._cuda_stream.synchronize()
else:
torch.cuda.current_stream().synchronize()
# READ is receiver-initiated; synchronize all ranks before transfer.
# We use monitored_barrier so a rank that crashes or exits early
# produces a diagnostic timeout instead of a silent hang.
# Post-READ barrier.
# Correctness fence for zero-copy: prevents overwrite-while-
# remote-read race.
torch.distributed.monitored_barrier(
group=self._cpu_group,
timeout=timedelta(minutes=5),
)
# Phase 2: issue one batched READ per peer.
recv_offsets: dict[tuple[int, int], int] = {}
recv_offset = 0
recv_base = self._recv_buffer.data_ptr()
for src in range(self._world_size):
if src == self._rank:
continue
recv_experts = self._recv_map.get(src)
if not recv_experts:
continue
expert_ids = list(recv_experts.keys())
remote_base, remote_dev = self._remote_send_meta[src]
local_descs: list[tuple[int, int, int]] = []
remote_descs: list[tuple[int, int, int]] = []
for expert_id in expert_ids:
slot = expert_to_send_slot[src][expert_id]
remote_off = slot * self._expert_bytes
recv_offsets[(src, expert_id)] = recv_offset
local_descs.append(
(
recv_base + recv_offset,
self._expert_bytes,
self._cuda_device_id,
)
)
remote_descs.append(
(remote_base + remote_off, self._expert_bytes, remote_dev)
)
recv_offset += self._expert_bytes
assert recv_offset <= self._recv_buffer.nbytes
local_h, remote_h, xfer_h = self._create_peer_xfer(
src, local_descs, remote_descs
)
self._nixl_wrapper.transfer(xfer_h)
xfer_entries.append((local_h, remote_h, xfer_h))
# Phase 3: wait for all in-flight transfers, then unpack.
self._wait_for_all_transfers([x[2] for x in xfer_entries])
with torch.cuda.stream(self._cuda_stream):
for (src, expert_id), offset in recv_offsets.items():
self._unpack_recv_buffer(
self._recv_buffer,
self._recv_map[src][expert_id],
offset,
)
finally:
for local_h, remote_h, xfer_h in xfer_entries:
for local_h, remote_h, xfer_h in self._xfer_entries:
with contextlib.suppress(Exception):
self._nixl_wrapper.release_xfer_handle(xfer_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(local_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(remote_h)
self._expert_send_map.clear()
self._recv_map.clear()
self._xfer_entries.clear()
self._expert_to_src_row = None
self._layer_idx = None
def __del__(self) -> None:
try:
if self._registered_desc is not None:
self._nixl_wrapper.deregister_memory(self._registered_desc)
self._registered_desc = None
with contextlib.suppress(Exception):
for local_h, remote_h, xfer_h in self._xfer_entries:
with contextlib.suppress(Exception):
self._nixl_wrapper.release_xfer_handle(xfer_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(local_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(remote_h)
with contextlib.suppress(Exception):
for descs in self._registered_descs:
with contextlib.suppress(Exception):
self._nixl_wrapper.deregister_memory(descs)
self._registered_descs.clear()
with contextlib.suppress(Exception):
for agent_name in self._remote_agents.values():
self._nixl_wrapper.remove_remote_agent(agent_name)
with contextlib.suppress(Exception):
self._nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear()
except Exception as e:
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
class PyNcclEplbCommunicator(EplbCommunicator):
@@ -600,7 +609,7 @@ class PyNcclEplbCommunicator(EplbCommunicator):
for tensor in tensors:
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
def execute(self, old_indices: np.ndarray | None = None) -> None:
def execute(self) -> None:
if self._group_started:
self._pynccl_comm.group_end()
self._group_started = False
@@ -609,7 +618,8 @@ class PyNcclEplbCommunicator(EplbCommunicator):
def create_eplb_communicator(
group_coordinator: GroupCoordinator,
backend: str | None,
expert_weights: Sequence[torch.Tensor],
expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
) -> EplbCommunicator:
"""Create an EPLB communicator for the given backend.
@@ -624,16 +634,18 @@ def create_eplb_communicator(
``"pynccl"`` in that case. When tensors reside on CPU,
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
process group.
expert_weights: Expert weight tensors from *one* MoE layer.
NixlEplbCommunicator pre-allocates send/recv buffers sized
to this layer, so all other MoE layers must have the same
tensor count, shapes, and dtypes.
expert_weights: Expert weight tensors for *all* MoE layers.
Shape ``(num_layers)(num_tensors_per_layer)``.
NixlEplbCommunicator registers all layers with NIXL for
zero-copy RDMA reads.
expert_buffer: Pre-allocated receive buffer tensors (one per
weight tensor in a single layer).
"""
# Keep a safe default for callers that have not resolved communicator yet.
if backend is None:
backend = "torch_nccl"
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
first_layer = expert_weights[0] if expert_weights else []
tensor_device_type = first_layer[0].device.type if first_layer else "cpu"
torch_group = (
group_coordinator.cpu_group
if tensor_device_type == "cpu"
@@ -649,7 +661,7 @@ def create_eplb_communicator(
unsupported_dtypes = sorted(
{
tensor.dtype
for tensor in expert_weights
for tensor in first_layer
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
},
key=str,
@@ -704,7 +716,8 @@ def create_eplb_communicator(
try:
return NixlEplbCommunicator(
cpu_group=group_coordinator.cpu_group,
expert_weights=expert_weights,
all_expert_weights=expert_weights,
expert_buffer=expert_buffer,
)
except Exception as exc:
raise RuntimeError(
+3 -1
View File
@@ -450,7 +450,8 @@ class EplbState:
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=self.parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0],
expert_weights=model.expert_weights,
expert_buffer=expert_buffer,
)
model_state = EplbModelState(
@@ -766,6 +767,7 @@ class EplbState:
eplb_model_state.physical_to_logical_map,
new_physical_to_logical_map,
eplb_model_state.model.expert_weights,
eplb_model_state.expert_buffer,
ep_group,
eplb_model_state.communicator,
is_profile,
+16 -9
View File
@@ -178,6 +178,7 @@ def move_to_buffer(
cuda_stream: torch.cuda.Stream | None,
ep_rank: int,
communicator: EplbCommunicator,
layer_idx: int = 0,
) -> TransferMetadata:
"""
Rearranges expert weights during EPLB rebalancing.
@@ -193,6 +194,7 @@ def move_to_buffer(
cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_rank: Rank of this process in expert parallel group.
communicator: EplbCommunicator instance for P2P communication.
layer_idx: Index of the MoE layer being transferred.
Returns:
TransferMetadata: Metadata needed for completing remote weight transfers.
@@ -265,6 +267,8 @@ def move_to_buffer(
for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
communicator.set_transfer_context(old_indices, layer_idx)
# 2. Post sends
if send_count > 0:
experts = send_expert_ids[:send_count]
@@ -331,9 +335,8 @@ def move_to_buffer(
expert_id=int(expert),
)
# 4. Execute the P2P operations. The real communication happens here.
communicator.execute(old_indices=old_indices)
# wait for the communication to finish
# 4. Execute transfers and wait for completion.
communicator.execute()
return TransferMetadata(
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
@@ -431,6 +434,7 @@ def transfer_layer(
is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
layer_idx: int = 0,
) -> TransferMetadata:
"""
Rearranges the expert weights in place according to the new expert indices.
@@ -452,6 +456,7 @@ def transfer_layer(
communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
layer_idx: Index of the MoE layer being transferred.
Returns:
TransferMetadata: Metadata needed for completing remote weight transfers,
@@ -499,6 +504,7 @@ def transfer_layer(
cuda_stream=cuda_stream,
ep_rank=ep_group.rank(),
communicator=communicator,
layer_idx=layer_idx,
)
@@ -506,6 +512,7 @@ def rearrange_expert_weights_inplace(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
communicator: EplbCommunicator,
is_profile: bool = False,
@@ -524,6 +531,8 @@ def rearrange_expert_weights_inplace(
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
expert_buffer: Pre-allocated receive buffer tensors (one per
weight tensor in a single layer).
ep_group: The device process group for expert parallelism.
communicator: EplbCommunicator instance for P2P communication.
is_profile (bool): If `True`, do not perform any actual weight copy.
@@ -566,10 +575,10 @@ def rearrange_expert_weights_inplace(
# Reserve NCCL communication buffers via a dummy all_gather.
# Backends that pre-allocate their own transfer buffers
# skip this to avoid the extra memory spike during profiling.
weights_buffer: list[torch.Tensor] = [
profile_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights
]
for weight, buffer in zip(expert_weights[0], weights_buffer):
for weight, buffer in zip(expert_weights[0], profile_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier()
all_gather(
@@ -579,10 +588,7 @@ def rearrange_expert_weights_inplace(
)
return
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
weights_buffer = list(expert_buffer)
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
@@ -597,6 +603,7 @@ def rearrange_expert_weights_inplace(
cuda_stream=None,
ep_rank=ep_rank,
communicator=communicator,
layer_idx=layer_idx,
)
move_from_buffer(
@@ -120,15 +120,12 @@ class KVOutputAggregator:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
assert isinstance(
aggregated_kv_connector_stats, type(kv_connector_stats)
)
aggregated_kv_connector_stats = (
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
assert isinstance(
aggregated_kv_connector_stats, type(kv_connector_stats)
)
aggregated_kv_connector_stats = aggregated_kv_connector_stats.aggregate(
kv_connector_stats
)
# Aggregate kv_connector_worker_meta from all workers.
if aggregated_kv_connector_worker_meta is None:
@@ -2333,9 +2333,25 @@ class NixlConnectorWorker:
for i, remote_group in enumerate(remote_block_ids):
num_local_blocks = len(local_block_ids[i])
num_remote_blocks = len(remote_group)
if _is_ssm_spec(self._group_spec_types[i]):
assert num_local_blocks == num_remote_blocks
if (
_is_ssm_spec(self._group_spec_types[i])
and num_local_blocks < num_remote_blocks
):
# NOTE (NickLucche): With prefix caching on SSM, (remote) blocks
# prior to the last one are placeholders (null blocks). Mind that
# this doesn't really impact transfer, as we only still care about
# the last "block", the full in-place state.
assert num_local_blocks == 1, "SSM can only have one local block"
remote_block_ids[i] = remote_group[-num_local_blocks:]
elif (
self._physical_blocks_per_logical_kv_block
== remote_physical_per_logical
and num_local_blocks < num_remote_blocks
):
# Partial prefix cache hit for FA group.
remote_block_ids[i] = remote_group[-num_local_blocks:]
else:
# TODO Handle prefix caching with different block_sizes
max_padding = max(
self._physical_blocks_per_logical_kv_block,
remote_physical_per_logical,
-38
View File
@@ -1270,9 +1270,6 @@ def get_dcp_group() -> GroupCoordinator:
return _DCP
# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group
_PP: GroupCoordinator | None = None
@@ -1840,31 +1837,6 @@ def model_parallel_is_initialized():
_TP_STATE_PATCHED = False
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
@@ -1875,16 +1847,6 @@ def get_tensor_model_parallel_rank() -> int:
return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size() -> int:
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank() -> int:
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment."""
assert _NODE_COUNT is not None, "distributed environment is not initialized"
+14
View File
@@ -64,6 +64,20 @@ def divide(numerator, denominator):
return numerator // denominator
def is_weak_contiguous(inp: torch.Tensor) -> bool:
"""Check that *inp* occupies a single contiguous block of memory.
Unlike ``torch.Tensor.is_contiguous()``, this also accepts tensors
whose strides are not strictly C-contiguous (e.g. column-major) as
long as the underlying storage from the tensor's offset onward is
exactly ``numel * element_size`` bytes.
"""
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
+2 -2
View File
@@ -17,9 +17,9 @@ from vllm.entrypoints.anthropic.protocol import (
)
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.utils import (
from vllm.entrypoints.serve.utils.api_utils import (
load_aware_call,
validate_json_request,
with_cancellation,
)
from vllm.logger import init_logger
+1 -1
View File
@@ -29,7 +29,6 @@ from vllm.entrypoints.anthropic.protocol import (
AnthropicUsage,
)
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
@@ -45,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
StreamOptions,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
if TYPE_CHECKING:
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
+1 -1
View File
@@ -22,7 +22,7 @@ import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation
from vllm.entrypoints.serve.utils.api_utils import with_cancellation
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
+1 -1
View File
@@ -7,7 +7,7 @@ import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
if typing.TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser

Some files were not shown because too many files have changed in this diff Show More