Merge branch 'main' into wentao-enable-all-dense-for-mrv2

This commit is contained in:
yewentao256
2026-06-04 18:33:39 +00:00
227 changed files with 9771 additions and 3209 deletions
+4 -5
View File
@@ -1299,12 +1299,11 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/entrypoints/llm - tests/entrypoints/llm
- tests/entrypoints/offline_mode
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
- pytest -v -s entrypoints/llm/test_generate.py - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/offline_mode - pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
- label: Entrypoints Integration (Pooling) # TBD - label: Entrypoints Integration (Pooling) # TBD
timeout_in_minutes: 180 timeout_in_minutes: 180
@@ -1346,7 +1345,7 @@ steps:
- vllm/platforms/rocm.py - vllm/platforms/rocm.py
commands: commands:
- pytest -v -s entrypoints/openai/tool_parsers - pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- label: OpenAI API correctness # TBD - label: OpenAI API correctness # TBD
timeout_in_minutes: 180 timeout_in_minutes: 180
+3 -4
View File
@@ -11,7 +11,7 @@ steps:
- tests/entrypoints/ - tests/entrypoints/
commands: commands:
- pytest -v -s entrypoints/openai/tool_parsers - pytest -v -s entrypoints/openai/tool_parsers
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate - pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
- label: Entrypoints Integration (LLM) - label: Entrypoints Integration (LLM)
key: entrypoints-integration-llm key: entrypoints-integration-llm
@@ -20,12 +20,11 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
- tests/entrypoints/llm - tests/entrypoints/llm
- tests/entrypoints/offline_mode
commands: commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
mirror: mirror:
amd: amd:
device: mi325_1 device: mi325_1
-7
View File
@@ -33,10 +33,3 @@ share/python-wheels/
*.egg *.egg
MANIFEST MANIFEST
rust/target/ rust/target/
# Not needed in Docker builds
docs/
.github/
.pre-commit-config.yaml
.clang-format
.gitattributes
format.sh
+2 -1
View File
@@ -34,10 +34,11 @@
/vllm/entrypoints/speech_to_text/realtime @njhill /vllm/entrypoints/speech_to_text/realtime @njhill
/vllm/entrypoints/speech_to_text @NickLucche /vllm/entrypoints/speech_to_text @NickLucche
/vllm/entrypoints/pooling @noooop /vllm/entrypoints/pooling @noooop
/vllm/entrypoints/sagemaker @DarkLight1337 /vllm/entrypoints/serve/sagemaker @DarkLight1337
/vllm/entrypoints/serve @njhill /vllm/entrypoints/serve @njhill
/vllm/entrypoints/*.py @njhill /vllm/entrypoints/*.py @njhill
/vllm/entrypoints/chat_utils.py @DarkLight1337 /vllm/entrypoints/chat_utils.py @DarkLight1337
/vllm/entrypoints/offline_utils.py @DarkLight1337
/vllm/entrypoints/llm.py @DarkLight1337 /vllm/entrypoints/llm.py @DarkLight1337
# Rust Frontend # Rust Frontend
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
actions: write actions: write
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 - uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v10.3.0
with: with:
# Increasing this value ensures that changes to this workflow # Increasing this value ensures that changes to this workflow
# propagate to all issues and PRs in days rather than months # propagate to all issues and PRs in days rather than months
@@ -102,6 +102,35 @@ constexpr float NUM_TOKEN_CUTOFF = 1024;
constexpr int kNumLanes = 32; constexpr int kNumLanes = 32;
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16 constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
// Pack this lane's 16 fp32 elements into per-tensor E4M3 FP8 (one uint4 = 16
// B), scaling by `scale` (a reciprocal scale) and saturating to ±448. Used by
// the FlashInfer full-cache path for both the Q and KV stores.
__device__ __forceinline__ uint4 packFp8E4M3x16(float const* values,
float const scale) {
#ifndef USE_ROCM
uint4 out;
auto* out2 = reinterpret_cast<__nv_fp8x2_storage_t*>(&out);
#pragma unroll
for (int i = 0; i < kElemsPerLane / 2; i++) {
float2 scaled =
make_float2(values[2 * i] * scale, values[2 * i + 1] * scale);
scaled.x = fminf(fmaxf(scaled.x, -kFp8Max), kFp8Max);
scaled.y = fminf(fmaxf(scaled.y, -kFp8Max), kFp8Max);
out2[i] = __nv_cvt_float2_to_fp8x2(scaled, __NV_SATFINITE, __NV_E4M3);
}
return out;
#else
uint8_t out_bytes[kElemsPerLane];
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
float scaled = values[i] * scale;
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled);
}
return *reinterpret_cast<uint4 const*>(out_bytes);
#endif
}
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
// Small inline helpers // Small inline helpers
// ──────────────────────────────────────────────────────────────────────────── // ────────────────────────────────────────────────────────────────────────────
@@ -649,6 +678,257 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
#undef DISPATCH #undef DISPATCH
} }
// ────────────────────────────────────────────────────────────────────────────
// FlashInfer full-cache kernel
// ────────────────────────────────────────────────────────────────────────────
//
// Sibling to the FlashMLA kernel above, used by the FlashInfer V4 sparse-MLA
// backend. Differences from the legacy path:
// * No Q head padding — output Q layout matches the input num_heads_q.
// * KV is written as a *contiguous* 512-wide row per token (token-strided),
// not the legacy UE8M0 paged layout with a separate scale tail.
// * Q/KV are stored either as bf16 or as per-tensor E4M3 FP8 (one global
// scale), selected by the STORE_Q_FP8 / STORE_KV_FP8 template flags.
//
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) / warps).
// Each warp handles one (token, slot): slot < num_heads_q → Q, slot ==
// num_heads_q → KV.
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
__global__ void fusedDeepseekV4FullCacheKernel(
scalar_t_in* __restrict__ q_inout, // [N, H, 512], in place (bf16)
uint8_t* __restrict__ q_fp8_out, // [N, H, 512] fp8, optional
int64_t const q_fp8_stride0, // elements (fp8 == bytes)
int64_t const q_fp8_stride1, // elements (fp8 == bytes)
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
uint8_t* __restrict__ k_cache, // contiguous bf16 or fp8 cache
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
int64_t const* __restrict__ position_ids, // [N] i64
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
float const* __restrict__ fp8_scale_ptr, // scalar, KV fp8 only
float const* __restrict__ q_fp8_scale_inv, // scalar, Q fp8 only
float const eps,
int const num_tokens_full, // = q.size(0) = kv.size(0)
int const num_tokens_insert, // = slot_mapping.size(0)
int const num_heads_q, // H (no padding)
int const cache_block_size, // tokens per cache block
int64_t const kv_block_stride, // bytes per cache block
int64_t const kv_token_stride) { // bytes per cache token
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
return;
} else {
#endif
using Converter = vllm::_typeConvert<scalar_t_in>;
int const warpsPerBlock = blockDim.x / 32;
int const warpId = threadIdx.x / 32;
int const laneId = threadIdx.x % 32;
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
int const slotsPerToken = num_heads_q + 1;
int const tokenIdx = globalWarpIdx / slotsPerToken;
int const slotIdx = globalWarpIdx % slotsPerToken;
if (tokenIdx >= num_tokens_full) return;
bool const isKV = (slotIdx == num_heads_q);
// KV branch: skip DP-padded tokens (no slot reserved for them).
if (isKV && tokenIdx >= num_tokens_insert) return;
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaGridDependencySynchronize();
#endif
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
scalar_t_in const* src_ptr;
if (isKV) {
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
} else {
src_ptr = q_inout +
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) *
kHeadDim +
dim_base;
}
uint4 const v0 = *reinterpret_cast<uint4 const*>(src_ptr);
uint4 const v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
// ── Decode bf16 → 16 fp32 registers ───────────────────────────────────
float elements[kElemsPerLane];
{
auto const* p0 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
auto const* p1 =
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p0[i]);
elements[2 * i] = f2.x;
elements[2 * i + 1] = f2.y;
}
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 f2 = Converter::convert(p1[i]);
elements[8 + 2 * i] = f2.x;
elements[8 + 2 * i + 1] = f2.y;
}
}
// ── Q branch: RMSNorm (no weight) ─────────────────────────────────────
if (!isKV) {
float sumOfSquares = 0.0f;
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
sumOfSquares += elements[i] * elements[i];
}
sumOfSquares = warpSum<float>(sumOfSquares);
float const rms_rcp =
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
#pragma unroll
for (int i = 0; i < kElemsPerLane; i++) {
elements[i] = elements[i] * rms_rcp;
}
}
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────
bool const is_rope_lane = dim_base >= kNopeDim;
if (is_rope_lane) {
int64_t const pos = position_ids[tokenIdx];
constexpr int kHalfRope = kRopeDim / 2;
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
float const* sin_ptr = cos_ptr + kHalfRope;
int const rope_local_base = dim_base - kNopeDim;
int const half_base = rope_local_base >> 1;
float4 const c0 = *reinterpret_cast<float4 const*>(cos_ptr + half_base);
float4 const c1 = *reinterpret_cast<float4 const*>(cos_ptr + half_base + 4);
float4 const s0 = *reinterpret_cast<float4 const*>(sin_ptr + half_base);
float4 const s1 = *reinterpret_cast<float4 const*>(sin_ptr + half_base + 4);
float const cos_arr[8] = {c0.x, c0.y, c0.z, c0.w, c1.x, c1.y, c1.z, c1.w};
float const sin_arr[8] = {s0.x, s0.y, s0.z, s0.w, s1.x, s1.y, s1.z, s1.w};
#pragma unroll
for (int p = 0; p < kElemsPerLane / 2; p++) {
float const x_even = elements[2 * p];
float const x_odd = elements[2 * p + 1];
elements[2 * p] = x_even * cos_arr[p] - x_odd * sin_arr[p];
elements[2 * p + 1] = x_even * sin_arr[p] + x_odd * cos_arr[p];
}
}
// ── Store ─────────────────────────────────────────────────────────────
if (!isKV) {
if constexpr (STORE_Q_FP8) {
float const scale_inv = VLLM_LDG(q_fp8_scale_inv);
uint4 const out = packFp8E4M3x16(elements, scale_inv);
uint8_t* dst = q_fp8_out +
static_cast<int64_t>(tokenIdx) * q_fp8_stride0 +
static_cast<int64_t>(slotIdx) * q_fp8_stride1 + dim_base;
*reinterpret_cast<uint4*>(dst) = out;
} else {
uint4 out0, out1;
auto* po0 = reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
auto* po1 = reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
scalar_t_in* dst =
q_inout +
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
dim_base;
*reinterpret_cast<uint4*>(dst) = out0;
*reinterpret_cast<uint4*>(dst + 8) = out1;
}
} else {
int64_t const slot_id = slot_mapping[tokenIdx];
if (slot_id >= 0) {
int64_t const block_idx = slot_id / cache_block_size;
int64_t const pos_in_block = slot_id % cache_block_size;
uint8_t* cache_row =
k_cache + block_idx * kv_block_stride + pos_in_block * kv_token_stride;
if constexpr (STORE_KV_FP8) {
float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr);
uint4 const out = packFp8E4M3x16(elements, inv_scale);
*reinterpret_cast<uint4*>(cache_row + dim_base) = out;
} else {
uint4 out0, out1;
auto* po0 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
auto* po1 =
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
#pragma unroll
for (int i = 0; i < 4; i++) {
po0[i] = Converter::convert(
make_float2(elements[2 * i], elements[2 * i + 1]));
}
#pragma unroll
for (int i = 0; i < 4; i++) {
po1[i] = Converter::convert(
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
}
scalar_t_in* dst = reinterpret_cast<scalar_t_in*>(cache_row) + dim_base;
*reinterpret_cast<uint4*>(dst) = out0;
*reinterpret_cast<uint4*>(dst + 8) = out1;
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
}
#endif
}
// Configure + launch helper shared by the bf16 and fp8 full-cache launchers.
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
static void launchFullCacheKernel(
scalar_t_in* q_inout, uint8_t* q_fp8_out, int64_t q_fp8_stride0,
int64_t q_fp8_stride1, scalar_t_in const* kv_in, uint8_t* k_cache,
int64_t const* slot_mapping, int64_t const* position_ids,
float const* cos_sin_cache, float const* fp8_scale,
float const* q_fp8_scale_inv, float const eps, int const num_tokens_full,
int const num_tokens_insert, int const num_heads_q,
int const cache_block_size, int64_t const kv_block_stride,
int64_t const kv_token_stride, char const* op_name, cudaStream_t stream) {
constexpr int kBlockSize = 256;
constexpr int kWarpsPerBlock = kBlockSize / 32;
int64_t const total_warps =
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
int const grid =
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
auto* kernel =
fusedDeepseekV4FullCacheKernel<scalar_t_in, STORE_Q_FP8, STORE_KV_FP8>;
#ifndef USE_ROCM
static int const sm_version = getSMVersion();
STD_TORCH_CHECK(sm_version >= 80, op_name,
" requires sm_80+ (Ampere or newer); got sm_", sm_version);
cudaLaunchConfig_t config;
config.gridDim = dim3(grid);
config.blockDim = dim3(kBlockSize);
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attrs;
config.numAttrs = (sm_version >= 90) ? 1 : 0;
cudaLaunchKernelEx(&config, kernel, q_inout, q_fp8_out, q_fp8_stride0,
q_fp8_stride1, kv_in, k_cache, slot_mapping, position_ids,
cos_sin_cache, fp8_scale, q_fp8_scale_inv, eps,
num_tokens_full, num_tokens_insert, num_heads_q,
cache_block_size, kv_block_stride, kv_token_stride);
#else
kernel<<<grid, kBlockSize, 0, stream>>>(
q_inout, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache,
slot_mapping, position_ids, cos_sin_cache, fp8_scale, q_fp8_scale_inv,
eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
kv_block_stride, kv_token_stride);
#endif
}
} // namespace deepseek_v4_fused_ops } // namespace deepseek_v4_fused_ops
} // namespace vllm } // namespace vllm
@@ -735,3 +1015,167 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
}); });
return q_out; return q_out;
} }
// ────────────────────────────────────────────────────────────────────────────
// FlashInfer full-cache torch ops
// ────────────────────────────────────────────────────────────────────────────
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
torch::stable::Tensor& q, // [N, H, 512] bf16, in place
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] bf16
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
double eps, int64_t cache_block_size) {
using torch::headeronly::ScalarType;
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
"q must be contiguous CUDA");
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() == ScalarType::Long,
"slot_mapping must be int64 CUDA");
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() == ScalarType::Long,
"position_ids must be int64 CUDA");
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
cos_sin_cache.scalar_type() == ScalarType::Float &&
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64] float32");
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q.scalar_type() == ScalarType::BFloat16 &&
kv.scalar_type() == ScalarType::BFloat16,
"q and kv must be bfloat16");
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::BFloat16,
"k_cache must be bfloat16");
int const num_tokens_full = static_cast<int>(q.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q.size(1));
const torch::stable::accelerator::DeviceGuard device_guard(
q.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
// bf16 cache: 2 bytes/element -> byte strides for the uint8-addressed kernel.
int64_t const kv_block_stride = k_cache.stride(0) * 2;
int64_t const kv_token_stride = k_cache.stride(1) * 2;
VLLM_STABLE_DISPATCH_HALF_TYPES(
q.scalar_type(),
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert", [&] {
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, false,
false>(
reinterpret_cast<scalar_t*>(q.mutable_data_ptr()), nullptr, 0, 0,
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(), nullptr, nullptr,
static_cast<float>(eps), num_tokens_full, num_tokens_insert,
num_heads_q, static_cast<int>(cache_block_size), kv_block_stride,
kv_token_stride,
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
stream);
});
}
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
torch::stable::Tensor const& q, // [N, H, 512] bf16, read-only
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
torch::stable::Tensor& q_fp8, // [N, H, 512] fp8 e4m3
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] fp8
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
torch::stable::Tensor const& position_ids, // [N] int64
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
torch::stable::Tensor const& fp8_scale, // scalar float32 (KV scale)
torch::stable::Tensor const& q_fp8_scale_inv, // scalar float32 (1 / Q scale)
double eps, int64_t cache_block_size) {
using torch::headeronly::ScalarType;
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
"q must be contiguous CUDA");
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
"kv must be contiguous CUDA");
STD_TORCH_CHECK(q_fp8.device().is_cuda() && q_fp8.is_contiguous() &&
q_fp8.scalar_type() == ScalarType::Float8_e4m3fn &&
q_fp8.dim() == 3 && q_fp8.size(0) == q.size(0) &&
q_fp8.size(1) == q.size(1) && q_fp8.size(2) == q.size(2),
"q_fp8 must be a contiguous float8_e4m3fn tensor matching q");
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
slot_mapping.scalar_type() == ScalarType::Long,
"slot_mapping must be int64 CUDA");
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
position_ids.scalar_type() == ScalarType::Long,
"position_ids must be int64 CUDA");
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
cos_sin_cache.scalar_type() == ScalarType::Float &&
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
"cos_sin_cache shape [max_pos, 64] float32");
STD_TORCH_CHECK(fp8_scale.device().is_cuda() &&
fp8_scale.scalar_type() == ScalarType::Float &&
fp8_scale.size(0) == 1,
"fp8_scale must be a scalar float32 CUDA tensor");
STD_TORCH_CHECK(q_fp8_scale_inv.device().is_cuda() &&
q_fp8_scale_inv.scalar_type() == ScalarType::Float &&
q_fp8_scale_inv.size(0) == 1,
"q_fp8_scale_inv must be a scalar float32 CUDA tensor");
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
STD_TORCH_CHECK(q.scalar_type() == kv.scalar_type(),
"q and kv dtype must match");
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::Float8_e4m3fn,
"k_cache must be float8_e4m3fn");
int const num_tokens_full = static_cast<int>(q.size(0));
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
static_cast<int>(position_ids.size(0)) == num_tokens_full,
"q/kv/position_ids row counts must match");
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
"slot_mapping must not exceed q row count");
int const num_heads_q = static_cast<int>(q.size(1));
const torch::stable::accelerator::DeviceGuard device_guard(
q.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
VLLM_STABLE_DISPATCH_HALF_TYPES(
q.scalar_type(),
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert", [&] {
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, true,
true>(
// q is read-only in the fp8 path (the kernel writes q_fp8); the
// launcher signature is non-const, so cast away const on the ptr.
reinterpret_cast<scalar_t*>(
const_cast<void*>(q.const_data_ptr())),
reinterpret_cast<uint8_t*>(q_fp8.mutable_data_ptr()),
q_fp8.stride(0), q_fp8.stride(1),
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
slot_mapping.const_data_ptr<int64_t>(),
position_ids.const_data_ptr<int64_t>(),
cos_sin_cache.const_data_ptr<float>(),
fp8_scale.const_data_ptr<float>(),
q_fp8_scale_inv.const_data_ptr<float>(), static_cast<float>(eps),
num_tokens_full, num_tokens_insert, num_heads_q,
static_cast<int>(cache_block_size),
// fp8 cache: 1 byte/element -> stride already in bytes.
k_cache.stride(0), k_cache.stride(1),
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
stream);
});
}
+17
View File
@@ -238,6 +238,23 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded, torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
double eps, int64_t cache_block_size); double eps, int64_t cache_block_size);
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
torch::stable::Tensor& q, torch::stable::Tensor const& kv,
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache, double eps,
int64_t cache_block_size);
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
torch::stable::Tensor const& q, torch::stable::Tensor const& kv,
torch::stable::Tensor& q_fp8, torch::stable::Tensor& k_cache,
torch::stable::Tensor const& slot_mapping,
torch::stable::Tensor const& position_ids,
torch::stable::Tensor const& cos_sin_cache,
torch::stable::Tensor const& fp8_scale,
torch::stable::Tensor const& q_fp8_scale_inv, double eps,
int64_t cache_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::stable::Tensor minimax_allreduce_rms( torch::stable::Tensor minimax_allreduce_rms(
torch::stable::Tensor const& input, torch::stable::Tensor const& input,
+20
View File
@@ -343,6 +343,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, " "Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"int q_head_padded, float eps, int cache_block_size) -> Tensor"); "int q_head_padded, float eps, int cache_block_size) -> Tensor");
// FlashInfer V4 full-cache variants: write Q in place (bf16) or to a separate
// FP8 tensor, and KV into a contiguous 512-wide token-strided cache.
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert("
"Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, "
"Tensor position_ids, Tensor cos_sin_cache, float eps, "
"int cache_block_size) -> ()");
ops.def(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert("
"Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, "
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
"Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, "
"int cache_block_size) -> ()");
#ifndef USE_ROCM #ifndef USE_ROCM
ops.def( ops.def(
"minimax_allreduce_rms(" "minimax_allreduce_rms("
@@ -591,6 +605,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope)); ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert)); TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
ops.impl(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert));
ops.impl(
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert));
#ifndef USE_ROCM #ifndef USE_ROCM
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms)); ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk)); ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
+2 -1
View File
@@ -55,7 +55,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and // Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one // GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch. Registered in _C_stable_libtorch. // kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
// full-cache bf16/fp8 variants).
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
-1
View File
@@ -98,7 +98,6 @@ RUN if [ "$USE_SCCACHE" = "1" ]; then \
ARG USE_SCCACHE ARG USE_SCCACHE
ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}} ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}}
ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}} ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}}
ENV SCCACHE_ENDPOINT=${USE_SCCACHE:+${SCCACHE_ENDPOINT}}
ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}} ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}}
ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0} ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0}
+14
View File
@@ -228,3 +228,17 @@ MLA decode backends are selected using the standard
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | | `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
### DeepSeek V4 Decode Backends
DeepSeek V4 sparse MLA uses its own decode backends, selected via
`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,
`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index
pipeline (compressor + SWA + indexer, 256-token blocks, head 512);
default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Non-Causal | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
| ------- | ------ | --------- | ----------- | ---------- | ---- | ---------- | ------ | --------- | --- | --------------- | ------------ |
| `FLASHINFER_MLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
| `FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | 256 | 512 | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
| `ROCM_FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
+1
View File
@@ -82,6 +82,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
| Architecture | Models | CG for Image | CG for Video | | Architecture | Models | CG for Image | CG for Video |
| ------------ | ------ | ------------ | ------------ | | ------------ | ------ | ------------ | ------------ |
| `InternVLChatModel` | `InternVL3.5`, `InternVL3`, `InternVL2.5`, `InternVL2` | ✅︎ | ✅︎ |
| `Qwen2VLForConditionalGeneration` | `Qwen2-VL` | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | `Qwen2-VL` | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ | | `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
+17 -14
View File
@@ -3,7 +3,7 @@
Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices. Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices.
!!! tip !!! tip
To get started with quantization, see [LLM Compressor](llm_compressor.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats. To get started with quantization, see [LLM Compressor](llm_compressor/README.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
The following are the supported quantization formats for vLLM: The following are the supported quantization formats for vLLM:
@@ -12,9 +12,11 @@ The following are the supported quantization formats for vLLM:
- [GGUF](gguf.md) - [GGUF](gguf.md)
- [GPTQModel](gptqmodel.md) - [GPTQModel](gptqmodel.md)
- [Intel Neural Compressor](inc.md) - [Intel Neural Compressor](inc.md)
- [INT4 W4A16](int4.md) - [LLM Compressor](llm_compressor/README.md)
- [INT8 W8A8](int8.md) - [FP8 W8A8](llm_compressor/fp8.md)
- [FP8 W8A8](fp8.md) - [INT4 W4A16](llm_compressor/int4.md)
- [INT8 W4A8](llm_compressor/int8_w4a8.md)
- [INT8 W8A8](llm_compressor/int8_w8a8.md)
- [NVIDIA Model Optimizer](modelopt.md) - [NVIDIA Model Optimizer](modelopt.md)
- [Online Quantization](online.md) - [Online Quantization](online.md)
- [AMD Quark](quark.md) - [AMD Quark](quark.md)
@@ -46,16 +48,17 @@ th:not(:first-child) {
} }
</style> </style>
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | | Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | Arm CPU |
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- | | ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- | ------- |
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | | AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | | GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | | Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | | llm-compressor INT8 (W8A8)| ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | | llm-compressor INT8 (W4A8)| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | ❌ | ✅︎ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | | llm-compressor FP8 (W8A8) | ❌ | ❌ | | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | | bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | | DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- ✅︎ indicates that the quantization method is supported on the specified hardware. - ✅︎ indicates that the quantization method is supported on the specified hardware.
@@ -21,9 +21,17 @@ The FP8 types typically supported in hardware have two distinct representations,
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
```bash ```bash
pip install llmcompressor (venv-llm-compressor) pip install llmcompressor
``` ```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process ## Quantization Process
The quantization process involves three main steps: The quantization process involves three main steps:
@@ -57,36 +65,28 @@ For FP8 quantization, we can recover accuracy with simple RTN quantization. We r
Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow. Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
??? code ```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
```python # Configure the simple PTQ quantization
from llmcompressor import oneshot recipe = QuantizationModifier(
from llmcompressor.modifiers.quantization import QuantizationModifier targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["lm_head"],
)
# Configure the simple PTQ quantization # Apply the quantization algorithm.
recipe = QuantizationModifier( oneshot(model=model, recipe=recipe)
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["lm_head"],
)
# Apply the quantization algorithm. # Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic
oneshot(model=model, recipe=recipe) SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR)
# Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic tokenizer.save_pretrained(SAVE_DIR)
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" ```
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
```
### 3. Evaluating Accuracy ### 3. Evaluating Accuracy
Install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
pip install vllm "lm-eval[api]>=0.4.12"
```
Load and run the model in `vllm`: Load and run the model in `vllm`:
```python ```python
@@ -12,15 +12,17 @@ Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs re
To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
```bash ```bash
pip install llmcompressor (venv-llm-compressor) pip install llmcompressor
``` ```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash ```bash
pip install vllm "lm-eval[api]>=0.4.12" (venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
``` ```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process ## Quantization Process
The quantization process involves four main steps: The quantization process involves four main steps:
@@ -52,55 +54,51 @@ When quantizing weights to INT4, you need sample data to estimate the weight upd
It's best to use calibration data that closely matches your deployment data. It's best to use calibration data that closely matches your deployment data.
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`: For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
??? code ```python
from datasets import load_dataset
```python NUM_CALIBRATION_SAMPLES = 512
from datasets import load_dataset MAX_SEQUENCE_LENGTH = 2048
NUM_CALIBRATION_SAMPLES = 512 # Load and preprocess the dataset
MAX_SEQUENCE_LENGTH = 2048 ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
# Load and preprocess the dataset def preprocess(example):
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) ds = ds.map(preprocess)
def preprocess(example): def tokenize(sample):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)} return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(preprocess) ds = ds.map(tokenize, remove_columns=ds.column_names)
```
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
### 3. Applying Quantization ### 3. Applying Quantization
Now, apply the quantization algorithms: Now, apply the quantization algorithms:
??? code ```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
```python # Configure the quantization algorithms
from llmcompressor import oneshot recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
# Configure the quantization algorithms # Apply quantization
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Apply quantization # Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128
oneshot( SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model=model, model.save_pretrained(SAVE_DIR, save_compressed=True)
dataset=ds, tokenizer.save_pretrained(SAVE_DIR)
recipe=recipe, ```
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
This process creates a W4A16 model with weights quantized to 4-bit integers. This process creates a W4A16 model with weights quantized to 4-bit integers.
@@ -141,36 +139,34 @@ lm_eval --model vllm \
The following is an example of an expanded quantization recipe you can tune to your own use case: The following is an example of an expanded quantization recipe you can tune to your own use case:
??? code ```python
from compressed_tensors.quantization import (
```python QuantizationArgs,
from compressed_tensors.quantization import ( QuantizationScheme,
QuantizationArgs, QuantizationStrategy,
QuantizationScheme, QuantizationType,
QuantizationStrategy, )
QuantizationType, recipe = GPTQModifier(
) targets="Linear",
recipe = GPTQModifier( config_groups={
targets="Linear", "config_group": QuantizationScheme(
config_groups={ targets=["Linear"],
"config_group": QuantizationScheme( weights=QuantizationArgs(
targets=["Linear"], num_bits=4,
weights=QuantizationArgs( type=QuantizationType.INT,
num_bits=4, strategy=QuantizationStrategy.GROUP,
type=QuantizationType.INT, group_size=128,
strategy=QuantizationStrategy.GROUP, symmetric=True,
group_size=128, dynamic=False,
symmetric=True, actorder="weight",
dynamic=False,
actorder="weight",
),
), ),
}, ),
ignore=["lm_head"], },
update_size=NUM_CALIBRATION_SAMPLES, ignore=["lm_head"],
dampening_frac=0.01, update_size=NUM_CALIBRATION_SAMPLES,
) dampening_frac=0.01,
``` )
```
## Troubleshooting and Support ## Troubleshooting and Support
@@ -0,0 +1,217 @@
# INT8 W4A8
vLLM supports quantizing weights to INT4 and activations to INT8 for memory savings and inference acceleration.
This quantization method is particularly useful for reducing model size while maintaining good performance.
## Prerequisites
To use INT8 W4A8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library.
```bash
(venv-llm-compressor) pip install llmcompressor
```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process
The quantization process involves four main steps:
1. Loading the model
2. Preparing calibration data
3. Applying quantization
4. Evaluating accuracy in vLLM
### 1. Loading the Model
Load your model and tokenizer using the standard `transformers` AutoModel classes:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
```
### 2. Preparing Calibration Data
When quantizing activations to INT8 and weights to INT4, you need sample data to estimate the activation scales.
It's best to use calibration data that closely matches your deployment data.
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
```python
from datasets import load_dataset
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Load and preprocess the dataset
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
def preprocess(example):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.map(preprocess)
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
### 3. Applying Quantization
Now, apply the quantization algorithms.
The following recipes create W4A8 models (int4 weights, int8 activations). On Arm® CPUs, this is accelerated through [KleidiAI](https://github.com/ARM-software/kleidiai).
Use groupwise for best accuracy, and channelwise for best inference performance.
=== "Groupwise"
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
# Configure the quantization algorithms
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A8",
ignore=["lm_head"],
dampening_frac=0.01
),
]
# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-G128-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
=== "Channelwise"
```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
scheme = {
"targets": ["Linear"],
"weights": {
"num_bits": 4,
"type": QuantizationType.INT,
"strategy": QuantizationStrategy.CHANNEL,
"symmetric": True,
"dynamic": False,
"group_size": None,
},
"input_activations": {
"num_bits": 8,
"type": QuantizationType.INT,
"strategy": QuantizationStrategy.TOKEN,
"dynamic": True,
"symmetric": False,
"observer": None,
},
"output_activations": None,
}
recipe = [
GPTQModifier(
targets="Linear",
config_groups={"group_0": scheme},
ignore=["lm_head"],
dampening_frac=0.01,
),
]
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-Channelwise-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
### 4. Evaluating Accuracy
=== "Groupwise"
After quantization, you can load and run the model in vLLM:
```python
from vllm import LLM
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token")
```
To evaluate accuracy, you can use `lm_eval`:
```bash
lm_eval --model vllm \
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
--limit 250 \
--batch_size 'auto'
```
=== "Channelwise"
After quantization, you can load and run the model in vLLM:
```python
from vllm import LLM
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token")
```
To evaluate accuracy, you can use `lm_eval`:
```bash
lm_eval --model vllm \
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token",add_bos_token=true \
--tasks gsm8k \
--num_fewshot 5 \
--limit 250 \
--batch_size 'auto'
```
!!! note
Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations.
## Best Practices
- Start with 512 samples for calibration data (increase if accuracy drops)
- Use a sequence length of 2048 as a starting point
- Employ the chat template or instruction template that the model was trained with
- If you've fine-tuned a model, consider using a sample of your training data for calibration
## Troubleshooting and Support
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository.
@@ -17,15 +17,17 @@ Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs re
To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
```bash ```bash
pip install llmcompressor (venv-llm-compressor) pip install llmcompressor
``` ```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```bash ```bash
pip install vllm "lm-eval[api]>=0.4.12" (venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
``` ```
Please use separate environments for vLLM and llm-compressor as they might not work together.
## Quantization Process ## Quantization Process
The quantization process involves four main steps: The quantization process involves four main steps:
@@ -57,26 +59,24 @@ When quantizing activations to INT8, you need sample data to estimate the activa
It's best to use calibration data that closely matches your deployment data. It's best to use calibration data that closely matches your deployment data.
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`: For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
??? code ```python
from datasets import load_dataset
```python NUM_CALIBRATION_SAMPLES = 512
from datasets import load_dataset MAX_SEQUENCE_LENGTH = 2048
NUM_CALIBRATION_SAMPLES = 512 # Load and preprocess the dataset
MAX_SEQUENCE_LENGTH = 2048 ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
# Load and preprocess the dataset def preprocess(example):
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) ds = ds.map(preprocess)
def preprocess(example): def tokenize(sample):
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)} return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(preprocess) ds = ds.map(tokenize, remove_columns=ds.column_names)
```
def tokenize(sample):
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
ds = ds.map(tokenize, remove_columns=ds.column_names)
```
</details> </details>
@@ -84,33 +84,31 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra
Now, apply the quantization algorithms: Now, apply the quantization algorithms:
??? code ```python
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
```python # Configure the quantization algorithms
from llmcompressor import oneshot recipe = [
from llmcompressor.modifiers.quantization import GPTQModifier SmoothQuantModifier(smoothing_strength=0.8),
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
]
# Configure the quantization algorithms # Apply quantization
recipe = [ oneshot(
SmoothQuantModifier(smoothing_strength=0.8), model=model,
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]), dataset=ds,
] recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Apply quantization # Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token
oneshot( SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model=model, model.save_pretrained(SAVE_DIR, save_compressed=True)
dataset=ds, tokenizer.save_pretrained(SAVE_DIR)
recipe=recipe, ```
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
This process creates a W8A8 model with weights and activations quantized to 8-bit integers. This process creates a W8A8 model with weights and activations quantized to 8-bit integers.
+2
View File
@@ -569,6 +569,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ | | `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
| `Granite4VisionForConditionalGeneration` | Granite 4 Vision | T + I<sup>E+</sup> | `ibm-granite/granite-4.1-3b-vision`, etc. | ✅︎ | ✅︎ | | `Granite4VisionForConditionalGeneration` | Granite 4 Vision | T + I<sup>E+</sup> | `ibm-granite/granite-4.1-3b-vision`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | T + A | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
| `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | | | `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ |
@@ -709,6 +710,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ | | `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ | | `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ | | `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
+7 -1
View File
@@ -24,8 +24,14 @@ echo "Checking pre-commit/pre-run-check status..."
MAX_WAIT=300 MAX_WAIT=300
INTERVAL=60 INTERVAL=60
ELAPSED=0 ELAPSED=0
# Use a GitHub token if provided to raise the API rate limit (60 -> 5000
# requests/hour). Set GITHUB_TOKEN in the Read the Docs environment variables.
CURL_AUTH=()
if [ -n "$GITHUB_TOKEN" ]; then
CURL_AUTH=(-H "Authorization: Bearer $GITHUB_TOKEN")
fi
while :; do while :; do
RAW=$(curl -sS -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest") RAW=$(curl -sS "${CURL_AUTH[@]}" -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
HTTP_CODE=$(printf %s "$RAW" | tail -n1) HTTP_CODE=$(printf %s "$RAW" | tail -n1)
BODY=$(printf %s "$RAW" | sed '$d') BODY=$(printf %s "$RAW" | sed '$d')
if [ "$HTTP_CODE" != "200" ]; then if [ "$HTTP_CODE" != "200" ]; then
@@ -2554,6 +2554,7 @@ MODELS_NEED_VIDEO_METADATA = [
MODELS_SUPPORT_VIT_CUDA_GRAPH = [ MODELS_SUPPORT_VIT_CUDA_GRAPH = [
"internvl_chat",
"qwen2_5_vl", "qwen2_5_vl",
"qwen3_vl", "qwen3_vl",
"qwen3_vl_moe", "qwen3_vl_moe",
+3
View File
@@ -110,6 +110,9 @@ plugins:
redirect_maps: redirect_maps:
features/spec_decode/README.md: features/speculative_decoding/README.md features/spec_decode/README.md: features/speculative_decoding/README.md
features/spec_decode/speculators.md: features/speculative_decoding/speculators.md features/spec_decode/speculators.md: features/speculative_decoding/speculators.md
features/quantization/fp8.md: features/quantization/llm_compressor/fp8.md
features/quantization/int4.md: features/quantization/llm_compressor/int4.md
features/quantization/int8.md: features/quantization/llm_compressor/int8_w8a8.md
serving/openai_compatible_server.md: serving/online_serving/README.md serving/openai_compatible_server.md: serving/online_serving/README.md
markdown_extensions: markdown_extensions:
+1 -1
View File
@@ -38,7 +38,7 @@ pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
einops # Required for Qwen2-VL. einops # Required for Qwen2-VL.
compressed-tensors == 0.15.0.1 # required for compressed-tensors compressed-tensors == 0.17.0 # required for compressed-tensors
depyf==0.20.0 # required for profiling and debugging with compilation config depyf==0.20.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files watchfiles # required for http server to monitor the updates of TLS files
+2 -2
View File
@@ -18,7 +18,7 @@ tilelang==0.1.9
nvidia-cudnn-frontend>=1.13.0,<1.19.0 nvidia-cudnn-frontend>=1.13.0,<1.19.0
# Required for faster safetensors model loading # Required for faster safetensors model loading
fastsafetensors >= 0.2.2 fastsafetensors >= 0.3.2
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation) # QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
nvidia-cutlass-dsl[cu13]==4.5.2 nvidia-cutlass-dsl[cu13]==4.5.2
@@ -28,4 +28,4 @@ quack-kernels>=0.3.3
tokenspeed-mla==0.1.2 tokenspeed-mla==0.1.2
# Humming kernels for quantization gemm # Humming kernels for quantization gemm
humming-kernels[cu13]==0.1.2 humming-kernels[cu13]==0.1.4
+4 -1
View File
@@ -19,7 +19,10 @@ setuptools-rust>=1.9.0
runai-model-streamer[s3,gcs,azure]==0.15.7 runai-model-streamer[s3,gcs,azure]==0.15.7
conch-triton-kernels==1.2.1 conch-triton-kernels==1.2.1
timm>=1.0.17 timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm # amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py # To be consistent with test_quark.py
amd-quark>=0.8.99 amd-quark>=0.8.99
tilelang==0.1.10 tilelang==0.1.10
# Required for faster safetensors model loading
fastsafetensors >= 0.3.2
+1 -1
View File
@@ -57,7 +57,7 @@ arctic-inference == 0.1.1; platform_machine == "x86_64" # Required for suffix de
numba == 0.65.0 # Required for N-gram speculative decoding numba == 0.65.0 # Required for N-gram speculative decoding
numpy numpy
runai-model-streamer[s3,gcs,azure]==0.15.7 runai-model-streamer[s3,gcs,azure]==0.15.7
fastsafetensors>=0.2.2; platform_machine == "x86_64" # 0.2.2 contains important fixes for multi-GPU mem usage fastsafetensors>=0.3.2
instanttensor>=0.1.5; platform_machine == "x86_64" instanttensor>=0.1.5; platform_machine == "x86_64"
pydantic>=2.12 # 2.11 leads to error on python 3.13 pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0; platform_machine == "x86_64" decord==0.6.0; platform_machine == "x86_64"
+1 -1
View File
@@ -191,7 +191,7 @@ fastparquet==2024.11.0
# via genai-perf # via genai-perf
fastrlock==0.8.2 fastrlock==0.8.2
# via cupy-cuda12x # via cupy-cuda12x
fastsafetensors==0.2.2 fastsafetensors==0.3.2
# via # via
# -c requirements/cuda.txt # -c requirements/cuda.txt
# -r requirements/test/cuda.in # -r requirements/test/cuda.in
+1 -1
View File
@@ -43,6 +43,6 @@ tritonclient>=2.51.0
numba == 0.65.0 # Required for N-gram speculative decoding numba == 0.65.0 # Required for N-gram speculative decoding
numpy numpy
runai-model-streamer[s3,gcs,azure]==0.15.7 runai-model-streamer[s3,gcs,azure]==0.15.7
fastsafetensors>=0.2.2 fastsafetensors>=0.3.2
instanttensor>=0.1.5 instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13 pydantic>=2.12 # 2.11 leads to error on python 3.13
+1 -1
View File
@@ -56,7 +56,7 @@ arctic-inference==0.1.1 # Required for suffix decoding test
numba==0.65.0 # Required for N-gram speculative decoding numba==0.65.0 # Required for N-gram speculative decoding
numpy numpy
runai-model-streamer[s3,gcs,azure]==0.15.7 runai-model-streamer[s3,gcs,azure]==0.15.7
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@0.2.2 # PyPI only ships CUDA wheels fastsafetensors>=0.3.2
instanttensor>=0.1.5 instanttensor>=0.1.5
pydantic>=2.12 # 2.11 leads to error on python 3.13 pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0 decord==0.6.0
+5 -3
View File
@@ -143,7 +143,7 @@ colorful==0.5.8
# via ray # via ray
colorlog==6.10.1 colorlog==6.10.1
# via optuna # via optuna
compressed-tensors==0.15.0.1 compressed-tensors==0.17.0
# via # via
# -c requirements/common.txt # -c requirements/common.txt
# -r requirements/test/../common.txt # -r requirements/test/../common.txt
@@ -240,8 +240,10 @@ fastar==0.10.0
# via fastapi-cloud-cli # via fastapi-cloud-cli
fastparquet==2026.3.0 fastparquet==2026.3.0
# via genai-perf # via genai-perf
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@65d80088fca7a8f567fba30415fbcc80f7d2259c fastsafetensors==0.3.2
# via -r requirements/test/rocm.in # via
# -c requirements/rocm.txt
# -r requirements/test/rocm.in
filelock==3.25.2 filelock==3.25.2
# via # via
# -c requirements/common.txt # -c requirements/common.txt
+1 -1
View File
@@ -1168,7 +1168,7 @@ setup(
"zen": ["zentorch==2.11.0.0"], "zen": ["zentorch==2.11.0.0"],
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"], "bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"], "fastsafetensors": ["fastsafetensors >= 0.3.2"],
"instanttensor": ["instanttensor >= 0.1.5"], "instanttensor": ["instanttensor >= 0.1.5"],
"runai": ["runai-model-streamer[s3,gcs,azure] >= 0.15.7"], "runai": ["runai-model-streamer[s3,gcs,azure] >= 0.15.7"],
"audio": [ "audio": [
@@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
AllReduceFusionPass, AllReduceFusionPass,
RocmAiterAllReduceFusionPass, RocmAiterAllReduceFusionPass,
) )
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.fix_functionalization import ( from vllm.compilation.passes.utility.fix_functionalization import (
FixFunctionalizationPass, FixFunctionalizationPass,
) )
@@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import (
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
@@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceGemmaRMSNormModel(torch.nn.Module):
def __init__(
self,
hidden_size=16,
token_num=16,
eps=1e-6,
dtype: torch.dtype = torch.float16,
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)]
# Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path.
for n in self.norm:
n.weight.data.normal_(mean=0.0, std=0.1)
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym quant_key = kFp8StaticTensorSym
@@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
"test_model, enable_quant_fp8_custom_op, use_aiter", "test_model, enable_quant_fp8_custom_op, use_aiter",
[ [
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND), (TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
pytest.param(
TestAllReduceGemmaRMSNormModel,
False,
False,
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Not supported on ROCm platform",
),
),
pytest.param( pytest.param(
TestAllReduceRMSNormStaticQuantFP8Model, TestAllReduceRMSNormStaticQuantFP8Model,
True, True,
@@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model(
) )
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after()) backend.check_after_ops(model.ops_in_model_after())
if test_model_cls is TestAllReduceGemmaRMSNormModel:
fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass))
assert fused_nodes
assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes)
del all_reduce_fusion_pass del all_reduce_fusion_pass
@@ -0,0 +1,250 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Inductor FALLBACK_ALLOW_LIST patch in env_override.py.
The patch wraps ``torch._inductor.lowering.FALLBACK_ALLOW_LIST`` in a thin
proxy that auto-allows any custom op in the ``vllm::`` or ``vllm_aiter::``
namespaces. This routes those ops through Inductor's fast-path
``make_fallback(target, warn=False, override_decomp=True)`` and avoids the
expensive ``error.operator_str(target, args, kwargs)`` formatting that
recursively stringifies every input ``TensorBox``.
The slow path is what made ``torch.compile`` effectively hang on Kimi-K2.6
TP=8 (deep MoE/TP IR provenance trees). These tests cover both the proxy's
semantics in isolation and the membership-check fast-path that Inductor's
``GraphLowering.call_function`` actually performs, so we can validate the
optimization without needing a full GPU compile.
"""
import time
import pytest
from vllm.env_override import (
_patch_inductor_fallback_allow_list,
_VllmFallbackAllowList,
)
class TestVllmFallbackAllowListProxy:
"""Unit tests for the membership-proxy semantics."""
def test_vllm_namespace_auto_allowed(self):
proxy = _VllmFallbackAllowList(set())
assert "vllm::all_reduce" in proxy
assert "vllm::fused_add_rms_norm" in proxy
assert "vllm::all_reduce.default" in proxy
def test_vllm_aiter_namespace_auto_allowed(self):
proxy = _VllmFallbackAllowList(set())
assert "vllm_aiter::fused_add_rms_norm" in proxy
assert "vllm_aiter::rocm_aiter_fused_moe" in proxy
def test_unknown_namespace_falls_through(self):
proxy = _VllmFallbackAllowList({"torchvision::roi_align"})
assert "torchvision::roi_align" in proxy
assert "made_up_ns::nonexistent_op" not in proxy
def test_non_string_falls_through_to_inner(self):
sentinel = object()
inner = {sentinel}
proxy = _VllmFallbackAllowList(inner)
assert sentinel in proxy
assert object() not in proxy
def test_prefix_only_match_not_substring(self):
proxy = _VllmFallbackAllowList(set())
assert "not_vllm::something" not in proxy
assert " vllm::space_prefixed" not in proxy
def test_standard_entries_preserved(self):
base = {"torchvision::roi_align", "aten::index_add"}
proxy = _VllmFallbackAllowList(base)
assert "torchvision::roi_align" in proxy
assert "aten::index_add" in proxy
assert "aten::__not_present__" not in proxy
def test_add_and_discard_delegate_to_inner(self):
inner: set[str] = set()
proxy = _VllmFallbackAllowList(inner)
proxy.add("custom::op")
assert "custom::op" in inner
proxy.discard("custom::op")
assert "custom::op" not in inner
def test_iter_len_repr(self):
base = {"torchvision::roi_align", "aten::index_add"}
proxy = _VllmFallbackAllowList(base)
assert set(iter(proxy)) == base
assert len(proxy) == len(base)
assert "torchvision::roi_align" in repr(proxy)
def test_getattr_delegates_to_inner(self):
class _Inner:
sentinel = "i_am_inner"
def some_method(self):
return 42
inner = _Inner()
proxy = _VllmFallbackAllowList(inner)
assert proxy.sentinel == "i_am_inner"
assert proxy.some_method() == 42
def test_sentinel_attribute(self):
proxy = _VllmFallbackAllowList(set())
assert proxy._vllm_patched is True
class TestPatchApplication:
"""Integration tests verifying the patch reaches ``torch._inductor``."""
def test_patch_applied_to_lowering(self):
import torch._inductor.lowering as _lowering
assert getattr(_lowering.FALLBACK_ALLOW_LIST, "_vllm_patched", False), (
"env_override._patch_inductor_fallback_allow_list did not run"
)
def test_graph_module_local_binding_rebound(self):
# ``torch/_inductor/graph.py`` does:
# from torch._inductor.lowering import FALLBACK_ALLOW_LIST
# so the patch has to overwrite the graph module's local binding too,
# otherwise the fast-path check in GraphLowering.call_function still
# sees the original (unwrapped) OrderedSet.
import torch._inductor.graph as _graph
import torch._inductor.lowering as _lowering
if not hasattr(_graph, "FALLBACK_ALLOW_LIST"):
pytest.skip(
"torch._inductor.graph no longer imports FALLBACK_ALLOW_LIST "
"as a module-level symbol; nothing to rebind."
)
assert _graph.FALLBACK_ALLOW_LIST is _lowering.FALLBACK_ALLOW_LIST
def test_patch_is_idempotent(self):
import torch._inductor.lowering as _lowering
first = _lowering.FALLBACK_ALLOW_LIST
_patch_inductor_fallback_allow_list()
_patch_inductor_fallback_allow_list()
assert _lowering.FALLBACK_ALLOW_LIST is first
def test_real_vllm_ops_in_real_allow_list(self):
# End-to-end membership check using the live (already-patched) object.
import torch._inductor.lowering as _lowering
allow_list = _lowering.FALLBACK_ALLOW_LIST
assert "vllm::all_reduce" in allow_list
assert "vllm::fused_add_rms_norm" in allow_list
assert "vllm_aiter::fused_add_rms_norm" in allow_list
class TestInductorFallbackFastPath:
"""Emulates ``GraphLowering.call_function``'s FALLBACK_ALLOW_LIST check.
The relevant snippet in ``torch/_inductor/graph.py`` is roughly::
base_name = target.name()
if base_name not in FALLBACK_ALLOW_LIST:
log.info(
"Creating implicit fallback for:\\n%s",
error.operator_str(target, args, kwargs),
)
out = make_fallback(target, ...)
On a deep MoE/TP graph (Kimi-K2.6 at TP=4/8) ``operator_str`` recurses
through every input ``TensorBox.__str__`` and ends up taking many minutes
of CPU per encountered op. The patch ensures the membership test
short-circuits for ``vllm::*``/``vllm_aiter::*`` ops so the slow path is
never entered. These tests pin that behaviour without needing a real
GPU compile.
"""
def _simulate_graph_lowering(self, target_names: list[str]):
"""Returns the set of target names that would have hit the slow
operator_str() path under the patched FALLBACK_ALLOW_LIST.
"""
import torch._inductor.lowering as _lowering
allow_list = _lowering.FALLBACK_ALLOW_LIST
slow_path_hits: list[str] = []
for name in target_names:
if name not in allow_list:
slow_path_hits.append(name)
return slow_path_hits
def test_vllm_ops_skip_slow_path(self):
slow = self._simulate_graph_lowering(
[
"vllm::all_reduce",
"vllm::fused_add_rms_norm",
"vllm_aiter::rocm_aiter_fused_moe",
"vllm_aiter::asm_moe",
]
)
assert slow == [], (
"Patched FALLBACK_ALLOW_LIST must short-circuit for all "
f"vllm::*/vllm_aiter::* ops; got slow-path hits: {slow}"
)
def test_non_vllm_ops_still_hit_slow_path(self):
# Without the patch this is also what would happen; with the patch
# the behaviour for non-vllm namespaces must be unchanged.
slow = self._simulate_graph_lowering(
["my_user_ns::custom_op", "fancy_ns::something_else"]
)
assert "my_user_ns::custom_op" in slow
assert "fancy_ns::something_else" in slow
def test_kimi_k2_6_style_op_stream(self):
"""Emulates one decoder layer's worth of fallback hits.
Kimi-K2.6 at TP=4 lowers a stream of ``vllm::all_reduce`` +
``vllm_aiter::fused_add_rms_norm`` calls (one per residual block)
plus a handful of fused-MoE ops. Pre-patch every one of these would
invoke ``operator_str`` and stringify a hundreds-deep IR provenance
tree; post-patch they must all short-circuit.
"""
n_layers = 64 # Kimi-K2.6 has ~64 decoder layers per replica
op_stream: list[str] = []
for _ in range(n_layers):
op_stream.extend(
[
"vllm::all_reduce",
"vllm_aiter::fused_add_rms_norm",
"vllm_aiter::rocm_aiter_fused_moe",
]
)
start = time.perf_counter()
slow = self._simulate_graph_lowering(op_stream)
elapsed_s = time.perf_counter() - start
assert slow == [], (
f"Expected all {len(op_stream)} vllm/vllm_aiter ops to take "
f"the fast path; got {len(slow)} slow-path hits."
)
# ``__contains__`` is O(1) per call, so a Kimi-sized stream should
# complete in well under a second even on a slow runner. The
# pre-patch slow path took many minutes per op on Kimi-K2.6 TP=8.
assert elapsed_s < 1.0, (
f"FALLBACK_ALLOW_LIST membership check is unexpectedly slow: "
f"{elapsed_s:.3f}s for {len(op_stream)} ops"
)
def test_inner_set_membership_still_works_for_standard_ops(self):
"""The patch must not break Inductor's existing fallback decisions
for non-vllm ops such as ``torchvision::roi_align``."""
import torch._inductor.lowering as _lowering
allow_list = _lowering.FALLBACK_ALLOW_LIST
# ``torchvision::roi_align`` has been a member of the upstream
# FALLBACK_ALLOW_LIST since the original Inductor implementation.
# If the proxy ever broke pass-through, this would regress.
if "torchvision::roi_align" not in allow_list:
pytest.skip(
"Upstream FALLBACK_ALLOW_LIST no longer ships "
"torchvision::roi_align; nothing to verify."
)
+21 -9
View File
@@ -277,12 +277,15 @@ def assert_verification_synced(local_ok: bool, msg: str) -> None:
assert bool(ok_tensor.item()), msg assert bool(ok_tensor.item()), msg
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights): def create_eplb_communicator_or_raise(
*, group_coordinator, backend, expert_weights, expert_buffer
):
try: try:
return create_eplb_communicator( return create_eplb_communicator(
group_coordinator=group_coordinator, group_coordinator=group_coordinator,
backend=backend, backend=backend,
expert_weights=expert_weights, expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
except Exception as exc: except Exception as exc:
raise RuntimeError( raise RuntimeError(
@@ -355,7 +358,8 @@ def _test_async_transfer_layer_without_mtp_worker(
communicator = create_eplb_communicator_or_raise( communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator, group_coordinator=ep_group_coordinator,
backend=eplb_communicator, backend=eplb_communicator,
expert_weights=expert_weights[0], expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
communicator.set_stream(cuda_stream) communicator.set_stream(cuda_stream)
@@ -368,6 +372,7 @@ def _test_async_transfer_layer_without_mtp_worker(
ep_group=ep_group, ep_group=ep_group,
communicator=communicator, communicator=communicator,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
layer_idx=layer_idx,
) )
cuda_stream.synchronize() cuda_stream.synchronize()
move_from_buffer( move_from_buffer(
@@ -460,10 +465,12 @@ def _test_rearrange_expert_weights_with_redundancy(
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
) )
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise( communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator, group_coordinator=ep_group_coordinator,
backend=eplb_communicator, backend=eplb_communicator,
expert_weights=expert_weights[0], expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
# Execute weight rearrangement # Execute weight rearrangement
@@ -471,9 +478,9 @@ def _test_rearrange_expert_weights_with_redundancy(
old_indices, old_indices,
new_indices, new_indices,
expert_weights, expert_weights,
expert_buffer,
ep_group, ep_group,
is_profile=False, communicator,
communicator=communicator,
) )
# Verify the rearrangement result # Verify the rearrangement result
@@ -593,10 +600,12 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
layer_copy.append(weight.clone()) layer_copy.append(weight.clone())
original_weights.append(layer_copy) original_weights.append(layer_copy)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise( communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator, group_coordinator=ep_group_coordinator,
backend="torch_nccl", backend="torch_nccl",
expert_weights=expert_weights[0], expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
# Execute rearrangement (should be no change) # Execute rearrangement (should be no change)
@@ -604,9 +613,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
indices, indices,
indices, # Same indices indices, # Same indices
expert_weights, expert_weights,
expert_buffer,
ep_group, ep_group,
communicator, communicator,
is_profile=False,
) )
# Verify that the weights have not changed # Verify that the weights have not changed
@@ -726,10 +735,12 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
layer_copy.append(weight.clone()) layer_copy.append(weight.clone())
original_weights.append(layer_copy) original_weights.append(layer_copy)
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator_or_raise( communicator = create_eplb_communicator_or_raise(
group_coordinator=ep_group_coordinator, group_coordinator=ep_group_coordinator,
backend="torch_nccl", backend="torch_nccl",
expert_weights=expert_weights[0], expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
# Execute profile mode rearrangement # Execute profile mode rearrangement
@@ -737,9 +748,10 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
old_indices, old_indices,
new_indices, new_indices,
expert_weights, expert_weights,
expert_buffer,
ep_group, ep_group,
communicator, communicator,
is_profile=True, # Profile mode is_profile=True,
) )
# In profile mode, the weights should remain unchanged # In profile mode, the weights should remain unchanged
+11 -1
View File
@@ -9,9 +9,11 @@ import pytest
import torch import torch
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
get_eplb_group,
get_tp_group, get_tp_group,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -213,12 +215,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers): for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts) shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend="torch_nccl",
expert_weights=rank_expert_weights,
expert_buffer=expert_buffer,
)
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
indices, indices,
shuffled_indices, shuffled_indices,
rank_expert_weights, rank_expert_weights,
expert_buffer,
ep_group, ep_group,
is_profile=False, communicator,
) )
num_local_experts = test_config.num_local_experts num_local_experts = test_config.num_local_experts
@@ -10,11 +10,13 @@ import torch
from tests.kernels.moe.utils import make_test_quant_config from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
get_dp_group, get_dp_group,
get_eplb_group,
) )
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -171,12 +173,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers): for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts) shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend="torch_nccl",
expert_weights=rank_expert_weights,
expert_buffer=expert_buffer,
)
rearrange_expert_weights_inplace( rearrange_expert_weights_inplace(
indices, indices,
shuffled_indices, shuffled_indices,
rank_expert_weights, rank_expert_weights,
expert_buffer,
ep_group, ep_group,
is_profile=False, communicator,
) )
num_global_experts = test_config.num_experts num_global_experts = test_config.num_experts
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from vllm import PoolingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
@@ -13,10 +14,13 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.entrypoints.serve.lora.protocol import ( from vllm.entrypoints.serve.lora.protocol import (
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest, UnloadLoRAAdapterRequest,
) )
from vllm.exceptions import VLLMNotFoundError
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
@@ -130,3 +134,60 @@ async def test_unload_lora_adapter_not_found():
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.error.type == "NotFoundError" assert response.error.type == "NotFoundError"
assert response.error.code == HTTPStatus.NOT_FOUND assert response.error.code == HTTPStatus.NOT_FOUND
class _ConcretePoolingServing(PoolingServingBase):
"""Minimal concrete subclass used only in these unit tests."""
request_id_prefix = "test"
def get_io_processor(self, request):
raise NotImplementedError
def _build_response(self, ctx):
raise NotImplementedError
def _make_pooling_serving(lora_name: str) -> _ConcretePoolingServing:
lora_request = LoRARequest(
lora_name=lora_name, lora_int_id=1, lora_path="/path/to/lora"
)
mock_models = MagicMock()
mock_models.lora_requests = {lora_name: lora_request}
mock_models.is_base_model.side_effect = lambda name: name == MODEL_NAME
serving = object.__new__(_ConcretePoolingServing)
serving.models = mock_models
return serving
def _make_pooling_ctx(model_name: str) -> PoolingServeContext:
mock_request = MagicMock()
mock_request.model = model_name
return PoolingServeContext(
request=mock_request,
model_name=MODEL_NAME,
request_id="test-id",
pooling_params=PoolingParams(),
)
def test_pooling_maybe_get_adapters_lora_name_sets_lora_request():
"""LoRA adapter name must populate ctx.lora_request without raising."""
lora_name = "bot-embed-lora"
serving = _make_pooling_serving(lora_name)
ctx = _make_pooling_ctx(lora_name)
serving._maybe_get_adapters(ctx)
assert ctx.lora_request is not None
assert ctx.lora_request.lora_name == lora_name
def test_pooling_maybe_get_adapters_unknown_model_raises():
"""An unrecognised model name must still raise VLLMNotFoundError."""
serving = _make_pooling_serving("some-lora")
ctx = _make_pooling_ctx("unknown-model")
with pytest.raises(VLLMNotFoundError):
serving._maybe_get_adapters(ctx)
@@ -6,7 +6,7 @@
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from ...utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
# Model name constants used across tests # Model name constants used across tests
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct" MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
@@ -22,7 +22,8 @@ import tempfile
import pytest import pytest
import requests import requests
from ...utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from .conftest import ( from .conftest import (
MODEL_NAME_SMOLLM, MODEL_NAME_SMOLLM,
) )
@@ -4,7 +4,8 @@ import openai # use the official async_client for correctness check
import pytest import pytest
import requests import requests
from ...utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from .conftest import MODEL_NAME_SMOLLM from .conftest import MODEL_NAME_SMOLLM
@@ -12,7 +12,8 @@ import tempfile
import pytest import pytest
import requests import requests
from ...utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from .conftest import ( from .conftest import (
MODEL_NAME_SMOLLM, MODEL_NAME_SMOLLM,
) )
@@ -6,7 +6,8 @@ import openai # use the official client for correctness check
import pytest import pytest
import requests import requests
from ...utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from .conftest import ( from .conftest import (
HEADER_SAGEMAKER_CLOSED_SESSION_ID, HEADER_SAGEMAKER_CLOSED_SESSION_ID,
HEADER_SAGEMAKER_NEW_SESSION_ID, HEADER_SAGEMAKER_NEW_SESSION_ID,
@@ -4,7 +4,7 @@
import pytest import pytest
from vllm.entrypoints.openai.engine.protocol import StreamOptions from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.utils import ( from vllm.entrypoints.serve.utils.api_utils import (
get_max_tokens, get_max_tokens,
sanitize_message, sanitize_message,
should_include_usage, should_include_usage,
@@ -6,7 +6,7 @@ from types import SimpleNamespace
import pytest import pytest
from vllm.entrypoints.openai import fingerprint as fp from vllm.entrypoints.serve.utils import fingerprint as fp
def _cfg(tp=1, pp=1, dp=1, ep=False, digest="a3b21f94deadbeef"): def _cfg(tp=1, pp=1, dp=1, ep=False, digest="a3b21f94deadbeef"):
@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
def test_request_logger_log_outputs():
"""Test the new log_outputs functionality."""
# Create a mock logger to capture log calls
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test basic output logging
request_logger.log_outputs(
request_id="test-123",
outputs="Hello, world!",
output_token_ids=[1, 2, 3, 4],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-123"
assert call_args[3] == "Hello, world!"
assert call_args[4] == [1, 2, 3, 4]
assert call_args[5] == "stop"
def test_request_logger_log_outputs_streaming_delta():
"""Test log_outputs with streaming delta mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming delta logging
request_logger.log_outputs(
request_id="test-456",
outputs="Hello",
output_token_ids=[1],
finish_reason=None,
is_streaming=True,
delta=True,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-456"
assert call_args[2] == " (streaming delta)"
assert call_args[3] == "Hello"
assert call_args[4] == [1]
assert call_args[5] is None
def test_request_logger_log_outputs_streaming_complete():
"""Test log_outputs with streaming complete mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming complete logging
request_logger.log_outputs(
request_id="test-789",
outputs="Complete response",
output_token_ids=[1, 2, 3],
finish_reason="length",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-789"
assert call_args[2] == " (streaming complete)"
assert call_args[3] == "Complete response"
assert call_args[4] == [1, 2, 3]
assert call_args[5] == "length"
def test_request_logger_log_outputs_with_truncation():
"""Test log_outputs respects max_log_len setting."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
# Set max_log_len to 10
request_logger = RequestLogger(max_log_len=10)
# Test output truncation
long_output = "This is a very long output that should be truncated"
long_token_ids = list(range(20)) # 20 tokens
request_logger.log_outputs(
request_id="test-truncate",
outputs=long_output,
output_token_ids=long_token_ids,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
# Check that output was truncated to first 10 characters
logged_output = call_args[0][3]
assert logged_output == "This is a "
assert len(logged_output) == 10
# Check that token IDs were truncated to first 10 tokens
logged_token_ids = call_args[0][4]
assert logged_token_ids == list(range(10))
assert len(logged_token_ids) == 10
def test_request_logger_log_outputs_none_values():
"""Test log_outputs handles None values correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with None output_token_ids
request_logger.log_outputs(
request_id="test-none",
outputs="Test output",
output_token_ids=None,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-none"
assert call_args[3] == "Test output"
assert call_args[4] is None
assert call_args[5] == "stop"
def test_request_logger_log_outputs_empty_output():
"""Test log_outputs handles empty output correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=5)
# Test with empty output
request_logger.log_outputs(
request_id="test-empty",
outputs="",
output_token_ids=[],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-empty"
assert call_args[3] == ""
assert call_args[4] == []
assert call_args[5] == "stop"
def test_request_logger_log_outputs_integration():
"""Test that log_outputs can be called alongside log_inputs."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test that both methods can be called without interference
request_logger.log_inputs(
request_id="test-integration",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_embeds=None,
params=None,
lora_request=None,
)
request_logger.log_outputs(
request_id="test-integration",
outputs="Test output",
output_token_ids=[4, 5, 6],
finish_reason="stop",
is_streaming=False,
delta=False,
)
# Should have been called twice - once for inputs, once for outputs
assert mock_logger.info.call_count == 2
# Check that the calls were made with correct patterns
input_call = mock_logger.info.call_args_list[0][0]
output_call = mock_logger.info.call_args_list[1][0]
assert "Received request %s" in input_call[0]
assert input_call[1] == "test-integration"
assert "Generated response %s%s" in output_call[0]
assert output_call[1] == "test-integration"
def test_streaming_complete_logs_full_text_content():
"""Test that streaming complete logging includes
full accumulated text, not just token count."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with actual content instead of token count format
full_response = "This is a complete response from streaming"
request_logger.log_outputs(
request_id="test-streaming-full-text",
outputs=full_response,
output_token_ids=None,
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
# Verify the logged output is the full text, not a token count format
logged_output = call_args[3]
assert logged_output == full_response
assert "tokens>" not in logged_output
assert "streaming_complete" not in logged_output
# Verify other parameters
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"
@@ -7,7 +7,7 @@ from ssl import SSLContext
import pytest import pytest
from vllm.entrypoints.ssl import SSLCertRefresher from vllm.entrypoints.serve.utils.ssl import SSLCertRefresher
class MockSSLContext(SSLContext): class MockSSLContext(SSLContext):
@@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
accuracy_threshold: 0.72 accuracy_threshold: 0.72
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
rocm_request_timeout_seconds: 1800
server_args: "--enforce-eager --max-model-len 4096" server_args: "--enforce-eager --max-model-len 4096"
@@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
accuracy_threshold: 0.45 accuracy_threshold: 0.45
num_questions: 1319 num_questions: 1319
num_fewshot: 5 num_fewshot: 5
rocm_request_timeout_seconds: 1800
server_args: "--enforce-eager --max-model-len 4096" server_args: "--enforce-eager --max-model-len 4096"
+4 -4
View File
@@ -106,7 +106,7 @@ async def call_vllm_api(
completion_tokens = result.get("usage", {}).get("completion_tokens", 0) completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
return text, completion_tokens return text, completion_tokens
except Exception as e: except Exception as e:
print(f"Error calling vLLM API: {e}") print(f"Error calling vLLM API ({type(e).__name__}): {e}")
return "", 0 return "", 0
@@ -177,6 +177,7 @@ def evaluate_gsm8k(
port: int = 8000, port: int = 8000,
temperature: float = 0.0, temperature: float = 0.0,
seed: int | None = 42, seed: int | None = 42,
request_timeout_seconds: float = 600,
) -> dict[str, float | int]: ) -> dict[str, float | int]:
""" """
Evaluate GSM8K accuracy using vLLM serve endpoint. Evaluate GSM8K accuracy using vLLM serve endpoint.
@@ -205,9 +206,8 @@ def evaluate_gsm8k(
output_tokens[i] = tokens output_tokens[i] = tokens
return answer, tokens return answer, tokens
async with aiohttp.ClientSession( timeout = aiohttp.ClientTimeout(total=request_timeout_seconds)
timeout=aiohttp.ClientTimeout(total=600) async with aiohttp.ClientSession(timeout=timeout) as session:
) as session:
tasks = [get_answer(session, i) for i in range(num_questions)] tasks = [get_answer(session, i) for i in range(num_questions)]
await tqdm.gather(*tasks, desc="Evaluating") await tqdm.gather(*tasks, desc="Evaluating")
@@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
host = f"http://{host}" host = f"http://{host}"
# Run GSM8K evaluation # Run GSM8K evaluation
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
if current_platform.is_rocm():
request_timeout_seconds = eval_config.get(
"rocm_request_timeout_seconds", request_timeout_seconds
)
results = evaluate_gsm8k( results = evaluate_gsm8k(
num_questions=eval_config["num_questions"], num_questions=eval_config["num_questions"],
num_shots=eval_config["num_fewshot"], num_shots=eval_config["num_fewshot"],
host=host, host=host,
port=port, port=port,
request_timeout_seconds=request_timeout_seconds,
) )
return results return results
@@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename):
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}") print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
print(f"Number of questions: {eval_config['num_questions']}") print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}") print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
if current_platform.is_rocm():
request_timeout_seconds = eval_config.get(
"rocm_request_timeout_seconds", request_timeout_seconds
)
print(f"Request timeout: {request_timeout_seconds}s")
print(f"Server args: {' '.join(server_args)}") print(f"Server args: {' '.join(server_args)}")
print(f"Environment variables: {env_dict}") print(f"Environment variables: {env_dict}")
@@ -0,0 +1,339 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""ROCm kernel correctness tests for AITER unified attention.
Compares ``aiter.ops.triton.unified_attention`` against ``ref_paged_attn`` under
decode, prefill, and mixed batches with varied shapes.
"""
from typing import Any, Literal
import pytest
import torch
from tests.kernels.attention.test_triton_unified_attention import ref_paged_attn
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
_SKIP_NON_MI3XX = True
if current_platform.is_rocm():
from vllm.platforms.rocm import on_mi3xx
_SKIP_NON_MI3XX = not on_mi3xx()
pytestmark = [
pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-specific tests"),
pytest.mark.skipif(_SKIP_NON_MI3XX, reason="MI300/MI350 ROCm only"),
]
NUM_Q_HEADS = 8
NUM_KV_HEADS = 8
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 64]
DTYPES = [torch.bfloat16, torch.float16]
FP8_DTYPE = current_platform.fp8_dtype()
# (query_len, kv_len) per sequence
MIXED_SEQ_LENS = [
[(1, 128), (5, 18), (129, 463)],
[(10, 256), (5, 64), (32, 128)],
[(1, 1024), (5, 18), (129, 1328)],
]
DECODE_SEQ_LENS = [
[(1, 128), (1, 256), (1, 384), (1, 512)],
[(1, 1024), (1, 1536), (1, 2048)],
]
PREFILL_SEQ_LENS = [
[(256, 256), (128, 512)],
[(64, 128), (32, 256), (16, 512)],
[(256, 1024), (128, 2048)],
]
DEFAULT_ATOL, DEFAULT_RTOL = 1.5e-2, 1e-2
FP8_ATOL, FP8_RTOL = 1.5e-1, 1.5e-1
# Non-unity scale so q_descale handling is exercised explicitly.
Q_SCALE = 0.75
K_SCALE, V_SCALE = 0.5, 0.25
Fp8Variant = Literal["fp8_kv", "fp8_query", "fp8_query_kv"]
FP8_VARIANTS = [
pytest.param("fp8_kv", id="fp8_kv"),
pytest.param("fp8_query", id="fp8_query"),
pytest.param("fp8_query_kv", id="fp8_query_kv"),
]
FP8_SEQ_LENS = [
MIXED_SEQ_LENS[0],
DECODE_SEQ_LENS[0],
DECODE_SEQ_LENS[1],
PREFILL_SEQ_LENS[0],
PREFILL_SEQ_LENS[2],
]
def _require_aiter() -> None:
from vllm._aiter_ops import is_aiter_found_and_supported
if not is_aiter_found_and_supported():
pytest.skip("aiter is required on supported ROCm hardware for this test")
def _make_case(
*,
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
dtype: torch.dtype,
num_blocks: int = 2048,
kv_cache_dtype: torch.dtype | None = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
q_dtype: torch.dtype | None = None,
q_scale: float = Q_SCALE,
) -> dict[str, Any]:
torch.set_default_device("cuda")
query_lens = [q for q, _ in seq_lens]
kv_lens = [k for _, k in seq_lens]
num_seqs = len(seq_lens)
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(sum(query_lens), NUM_Q_HEADS, head_size, dtype=dtype)
if kv_cache_dtype is None:
key_cache = torch.randn(
num_blocks, block_size, NUM_KV_HEADS, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
else:
key_cache = torch.clamp(
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
-1.0,
1.0,
).to(kv_cache_dtype)
value_cache = torch.clamp(
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
-1.0,
1.0,
).to(kv_cache_dtype)
cu_seqlens_q = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
seq_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks), dtype=torch.int32
)
descale_shape = (num_seqs, NUM_KV_HEADS)
k_descale = torch.full(descale_shape, k_scale, dtype=torch.float32, device="cuda")
v_descale = torch.full(descale_shape, v_scale, dtype=torch.float32, device="cuda")
kernel_query = query
q_descale = None
if q_dtype is not None:
q_descale = torch.tensor(q_scale, dtype=torch.float32, device="cuda")
kernel_query = (query / q_scale).to(q_dtype)
return {
"query": query,
"kernel_query": kernel_query,
"key_cache": key_cache,
"value_cache": value_cache,
"block_tables": block_tables,
"query_lens": query_lens,
"kv_lens": kv_lens,
"seq_lens_tensor": seq_lens_tensor,
"cu_seqlens_q": cu_seqlens_q,
"q_descale": q_descale,
"k_descale": k_descale,
"v_descale": v_descale,
"scale": scale,
"max_query_len": max_query_len,
"max_kv_len": max_kv_len,
"query_dtype": dtype,
"k_scale": k_scale,
"v_scale": v_scale,
}
def _make_fp8_case(
*,
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
variant: Fp8Variant,
) -> dict[str, Any]:
use_fp8_kv = variant in ("fp8_kv", "fp8_query_kv")
use_fp8_query = variant in ("fp8_query", "fp8_query_kv")
return _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=torch.bfloat16,
kv_cache_dtype=FP8_DTYPE if use_fp8_kv else None,
k_scale=K_SCALE if use_fp8_kv else 1.0,
v_scale=V_SCALE if use_fp8_kv else 1.0,
q_dtype=FP8_DTYPE if use_fp8_query else None,
)
def _run_aiter_unified_attention(case: dict[str, Any]) -> torch.Tensor:
from aiter.ops.triton.unified_attention import unified_attention
kernel_query = case["kernel_query"]
# Kernel writes high-precision output even when Q is FP8 (matches vLLM usage).
output = torch.empty_like(case["query"])
unified_attention(
q=kernel_query,
k=case["key_cache"],
v=case["value_cache"],
out=output,
cu_seqlens_q=case["cu_seqlens_q"],
max_seqlen_q=case["max_query_len"],
seqused_k=case["seq_lens_tensor"],
max_seqlen_k=case["max_kv_len"],
softmax_scale=case["scale"],
causal=True,
alibi_slopes=None,
window_size=(-1, -1),
block_table=case["block_tables"],
softcap=0,
q_descale=case["q_descale"],
k_descale=case["k_descale"],
v_descale=case["v_descale"],
sinks=None,
output_scale=None,
)
return output
def _ref_output(case: dict[str, Any]) -> torch.Tensor:
key_cache = case["key_cache"]
value_cache = case["value_cache"]
if key_cache.dtype != case["query_dtype"]:
key_cache = key_cache.to(case["query_dtype"]) * case["k_scale"]
value_cache = value_cache.to(case["query_dtype"]) * case["v_scale"]
return ref_paged_attn(
query=case["query"],
key_cache=key_cache,
value_cache=value_cache,
query_lens=case["query_lens"],
kv_lens=case["kv_lens"],
block_tables=case["block_tables"],
scale=case["scale"],
)
def _assert_matches_reference(
case: dict[str, Any],
*,
atol: float = DEFAULT_ATOL,
rtol: float = DEFAULT_RTOL,
) -> None:
output = _run_aiter_unified_attention(case)
output_ref = _ref_output(case)
torch.testing.assert_close(output, output_ref, atol=atol, rtol=rtol)
@pytest.mark.parametrize("seq_lens", MIXED_SEQ_LENS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_aiter_unified_attn_mixed_batch(
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
dtype: torch.dtype,
) -> None:
"""Decode + prefill sequences in one batch (native dtypes)."""
_require_aiter()
set_random_seed(0)
case = _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=dtype,
)
_assert_matches_reference(case)
@pytest.mark.parametrize("seq_lens", DECODE_SEQ_LENS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_aiter_unified_attn_decode(
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
dtype: torch.dtype,
) -> None:
"""Single-token decode (native dtypes)."""
_require_aiter()
set_random_seed(0)
case = _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=dtype,
)
_assert_matches_reference(case)
@pytest.mark.parametrize("seq_lens", PREFILL_SEQ_LENS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16])
@torch.inference_mode()
def test_aiter_unified_attn_prefill(
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
) -> None:
"""Prefill-only batches with query_len > 1 (native dtypes)."""
_require_aiter()
set_random_seed(0)
case = _make_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
dtype=torch.bfloat16,
)
_assert_matches_reference(case)
@pytest.mark.skipif(
not current_platform.supports_fp8(),
reason="FP8 not supported on this hardware",
)
@pytest.mark.parametrize("variant", FP8_VARIANTS)
@pytest.mark.parametrize("seq_lens", FP8_SEQ_LENS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16, 64])
@torch.inference_mode()
def test_aiter_unified_attn_fp8(
variant: Fp8Variant,
seq_lens: list[tuple[int, int]],
head_size: int,
block_size: int,
) -> None:
"""FP8 KV cache, FP8 query, or both; compared at bf16 reference precision."""
_require_aiter()
set_random_seed(0)
case = _make_fp8_case(
seq_lens=seq_lens,
head_size=head_size,
block_size=block_size,
variant=variant,
)
_assert_matches_reference(case, atol=FP8_ATOL, rtol=FP8_RTOL)
+13 -5
View File
@@ -205,7 +205,10 @@ def run_with_expert_maps(
w2 = kwargs["w2"] w2 = kwargs["w2"]
a = kwargs["hidden_states"] a = kwargs["hidden_states"]
moe_config = make_dummy_moe_config( moe_config = make_dummy_moe_config(
num_experts=w2.shape[0], max_num_tokens=kwargs.get("hidden_states").shape[0],
experts_per_token=kwargs.get("topk_ids").shape[1],
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_dim=w2.shape[1], hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2], intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype, in_dtype=a.dtype,
@@ -258,23 +261,27 @@ def run_8_bit(
a1_scale=None, a1_scale=None,
) )
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
kwargs = { kwargs = {
"hidden_states": moe_tensors.a, "hidden_states": moe_tensors.a,
"w1": moe_tensors.w1_q, # type: ignore[union-attr] "w1": moe_tensors.w1_q, # type: ignore[union-attr]
"w2": moe_tensors.w2_q, # type: ignore[union-attr] "w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights, "topk_weights": topk_weights,
"topk_ids": topk_ids, "topk_ids": topk_ids,
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr] "global_num_experts": num_experts,
"activation": MoEActivation.SILU, "activation": MoEActivation.SILU,
"expert_map": None, "expert_map": None,
"apply_router_weight_on_input": False, "apply_router_weight_on_input": False,
} }
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep: if not with_ep:
moe_config = make_dummy_moe_config( moe_config = make_dummy_moe_config(
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] max_num_tokens=moe_tensors.a.shape[0],
experts_per_token=topk_ids.shape[1],
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype, in_dtype=moe_tensors.a.dtype,
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
per_out_channel, per_out_channel,
False, False,
topk_weights, topk_weights,
None,
) )
workspace13.random_() workspace13.random_()
+4 -1
View File
@@ -1287,10 +1287,12 @@ def _test_body_eplb(
expert_weights = [list(eplb_moe_layer.get_expert_weights())] expert_weights = [list(eplb_moe_layer.get_expert_weights())]
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
communicator = create_eplb_communicator( communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(), group_coordinator=get_eplb_group(),
backend=vllm_config.parallel_config.eplb_config.communicator, backend=vllm_config.parallel_config.eplb_config.communicator,
expert_weights=expert_weights[0], expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
# Rearrange expert weights across EP ranks # Rearrange expert weights across EP ranks
@@ -1298,6 +1300,7 @@ def _test_body_eplb(
old_global_expert_indices=initial_indices.unsqueeze(0), old_global_expert_indices=initial_indices.unsqueeze(0),
new_global_expert_indices=shuffled_indices.unsqueeze(0), new_global_expert_indices=shuffled_indices.unsqueeze(0),
expert_weights=expert_weights, expert_weights=expert_weights,
expert_buffer=expert_buffer,
ep_group=cpu_group, ep_group=cpu_group,
communicator=communicator, communicator=communicator,
) )
+6 -2
View File
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
def make_dummy_moe_config( def make_dummy_moe_config(
num_experts: int = 1, num_experts: int = 1,
num_local_experts: int | None = None,
experts_per_token: int = 1, experts_per_token: int = 1,
hidden_dim: int = 1, hidden_dim: int = 1,
intermediate_size_per_partition: int = 1, intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
max_num_tokens: int = 512,
) -> FusedMoEConfig: ) -> FusedMoEConfig:
""" """
This is a dummy config for the mk constructor interface This is a dummy config for the mk constructor interface
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
experts_per_token=experts_per_token, experts_per_token=experts_per_token,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition, intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts, num_local_experts=num_local_experts
if num_local_experts is not None
else num_experts,
num_logical_experts=num_experts, num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU, activation=MoEActivation.SILU,
in_dtype=in_dtype, in_dtype=in_dtype,
device="cuda", device="cuda",
routing_method=RoutingMethodType.TopK, routing_method=RoutingMethodType.TopK,
max_num_tokens=512, max_num_tokens=max_num_tokens,
) )
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the Triton dequant-gather kernel used by
``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup)."""
import pytest
import torch
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
_dequant_gather_triton,
)
from vllm.platforms import current_platform
def _dequant_gather_torch(
ids: torch.Tensor,
weight_packed: torch.Tensor,
weight_scale: torch.Tensor,
hidden: int,
num_bits: int,
) -> torch.Tensor:
"""Reference: gather packed rows by id, unpack int32-packed INT, dequant."""
n = ids.shape[0]
int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden]))
scale_rows = weight_scale[ids]
w = int8.to(scale_rows.dtype)
if scale_rows.shape[1] == 1:
return w * scale_rows
ng = scale_rows.shape[1]
return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden)
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA"
)
@pytest.mark.parametrize("num_bits", [2, 4, 8])
@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("num_ids", [1, 17, 4096])
def test_dequant_gather(num_bits, group_size, dtype, num_ids):
torch.manual_seed(0)
device = "cuda"
vocab, hidden = 1000, 2048
pack_factor = 32 // num_bits
# Random full-range int32 packed weights (covers the sign bit -> exercises the
# arithmetic-shift + mask unpack path).
weight_packed = torch.randint(
-(2**31),
2**31,
(vocab, hidden // pack_factor),
dtype=torch.int32,
device=device,
)
num_groups = 1 if group_size == 0 else hidden // group_size
weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01
ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device)
out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits)
ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits)
assert out.shape == (num_ids, hidden)
assert out.dtype == dtype
torch.testing.assert_close(out, ref)
+149
View File
@@ -468,6 +468,7 @@ def _reference_kv_compress_norm_rope(
use_fp4: bool = False, use_fp4: bool = False,
rms_eps: float = 1e-6, rms_eps: float = 1e-6,
fp8_max: float = 448.0, fp8_max: float = 448.0,
return_full_cache: bool = False,
): ):
"""Compress → RMSNorm → GPT-J RoPE → quantize. """Compress → RMSNorm → GPT-J RoPE → quantize.
@@ -521,6 +522,12 @@ def _reference_kv_compress_norm_rope(
results.append(torch.cat([nope, rope]).to(state_cache.dtype)) results.append(torch.cat([nope, rope]).to(state_cache.dtype))
result = torch.stack(results) result = torch.stack(results)
if return_full_cache:
# Contiguous 512-wide bf16 row (nope unrotated + rope rotated), matching
# the FlashInfer full-cache layout before any per-tensor fp8 quant. The
# kernel rounds the fp32 result to bf16 once at the store.
return result.to(torch.bfloat16)
if use_fp4: if use_fp4:
return quantize_to_mxfp4(result) return quantize_to_mxfp4(result)
else: else:
@@ -667,3 +674,145 @@ def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: b
assert torch.equal(actual_scale, scale[i : i + 1]), ( assert torch.equal(actual_scale, scale[i : i + 1]), (
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}" f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
) )
@pytest.mark.parametrize("compress_ratio", [4, 128])
@pytest.mark.parametrize("store_fp8", [False, True])
def test_cutedsl_full_cache_store(compress_ratio: int, store_fp8: bool):
"""CuTeDSL compressor full-cache (FlashInfer) store parity for head=512.
Exercises the contiguous bf16 / per-tensor fp8 store branch of both the C4
fused kernel and the C128 split kernel against the PyTorch reference.
"""
cutedsl = pytest.importorskip("cutlass") # noqa: F841
from vllm.models.deepseek_v4.nvidia.ops.sparse_attn_compress_cutedsl import (
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
)
HEAD_DIM = 512
ROPE_DIM = 64
RMS_EPS = 1e-6
FP8_MAX = 448.0
# C128 compress (Block8 kernel) requires state-cache block_size=8; C4 uses 16.
BLOCK_SIZE = 8 if compress_ratio == 128 else 16
KV_BLOCK_SIZE = 64
device = "cuda"
torch.manual_seed(7)
overlap = 1 if compress_ratio == 4 else 0
coff = 1 + overlap
num_tokens = 8
num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2
# The production CompressorStateCache is fp32.
state_cache = torch.randn(
num_pages, BLOCK_SIZE, 2 * coff * HEAD_DIM, dtype=torch.float32, device=device
)
block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)
token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device)
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
positions = torch.arange(
compress_ratio - 1,
compress_ratio * num_tokens,
compress_ratio,
dtype=torch.int64,
device=device,
)
rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
cos_sin_cache = torch.randn(
compress_ratio * num_tokens, ROPE_DIM, dtype=torch.float32, device=device
)
dtype = torch.float8_e4m3fn if store_fp8 else torch.bfloat16
kv_n_blocks = (num_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + 1
k_cache = torch.zeros(
kv_n_blocks, KV_BLOCK_SIZE, HEAD_DIM, dtype=dtype, device=device
)
fp8_scale = torch.tensor(
[0.5 if store_fp8 else 1.0], dtype=torch.float32, device=device
)
if compress_ratio == 4:
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
state_cache,
token_to_req,
positions,
slot_mapping,
block_table,
BLOCK_SIZE,
rms_weight,
RMS_EPS,
cos_sin_cache,
k_cache,
slot_mapping,
KV_BLOCK_SIZE,
k_cache.stride(0),
head_size=HEAD_DIM,
state_width=coff * HEAD_DIM,
rope_head_dim=ROPE_DIM,
fp8_max=FP8_MAX,
quant_block=64,
token_stride=576,
scale_dim=8,
compress_ratio=compress_ratio,
overlap=True,
store_full_kv=True,
store_full_fp8=store_fp8,
fp8_scale=fp8_scale,
)
else:
compressed_kv = torch.empty(
(num_tokens, HEAD_DIM), dtype=torch.float32, device=device
)
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
state_cache,
token_to_req,
positions,
slot_mapping,
block_table,
BLOCK_SIZE,
compressed_kv,
rms_weight,
RMS_EPS,
cos_sin_cache,
k_cache,
slot_mapping,
KV_BLOCK_SIZE,
k_cache.stride(0),
head_size=HEAD_DIM,
state_width=coff * HEAD_DIM,
rope_head_dim=ROPE_DIM,
fp8_max=FP8_MAX,
quant_block=64,
token_stride=576,
scale_dim=8,
compress_ratio=compress_ratio,
overlap=bool(overlap),
store_full_kv=True,
store_full_fp8=store_fp8,
fp8_scale=fp8_scale,
)
ref = _reference_kv_compress_norm_rope(
state_cache,
block_table,
positions,
rms_weight,
cos_sin_cache,
compress_ratio,
overlap,
rms_eps=RMS_EPS,
return_full_cache=True,
) # [num_tokens, HEAD_DIM] bf16
actual = torch.stack(
[k_cache[i // KV_BLOCK_SIZE, i % KV_BLOCK_SIZE] for i in range(num_tokens)]
)
if store_fp8:
ref_fp8 = torch.clamp(ref.float() / fp8_scale, -FP8_MAX, FP8_MAX).to(
torch.float8_e4m3fn
)
torch.testing.assert_close(actual.float(), ref_fp8.float(), rtol=0.0, atol=0.3)
else:
torch.testing.assert_close(actual.float(), ref.float(), rtol=3e-2, atol=3e-2)
@@ -67,7 +67,7 @@ def apply_rope_gptj_last_k(
head_dim = x.shape[-1] head_dim = x.shape[-1]
nope_dim = head_dim - rope_dim nope_dim = head_dim - rope_dim
cs = cos_sin_cache[positions].to(torch.float32) cs = cos_sin_cache[positions.long()].to(torch.float32)
cos = cs[..., :half] cos = cs[..., :half]
sin = cs[..., half:] sin = cs[..., half:]
@@ -114,6 +114,18 @@ def _op_available() -> bool:
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert") return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
def _full_cache_fp8_op_available() -> bool:
return hasattr(
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert"
)
def _full_cache_bf16_op_available() -> bool:
return hasattr(
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert"
)
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or not _op_available(), not torch.cuda.is_available() or not _op_available(),
reason="CUDA not available or fused DeepseekV4 op not built in", reason="CUDA not available or fused DeepseekV4 op not built in",
@@ -415,3 +427,238 @@ def test_combined_q_and_kv(
"padded head slots must be exact zero" "padded head slots must be exact zero"
) )
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
# ── Full-cache (FlashInfer) path parity ──────────────────────────────────────
def _call_full_cache_fp8_fused(
q,
kv,
q_fp8,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
bs,
):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
q,
kv,
q_fp8,
k_cache,
slot_mapping,
positions.long(),
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
bs,
)
def _call_full_cache_bf16_fused(
q,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
eps,
bs,
):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
q,
kv,
k_cache,
slot_mapping,
positions.long(),
cos_sin_cache,
eps,
bs,
)
def _fp8_full_cache_reference(
q,
kv,
k_cache,
q_fp8,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
fp8_scale,
q_fp8_scale_inv,
):
q_ref = rmsnorm_no_weight(q, eps)
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
q_fp8.copy_(
torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to(
torch.float8_e4m3fn
)
)
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
valid = slot_mapping >= 0
slots = slot_mapping[valid]
block_idx = slots // block_size
pos_in_block = slots % block_size
k_cache[block_idx, pos_in_block] = torch.clamp(
kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX
).to(torch.float8_e4m3fn)
def _bf16_full_cache_reference(
q,
kv,
k_cache,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
):
q_ref = rmsnorm_no_weight(q, eps)
# Kernel keeps RMSNorm+RoPE in fp32 and rounds to bf16 once at the store.
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache).to(q.dtype)
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
valid = slot_mapping >= 0
slots = slot_mapping[valid]
block_idx = slots // block_size
pos_in_block = slots % block_size
k_cache[block_idx, pos_in_block] = kv_ref[valid]
return q_ref
@pytest.mark.skipif(
not _full_cache_fp8_op_available(),
reason="full-cache per-tensor FP8 DeepseekV4 op not built in",
)
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("n_heads", [8, 17])
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
def test_full_cache_per_tensor_fp8_matches_reference(
num_tokens: int,
n_heads: int,
positions_dtype: torch.dtype,
):
torch.manual_seed(4)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
block_size = 16
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device)
q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn)
q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn)
k_cache_ref = torch.zeros(
num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device
)
k_cache_fused = torch.zeros_like(k_cache_ref)
_fp8_full_cache_reference(
q,
kv,
k_cache_ref,
q_fp8_ref,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
fp8_scale,
q_fp8_scale_inv,
)
_call_full_cache_fp8_fused(
q.clone(),
kv,
q_fp8_fused,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
fp8_scale,
q_fp8_scale_inv,
eps,
block_size,
)
torch.testing.assert_close(
q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25
)
torch.testing.assert_close(
k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25
)
@pytest.mark.skipif(
not _full_cache_bf16_op_available(),
reason="full-cache BF16 DeepseekV4 op not built in",
)
@pytest.mark.parametrize("num_tokens", [4, 17])
@pytest.mark.parametrize("n_heads", [8, 17])
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
def test_full_cache_bf16_matches_reference(
num_tokens: int,
n_heads: int,
positions_dtype: torch.dtype,
):
torch.manual_seed(5)
device = "cuda"
dtype = torch.bfloat16
eps = 1e-6
block_size = 16
max_pos = 4096
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
num_blocks = (num_tokens + block_size - 1) // block_size + 1
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
q_fused = q.clone()
k_cache_ref = torch.zeros(
num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device
)
k_cache_fused = torch.zeros_like(k_cache_ref)
q_ref = _bf16_full_cache_reference(
q,
kv,
k_cache_ref,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
_call_full_cache_bf16_fused(
q_fused,
kv,
k_cache_fused,
slot_mapping,
positions,
cos_sin_cache,
eps,
block_size,
)
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
@@ -0,0 +1,481 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for sequence and token pooler head classes."""
import torch
import torch.nn as nn
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise.heads import (
ClassifierPoolerHead,
EmbeddingPoolerHead,
)
from vllm.model_executor.layers.pooler.tokwise.heads import (
TokenClassifierPoolerHead,
TokenEmbeddingPoolerHead,
)
from vllm.pooling_params import PoolingParams
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
_HIDDEN = 16
_BATCH = 3
def _make_params(
n: int,
*,
task: str = "embed",
dimensions: int | None = None,
use_activation: bool | None = None,
) -> list[PoolingParams]:
return [
PoolingParams(task=task, dimensions=dimensions, use_activation=use_activation)
for _ in range(n)
]
def _make_metadata(pooling_params: list[PoolingParams]) -> PoolingMetadata:
n = len(pooling_params)
return PoolingMetadata(
prompt_lens=torch.ones(n, dtype=torch.long),
prompt_token_ids=None,
prompt_token_ids_cpu=None,
pooling_params=pooling_params,
pooling_states=[PoolingStates() for _ in range(n)],
)
def _linear(in_f: int, out_f: int) -> nn.Linear:
torch.manual_seed(42)
return nn.Linear(in_f, out_f, bias=False)
# ---------------------------------------------------------------------------
# EmbeddingPoolerHead
# ---------------------------------------------------------------------------
class TestEmbeddingPoolerHead:
def test_supported_tasks(self):
head = EmbeddingPoolerHead()
assert head.get_supported_tasks() == {"embed"}
def test_passthrough(self):
head = EmbeddingPoolerHead()
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH))
out = head(x, meta)
assert torch.equal(out, x)
def test_head_dtype(self):
head = EmbeddingPoolerHead(head_dtype=torch.float16)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH))
out = head(x, meta)
assert out.dtype == torch.float16
def test_projector(self):
proj = _linear(_HIDDEN, 8)
head = EmbeddingPoolerHead(projector=proj)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH))
out = head(x, meta)
assert out.shape == (_BATCH, 8)
assert torch.allclose(out, proj(x))
def test_matryoshka_uniform(self):
head = EmbeddingPoolerHead()
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, dimensions=4)
meta = _make_metadata(params)
out = head(x, meta)
assert out.shape == (_BATCH, 4)
assert torch.equal(out, x[..., :4])
def test_matryoshka_mixed(self):
head = EmbeddingPoolerHead()
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="embed", dimensions=4),
PoolingParams(task="embed", dimensions=8),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
assert len(out) == 2
assert out[0].shape[-1] == 4
assert out[1].shape[-1] == 8
def test_matryoshka_mixed_with_none(self):
head = EmbeddingPoolerHead()
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="embed", dimensions=4),
PoolingParams(task="embed", dimensions=None),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
assert out[0].shape[-1] == 4
assert torch.equal(out[1], x[1])
def test_activation_uniform_true(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, use_activation=True)
meta = _make_metadata(params)
out = head(x, meta)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
def test_activation_uniform_false(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, use_activation=False)
meta = _make_metadata(params)
out = head(x, meta)
assert torch.equal(out, x)
def test_activation_mixed_flags(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="embed", use_activation=True),
PoolingParams(task="embed", use_activation=False),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
norm_0 = torch.linalg.norm(out[0], dim=-1)
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
assert torch.equal(out[1], x[1])
def test_list_input_gets_stacked(self):
head = EmbeddingPoolerHead()
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
meta = _make_metadata(_make_params(_BATCH))
out = head(tensors, meta)
assert out.shape == (_BATCH, _HIDDEN)
expected = torch.stack(tensors)
assert torch.equal(out, expected)
def test_projector_then_matryoshka(self):
proj = _linear(_HIDDEN, 8)
head = EmbeddingPoolerHead(projector=proj)
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, dimensions=4)
meta = _make_metadata(params)
out = head(x, meta)
assert out.shape == (_BATCH, 4)
assert torch.equal(out, proj(x)[..., :4])
def test_matryoshka_then_activation(self):
head = EmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, dimensions=4, use_activation=True)
meta = _make_metadata(params)
out = head(x, meta)
assert out.shape == (_BATCH, 4)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
def test_empty_batch(self):
head = EmbeddingPoolerHead()
x = torch.randn(0, _HIDDEN)
meta = _make_metadata([])
out = head(x, meta)
assert out.shape == (0, _HIDDEN)
# ---------------------------------------------------------------------------
# ClassifierPoolerHead
# ---------------------------------------------------------------------------
class TestClassifierPoolerHead:
def test_supported_tasks(self):
head = ClassifierPoolerHead()
assert head.get_supported_tasks() == {"classify"}
def test_passthrough(self):
head = ClassifierPoolerHead()
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.equal(out, x)
def test_head_dtype(self):
head = ClassifierPoolerHead(head_dtype=torch.float16)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert out.dtype == torch.float16
def test_classifier(self):
clf = _linear(_HIDDEN, 3)
head = ClassifierPoolerHead(classifier=clf)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert out.shape == (_BATCH, 3)
assert torch.allclose(out, clf(x))
def test_logit_mean(self):
head = ClassifierPoolerHead(logit_mean=2.0)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.allclose(out, x - 2.0)
def test_logit_sigma(self):
head = ClassifierPoolerHead(logit_sigma=0.5)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.allclose(out, x / 0.5)
def test_platt_scaling_combined(self):
head = ClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
assert torch.allclose(out, (x - 1.0) / 2.0)
def test_activation_uniform_true(self):
head = ClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, task="classify", use_activation=True)
meta = _make_metadata(params)
out = head(x, meta)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
def test_activation_uniform_false(self):
head = ClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(_BATCH, _HIDDEN)
params = _make_params(_BATCH, task="classify", use_activation=False)
meta = _make_metadata(params)
out = head(x, meta)
assert torch.equal(out, x)
def test_activation_mixed_flags(self):
head = ClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(2, _HIDDEN)
params = [
PoolingParams(task="classify", use_activation=True),
PoolingParams(task="classify", use_activation=False),
]
meta = _make_metadata(params)
out = head(x, meta)
assert isinstance(out, list)
norm_0 = torch.linalg.norm(out[0], dim=-1)
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
assert torch.equal(out[1], x[1])
def test_list_input_gets_stacked(self):
head = ClassifierPoolerHead()
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(tensors, meta)
assert out.shape == (_BATCH, _HIDDEN)
expected = torch.stack(tensors)
assert torch.equal(out, expected)
def test_classifier_then_platt_scaling(self):
clf = _linear(_HIDDEN, 3)
head = ClassifierPoolerHead(classifier=clf, logit_mean=1.0, logit_sigma=2.0)
x = torch.randn(_BATCH, _HIDDEN)
meta = _make_metadata(_make_params(_BATCH, task="classify"))
out = head(x, meta)
expected = (clf(x) - 1.0) / 2.0
assert torch.allclose(out, expected)
def test_empty_batch(self):
head = ClassifierPoolerHead()
x = torch.randn(0, _HIDDEN)
meta = _make_metadata([])
out = head(x, meta)
assert out.shape == (0, _HIDDEN)
# ---------------------------------------------------------------------------
# TokenEmbeddingPoolerHead
# ---------------------------------------------------------------------------
class TestTokenEmbeddingPoolerHead:
def test_supported_tasks(self):
head = TokenEmbeddingPoolerHead()
assert head.get_supported_tasks() == {"token_embed"}
def test_passthrough(self):
head = TokenEmbeddingPoolerHead()
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed")
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_none_chunked_prefill(self):
head = TokenEmbeddingPoolerHead()
param = PoolingParams(task="token_embed")
out = head.forward_chunk(None, param)
assert out is None
def test_head_dtype(self):
head = TokenEmbeddingPoolerHead(head_dtype=torch.float16)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed")
out = head.forward_chunk(x, param)
assert out.dtype == torch.float16
def test_projector(self):
proj = _linear(_HIDDEN, 8)
head = TokenEmbeddingPoolerHead(projector=proj)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed")
out = head.forward_chunk(x, param)
assert out.shape == (5, 8)
assert torch.allclose(out, proj(x))
def test_matryoshka_truncation(self):
head = TokenEmbeddingPoolerHead()
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", dimensions=4)
out = head.forward_chunk(x, param)
assert out.shape == (5, 4)
assert torch.equal(out, x[..., :4])
def test_activation_true(self):
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", use_activation=True)
out = head.forward_chunk(x, param)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
def test_activation_false(self):
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", use_activation=False)
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_projector_then_matryoshka(self):
proj = _linear(_HIDDEN, 8)
head = TokenEmbeddingPoolerHead(projector=proj)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", dimensions=4)
out = head.forward_chunk(x, param)
assert out.shape == (5, 4)
assert torch.equal(out, proj(x)[..., :4])
def test_matryoshka_then_activation(self):
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_embed", dimensions=4, use_activation=True)
out = head.forward_chunk(x, param)
assert out.shape == (5, 4)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
def test_forward_mixed_batch_chunked_prefill(self):
head = TokenEmbeddingPoolerHead()
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
params = _make_params(3, task="token_embed")
meta = _make_metadata(params)
out = head(pooled_data, meta)
assert len(out) == 3
assert torch.equal(out[0], pooled_data[0])
assert out[1] is None
assert torch.equal(out[2], pooled_data[2])
def test_forward_empty_batch(self):
head = TokenEmbeddingPoolerHead()
meta = _make_metadata([])
out = head([], meta)
assert out == []
# ---------------------------------------------------------------------------
# TokenClassifierPoolerHead
# ---------------------------------------------------------------------------
class TestTokenClassifierPoolerHead:
def test_supported_tasks(self):
head = TokenClassifierPoolerHead()
assert head.get_supported_tasks() == {"token_classify"}
def test_passthrough(self):
head = TokenClassifierPoolerHead()
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_none_chunked_prefill(self):
head = TokenClassifierPoolerHead()
param = PoolingParams(task="token_classify")
out = head.forward_chunk(None, param)
assert out is None
def test_head_dtype(self):
head = TokenClassifierPoolerHead(head_dtype=torch.float16)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert out.dtype == torch.float16
def test_classifier(self):
clf = _linear(_HIDDEN, 3)
head = TokenClassifierPoolerHead(classifier=clf)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert out.shape == (5, 3)
assert torch.allclose(out, clf(x))
def test_logit_mean(self):
head = TokenClassifierPoolerHead(logit_mean=2.0)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.allclose(out, x - 2.0)
def test_logit_sigma(self):
head = TokenClassifierPoolerHead(logit_sigma=0.5)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.allclose(out, x / 0.5)
def test_platt_scaling_combined(self):
head = TokenClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify")
out = head.forward_chunk(x, param)
assert torch.allclose(out, (x - 1.0) / 2.0)
def test_activation_true(self):
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify", use_activation=True)
out = head.forward_chunk(x, param)
norms = torch.linalg.norm(out, dim=-1)
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
def test_activation_false(self):
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
x = torch.randn(5, _HIDDEN)
param = PoolingParams(task="token_classify", use_activation=False)
out = head.forward_chunk(x, param)
assert torch.equal(out, x)
def test_forward_mixed_batch_chunked_prefill(self):
head = TokenClassifierPoolerHead()
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
params = _make_params(3, task="token_classify")
meta = _make_metadata(params)
out = head(pooled_data, meta)
assert len(out) == 3
assert torch.equal(out[0], pooled_data[0])
assert out[1] is None
assert torch.equal(out[2], pooled_data[2])
def test_forward_empty_batch(self):
head = TokenClassifierPoolerHead()
meta = _make_metadata([])
out = head([], meta)
assert out == []
+76 -41
View File
@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from contextlib import contextmanager, nullcontext
import pytest import pytest
from tests.models.registry import HF_EXAMPLE_MODELS from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test from tests.utils import multi_gpu_test, wait_for_gpu_memory_to_clear
from vllm import LLM from vllm import LLM
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -404,6 +405,30 @@ def _get_vllm_runner_params(
} }
def _wait_for_rocm_memory_to_settle() -> None:
if not current_platform.is_rocm():
return
num_gpus = current_platform.device_count()
if num_gpus == 0:
return
wait_for_gpu_memory_to_clear(
devices=list(range(num_gpus)),
threshold_ratio=0.01,
timeout_s=120,
)
@contextmanager
def _owned_vLLM_runner(vllm_runner, kwargs):
try:
with vllm_runner(**kwargs) as runner:
yield runner
finally:
_wait_for_rocm_memory_to_settle()
def _get_vLLM_output( def _get_vLLM_output(
vllm_runner, vllm_runner,
kwargs, kwargs,
@@ -413,17 +438,21 @@ def _get_vLLM_output(
num_repetitions=1, num_repetitions=1,
vllm_model=None, vllm_model=None,
): ):
outs = [] runner_context = (
if vllm_model is None: _owned_vLLM_runner(vllm_runner, kwargs)
vllm_model = vllm_runner(**kwargs) if vllm_model is None
for _ in range(num_repetitions): else nullcontext(vllm_model)
if num_logprobs < 0: )
vllm_output = vllm_model.generate_greedy(prompts, max_tokens) with runner_context as runner:
else: outs = []
vllm_output = vllm_model.generate_greedy_logprobs( for _ in range(num_repetitions):
prompts, max_tokens, num_logprobs if num_logprobs < 0:
) vllm_output = runner.generate_greedy(prompts, max_tokens)
outs.append(vllm_output) else:
vllm_output = runner.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
outs.append(vllm_output)
return outs, vllm_model return outs, vllm_model
@@ -772,38 +801,44 @@ def test_apc_multiple_prompts_partial_cached_outputs(
# Cache only part of all the prompts # Cache only part of all the prompts
vllm_runner_kwargs["enable_prefix_caching"] = True vllm_runner_kwargs["enable_prefix_caching"] = True
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output( with _owned_vLLM_runner(vllm_runner, vllm_runner_kwargs) as vllm_model:
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs vllm_outputs_partial_cache, _ = _get_vLLM_output(
) vllm_runner,
vllm_runner_kwargs,
compare_operator( generated_prompts[:3],
outputs_0_lst=vllm_outputs_no_cache[0][:3], max_tokens,
outputs_1_lst=vllm_outputs_partial_cache[0], num_logprobs,
name_0="vllm_no_cache", vllm_model=vllm_model,
name_1="vllm_partial_cache", )
)
vllm_outputs_cache_rep, _ = _get_vLLM_output(
vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens,
num_logprobs,
n_repetitions,
vllm_model=vllm_model,
)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator( compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0], outputs_0_lst=vllm_outputs_no_cache[0][:3],
outputs_1_lst=vllm_outputs_cache_itn, outputs_1_lst=vllm_outputs_partial_cache[0],
name_0="vllm_no_cache", name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}", name_1="vllm_partial_cache",
) )
vllm_outputs_cache_rep, _ = _get_vLLM_output(
vllm_runner,
vllm_runner_kwargs,
generated_prompts,
max_tokens,
num_logprobs,
n_repetitions,
vllm_model=vllm_model,
)
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
# In the first repetition, the caches are filled
# In the second repetition, these caches are reused
compare_operator(
outputs_0_lst=vllm_outputs_no_cache[0],
outputs_1_lst=vllm_outputs_cache_itn,
name_0="vllm_no_cache",
name_1=f"vllm_cache_it_{r_idx + 1}",
)
# Test that outputs match whether prefix caching is enabled or not for mamba. # Test that outputs match whether prefix caching is enabled or not for mamba.
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"]) @pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
@@ -826,7 +861,7 @@ def test_same_mamba_output_apc_on_vs_off(
# No prefix caching # No prefix caching
kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False} kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
with vllm_runner(**kwargs_no_apc) as vllm_model: with _owned_vLLM_runner(vllm_runner, kwargs_no_apc) as vllm_model:
outputs_no_apc, _ = _get_vLLM_output( outputs_no_apc, _ = _get_vLLM_output(
vllm_runner, vllm_runner,
kwargs_no_apc, kwargs_no_apc,
@@ -841,7 +876,7 @@ def test_same_mamba_output_apc_on_vs_off(
"enable_prefix_caching": True, "enable_prefix_caching": True,
"mamba_block_size": 16, "mamba_block_size": 16,
} }
with vllm_runner(**kwargs_with_apc) as vllm_model: with _owned_vLLM_runner(vllm_runner, kwargs_with_apc) as vllm_model:
outputs_with_apc, _ = _get_vLLM_output( outputs_with_apc, _ = _get_vLLM_output(
vllm_runner, vllm_runner,
kwargs_with_apc, kwargs_with_apc,
@@ -30,11 +30,14 @@ def vllm_to_hf_output(
MODEL_NAME = "ibm-granite/granite-speech-3.3-2b" MODEL_NAME = "ibm-granite/granite-speech-3.3-2b"
MODEL_NAME_4_0 = "ibm-granite/granite-4.0-1b-speech" MODEL_NAME_4_0 = "ibm-granite/granite-4.0-1b-speech"
# "plus" variant of granite speech (uses GraniteSpeechPlusForConditionalGeneration).
MODEL_NAME_4_1_PLUS = "ibm-granite/granite-speech-4.1-2b-plus"
# Audio lora co-exists directly in the 3.3 model directory, # Audio lora co-exists directly in the 3.3 model directory,
# the 4.0 model has adapters merged into the weights. # the 4.0 and 4.1-plus models have adapters merged into the weights.
models: dict[str, str | None] = { models: dict[str, str | None] = {
MODEL_NAME: MODEL_NAME, MODEL_NAME: MODEL_NAME,
MODEL_NAME_4_0: None, MODEL_NAME_4_0: None,
MODEL_NAME_4_1_PLUS: None,
} }
@@ -43,6 +43,10 @@ def qwen_vl_chat_template(content: str) -> str:
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n" return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
def internvl_chat_template(content: str) -> str:
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
def step3_vl_chat_template(content: str) -> str: def step3_vl_chat_template(content: str) -> str:
return ( return (
"<begin▁of▁sentence> You are a helpful assistant.<|BOT|>user\n " "<begin▁of▁sentence> You are a helpful assistant.<|BOT|>user\n "
@@ -51,6 +55,17 @@ def step3_vl_chat_template(content: str) -> str:
MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = { MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
"internvl": VitCudagraphTestConfig(
model="OpenGVLab/InternVL3-1B",
num_video_frames=8,
image_prompt=internvl_chat_template("<image>\nWhat is in this image?"),
video_prompt=internvl_chat_template(
"<video>\nDescribe this video in one sentence."
),
needs_video_metadata=False,
vllm_runner_kwargs={"trust_remote_code": True},
marks=[pytest.mark.core_model],
),
"qwen2_5_vl": VitCudagraphTestConfig( "qwen2_5_vl": VitCudagraphTestConfig(
model="Qwen/Qwen2.5-VL-3B-Instruct", model="Qwen/Qwen2.5-VL-3B-Instruct",
image_prompt=qwen_vl_chat_template( image_prompt=qwen_vl_chat_template(
+4
View File
@@ -938,6 +938,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"ibm-granite/granite-speech-3.3-2b", "ibm-granite/granite-speech-3.3-2b",
extras={"4.0-1b": "ibm-granite/granite-4.0-1b-speech"}, extras={"4.0-1b": "ibm-granite/granite-4.0-1b-speech"},
), ),
"GraniteSpeechPlusForConditionalGeneration": _HfExamplesInfo(
"ibm-granite/granite-speech-4.1-2b-plus",
min_transformers_version="5.8.0",
),
"GLM4VForCausalLM": _HfExamplesInfo( "GLM4VForCausalLM": _HfExamplesInfo(
"zai-org/glm-4v-9b", "zai-org/glm-4v-9b",
trust_remote_code=True, trust_remote_code=True,
+37
View File
@@ -36,11 +36,24 @@ def tokenizer():
return get_tokenizer("Qwen/Qwen3-32B") return get_tokenizer("Qwen/Qwen3-32B")
TOOLS = [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}},
},
}
]
@pytest.fixture @pytest.fixture
def request_obj(): def request_obj():
return ChatCompletionRequest( return ChatCompletionRequest(
model="test-model", model="test-model",
messages=[{"role": "user", "content": "hi"}], messages=[{"role": "user", "content": "hi"}],
tools=TOOLS,
tool_choice="auto",
) )
@@ -328,3 +341,27 @@ def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
tc.function.arguments for tc in tool_calls if tc.function.arguments tc.function.arguments for tc in tool_calls if tc.function.arguments
) )
assert tool_args.endswith(remainder) assert tool_args.endswith(remainder)
def test_parse_delta_tool_choice_none(tokenizer, request_obj):
parser = make_parser(tokenizer, reasoning=False, tool=True)
request = request_obj.model_copy(update={"tool_choice": "none"})
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
reasoning, content, tool_calls = collect_fields(results)
assert reasoning == ""
assert len(tool_calls) == 0
assert "<tool_call>" in content
assert "get_weather" in content
def test_parse_delta_tool_choice_none_with_reasoning(tokenizer, request_obj):
parser = make_parser(tokenizer, reasoning=True, tool=True)
request = request_obj.model_copy(update={"tool_choice": "none"})
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
reasoning, content, tool_calls = collect_fields(results)
assert "let me think about this" in reasoning
assert len(tool_calls) == 0
assert "<tool_call>" in content
assert "get_weather" in content
+5 -17
View File
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A4Fp4, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Mxfp4, CompressedTensorsW4A4Mxfp4,
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Fp8,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8, CompressedTensorsW8A8Mxfp8,
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, find_matched_target,
) )
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args", "args",
[ [
# TODO: Enable once model is available again ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4), ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
], ],
) )
def test_compressed_tensors_nvfp4(vllm_runner, args): def test_compressed_tensors_nvfp4(vllm_runner, args):
model, scheme = args model, use_a16 = args
with vllm_runner(model, enforce_eager=True) as llm: with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model): def check_model(model):
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
qkv_proj = layer.self_attn.qkv_proj qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
if ( assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
isinstance(qkv_proj.scheme, scheme) assert qkv_proj.scheme.use_a16 == use_a16
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
and not cutlass_fp4_supported()
):
assert True
else:
raise AssertionError("FP4 Scheme Mismatch")
assert qkv_proj.scheme.group_size == 16 assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model) llm.apply_model(check_model)
+1 -244
View File
@@ -10,12 +10,11 @@ from dataclasses import dataclass
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import patch
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from vllm.entrypoints.logger import RequestLogger
from vllm.logger import ( from vllm.logger import (
_DATE_FORMAT, _DATE_FORMAT,
_FORMAT, _FORMAT,
@@ -269,248 +268,6 @@ def test_prepare_object_to_dump():
assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')" assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')"
def test_request_logger_log_outputs():
"""Test the new log_outputs functionality."""
# Create a mock logger to capture log calls
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test basic output logging
request_logger.log_outputs(
request_id="test-123",
outputs="Hello, world!",
output_token_ids=[1, 2, 3, 4],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-123"
assert call_args[3] == "Hello, world!"
assert call_args[4] == [1, 2, 3, 4]
assert call_args[5] == "stop"
def test_request_logger_log_outputs_streaming_delta():
"""Test log_outputs with streaming delta mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming delta logging
request_logger.log_outputs(
request_id="test-456",
outputs="Hello",
output_token_ids=[1],
finish_reason=None,
is_streaming=True,
delta=True,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-456"
assert call_args[2] == " (streaming delta)"
assert call_args[3] == "Hello"
assert call_args[4] == [1]
assert call_args[5] is None
def test_request_logger_log_outputs_streaming_complete():
"""Test log_outputs with streaming complete mode."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test streaming complete logging
request_logger.log_outputs(
request_id="test-789",
outputs="Complete response",
output_token_ids=[1, 2, 3],
finish_reason="length",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-789"
assert call_args[2] == " (streaming complete)"
assert call_args[3] == "Complete response"
assert call_args[4] == [1, 2, 3]
assert call_args[5] == "length"
def test_request_logger_log_outputs_with_truncation():
"""Test log_outputs respects max_log_len setting."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
# Set max_log_len to 10
request_logger = RequestLogger(max_log_len=10)
# Test output truncation
long_output = "This is a very long output that should be truncated"
long_token_ids = list(range(20)) # 20 tokens
request_logger.log_outputs(
request_id="test-truncate",
outputs=long_output,
output_token_ids=long_token_ids,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
# Check that output was truncated to first 10 characters
logged_output = call_args[0][3]
assert logged_output == "This is a "
assert len(logged_output) == 10
# Check that token IDs were truncated to first 10 tokens
logged_token_ids = call_args[0][4]
assert logged_token_ids == list(range(10))
assert len(logged_token_ids) == 10
def test_request_logger_log_outputs_none_values():
"""Test log_outputs handles None values correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with None output_token_ids
request_logger.log_outputs(
request_id="test-none",
outputs="Test output",
output_token_ids=None,
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-none"
assert call_args[3] == "Test output"
assert call_args[4] is None
assert call_args[5] == "stop"
def test_request_logger_log_outputs_empty_output():
"""Test log_outputs handles empty output correctly."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=5)
# Test with empty output
request_logger.log_outputs(
request_id="test-empty",
outputs="",
output_token_ids=[],
finish_reason="stop",
is_streaming=False,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
assert "Generated response %s%s" in call_args[0]
assert call_args[1] == "test-empty"
assert call_args[3] == ""
assert call_args[4] == []
assert call_args[5] == "stop"
def test_request_logger_log_outputs_integration():
"""Test that log_outputs can be called alongside log_inputs."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test that both methods can be called without interference
request_logger.log_inputs(
request_id="test-integration",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_embeds=None,
params=None,
lora_request=None,
)
request_logger.log_outputs(
request_id="test-integration",
outputs="Test output",
output_token_ids=[4, 5, 6],
finish_reason="stop",
is_streaming=False,
delta=False,
)
# Should have been called twice - once for inputs, once for outputs
assert mock_logger.info.call_count == 2
# Check that the calls were made with correct patterns
input_call = mock_logger.info.call_args_list[0][0]
output_call = mock_logger.info.call_args_list[1][0]
assert "Received request %s" in input_call[0]
assert input_call[1] == "test-integration"
assert "Generated response %s%s" in output_call[0]
assert output_call[1] == "test-integration"
def test_streaming_complete_logs_full_text_content():
"""Test that streaming complete logging includes
full accumulated text, not just token count."""
mock_logger = MagicMock()
with patch("vllm.entrypoints.logger.logger", mock_logger):
request_logger = RequestLogger(max_log_len=None)
# Test with actual content instead of token count format
full_response = "This is a complete response from streaming"
request_logger.log_outputs(
request_id="test-streaming-full-text",
outputs=full_response,
output_token_ids=None,
finish_reason="streaming_complete",
is_streaming=True,
delta=False,
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args.args
# Verify the logged output is the full text, not a token count format
logged_output = call_args[3]
assert logged_output == full_response
assert "tokens>" not in logged_output
assert "streaming_complete" not in logged_output
# Verify other parameters
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"
# Add vllm prefix to make sure logs go through the vllm logger # Add vllm prefix to make sure logs go through the vllm logger
test_logger = init_logger("vllm.test_logger") test_logger = init_logger("vllm.test_logger")
@@ -54,15 +54,14 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
( (
MiniMaxText01LinearAttention, MiniMaxText01LinearAttention,
dict( dict(
hidden_size=128, config=SimpleNamespace(
hidden_inner_size=256, hidden_size=256,
num_heads=8, num_attention_heads=8,
head_dim=32, head_dim=32,
max_position=2048, num_hidden_layers=12,
block_size=64, block=64,
num_hidden_layer=12, ),
layer_idx=0, prefix="layers.0.self_attn",
linear_layer_idx=0,
), ),
LinearAttentionBackend, LinearAttentionBackend,
MambaAttentionBackendEnum.LINEAR, MambaAttentionBackendEnum.LINEAR,
@@ -88,6 +87,8 @@ def test_mamba_layers_get_attn_backend(
expected_mamba_type, expected_mamba_type,
): ):
"""Test that Mamba-like layers return the correct attention backend.""" """Test that Mamba-like layers return the correct attention backend."""
if layer_class is MiniMaxText01LinearAttention:
init_kwargs["vllm_config"] = default_vllm_config
layer = layer_class(**init_kwargs) layer = layer_class(**init_kwargs)
backend_class = layer.get_attn_backend() backend_class = layer.get_attn_backend()
+37
View File
@@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n():
) )
def test_free_kv_cache_block_queue_prepend_n():
# Seed the queue with one block so prepend has an existing head to splice
# in front of (fake_head->b0->fake_tail).
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
queue = FreeKVCacheBlockQueue(blocks[0:1])
# Prepend 0 blocks is a no-op.
queue.prepend_n([])
assert queue.num_free_blocks == 1
assert queue.fake_free_list_head.next_free_block is blocks[0]
# Prepend 2 blocks; they land in front of the existing head, in order.
# fake_head->b4->b5->b0->fake_tail
queue.prepend_n(blocks[4:6])
assert queue.num_free_blocks == 3
assert queue.fake_free_list_head.next_free_block is blocks[4]
assert blocks[4].prev_free_block is queue.fake_free_list_head
assert blocks[4].next_free_block is blocks[5]
assert blocks[5].prev_free_block is blocks[4]
assert blocks[5].next_free_block is blocks[0]
assert blocks[0].prev_free_block is blocks[5]
assert blocks[0].next_free_block is queue.fake_free_list_tail
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
# A second prepend goes ahead of everything previously prepended.
# fake_head->b1->b2->b4->b5->b0->fake_tail
queue.prepend_n(blocks[1:3])
assert queue.num_free_blocks == 5
assert queue.fake_free_list_head.next_free_block is blocks[1]
assert blocks[1].next_free_block is blocks[2]
assert blocks[2].next_free_block is blocks[4]
# The popleft order reflects the front-to-back queue order.
assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0]
assert queue.num_free_blocks == 0
def test_free_kv_cache_block_queue_popleft_n(): def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)] blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create an empty FreeKVCacheBlockQueue with these blocks # Create an empty FreeKVCacheBlockQueue with these blocks
+557
View File
@@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec, KVCacheGroupSpec,
KVCacheSpecKind, KVCacheSpecKind,
MambaSpec, MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec, SlidingWindowSpec,
) )
@@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
) )
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# The SWA manager uses the configured 64-token interval (a multiple of the
# 32-token lcm_block_size) as its retention segment. For this 128-token
# prompt, the retained SWA tails are the 64-token interval boundary, the
# 96-token replay boundary, and the 128-token interval boundary.
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {7, 11, 15}
for i in range(16):
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
@pytest.mark.parametrize(
"interval, expected_match",
[
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
("33", "multiple of scheduler_block_size"),
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
# otherwise it would pass the modulo check and silently degrade to dense.
("-32", "non-negative"),
],
)
def test_hybrid_local_kv_retention_interval_rejects_invalid(
monkeypatch, interval, expected_match
):
"""A retention interval that is negative or not a multiple of
scheduler_block_size errors out at construction time."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
with pytest.raises(ValueError, match=expected_match):
make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
"""Verify retained local checkpoints are reused after block recycling."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
hash_block_size = 4
kv_cache_config = KVCacheConfig(
num_blocks=800,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
MLAAttentionSpec(
block_size=64 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.uint8,
compress_ratio=4,
),
),
KVCacheGroupSpec(
["swa"],
SlidingWindowSpec(
block_size=16 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=512,
),
),
KVCacheGroupSpec(
["c128"],
SlidingWindowSpec(
block_size=2 * hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=128,
),
),
KVCacheGroupSpec(
["c4"],
SlidingWindowSpec(
block_size=hash_block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=8,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=4096,
enable_caching=True,
hash_block_size=hash_block_size,
)
def fill_request(request_id: str, token_offset: int) -> list[int]:
token_ids = [
token_offset + i for i in range(1024) for _ in range(hash_block_size)
]
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
while fill_req.num_computed_tokens < len(token_ids):
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
blocks = manager.allocate_slots(fill_req, num_new_tokens)
assert blocks is not None
fill_req.num_computed_tokens += num_new_tokens
manager.free(fill_req)
return token_ids
token_ids = fill_request("fill_0", 0)
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
assert num_computed_tokens == 1024
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
fill_request("fill_1", 100_000)
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
assert num_computed_tokens == 1024
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
"""Verify latest-only retention reuses only the replayable prompt boundary."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
token_ids = [i for i in range(16) for _ in range(block_size)]
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req0)
blocks = manager.allocate_slots(
req0,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {11}
for i in range(16):
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
manager.free(req0)
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
assert retained_swa_block is not None
assert retained_swa_block[0].ref_cnt == 0
req1 = make_request("1", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Full prompt hits intentionally recompute the final block for logits, so
# the longest usable hit is the previous LCM boundary: 96 tokens.
assert num_computed_tokens == 12 * block_size
assert len(computed_blocks.blocks[1]) == 12
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
assert num_computed_tokens == 0
assert len(computed_blocks.blocks[1]) == 0
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
EAGLE/MTP lookup matches one additional local block after the returned
prefix and then drops it. Sparse retention must therefore cache the normal
local tail at the latest replay boundary plus one extra SWA block.
"""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["full"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["swa_mtp"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
is_eagle_group=True,
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
req0 = make_request("0", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0,
len(token_ids),
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
expected_swa_cached = {11, 12}
for i in range(15):
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
if i in expected_swa_cached:
assert cached is not None, f"SWA hash {i} should be cached"
else:
assert cached is None, f"SWA hash {i} should not be cached"
manager.free(req0)
req1 = make_request("1", token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert num_computed_tokens == 12 * block_size
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
def test_block_lookup_cache_single_block_per_key(): def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap() cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0") key0 = BlockHashWithGroupId(b"hash0")
@@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
req = make_request("oversized", list(range(prompt_len)), block_size, sha256) req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
"""Default path (no retention): freeing an SWA request must place its
uncached scratch blocks at the front of the free queue (recycled first)
and keep its cached checkpoint blocks at the back (retained for prefix
hits). This split is always-on, independent of the retention interval."""
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
block_size = 8
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=4 * block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
),
),
],
)
manager = make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
swa_manager = manager.coordinator.single_type_managers[1]
null_block = manager.block_pool.null_block
cached_ids: set[int] = set()
uncached_ids: set[int] = set()
cached_hash_indices: list[int] = []
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
if block is null_block:
continue
if block.block_hash is None:
uncached_ids.add(block.block_id)
else:
cached_ids.add(block.block_id)
cached_hash_indices.append(i)
# The dense default mask caches only the per-segment tails, so a 16-block
# SWA prompt must produce a mix of retained and scratch blocks.
assert cached_ids, "expected some retained (cached) SWA tail blocks"
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
manager.free(req)
order = [
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]
pos = {bid: i for i, bid in enumerate(order)}
# Every scratch block is recycled before every retained block.
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
# The retained tails survive the free and still serve a prefix-cache hit.
for i in cached_hash_indices:
assert (
manager.block_pool.get_cached_block(
req.block_hashes[i], kv_cache_group_ids=[1]
)
is not None
)
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)
return make_kv_cache_manager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
**kwargs,
)
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
"""Sparse retention must work for a pure-SWA single-group model, not just
hybrid models: only the per-interval tails plus the latest replay tail are
cached, and a replay still hits the latest replayable boundary."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
# block 14. Crucially this is a strict subset of all 16 blocks: retention
# is actually sparse for pure SWA (not silently dense).
assert cached == {3, 7, 11, 14, 15}
# A replay of the same prompt hits the latest replayable boundary (240).
replay = make_request("1", token_ids, block_size, sha256)
_, num_computed = manager.get_computed_blocks(replay)
assert num_computed == 240
def test_pure_swa_retention_latest_only(monkeypatch):
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
# No segment tails (interval 0); only the latest replay tail (block 14).
assert cached == {14}
replay = make_request("1", token_ids, block_size, sha256)
_, num_computed = manager.get_computed_blocks(replay)
assert num_computed == 240
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
"""With retention unset, a pure-SWA model must keep the dense behavior:
every block boundary is a potential hit, so all blocks are cached."""
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
block_size = 16
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
token_ids = [i for i in range(16) for _ in range(block_size)]
req = make_request("0", token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req)
blocks = manager.allocate_slots(
req,
len(token_ids),
len(computed_blocks.blocks[0]) * block_size,
computed_blocks,
)
assert blocks is not None
pool = manager.block_pool
cached = {
i
for i in range(16)
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
is not None
}
assert cached == set(range(16))
@@ -11,7 +11,9 @@ import pytest
import torch import torch
from utils import skip_unsupported from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -51,7 +53,7 @@ def test_rms_norm_batch_invariant_vs_standard(
standard_output = rms_norm_layer.forward_cuda(input_tensor) standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation (Triton) # Batch-invariant implementation (Triton)
triton_output = triton_rms_norm(input_tensor, weight, eps=eps) triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Compare outputs # Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision # Use looser tolerance for bfloat16 due to its lower precision
@@ -125,7 +127,7 @@ def test_fused_add_rms_norm_batch_invariant_residual_path(
) )
merged_single = x_single + residual_single merged_single = x_single + residual_single
ref_out = triton_rms_norm(merged_single, weight, eps=eps) ref_out = rms_norm_batch_invariant(merged_single, weight, eps=eps)
torch.testing.assert_close( torch.testing.assert_close(
residual_out_single, residual_out_single,
@@ -193,7 +195,7 @@ def test_rms_norm_3d_input(
standard_output = rms_norm_layer.forward_cuda(input_tensor) standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation # Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps) triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16 # Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16 rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
@@ -242,7 +244,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
standard_output = rms_norm_layer.forward_cuda(input_tensor) standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation # Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps) triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Check for NaN or Inf # Check for NaN or Inf
assert not torch.isnan(standard_output).any(), ( assert not torch.isnan(standard_output).any(), (
@@ -289,7 +291,7 @@ def test_rms_norm_formula(default_vllm_config):
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
# Batch-invariant implementation # Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps) triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Compare against formula # Compare against formula
torch.testing.assert_close( torch.testing.assert_close(
@@ -325,7 +327,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
standard_output = rms_norm_layer.forward_cuda(input_tensor) standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation # Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps) triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16 # Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16 rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
@@ -360,7 +362,7 @@ def test_rms_norm_determinism(default_vllm_config):
# Run multiple times # Run multiple times
outputs = [] outputs = []
for _ in range(5): for _ in range(5):
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps) output = rms_norm_batch_invariant(input_tensor.clone(), weight, eps=eps)
outputs.append(output) outputs.append(output)
# All outputs should be identical # All outputs should be identical
@@ -395,7 +397,7 @@ if __name__ == "__main__":
standard_output = rms_norm_layer.forward_cuda(input_tensor) standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation # Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps) triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
# Compare # Compare
max_diff = (triton_output - standard_output).abs().max().item() max_diff = (triton_output - standard_output).abs().max().item()
@@ -98,7 +98,6 @@ def _make_connector_with_fake_worker(
) )
worker = connector.connector_worker worker = connector.connector_worker
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper) assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
worker.kv_cache_layout = "HND" worker.kv_cache_layout = "HND"
if do_handshake: if do_handshake:
remote_agents = worker._nixl_handshake( remote_agents = worker._nixl_handshake(
+41 -8
View File
@@ -6,25 +6,56 @@
import pytest import pytest
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
class _StubLMCacheMPConnector:
"""Stand-in for LMCacheMPConnector used in config-translation tests.
The real connector module hard-imports the optional ``lmcache`` package
at module load time, which is not installed in the cpu_test image. This
test only asserts on the connector *name* and the ``extra_config`` dict
produced by ``VllmConfig``, never instantiates the connector, so a bare
placeholder class is sufficient. Not subclassing ``SupportsHMA`` mirrors
the real connector's HMA support (it does not support HMA either)."""
@pytest.fixture
def stub_lmcache_mp_connector(monkeypatch):
"""Replace the lazy loader so VllmConfig.__post_init__ does not import
``lmcache_mp_connector`` (and thus ``lmcache``) during config tests."""
monkeypatch.setitem(
KVConnectorFactory._registry,
"LMCacheMPConnector",
lambda: _StubLMCacheMPConnector,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes", "kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
[ [
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)), ("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB # bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30)), ("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30)),
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0), # ``lmcache`` backend now defaults to LMCacheMPConnector. The KV
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB # storage capacity is owned by the standalone LMCache server, so
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0), # ``kv_offloading_size`` is intentionally not propagated.
("lmcache", 4.0, 1, 1, "LMCacheMPConnector", None),
("lmcache", 8.0, 2, 2, "LMCacheMPConnector", None),
# When kv_offloading_size is None, offloading is disabled (backend is ignored) # When kv_offloading_size is None, offloading is disabled (backend is ignored)
("native", None, 1, 1, None, None), ("native", None, 1, 1, None, None),
], ],
) )
def test_kv_connector( def test_kv_connector(
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes stub_lmcache_mp_connector,
kv_offloading_backend,
kv_offloading_size,
tp,
pp,
expected_backend,
expected_bytes,
): ):
kv_transfer_config = ( kv_transfer_config = (
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"}) KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
@@ -59,10 +90,12 @@ def test_kv_connector(
# Existing config should be preserved # Existing config should be preserved
assert kv_connector_extra_config["existing_key"] == "existing_value" assert kv_connector_extra_config["existing_key"] == "existing_value"
elif kv_offloading_backend == "lmcache": elif kv_offloading_backend == "lmcache":
assert kv_connector_extra_config["lmcache.local_cpu"] is True # MP mode does not push lmcache.local_cpu / max_local_cpu_size into
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes # extra config (the LMCache server owns capacity). Pre-existing
# Existing config should be replaced # extra config entries are preserved as-is.
assert "existing_key" not in kv_connector_extra_config assert "lmcache.local_cpu" not in kv_connector_extra_config
assert "lmcache.max_local_cpu_size" not in kv_connector_extra_config
assert kv_connector_extra_config["existing_key"] == "existing_value"
def _build_config( def _build_config(
@@ -197,13 +197,6 @@ class FakeNixlWrapper:
def get_xfer_telemetry(self, handle: int) -> dict: def get_xfer_telemetry(self, handle: int) -> dict:
return get_default_xfer_telemetry() return get_default_xfer_telemetry()
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
@contextlib.contextmanager @contextlib.contextmanager
def _make_fake_nixl_pkg(): def _make_fake_nixl_pkg():
@@ -578,10 +571,7 @@ class TestNixlHandshake:
"""Test case where multiple xfers are initiated to the same engine. """Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to `request_id`.
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
""" """
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
@@ -598,7 +588,6 @@ class TestNixlHandshake:
) )
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
worker = connector.connector_worker worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake # simulate handshake
worker.dst_xfer_side_handles = { worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
@@ -1304,7 +1293,6 @@ def test_scheduler_kv_connector_stats_aggregation():
# Worker stats with transfer metrics # Worker stats with transfer metrics
worker_stats = NixlKVConnectorStats() worker_stats = NixlKVConnectorStats()
worker_stats.record_transfer(get_default_xfer_telemetry()) worker_stats.record_transfer(get_default_xfer_telemetry())
worker_stats.data["remote_tokens"] = []
# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped) # Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
scheduler_stats = NixlKVConnectorStats() scheduler_stats = NixlKVConnectorStats()
@@ -1314,7 +1302,6 @@ def test_scheduler_kv_connector_stats_aggregation():
"post_duration": [0], "post_duration": [0],
"bytes_transferred": [0], "bytes_transferred": [0],
"num_descriptors": [0], "num_descriptors": [0],
"remote_tokens": [128],
} }
) )
@@ -1355,7 +1342,6 @@ def test_scheduler_kv_connector_stats_aggregation():
).scheduler_stats.kv_connector_stats ).scheduler_stats.kv_connector_stats
nixl_stats = final_stats["NixlConnector"] nixl_stats = final_stats["NixlConnector"]
assert nixl_stats.num_successful_transfers == 2 assert nixl_stats.num_successful_transfers == 2
assert nixl_stats.data["remote_tokens"] == [128]
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
@@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid(
) )
@pytest.mark.cpu_test
@pytest.mark.parametrize(
"local_physical_per_logical,remote_physical_per_logical,"
"local_block_ids,remote_block_ids,"
"expected_local,expected_remote",
[
# SSM prefix caching: remote has 3 placeholder + 1 real block,
# local has only the 1 real block. FA blocks are equal (no trim).
pytest.param(
10,
10,
[list(range(10)), [42]],
[list(range(10)), [40, 41, 42, 43]],
[list(range(10)), [42]],
[list(range(10)), [43]],
id="ssm_prefix_trim_only",
),
# FA partial prefix cache hit with homogeneous TP: local has 4 FA
# blocks (prefix cached), remote has full 10. SSM equal (no trim).
pytest.param(
10,
10,
[list(range(6, 10)), [42]],
[list(range(10)), [42]],
[list(range(6, 10)), [42]],
[list(range(6, 10)), [42]],
id="fa_prefix_hit_homo_tp",
),
# Both: FA partial prefix hit + SSM placeholder trim.
# local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9]
# local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real)
pytest.param(
10,
10,
[[6, 7, 8, 9], [99]],
[list(range(10)), [10, 20, 99]],
[[6, 7, 8, 9], [99]],
[[6, 7, 8, 9], [99]],
id="fa_prefix_hit_and_ssm_trim",
),
],
)
def test_apply_prefix_caching_ssm_prefix_cache_hit(
local_physical_per_logical,
remote_physical_per_logical,
local_block_ids,
remote_block_ids,
expected_local,
expected_remote,
):
"""_apply_prefix_caching end-trims SSM remote blocks to match the single
local block (placeholders dropped) and end-trims FA remote blocks on
partial prefix cache hits when physical_per_logical matches.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
NixlConnectorWorker,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
worker = object.__new__(NixlConnectorWorker)
worker._has_mamba = True
worker._physical_blocks_per_logical_kv_block = local_physical_per_logical
worker._group_spec_types = (FullAttentionSpec, MambaSpec)
worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True)
aligned_local, aligned_remote = worker._apply_prefix_caching(
local_block_ids, remote_block_ids, remote_physical_per_logical
)
assert aligned_local == expected_local, (
f"Expected local {expected_local}, got {aligned_local}"
)
assert aligned_remote == expected_remote, (
f"Expected remote {expected_remote}, got {aligned_remote}"
)
@pytest.mark.cpu_test @pytest.mark.cpu_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"local_physical_per_logical,remote_physical_per_logical," "local_physical_per_logical,remote_physical_per_logical,"
+4 -4
View File
@@ -1072,8 +1072,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
@pytest.mark.skipif( @pytest.mark.skipif(
current_platform.is_rocm(), not current_platform.is_cuda(),
reason="Attention backend FLASHINFER is not supported on ROCm.", reason="Attention backend FLASHINFER is only supported on CUDA.",
) )
def test_hybrid_attention_mamba_tensor_shapes(): def test_hybrid_attention_mamba_tensor_shapes():
""" """
@@ -1508,8 +1508,8 @@ def test_is_uniform_decode() -> None:
@pytest.mark.skipif( @pytest.mark.skipif(
current_platform.is_rocm(), not current_platform.is_cuda(),
reason="Attention backend FLASHINFER is not supported on ROCm.", reason="Attention backend FLASHINFER is only supported on CUDA.",
) )
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
"""Test that a ValueError is raised when max_num_seqs exceeds the """Test that a ValueError is raised when max_num_seqs exceeds the
@@ -1562,7 +1562,9 @@ def generate_legend() -> str:
def generate_mla_section( def generate_mla_section(
prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]] prefill_backends: list[dict[str, Any]],
decode_backends: list[dict[str, Any]],
v4_decode_backends: list[dict[str, Any]] | None = None,
) -> str: ) -> str:
"""Generate the complete MLA section with prefill and decode tables.""" """Generate the complete MLA section with prefill and decode tables."""
lines = [ lines = [
@@ -1611,6 +1613,22 @@ def generate_mla_section(
columns = _build_columns(is_mla=True, has_versions=False) columns = _build_columns(is_mla=True, has_versions=False)
lines.extend(_render_table(columns, decode_backends)) lines.extend(_render_table(columns, decode_backends))
if v4_decode_backends:
lines.extend(
[
"",
"### DeepSeek V4 Decode Backends",
"",
"DeepSeek V4 sparse MLA uses its own decode backends, selected via",
"`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,",
"`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index",
"pipeline (compressor + SWA + indexer, 256-token blocks, head 512);",
"default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.",
"",
]
)
lines.extend(_render_table(columns, v4_decode_backends))
lines.append("") lines.append("")
return "\n".join(lines) return "\n".join(lines)
@@ -1651,9 +1669,15 @@ def generate_docs() -> str:
if fi_features: if fi_features:
all_backends = _expand_flashinfer_variants(all_backends, fi_features) all_backends = _expand_flashinfer_variants(all_backends, fi_features)
# Split into MLA and non-MLA # DeepSeek V4 (*_DSV4) decode backends get their own subsection rather than
mla_backends = [b for b in all_backends if b["is_mla"]] # mixing into the main MLA / standard tables (the ROCm V4 backend isn't
non_mla_backends = [b for b in all_backends if not b["is_mla"]] # flagged is_mla by the AST heuristic, so filter purely on the name).
def _is_v4(b: dict[str, Any]) -> bool:
return b["name"].endswith("_DSV4")
v4_decode_backends = [b for b in all_backends if _is_v4(b)]
mla_backends = [b for b in all_backends if b["is_mla"] and not _is_v4(b)]
non_mla_backends = [b for b in all_backends if not b["is_mla"] and not _is_v4(b)]
# Generate documentation # Generate documentation
script_path = "tools/pre_commit/generate_attention_backend_docs.py" script_path = "tools/pre_commit/generate_attention_backend_docs.py"
@@ -1703,7 +1727,9 @@ def generate_docs() -> str:
doc_lines.append("\n>\n".join(footnotes) + "\n") doc_lines.append("\n>\n".join(footnotes) + "\n")
# Add MLA section with prefill and decode backends # Add MLA section with prefill and decode backends
doc_lines.append(generate_mla_section(mla_prefill_backends, mla_backends)) doc_lines.append(
generate_mla_section(mla_prefill_backends, mla_backends, v4_decode_backends)
)
return "\n".join(doc_lines) return "\n".join(doc_lines)
+1 -1
View File
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs from .plot import SweepPlotArgs
from .plot import main as plot_main from .plot import main as plot_main
@@ -44,6 +44,24 @@ from .matcher_utils import MatcherQuantFP8
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
_IR_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default
_IR_FUSED_ADD_RMS_NORM_OP = torch.ops.vllm_ir.fused_add_rms_norm.default
def _norm_input_weight_dtype_match(match: pm.Match) -> bool:
"""Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma
fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm."""
for node in match.nodes:
if node.target == _IR_RMS_NORM_OP:
x, weight = node.args[0], node.args[1]
elif node.target == _IR_FUSED_ADD_RMS_NORM_OP:
x, weight = node.args[0], node.args[2]
else:
continue
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
return x.meta["val"].dtype == weight.meta["val"].dtype
return True
# The empirical value for small batch # The empirical value for small batch
PDL_ADVANCE_LAUNCH_TOKENS = 16 PDL_ADVANCE_LAUNCH_TOKENS = 16
@@ -132,6 +150,7 @@ if flashinfer_comm is not None:
quant_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None, scale_factor: torch.Tensor | None = None,
weight_bias: float = 0.0,
) -> None: ) -> None:
num_tokens, hidden_size = allreduce_in.shape num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size() element_size = allreduce_in.element_size()
@@ -208,6 +227,7 @@ if flashinfer_comm is not None:
layout_code=layout_code, layout_code=layout_code,
use_oneshot=use_oneshot, use_oneshot=use_oneshot,
fp32_acc=fp32_acc, fp32_acc=fp32_acc,
weight_bias=weight_bias,
trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS, trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS,
) )
@@ -225,6 +245,7 @@ if flashinfer_comm is not None:
quant_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None, scale_out: torch.Tensor | None = None,
scale_factor: torch.Tensor | None = None, scale_factor: torch.Tensor | None = None,
weight_bias: float = 0.0,
) -> None: ) -> None:
pass pass
@@ -399,14 +420,142 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
# allreduce_in, residual # allreduce_in, residual
return allreduce[1], allreduce[2] return allreduce[1], allreduce[2]
# extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern.
pm.register_replacement( pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_norm_input_weight_dtype_match,
) )
# Same pattern, but only return the output and not residual # Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again) # (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0] first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call]
self.get_inputs(),
pm.fwd_only,
pm_pass,
extra_check=_norm_input_weight_dtype_match,
)
class AllReduceGemmaRMSNormPattern(BasePattern):
"""Gemma-style variant of AllReduceRMSNormPattern (no residual)."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def get_inputs(self) -> list[torch.Tensor]:
return [self.empty(5, 16), self.empty(16)]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = vllm.ir.ops.rms_norm(
allreduce_output, weight.float() + 1.0, self.epsilon
)
return rms, allreduce_output
def replacement(
input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
weight_bias=1.0,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[3], allreduce[1]
pm.register_replacement(
pattern,
replacement,
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedAddGemmaRMSNormPattern(BasePattern):
"""Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual)."""
def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
) -> None:
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def get_inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
residual = self.empty(5, 16)
weight = self.empty(16)
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, residual = vllm.ir.ops.fused_add_rms_norm(
allreduce_output, residual, weight.float() + 1.0, self.epsilon
)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
weight_bias=1.0,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
return allreduce[1], allreduce[2]
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement( pm.register_replacement(
first_return_only(pattern), # type: ignore[no-untyped-call] first_return_only(pattern), # type: ignore[no-untyped-call]
first_return_only(replacement), # type: ignore[no-untyped-call] first_return_only(replacement), # type: ignore[no-untyped-call]
@@ -881,6 +1030,18 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
self.device, self.device,
self.allreduce_params, self.allreduce_params,
).register(self.patterns) ).register(self.patterns)
AllReduceGemmaRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddGemmaRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache # WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon. # and allow multiple values of epsilon.
@@ -36,10 +36,12 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
} }
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501 QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
@@ -84,9 +84,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
} }
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
+20 -18
View File
@@ -109,22 +109,24 @@ class EPLBConfig:
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution.""" """Configuration for the distributed execution."""
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = Field(default=1, ge=1)
"""Number of pipeline parallel groups.""" """Number of pipeline parallel groups."""
tensor_parallel_size: int = 1 tensor_parallel_size: int = Field(default=1, ge=1)
"""Number of tensor parallel groups.""" """Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1 prefill_context_parallel_size: int = Field(default=1, ge=1)
"""Number of prefill context parallel groups.""" """Number of prefill context parallel groups."""
data_parallel_size: int = 1 data_parallel_size: int = Field(default=1, ge=1)
"""Number of data parallel groups. MoE layers will be sharded according to """Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size.""" the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1 data_parallel_size_local: int = Field(default=1, ge=0)
"""Number of local data parallel groups.""" """Number of local data parallel groups. A value of 0 is a sentinel used by
data_parallel_rank: int = 0 the engine-args layer to signal that data parallelism was specified
"""Rank of the data parallel group.""" externally (see `ParallelConfig.__post_init__`)."""
data_parallel_rank: int = Field(default=0, ge=0)
"""Rank of the data parallel group. The runtime check at
``__post_init__`` further bounds this by ``data_parallel_size``."""
data_parallel_rank_local: int | None = None data_parallel_rank_local: int | None = None
"""Local rank of the data parallel group, """Local rank of the data parallel group, set only in SPMD mode."""
set only in SPMD mode."""
data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master.""" """IP of the data parallel master."""
data_parallel_rpc_port: int = 29550 data_parallel_rpc_port: int = 29550
@@ -184,7 +186,7 @@ class ParallelConfig:
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl - "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels""" - "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
max_parallel_loading_workers: int | None = None max_parallel_loading_workers: int | None = Field(default=None, ge=1)
"""Maximum number of parallel loading workers when loading model """Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models.""" parallel and large models."""
@@ -197,15 +199,15 @@ class ParallelConfig:
enable_dbo: bool = False enable_dbo: bool = False
"""Enable dual batch overlap for the model executor.""" """Enable dual batch overlap for the model executor."""
ubatch_size: int = 0 ubatch_size: int = Field(default=0, ge=0)
"""Number of ubatch size.""" """Number of ubatch size."""
dbo_decode_token_threshold: int = 32 dbo_decode_token_threshold: int = Field(default=32, ge=0)
"""The threshold for dual batch overlap for batches only containing decodes. """The threshold for dual batch overlap for batches only containing decodes.
If the number of tokens in the request is greater than this threshold, If the number of tokens in the request is greater than this threshold,
microbatching will be used. Otherwise, the request will be processed in a microbatching will be used. Otherwise, the request will be processed in a
single batch.""" single batch."""
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune dbo_prefill_token_threshold: int = Field(default=512, ge=0) # TODO(lucas): tune
"""The threshold for dual batch overlap for batches that contain one or more """The threshold for dual batch overlap for batches that contain one or more
prefills. If the number of tokens in the request is greater than this prefills. If the number of tokens in the request is greater than this
threshold, microbatching will be used. Otherwise, the request will be threshold, microbatching will be used. Otherwise, the request will be
@@ -260,10 +262,10 @@ class ParallelConfig:
master_port: int = 29501 master_port: int = 29501
"""distributed master port for multi-node distributed """distributed master port for multi-node distributed
inference when distributed_executor_backend is mp.""" inference when distributed_executor_backend is mp."""
node_rank: int = 0 node_rank: int = Field(default=0, ge=0)
"""distributed node rank for multi-node distributed """distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp.""" inference when distributed_executor_backend is mp."""
nnodes: int = 1 nnodes: int = Field(default=1, ge=1)
"""num of nodes for multi-node distributed """num of nodes for multi-node distributed
inference when distributed_executor_backend is mp.""" inference when distributed_executor_backend is mp."""
numa_bind: bool = False numa_bind: bool = False
@@ -318,7 +320,7 @@ class ParallelConfig:
"""Port of the coordination TCPStore. Can be set by the API server; workers """Port of the coordination TCPStore. Can be set by the API server; workers
connect as clients to exchange self-picked group ports at runtime.""" connect as clients to exchange self-picked group ports at runtime."""
decode_context_parallel_size: int = 1 decode_context_parallel_size: int = Field(default=1, ge=1)
"""Number of decode context parallel groups, because the world size does """Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size.""" needs to be divisible by dcp_size."""
+6 -10
View File
@@ -771,10 +771,6 @@ class VllmConfig:
# If no KVTransferConfig is provided, create a default one. # If no KVTransferConfig is provided, create a default one.
if self.kv_transfer_config is None: if self.kv_transfer_config is None:
self.kv_transfer_config = KVTransferConfig() self.kv_transfer_config = KVTransferConfig()
num_kv_ranks = (
self.parallel_config.tensor_parallel_size
* self.parallel_config.pipeline_parallel_size
)
if kv_offloading_backend == "native": if kv_offloading_backend == "native":
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD: if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
@@ -786,12 +782,12 @@ class VllmConfig:
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)} {"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
) )
elif kv_offloading_backend == "lmcache": elif kv_offloading_backend == "lmcache":
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1" # Default to LMCache multi-process (MP) mode. The actual KV
kv_gb_per_rank = kv_offloading_size / num_kv_ranks # storage capacity is managed by the standalone LMCache server
self.kv_transfer_config.kv_connector_extra_config = { # process, so ``kv_offloading_size`` is not propagated here.
"lmcache.local_cpu": True, # ``LMCacheMPConnector`` falls back to ``tcp://localhost:5555``
"lmcache.max_local_cpu_size": kv_gb_per_rank, # when host/port are not provided via extra_config.
} self.kv_transfer_config.kv_connector = "LMCacheMPConnector"
# This is the same for all backends # This is the same for all backends
self.kv_transfer_config.kv_role = "kv_both" self.kv_transfer_config.kv_role = "kv_both"
@@ -40,11 +40,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True return True
def is_weak_contiguous(inp: torch.Tensor): from vllm.distributed.utils import is_weak_contiguous # noqa: E402
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class CustomAllreduce: class CustomAllreduce:
@@ -24,11 +24,7 @@ except Exception:
quick_ar = False quick_ar = False
def is_weak_contiguous(inp: torch.Tensor): from vllm.distributed.utils import is_weak_contiguous # noqa: E402, F401
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
class QuickReduceRegime(Enum): class QuickReduceRegime(Enum):
@@ -470,10 +470,14 @@ class ElasticEPScalingExecutor:
module._replace_quant_method(module.quant_method.old_quant_method) module._replace_quant_method(module.quant_method.old_quant_method)
prepare_communication_buffer_for_model(self.worker.model_runner.model) prepare_communication_buffer_for_model(self.worker.model_runner.model)
eplb_model_state.expert_buffer = [
torch.empty_like(w) for w in model.expert_weights[0]
]
eplb_model_state.communicator = create_eplb_communicator( eplb_model_state.communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(), group_coordinator=get_eplb_group(),
backend=parallel_config.eplb_config.communicator, backend=parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0], expert_weights=model.expert_weights,
expert_buffer=eplb_model_state.expert_buffer,
) )
if ( if (
+1
View File
@@ -120,6 +120,7 @@ def transfer_run_periodically(
ep_group=eplb_group, ep_group=eplb_group,
is_profile=is_profile, is_profile=is_profile,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
layer_idx=layer_idx,
) )
# Wait until all writes to expert_buffer have finished before making the # Wait until all writes to expert_buffer have finished before making the
+199 -186
View File
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (
is_local_first_rank, is_local_first_rank,
) )
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import is_weak_contiguous
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -63,8 +64,22 @@ class EplbCommunicator(ABC):
pass pass
@abstractmethod @abstractmethod
def execute(self, old_indices: np.ndarray | None = None) -> None: def execute(self) -> None:
pass """Complete all enqueued transfers.
Some backends perform communication here; others (e.g. NIXL)
issue transfers eagerly in add_recv and only wait here.
On return, all data is available in the destination buffers.
"""
def set_transfer_context( # noqa: B027
self, old_indices: np.ndarray, layer_idx: int
) -> None:
"""Pre-set layer context before add_recv calls.
Default is a no-op; overridden by backends (e.g. NIXL) that need
layer-level context to issue transfers inside add_recv.
"""
@property @property
def needs_profile_buffer_reservation(self) -> bool: def needs_profile_buffer_reservation(self) -> bool:
@@ -125,7 +140,7 @@ class TorchDistNcclEplbCommunicator(EplbCommunicator):
) )
) )
def execute(self, old_indices: np.ndarray | None = None) -> None: def execute(self) -> None:
if not self._p2p_ops: if not self._p2p_ops:
return return
try: try:
@@ -168,7 +183,7 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
for tensor in tensors: for tensor in tensors:
self._ops.append(("recv", tensor, src_rank)) self._ops.append(("recv", tensor, src_rank))
def execute(self, old_indices: np.ndarray | None = None) -> None: def execute(self) -> None:
if not self._ops: if not self._ops:
return return
@@ -229,29 +244,47 @@ class NixlEplbCommunicator(EplbCommunicator):
def __init__( def __init__(
self, self,
cpu_group: ProcessGroup, cpu_group: ProcessGroup,
expert_weights: Sequence[torch.Tensor], all_expert_weights: Sequence[Sequence[torch.Tensor]],
cuda_stream: torch.cuda.Stream | None = None, expert_buffer: Sequence[torch.Tensor],
) -> None: ) -> None:
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights." assert all_expert_weights, (
"NixlEplbCommunicator requires non-empty all_expert_weights."
)
assert expert_buffer, "NixlEplbCommunicator requires non-empty expert_buffer."
nixl_wrapper_cls = nixl_utils.NixlWrapper nixl_wrapper_cls = nixl_utils.NixlWrapper
if nixl_wrapper_cls is None: if nixl_wrapper_cls is None:
raise RuntimeError("NIXL/ RIXL is unavailable.") raise RuntimeError("NIXL/ RIXL is unavailable.")
self._cpu_group = cpu_group self._cpu_group = cpu_group
self._cuda_stream = cuda_stream
self._world_size = cpu_group.size() self._world_size = cpu_group.size()
self._rank = cpu_group.rank() self._rank = cpu_group.rank()
# expert_id -> weight tensors to pack into the send buffer.
self._expert_send_map: dict[int, list[torch.Tensor]] = {} self._all_expert_weights = all_expert_weights
# src_rank -> expert_id -> weight tensors to unpack after transfer. self._expert_buffer = expert_buffer
self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {} self._num_local_experts: int = all_expert_weights[0][0].shape[0]
self._num_local_experts: int = expert_weights[0].shape[0] self._device = all_expert_weights[0][0].device
self._device = expert_weights[0].device
for tensor in expert_weights: for layer_tensors in all_expert_weights:
assert tensor.device == self._device, ( for tensor in layer_tensors:
"All local EPLB tensors are expected to be on the same device: " assert is_weak_contiguous(tensor), (
f"expected={self._device}, got={tensor.device}" "Expert weight tensors must be contiguous in memory"
)
assert tensor.device == self._device, (
"All local EPLB tensors are expected to be on the same "
f"device: expected={self._device}, got={tensor.device}"
)
for tensor in expert_buffer:
assert is_weak_contiguous(tensor), (
"expert_buffer tensors must be contiguous in memory"
) )
# (local_dlist, remote_dlist, xfer_handle) for in-flight READs;
# accumulated by add_recv, drained by execute.
self._xfer_entries: list[tuple[int, int, int]] = []
# Per-rank expert_id -> physical row; set by set_transfer_context.
self._expert_to_src_row: list[dict[int, int]] | None = None
self._layer_idx: int | None = None
nixl_agent_config = nixl_utils.nixl_agent_config nixl_agent_config = nixl_utils.nixl_agent_config
config = ( config = (
nixl_agent_config(capture_telemetry=False) nixl_agent_config(capture_telemetry=False)
@@ -260,15 +293,16 @@ class NixlEplbCommunicator(EplbCommunicator):
) )
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config) self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
self._nixl_memory_type = "VRAM" self._nixl_memory_type = "VRAM"
self._registered_desc: object | None = None # NIXL registration handles; deregistered in __del__.
self._registered_descs: list[object] = []
self._remote_agents: dict[int, str] = {} self._remote_agents: dict[int, str] = {}
self._remote_send_meta: dict[int, tuple[int, int]] = {} # peer -> (layer, tensor) -> (base_ptr, bytes_per_expert, dev_id).
self._send_buffer: torch.Tensor = torch.empty(0) self._remote_send_meta: dict[
self._recv_buffer: torch.Tensor = torch.empty(0) int, dict[tuple[int, int], tuple[int, int, int]]
self._expert_bytes: int = 0 ] = {}
self._cuda_device_id = int(self._device.index or 0) self._cuda_device_id = int(self._device.index or 0)
self._init_step("buffers", self._init_registered_buffers, expert_weights) self._init_step("buffers", self._init_registered_buffers)
self._init_step("agents", self._init_remote_agents) self._init_step("agents", self._init_remote_agents)
self._init_step("send meta", self._exchange_remote_send_meta) self._init_step("send meta", self._exchange_remote_send_meta)
self._log_initialized() self._log_initialized()
@@ -291,19 +325,34 @@ class NixlEplbCommunicator(EplbCommunicator):
uid = uuid.uuid4().hex[:8] uid = uuid.uuid4().hex[:8]
return f"eplb-{self._rank}{pp_suffix}-{uid}" return f"eplb-{self._rank}{pp_suffix}-{uid}"
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
pass
def add_send( def add_send(
self, self,
tensors: list[torch.Tensor], tensors: list[torch.Tensor],
dst_rank: int, dst_rank: int,
expert_id: int, expert_id: int,
) -> None: ) -> None:
assert dst_rank != self._rank, ( # No-op: NIXL READ is receiver-initiated. The sender's expert
"EPLB communicator should not enqueue same-rank sends: " # weights are pre-registered and always readable in-place.
f"rank={self._rank}, dst_rank={dst_rank}" pass
def set_transfer_context(self, old_indices: np.ndarray, layer_idx: int) -> None:
# Pre-compute expert_id -> src_row mapping for every rank so that
# add_recv can immediately issue NIXL READs.
assert not self._xfer_entries, (
f"set_transfer_context() called with {len(self._xfer_entries)} "
f"pending transfers from layer {self._layer_idx}; "
f"execute() was not called after previous add_recv() calls"
) )
# An expert sent to multiple peers is packed only once; skip duplicates. self._layer_idx = layer_idx
if expert_id not in self._expert_send_map: n = self._num_local_experts
self._expert_send_map[expert_id] = tensors rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n)
self._expert_to_src_row = [
{int(eid): i for i, eid in enumerate(row) if eid != -1}
for row in rank_experts
]
def add_recv( def add_recv(
self, self,
@@ -311,13 +360,44 @@ class NixlEplbCommunicator(EplbCommunicator):
src_rank: int, src_rank: int,
expert_id: int, expert_id: int,
) -> None: ) -> None:
assert src_rank != self._rank, ( # Build NIXL descriptors and issue the RDMA READ immediately,
"EPLB communicator should not enqueue same-rank recvs: " # overlapping the transfer with the remaining Python loop in
f"rank={self._rank}, src_rank={src_rank}" # move_to_buffer.
assert self._expert_to_src_row is not None and self._layer_idx is not None, (
"set_transfer_context() must be called before add_recv()"
) )
recv_experts = self._recv_map.setdefault(src_rank, {}) src_row = self._expert_to_src_row[src_rank][expert_id]
if expert_id not in recv_experts: layer_idx = self._layer_idx
recv_experts[expert_id] = tensors
local_descs: list[tuple[int, int, int]] = []
remote_descs: list[tuple[int, int, int]] = []
for t_idx, t in enumerate(tensors):
send_base, send_stride, remote_dev = self._remote_send_meta[src_rank][
(layer_idx, t_idx)
]
assert t.nbytes == send_stride, (
f"tensor {t_idx} size {t.nbytes} != remote stride {send_stride}"
)
local_descs.append(
(
t.data_ptr(),
t.nbytes,
self._cuda_device_id,
)
)
remote_descs.append(
(
send_base + src_row * send_stride,
send_stride,
remote_dev,
)
)
local_h, remote_h, xfer_h = self._create_peer_xfer(
src_rank, local_descs, remote_descs
)
self._nixl_wrapper.transfer(xfer_h)
self._xfer_entries.append((local_h, remote_h, xfer_h))
def _init_remote_agents(self) -> None: def _init_remote_agents(self) -> None:
local_metadata = self._nixl_wrapper.get_agent_metadata() local_metadata = self._nixl_wrapper.get_agent_metadata()
@@ -334,73 +414,60 @@ class NixlEplbCommunicator(EplbCommunicator):
peer_metadata peer_metadata
) )
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None: def _init_registered_buffers(self) -> None:
total_bytes = max(sum(t.nbytes for t in expert_weights), 1) all_tensors: list[torch.Tensor] = []
assert total_bytes % self._num_local_experts == 0, ( for layer_tensors in self._all_expert_weights:
f"Number of bytes in moe layer {total_bytes} is not divisible " all_tensors.extend(layer_tensors)
f"by number of local experts {self._num_local_experts}" all_tensors.extend(self._expert_buffer)
)
self._expert_bytes = total_bytes // self._num_local_experts
self._send_buffer = torch.empty( descs = self._nixl_wrapper.get_reg_descs(all_tensors)
total_bytes, device=self._device, dtype=torch.uint8
)
self._recv_buffer = torch.empty(
total_bytes, device=self._device, dtype=torch.uint8
)
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
self._nixl_wrapper.register_memory(descs) self._nixl_wrapper.register_memory(descs)
self._registered_desc = descs self._registered_descs.append(descs)
def _exchange_remote_send_meta(self) -> None: def _exchange_remote_send_meta(self) -> None:
"""Exchange send-buffer metadata so each rank can build dynamic """Exchange per-layer per-tensor metadata so receivers can compute
descriptors at execute time.""" remote RDMA addresses at transfer time."""
local_meta: tuple[int, int] = ( local_meta: dict[tuple[int, int], tuple[int, int, int]] = {}
self._send_buffer.data_ptr(), for layer_idx, layer_tensors in enumerate(self._all_expert_weights):
self._cuda_device_id, for t_idx, t in enumerate(layer_tensors):
) nbytes_per_expert = t.nbytes // self._num_local_experts
gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size local_meta[(layer_idx, t_idx)] = (
t.data_ptr(),
nbytes_per_expert,
self._cuda_device_id,
)
# Per-rank map: (layer_idx, tensor_idx) -> (base_ptr, bytes_per_expert, dev_id).
# add_recv uses base_ptr + src_row * bytes_per_expert to compute
# the remote RDMA address for each expert.
gathered_meta: list[dict[tuple[int, int], tuple[int, int, int]] | None] = [
None
] * self._world_size
torch.distributed.all_gather_object( torch.distributed.all_gather_object(
gathered_meta, local_meta, group=self._cpu_group gathered_meta, local_meta, group=self._cpu_group
) )
local_keys = set(local_meta.keys())
for peer in self._remote_agents: for peer in self._remote_agents:
peer_meta = gathered_meta[peer] peer_meta = gathered_meta[peer]
assert peer_meta is not None assert peer_meta is not None
peer_keys = set(peer_meta.keys())
if peer_keys != local_keys:
raise RuntimeError(
f"NIXL EPLB metadata key mismatch with rank {peer}: "
f"local={sorted(local_keys)}, peer={sorted(peer_keys)}"
)
for key in local_keys:
_, local_stride, _ = local_meta[key]
_, peer_stride, _ = peer_meta[key]
if local_stride != peer_stride:
raise RuntimeError(
f"NIXL EPLB nbytes_per_expert mismatch for {key} "
f"with rank {peer}: "
f"local={local_stride}, peer={peer_stride}"
)
self._remote_send_meta[peer] = peer_meta self._remote_send_meta[peer] = peer_meta
@staticmethod
def _pack_send_buffer(
in_tensors: list[torch.Tensor],
send_buffer: torch.Tensor,
byte_offset: int,
) -> None:
for tensor in in_tensors:
raw = tensor.reshape(-1).view(torch.uint8)
if raw.numel() == 0:
continue
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
raw, non_blocking=True
)
byte_offset += raw.numel()
@staticmethod
def _unpack_recv_buffer(
recv_buffer: torch.Tensor,
out_tensors: list[torch.Tensor],
byte_offset: int,
) -> None:
for tensor in out_tensors:
num_bytes = tensor.numel() * tensor.element_size()
if num_bytes == 0:
continue
tensor.reshape(-1).view(torch.uint8).copy_(
recv_buffer[byte_offset : byte_offset + num_bytes],
non_blocking=True,
)
byte_offset += num_bytes
def _wait_for_all_transfers(self, handles: list[int]) -> None: def _wait_for_all_transfers(self, handles: list[int]) -> None:
pending = set(handles) pending = set(handles)
while pending: while pending:
@@ -456,110 +523,52 @@ class NixlEplbCommunicator(EplbCommunicator):
) )
return (local_handle, remote_handle, xfer_handle) return (local_handle, remote_handle, xfer_handle)
def execute(self, old_indices: np.ndarray | None = None) -> None: def execute(self) -> None:
assert old_indices is not None, ( assert self._layer_idx is not None or not self._xfer_entries, (
"NixlEplbCommunicator.execute requires old_indices" "set_transfer_context() must be called before execute() "
"if any add_recv() calls were made"
) )
xfer_entries: list[tuple[int, int, int]] = []
try: try:
n = self._num_local_experts self._wait_for_all_transfers([x[2] for x in self._xfer_entries])
rank_experts = old_indices[: self._world_size * n].reshape(
self._world_size, n
)
# Build expert_id -> send slot mapping per rank.
expert_to_send_slot: list[dict[int, int]] = [
{int(eid): i for i, eid in enumerate(row) if eid != -1}
for row in rank_experts
]
# Phase 1: pack each expert at its slot offset in the send buffer. # Post-READ barrier.
with torch.cuda.stream(self._cuda_stream): # Correctness fence for zero-copy: prevents overwrite-while-
for expert_id, tensors in self._expert_send_map.items(): # remote-read race.
slot = expert_to_send_slot[self._rank][expert_id]
byte_offset = slot * self._expert_bytes
self._pack_send_buffer(tensors, self._send_buffer, byte_offset)
# Ensure all packed data is visible in device memory before pulls.
if self._cuda_stream is not None:
self._cuda_stream.synchronize()
else:
torch.cuda.current_stream().synchronize()
# READ is receiver-initiated; synchronize all ranks before transfer.
# We use monitored_barrier so a rank that crashes or exits early
# produces a diagnostic timeout instead of a silent hang.
torch.distributed.monitored_barrier( torch.distributed.monitored_barrier(
group=self._cpu_group, group=self._cpu_group,
timeout=timedelta(minutes=5), timeout=timedelta(minutes=5),
) )
# Phase 2: issue one batched READ per peer.
recv_offsets: dict[tuple[int, int], int] = {}
recv_offset = 0
recv_base = self._recv_buffer.data_ptr()
for src in range(self._world_size):
if src == self._rank:
continue
recv_experts = self._recv_map.get(src)
if not recv_experts:
continue
expert_ids = list(recv_experts.keys())
remote_base, remote_dev = self._remote_send_meta[src]
local_descs: list[tuple[int, int, int]] = []
remote_descs: list[tuple[int, int, int]] = []
for expert_id in expert_ids:
slot = expert_to_send_slot[src][expert_id]
remote_off = slot * self._expert_bytes
recv_offsets[(src, expert_id)] = recv_offset
local_descs.append(
(
recv_base + recv_offset,
self._expert_bytes,
self._cuda_device_id,
)
)
remote_descs.append(
(remote_base + remote_off, self._expert_bytes, remote_dev)
)
recv_offset += self._expert_bytes
assert recv_offset <= self._recv_buffer.nbytes
local_h, remote_h, xfer_h = self._create_peer_xfer(
src, local_descs, remote_descs
)
self._nixl_wrapper.transfer(xfer_h)
xfer_entries.append((local_h, remote_h, xfer_h))
# Phase 3: wait for all in-flight transfers, then unpack.
self._wait_for_all_transfers([x[2] for x in xfer_entries])
with torch.cuda.stream(self._cuda_stream):
for (src, expert_id), offset in recv_offsets.items():
self._unpack_recv_buffer(
self._recv_buffer,
self._recv_map[src][expert_id],
offset,
)
finally: finally:
for local_h, remote_h, xfer_h in xfer_entries: for local_h, remote_h, xfer_h in self._xfer_entries:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self._nixl_wrapper.release_xfer_handle(xfer_h) self._nixl_wrapper.release_xfer_handle(xfer_h)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(local_h) self._nixl_wrapper.release_dlist_handle(local_h)
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(remote_h) self._nixl_wrapper.release_dlist_handle(remote_h)
self._expert_send_map.clear() self._xfer_entries.clear()
self._recv_map.clear() self._expert_to_src_row = None
self._layer_idx = None
def __del__(self) -> None: def __del__(self) -> None:
try: with contextlib.suppress(Exception):
if self._registered_desc is not None: for local_h, remote_h, xfer_h in self._xfer_entries:
self._nixl_wrapper.deregister_memory(self._registered_desc) with contextlib.suppress(Exception):
self._registered_desc = None self._nixl_wrapper.release_xfer_handle(xfer_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(local_h)
with contextlib.suppress(Exception):
self._nixl_wrapper.release_dlist_handle(remote_h)
with contextlib.suppress(Exception):
for descs in self._registered_descs:
with contextlib.suppress(Exception):
self._nixl_wrapper.deregister_memory(descs)
self._registered_descs.clear()
with contextlib.suppress(Exception):
for agent_name in self._remote_agents.values(): for agent_name in self._remote_agents.values():
self._nixl_wrapper.remove_remote_agent(agent_name) with contextlib.suppress(Exception):
self._nixl_wrapper.remove_remote_agent(agent_name)
self._remote_agents.clear() self._remote_agents.clear()
except Exception as e:
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
class PyNcclEplbCommunicator(EplbCommunicator): class PyNcclEplbCommunicator(EplbCommunicator):
@@ -600,7 +609,7 @@ class PyNcclEplbCommunicator(EplbCommunicator):
for tensor in tensors: for tensor in tensors:
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream) self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
def execute(self, old_indices: np.ndarray | None = None) -> None: def execute(self) -> None:
if self._group_started: if self._group_started:
self._pynccl_comm.group_end() self._pynccl_comm.group_end()
self._group_started = False self._group_started = False
@@ -609,7 +618,8 @@ class PyNcclEplbCommunicator(EplbCommunicator):
def create_eplb_communicator( def create_eplb_communicator(
group_coordinator: GroupCoordinator, group_coordinator: GroupCoordinator,
backend: str | None, backend: str | None,
expert_weights: Sequence[torch.Tensor], expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
) -> EplbCommunicator: ) -> EplbCommunicator:
"""Create an EPLB communicator for the given backend. """Create an EPLB communicator for the given backend.
@@ -624,16 +634,18 @@ def create_eplb_communicator(
``"pynccl"`` in that case. When tensors reside on CPU, ``"pynccl"`` in that case. When tensors reside on CPU,
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU ``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
process group. process group.
expert_weights: Expert weight tensors from *one* MoE layer. expert_weights: Expert weight tensors for *all* MoE layers.
NixlEplbCommunicator pre-allocates send/recv buffers sized Shape ``(num_layers)(num_tensors_per_layer)``.
to this layer, so all other MoE layers must have the same NixlEplbCommunicator registers all layers with NIXL for
tensor count, shapes, and dtypes. zero-copy RDMA reads.
expert_buffer: Pre-allocated receive buffer tensors (one per
weight tensor in a single layer).
""" """
# Keep a safe default for callers that have not resolved communicator yet.
if backend is None: if backend is None:
backend = "torch_nccl" backend = "torch_nccl"
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu" first_layer = expert_weights[0] if expert_weights else []
tensor_device_type = first_layer[0].device.type if first_layer else "cpu"
torch_group = ( torch_group = (
group_coordinator.cpu_group group_coordinator.cpu_group
if tensor_device_type == "cpu" if tensor_device_type == "cpu"
@@ -649,7 +661,7 @@ def create_eplb_communicator(
unsupported_dtypes = sorted( unsupported_dtypes = sorted(
{ {
tensor.dtype tensor.dtype
for tensor in expert_weights for tensor in first_layer
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype) if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
}, },
key=str, key=str,
@@ -704,7 +716,8 @@ def create_eplb_communicator(
try: try:
return NixlEplbCommunicator( return NixlEplbCommunicator(
cpu_group=group_coordinator.cpu_group, cpu_group=group_coordinator.cpu_group,
expert_weights=expert_weights, all_expert_weights=expert_weights,
expert_buffer=expert_buffer,
) )
except Exception as exc: except Exception as exc:
raise RuntimeError( raise RuntimeError(
+3 -1
View File
@@ -450,7 +450,8 @@ class EplbState:
communicator = create_eplb_communicator( communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(), group_coordinator=get_eplb_group(),
backend=self.parallel_config.eplb_config.communicator, backend=self.parallel_config.eplb_config.communicator,
expert_weights=model.expert_weights[0], expert_weights=model.expert_weights,
expert_buffer=expert_buffer,
) )
model_state = EplbModelState( model_state = EplbModelState(
@@ -766,6 +767,7 @@ class EplbState:
eplb_model_state.physical_to_logical_map, eplb_model_state.physical_to_logical_map,
new_physical_to_logical_map, new_physical_to_logical_map,
eplb_model_state.model.expert_weights, eplb_model_state.model.expert_weights,
eplb_model_state.expert_buffer,
ep_group, ep_group,
eplb_model_state.communicator, eplb_model_state.communicator,
is_profile, is_profile,
+16 -9
View File
@@ -178,6 +178,7 @@ def move_to_buffer(
cuda_stream: torch.cuda.Stream | None, cuda_stream: torch.cuda.Stream | None,
ep_rank: int, ep_rank: int,
communicator: EplbCommunicator, communicator: EplbCommunicator,
layer_idx: int = 0,
) -> TransferMetadata: ) -> TransferMetadata:
""" """
Rearranges expert weights during EPLB rebalancing. Rearranges expert weights during EPLB rebalancing.
@@ -193,6 +194,7 @@ def move_to_buffer(
cuda_stream: CUDA stream for async copies (can be None for sync mode). cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_rank: Rank of this process in expert parallel group. ep_rank: Rank of this process in expert parallel group.
communicator: EplbCommunicator instance for P2P communication. communicator: EplbCommunicator instance for P2P communication.
layer_idx: Index of the MoE layer being transferred.
Returns: Returns:
TransferMetadata: Metadata needed for completing remote weight transfers. TransferMetadata: Metadata needed for completing remote weight transfers.
@@ -265,6 +267,8 @@ def move_to_buffer(
for w, b in zip(expert_weights, expert_weights_buffers): for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True) b[dst].copy_(w[src_local], non_blocking=True)
communicator.set_transfer_context(old_indices, layer_idx)
# 2. Post sends # 2. Post sends
if send_count > 0: if send_count > 0:
experts = send_expert_ids[:send_count] experts = send_expert_ids[:send_count]
@@ -331,9 +335,8 @@ def move_to_buffer(
expert_id=int(expert), expert_id=int(expert),
) )
# 4. Execute the P2P operations. The real communication happens here. # 4. Execute transfers and wait for completion.
communicator.execute(old_indices=old_indices) communicator.execute()
# wait for the communication to finish
return TransferMetadata( return TransferMetadata(
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
@@ -431,6 +434,7 @@ def transfer_layer(
is_profile: bool = False, is_profile: bool = False,
cuda_stream: torch.cuda.Stream | None = None, cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
layer_idx: int = 0,
) -> TransferMetadata: ) -> TransferMetadata:
""" """
Rearranges the expert weights in place according to the new expert indices. Rearranges the expert weights in place according to the new expert indices.
@@ -452,6 +456,7 @@ def transfer_layer(
communications to reserve enough memory for the buffers. communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode). cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism. rank_mapping: Optional rank mapping for elastic expert parallelism.
layer_idx: Index of the MoE layer being transferred.
Returns: Returns:
TransferMetadata: Metadata needed for completing remote weight transfers, TransferMetadata: Metadata needed for completing remote weight transfers,
@@ -499,6 +504,7 @@ def transfer_layer(
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
ep_rank=ep_group.rank(), ep_rank=ep_group.rank(),
communicator=communicator, communicator=communicator,
layer_idx=layer_idx,
) )
@@ -506,6 +512,7 @@ def rearrange_expert_weights_inplace(
old_global_expert_indices: torch.Tensor, old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor, new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Sequence[torch.Tensor]], expert_weights: Sequence[Sequence[torch.Tensor]],
expert_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup, ep_group: ProcessGroup,
communicator: EplbCommunicator, communicator: EplbCommunicator,
is_profile: bool = False, is_profile: bool = False,
@@ -524,6 +531,8 @@ def rearrange_expert_weights_inplace(
of tensors of shape (num_local_physical_experts, hidden_size_i). of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection, For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different. so weight_count = 2. Each weight's hidden size can be different.
expert_buffer: Pre-allocated receive buffer tensors (one per
weight tensor in a single layer).
ep_group: The device process group for expert parallelism. ep_group: The device process group for expert parallelism.
communicator: EplbCommunicator instance for P2P communication. communicator: EplbCommunicator instance for P2P communication.
is_profile (bool): If `True`, do not perform any actual weight copy. is_profile (bool): If `True`, do not perform any actual weight copy.
@@ -566,10 +575,10 @@ def rearrange_expert_weights_inplace(
# Reserve NCCL communication buffers via a dummy all_gather. # Reserve NCCL communication buffers via a dummy all_gather.
# Backends that pre-allocate their own transfer buffers # Backends that pre-allocate their own transfer buffers
# skip this to avoid the extra memory spike during profiling. # skip this to avoid the extra memory spike during profiling.
weights_buffer: list[torch.Tensor] = [ profile_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights torch.empty_like(w) for w in first_layer_weights
] ]
for weight, buffer in zip(expert_weights[0], weights_buffer): for weight, buffer in zip(expert_weights[0], profile_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)] dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier() torch.distributed.barrier()
all_gather( all_gather(
@@ -579,10 +588,7 @@ def rearrange_expert_weights_inplace(
) )
return return
# Buffers to hold the expert weights during the exchange. weights_buffer = list(expert_buffer)
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy() old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
@@ -597,6 +603,7 @@ def rearrange_expert_weights_inplace(
cuda_stream=None, cuda_stream=None,
ep_rank=ep_rank, ep_rank=ep_rank,
communicator=communicator, communicator=communicator,
layer_idx=layer_idx,
) )
move_from_buffer( move_from_buffer(
@@ -120,15 +120,12 @@ class KVOutputAggregator:
# Use the first worker's kv_connector_stats as accumulator. # Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = kv_output.kv_connector_stats aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats: elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None: assert isinstance(
aggregated_kv_connector_stats = kv_connector_stats aggregated_kv_connector_stats, type(kv_connector_stats)
else: )
assert isinstance( aggregated_kv_connector_stats = aggregated_kv_connector_stats.aggregate(
aggregated_kv_connector_stats, type(kv_connector_stats) kv_connector_stats
) )
aggregated_kv_connector_stats = (
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
# Aggregate kv_connector_worker_meta from all workers. # Aggregate kv_connector_worker_meta from all workers.
if aggregated_kv_connector_worker_meta is None: if aggregated_kv_connector_worker_meta is None:
@@ -2333,9 +2333,25 @@ class NixlConnectorWorker:
for i, remote_group in enumerate(remote_block_ids): for i, remote_group in enumerate(remote_block_ids):
num_local_blocks = len(local_block_ids[i]) num_local_blocks = len(local_block_ids[i])
num_remote_blocks = len(remote_group) num_remote_blocks = len(remote_group)
if _is_ssm_spec(self._group_spec_types[i]): if (
assert num_local_blocks == num_remote_blocks _is_ssm_spec(self._group_spec_types[i])
and num_local_blocks < num_remote_blocks
):
# NOTE (NickLucche): With prefix caching on SSM, (remote) blocks
# prior to the last one are placeholders (null blocks). Mind that
# this doesn't really impact transfer, as we only still care about
# the last "block", the full in-place state.
assert num_local_blocks == 1, "SSM can only have one local block"
remote_block_ids[i] = remote_group[-num_local_blocks:]
elif (
self._physical_blocks_per_logical_kv_block
== remote_physical_per_logical
and num_local_blocks < num_remote_blocks
):
# Partial prefix cache hit for FA group.
remote_block_ids[i] = remote_group[-num_local_blocks:]
else: else:
# TODO Handle prefix caching with different block_sizes
max_padding = max( max_padding = max(
self._physical_blocks_per_logical_kv_block, self._physical_blocks_per_logical_kv_block,
remote_physical_per_logical, remote_physical_per_logical,
-38
View File
@@ -1270,9 +1270,6 @@ def get_dcp_group() -> GroupCoordinator:
return _DCP return _DCP
# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group
_PP: GroupCoordinator | None = None _PP: GroupCoordinator | None = None
@@ -1840,31 +1837,6 @@ def model_parallel_is_initialized():
_TP_STATE_PATCHED = False _TP_STATE_PATCHED = False
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tensor_model_parallel_world_size() -> int: def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group.""" """Return world size for the tensor model parallel group."""
return get_tp_group().world_size return get_tp_group().world_size
@@ -1875,16 +1847,6 @@ def get_tensor_model_parallel_rank() -> int:
return get_tp_group().rank_in_group return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size() -> int:
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank() -> int:
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
def get_node_count() -> int: def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment.""" """Return the total number of nodes in the distributed environment."""
assert _NODE_COUNT is not None, "distributed environment is not initialized" assert _NODE_COUNT is not None, "distributed environment is not initialized"
+14
View File
@@ -64,6 +64,20 @@ def divide(numerator, denominator):
return numerator // denominator return numerator // denominator
def is_weak_contiguous(inp: torch.Tensor) -> bool:
"""Check that *inp* occupies a single contiguous block of memory.
Unlike ``torch.Tensor.is_contiguous()``, this also accepts tensors
whose strides are not strictly C-contiguous (e.g. column-major) as
long as the underlying storage from the tensor's offset onward is
exactly ``numel * element_size`` bytes.
"""
return inp.is_contiguous() or (
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size()
)
def split_tensor_along_last_dim( def split_tensor_along_last_dim(
tensor: torch.Tensor, tensor: torch.Tensor,
num_partitions: int, num_partitions: int,
+2 -2
View File
@@ -17,9 +17,9 @@ from vllm.entrypoints.anthropic.protocol import (
) )
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.serve.utils.api_utils import (
from vllm.entrypoints.utils import (
load_aware_call, load_aware_call,
validate_json_request,
with_cancellation, with_cancellation,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
+1 -1
View File
@@ -29,7 +29,6 @@ from vllm.entrypoints.anthropic.protocol import (
AnthropicUsage, AnthropicUsage,
) )
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionRequest,
@@ -45,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
StreamOptions, StreamOptions,
) )
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
+1 -1
View File
@@ -22,7 +22,7 @@ import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.utils import with_cancellation from vllm.entrypoints.serve.utils.api_utils import with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
+1 -1
View File
@@ -7,7 +7,7 @@ import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser

Some files were not shown because too many files have changed in this diff Show More