mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Merge branch 'main' into wentao-enable-all-dense-for-mrv2
This commit is contained in:
@@ -1299,12 +1299,11 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/entrypoints/llm
|
- tests/entrypoints/llm
|
||||||
- tests/entrypoints/offline_mode
|
|
||||||
commands:
|
commands:
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
|
||||||
- pytest -v -s entrypoints/llm/test_generate.py
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/offline_mode
|
- pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
- label: Entrypoints Integration (Pooling) # TBD
|
- label: Entrypoints Integration (Pooling) # TBD
|
||||||
timeout_in_minutes: 180
|
timeout_in_minutes: 180
|
||||||
@@ -1346,7 +1345,7 @@ steps:
|
|||||||
- vllm/platforms/rocm.py
|
- vllm/platforms/rocm.py
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s entrypoints/openai/tool_parsers
|
- pytest -v -s entrypoints/openai/tool_parsers
|
||||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
||||||
|
|
||||||
- label: OpenAI API correctness # TBD
|
- label: OpenAI API correctness # TBD
|
||||||
timeout_in_minutes: 180
|
timeout_in_minutes: 180
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ steps:
|
|||||||
- tests/entrypoints/
|
- tests/entrypoints/
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s entrypoints/openai/tool_parsers
|
- pytest -v -s entrypoints/openai/tool_parsers
|
||||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
||||||
|
|
||||||
- label: Entrypoints Integration (LLM)
|
- label: Entrypoints Integration (LLM)
|
||||||
key: entrypoints-integration-llm
|
key: entrypoints-integration-llm
|
||||||
@@ -20,12 +20,11 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/entrypoints/llm
|
- tests/entrypoints/llm
|
||||||
- tests/entrypoints/offline_mode
|
|
||||||
commands:
|
commands:
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
|
||||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
- pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
|
||||||
mirror:
|
mirror:
|
||||||
amd:
|
amd:
|
||||||
device: mi325_1
|
device: mi325_1
|
||||||
|
|||||||
@@ -33,10 +33,3 @@ share/python-wheels/
|
|||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
rust/target/
|
rust/target/
|
||||||
# Not needed in Docker builds
|
|
||||||
docs/
|
|
||||||
.github/
|
|
||||||
.pre-commit-config.yaml
|
|
||||||
.clang-format
|
|
||||||
.gitattributes
|
|
||||||
format.sh
|
|
||||||
|
|||||||
+2
-1
@@ -34,10 +34,11 @@
|
|||||||
/vllm/entrypoints/speech_to_text/realtime @njhill
|
/vllm/entrypoints/speech_to_text/realtime @njhill
|
||||||
/vllm/entrypoints/speech_to_text @NickLucche
|
/vllm/entrypoints/speech_to_text @NickLucche
|
||||||
/vllm/entrypoints/pooling @noooop
|
/vllm/entrypoints/pooling @noooop
|
||||||
/vllm/entrypoints/sagemaker @DarkLight1337
|
/vllm/entrypoints/serve/sagemaker @DarkLight1337
|
||||||
/vllm/entrypoints/serve @njhill
|
/vllm/entrypoints/serve @njhill
|
||||||
/vllm/entrypoints/*.py @njhill
|
/vllm/entrypoints/*.py @njhill
|
||||||
/vllm/entrypoints/chat_utils.py @DarkLight1337
|
/vllm/entrypoints/chat_utils.py @DarkLight1337
|
||||||
|
/vllm/entrypoints/offline_utils.py @DarkLight1337
|
||||||
/vllm/entrypoints/llm.py @DarkLight1337
|
/vllm/entrypoints/llm.py @DarkLight1337
|
||||||
|
|
||||||
# Rust Frontend
|
# Rust Frontend
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ jobs:
|
|||||||
actions: write
|
actions: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
- uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v10.3.0
|
||||||
with:
|
with:
|
||||||
# Increasing this value ensures that changes to this workflow
|
# Increasing this value ensures that changes to this workflow
|
||||||
# propagate to all issues and PRs in days rather than months
|
# propagate to all issues and PRs in days rather than months
|
||||||
|
|||||||
@@ -102,6 +102,35 @@ constexpr float NUM_TOKEN_CUTOFF = 1024;
|
|||||||
constexpr int kNumLanes = 32;
|
constexpr int kNumLanes = 32;
|
||||||
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
|
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
|
||||||
|
|
||||||
|
// Pack this lane's 16 fp32 elements into per-tensor E4M3 FP8 (one uint4 = 16
|
||||||
|
// B), scaling by `scale` (a reciprocal scale) and saturating to ±448. Used by
|
||||||
|
// the FlashInfer full-cache path for both the Q and KV stores.
|
||||||
|
__device__ __forceinline__ uint4 packFp8E4M3x16(float const* values,
|
||||||
|
float const scale) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
uint4 out;
|
||||||
|
auto* out2 = reinterpret_cast<__nv_fp8x2_storage_t*>(&out);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kElemsPerLane / 2; i++) {
|
||||||
|
float2 scaled =
|
||||||
|
make_float2(values[2 * i] * scale, values[2 * i + 1] * scale);
|
||||||
|
scaled.x = fminf(fmaxf(scaled.x, -kFp8Max), kFp8Max);
|
||||||
|
scaled.y = fminf(fmaxf(scaled.y, -kFp8Max), kFp8Max);
|
||||||
|
out2[i] = __nv_cvt_float2_to_fp8x2(scaled, __NV_SATFINITE, __NV_E4M3);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
#else
|
||||||
|
uint8_t out_bytes[kElemsPerLane];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kElemsPerLane; i++) {
|
||||||
|
float scaled = values[i] * scale;
|
||||||
|
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
|
||||||
|
out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled);
|
||||||
|
}
|
||||||
|
return *reinterpret_cast<uint4 const*>(out_bytes);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
// Small inline helpers
|
// Small inline helpers
|
||||||
// ────────────────────────────────────────────────────────────────────────────
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
@@ -649,6 +678,257 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
|||||||
#undef DISPATCH
|
#undef DISPATCH
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// FlashInfer full-cache kernel
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
//
|
||||||
|
// Sibling to the FlashMLA kernel above, used by the FlashInfer V4 sparse-MLA
|
||||||
|
// backend. Differences from the legacy path:
|
||||||
|
// * No Q head padding — output Q layout matches the input num_heads_q.
|
||||||
|
// * KV is written as a *contiguous* 512-wide row per token (token-strided),
|
||||||
|
// not the legacy UE8M0 paged layout with a separate scale tail.
|
||||||
|
// * Q/KV are stored either as bf16 or as per-tensor E4M3 FP8 (one global
|
||||||
|
// scale), selected by the STORE_Q_FP8 / STORE_KV_FP8 template flags.
|
||||||
|
//
|
||||||
|
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) / warps).
|
||||||
|
// Each warp handles one (token, slot): slot < num_heads_q → Q, slot ==
|
||||||
|
// num_heads_q → KV.
|
||||||
|
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
|
||||||
|
__global__ void fusedDeepseekV4FullCacheKernel(
|
||||||
|
scalar_t_in* __restrict__ q_inout, // [N, H, 512], in place (bf16)
|
||||||
|
uint8_t* __restrict__ q_fp8_out, // [N, H, 512] fp8, optional
|
||||||
|
int64_t const q_fp8_stride0, // elements (fp8 == bytes)
|
||||||
|
int64_t const q_fp8_stride1, // elements (fp8 == bytes)
|
||||||
|
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
|
||||||
|
uint8_t* __restrict__ k_cache, // contiguous bf16 or fp8 cache
|
||||||
|
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
|
||||||
|
int64_t const* __restrict__ position_ids, // [N] i64
|
||||||
|
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
|
||||||
|
float const* __restrict__ fp8_scale_ptr, // scalar, KV fp8 only
|
||||||
|
float const* __restrict__ q_fp8_scale_inv, // scalar, Q fp8 only
|
||||||
|
float const eps,
|
||||||
|
int const num_tokens_full, // = q.size(0) = kv.size(0)
|
||||||
|
int const num_tokens_insert, // = slot_mapping.size(0)
|
||||||
|
int const num_heads_q, // H (no padding)
|
||||||
|
int const cache_block_size, // tokens per cache block
|
||||||
|
int64_t const kv_block_stride, // bytes per cache block
|
||||||
|
int64_t const kv_token_stride) { // bytes per cache token
|
||||||
|
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||||
|
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
#endif
|
||||||
|
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||||
|
int const warpsPerBlock = blockDim.x / 32;
|
||||||
|
int const warpId = threadIdx.x / 32;
|
||||||
|
int const laneId = threadIdx.x % 32;
|
||||||
|
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
|
||||||
|
|
||||||
|
int const slotsPerToken = num_heads_q + 1;
|
||||||
|
int const tokenIdx = globalWarpIdx / slotsPerToken;
|
||||||
|
int const slotIdx = globalWarpIdx % slotsPerToken;
|
||||||
|
if (tokenIdx >= num_tokens_full) return;
|
||||||
|
bool const isKV = (slotIdx == num_heads_q);
|
||||||
|
// KV branch: skip DP-padded tokens (no slot reserved for them).
|
||||||
|
if (isKV && tokenIdx >= num_tokens_insert) return;
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||||
|
cudaGridDependencySynchronize();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
|
||||||
|
scalar_t_in const* src_ptr;
|
||||||
|
if (isKV) {
|
||||||
|
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
|
||||||
|
} else {
|
||||||
|
src_ptr = q_inout +
|
||||||
|
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) *
|
||||||
|
kHeadDim +
|
||||||
|
dim_base;
|
||||||
|
}
|
||||||
|
uint4 const v0 = *reinterpret_cast<uint4 const*>(src_ptr);
|
||||||
|
uint4 const v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
|
||||||
|
|
||||||
|
// ── Decode bf16 → 16 fp32 registers ───────────────────────────────────
|
||||||
|
float elements[kElemsPerLane];
|
||||||
|
{
|
||||||
|
auto const* p0 =
|
||||||
|
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
|
||||||
|
auto const* p1 =
|
||||||
|
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
float2 f2 = Converter::convert(p0[i]);
|
||||||
|
elements[2 * i] = f2.x;
|
||||||
|
elements[2 * i + 1] = f2.y;
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
float2 f2 = Converter::convert(p1[i]);
|
||||||
|
elements[8 + 2 * i] = f2.x;
|
||||||
|
elements[8 + 2 * i + 1] = f2.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Q branch: RMSNorm (no weight) ─────────────────────────────────────
|
||||||
|
if (!isKV) {
|
||||||
|
float sumOfSquares = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kElemsPerLane; i++) {
|
||||||
|
sumOfSquares += elements[i] * elements[i];
|
||||||
|
}
|
||||||
|
sumOfSquares = warpSum<float>(sumOfSquares);
|
||||||
|
float const rms_rcp =
|
||||||
|
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < kElemsPerLane; i++) {
|
||||||
|
elements[i] = elements[i] * rms_rcp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────
|
||||||
|
bool const is_rope_lane = dim_base >= kNopeDim;
|
||||||
|
if (is_rope_lane) {
|
||||||
|
int64_t const pos = position_ids[tokenIdx];
|
||||||
|
constexpr int kHalfRope = kRopeDim / 2;
|
||||||
|
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
|
||||||
|
float const* sin_ptr = cos_ptr + kHalfRope;
|
||||||
|
int const rope_local_base = dim_base - kNopeDim;
|
||||||
|
int const half_base = rope_local_base >> 1;
|
||||||
|
float4 const c0 = *reinterpret_cast<float4 const*>(cos_ptr + half_base);
|
||||||
|
float4 const c1 = *reinterpret_cast<float4 const*>(cos_ptr + half_base + 4);
|
||||||
|
float4 const s0 = *reinterpret_cast<float4 const*>(sin_ptr + half_base);
|
||||||
|
float4 const s1 = *reinterpret_cast<float4 const*>(sin_ptr + half_base + 4);
|
||||||
|
float const cos_arr[8] = {c0.x, c0.y, c0.z, c0.w, c1.x, c1.y, c1.z, c1.w};
|
||||||
|
float const sin_arr[8] = {s0.x, s0.y, s0.z, s0.w, s1.x, s1.y, s1.z, s1.w};
|
||||||
|
#pragma unroll
|
||||||
|
for (int p = 0; p < kElemsPerLane / 2; p++) {
|
||||||
|
float const x_even = elements[2 * p];
|
||||||
|
float const x_odd = elements[2 * p + 1];
|
||||||
|
elements[2 * p] = x_even * cos_arr[p] - x_odd * sin_arr[p];
|
||||||
|
elements[2 * p + 1] = x_even * sin_arr[p] + x_odd * cos_arr[p];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Store ─────────────────────────────────────────────────────────────
|
||||||
|
if (!isKV) {
|
||||||
|
if constexpr (STORE_Q_FP8) {
|
||||||
|
float const scale_inv = VLLM_LDG(q_fp8_scale_inv);
|
||||||
|
uint4 const out = packFp8E4M3x16(elements, scale_inv);
|
||||||
|
uint8_t* dst = q_fp8_out +
|
||||||
|
static_cast<int64_t>(tokenIdx) * q_fp8_stride0 +
|
||||||
|
static_cast<int64_t>(slotIdx) * q_fp8_stride1 + dim_base;
|
||||||
|
*reinterpret_cast<uint4*>(dst) = out;
|
||||||
|
} else {
|
||||||
|
uint4 out0, out1;
|
||||||
|
auto* po0 = reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
|
||||||
|
auto* po1 = reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
po0[i] = Converter::convert(
|
||||||
|
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
po1[i] = Converter::convert(
|
||||||
|
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
|
||||||
|
}
|
||||||
|
scalar_t_in* dst =
|
||||||
|
q_inout +
|
||||||
|
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
|
||||||
|
dim_base;
|
||||||
|
*reinterpret_cast<uint4*>(dst) = out0;
|
||||||
|
*reinterpret_cast<uint4*>(dst + 8) = out1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int64_t const slot_id = slot_mapping[tokenIdx];
|
||||||
|
if (slot_id >= 0) {
|
||||||
|
int64_t const block_idx = slot_id / cache_block_size;
|
||||||
|
int64_t const pos_in_block = slot_id % cache_block_size;
|
||||||
|
uint8_t* cache_row =
|
||||||
|
k_cache + block_idx * kv_block_stride + pos_in_block * kv_token_stride;
|
||||||
|
if constexpr (STORE_KV_FP8) {
|
||||||
|
float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr);
|
||||||
|
uint4 const out = packFp8E4M3x16(elements, inv_scale);
|
||||||
|
*reinterpret_cast<uint4*>(cache_row + dim_base) = out;
|
||||||
|
} else {
|
||||||
|
uint4 out0, out1;
|
||||||
|
auto* po0 =
|
||||||
|
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
|
||||||
|
auto* po1 =
|
||||||
|
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
po0[i] = Converter::convert(
|
||||||
|
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
po1[i] = Converter::convert(
|
||||||
|
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
|
||||||
|
}
|
||||||
|
scalar_t_in* dst = reinterpret_cast<scalar_t_in*>(cache_row) + dim_base;
|
||||||
|
*reinterpret_cast<uint4*>(dst) = out0;
|
||||||
|
*reinterpret_cast<uint4*>(dst + 8) = out1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||||
|
cudaTriggerProgrammaticLaunchCompletion();
|
||||||
|
#endif
|
||||||
|
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure + launch helper shared by the bf16 and fp8 full-cache launchers.
|
||||||
|
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
|
||||||
|
static void launchFullCacheKernel(
|
||||||
|
scalar_t_in* q_inout, uint8_t* q_fp8_out, int64_t q_fp8_stride0,
|
||||||
|
int64_t q_fp8_stride1, scalar_t_in const* kv_in, uint8_t* k_cache,
|
||||||
|
int64_t const* slot_mapping, int64_t const* position_ids,
|
||||||
|
float const* cos_sin_cache, float const* fp8_scale,
|
||||||
|
float const* q_fp8_scale_inv, float const eps, int const num_tokens_full,
|
||||||
|
int const num_tokens_insert, int const num_heads_q,
|
||||||
|
int const cache_block_size, int64_t const kv_block_stride,
|
||||||
|
int64_t const kv_token_stride, char const* op_name, cudaStream_t stream) {
|
||||||
|
constexpr int kBlockSize = 256;
|
||||||
|
constexpr int kWarpsPerBlock = kBlockSize / 32;
|
||||||
|
int64_t const total_warps =
|
||||||
|
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
|
||||||
|
int const grid =
|
||||||
|
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
|
||||||
|
auto* kernel =
|
||||||
|
fusedDeepseekV4FullCacheKernel<scalar_t_in, STORE_Q_FP8, STORE_KV_FP8>;
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
static int const sm_version = getSMVersion();
|
||||||
|
STD_TORCH_CHECK(sm_version >= 80, op_name,
|
||||||
|
" requires sm_80+ (Ampere or newer); got sm_", sm_version);
|
||||||
|
cudaLaunchConfig_t config;
|
||||||
|
config.gridDim = dim3(grid);
|
||||||
|
config.blockDim = dim3(kBlockSize);
|
||||||
|
config.dynamicSmemBytes = 0;
|
||||||
|
config.stream = stream;
|
||||||
|
cudaLaunchAttribute attrs[1];
|
||||||
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||||
|
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||||
|
config.attrs = attrs;
|
||||||
|
config.numAttrs = (sm_version >= 90) ? 1 : 0;
|
||||||
|
cudaLaunchKernelEx(&config, kernel, q_inout, q_fp8_out, q_fp8_stride0,
|
||||||
|
q_fp8_stride1, kv_in, k_cache, slot_mapping, position_ids,
|
||||||
|
cos_sin_cache, fp8_scale, q_fp8_scale_inv, eps,
|
||||||
|
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||||
|
cache_block_size, kv_block_stride, kv_token_stride);
|
||||||
|
#else
|
||||||
|
kernel<<<grid, kBlockSize, 0, stream>>>(
|
||||||
|
q_inout, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache,
|
||||||
|
slot_mapping, position_ids, cos_sin_cache, fp8_scale, q_fp8_scale_inv,
|
||||||
|
eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
|
||||||
|
kv_block_stride, kv_token_stride);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace deepseek_v4_fused_ops
|
} // namespace deepseek_v4_fused_ops
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
@@ -735,3 +1015,167 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
|||||||
});
|
});
|
||||||
return q_out;
|
return q_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
// FlashInfer full-cache torch ops
|
||||||
|
// ────────────────────────────────────────────────────────────────────────────
|
||||||
|
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
||||||
|
torch::stable::Tensor& q, // [N, H, 512] bf16, in place
|
||||||
|
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
|
||||||
|
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] bf16
|
||||||
|
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
|
||||||
|
torch::stable::Tensor const& position_ids, // [N] int64
|
||||||
|
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
|
||||||
|
double eps, int64_t cache_block_size) {
|
||||||
|
using torch::headeronly::ScalarType;
|
||||||
|
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
|
||||||
|
"q must be contiguous CUDA");
|
||||||
|
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
|
||||||
|
"kv must be contiguous CUDA");
|
||||||
|
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
|
||||||
|
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
|
||||||
|
slot_mapping.scalar_type() == ScalarType::Long,
|
||||||
|
"slot_mapping must be int64 CUDA");
|
||||||
|
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
|
||||||
|
position_ids.scalar_type() == ScalarType::Long,
|
||||||
|
"position_ids must be int64 CUDA");
|
||||||
|
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
|
||||||
|
cos_sin_cache.scalar_type() == ScalarType::Float &&
|
||||||
|
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||||
|
"cos_sin_cache shape [max_pos, 64] float32");
|
||||||
|
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
|
||||||
|
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||||
|
STD_TORCH_CHECK(q.scalar_type() == ScalarType::BFloat16 &&
|
||||||
|
kv.scalar_type() == ScalarType::BFloat16,
|
||||||
|
"q and kv must be bfloat16");
|
||||||
|
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
|
||||||
|
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
|
||||||
|
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
|
||||||
|
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::BFloat16,
|
||||||
|
"k_cache must be bfloat16");
|
||||||
|
|
||||||
|
int const num_tokens_full = static_cast<int>(q.size(0));
|
||||||
|
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||||
|
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||||
|
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||||
|
"q/kv/position_ids row counts must match");
|
||||||
|
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||||
|
"slot_mapping must not exceed q row count");
|
||||||
|
int const num_heads_q = static_cast<int>(q.size(1));
|
||||||
|
|
||||||
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
|
q.get_device_index());
|
||||||
|
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
|
||||||
|
|
||||||
|
// bf16 cache: 2 bytes/element -> byte strides for the uint8-addressed kernel.
|
||||||
|
int64_t const kv_block_stride = k_cache.stride(0) * 2;
|
||||||
|
int64_t const kv_token_stride = k_cache.stride(1) * 2;
|
||||||
|
|
||||||
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
|
q.scalar_type(),
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert", [&] {
|
||||||
|
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, false,
|
||||||
|
false>(
|
||||||
|
reinterpret_cast<scalar_t*>(q.mutable_data_ptr()), nullptr, 0, 0,
|
||||||
|
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
|
||||||
|
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
|
||||||
|
slot_mapping.const_data_ptr<int64_t>(),
|
||||||
|
position_ids.const_data_ptr<int64_t>(),
|
||||||
|
cos_sin_cache.const_data_ptr<float>(), nullptr, nullptr,
|
||||||
|
static_cast<float>(eps), num_tokens_full, num_tokens_insert,
|
||||||
|
num_heads_q, static_cast<int>(cache_block_size), kv_block_stride,
|
||||||
|
kv_token_stride,
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
|
||||||
|
stream);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
||||||
|
torch::stable::Tensor const& q, // [N, H, 512] bf16, read-only
|
||||||
|
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
|
||||||
|
torch::stable::Tensor& q_fp8, // [N, H, 512] fp8 e4m3
|
||||||
|
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] fp8
|
||||||
|
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
|
||||||
|
torch::stable::Tensor const& position_ids, // [N] int64
|
||||||
|
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
|
||||||
|
torch::stable::Tensor const& fp8_scale, // scalar float32 (KV scale)
|
||||||
|
torch::stable::Tensor const& q_fp8_scale_inv, // scalar float32 (1 / Q scale)
|
||||||
|
double eps, int64_t cache_block_size) {
|
||||||
|
using torch::headeronly::ScalarType;
|
||||||
|
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
|
||||||
|
"q must be contiguous CUDA");
|
||||||
|
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
|
||||||
|
"kv must be contiguous CUDA");
|
||||||
|
STD_TORCH_CHECK(q_fp8.device().is_cuda() && q_fp8.is_contiguous() &&
|
||||||
|
q_fp8.scalar_type() == ScalarType::Float8_e4m3fn &&
|
||||||
|
q_fp8.dim() == 3 && q_fp8.size(0) == q.size(0) &&
|
||||||
|
q_fp8.size(1) == q.size(1) && q_fp8.size(2) == q.size(2),
|
||||||
|
"q_fp8 must be a contiguous float8_e4m3fn tensor matching q");
|
||||||
|
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
|
||||||
|
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
|
||||||
|
slot_mapping.scalar_type() == ScalarType::Long,
|
||||||
|
"slot_mapping must be int64 CUDA");
|
||||||
|
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
|
||||||
|
position_ids.scalar_type() == ScalarType::Long,
|
||||||
|
"position_ids must be int64 CUDA");
|
||||||
|
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
|
||||||
|
cos_sin_cache.scalar_type() == ScalarType::Float &&
|
||||||
|
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||||
|
"cos_sin_cache shape [max_pos, 64] float32");
|
||||||
|
STD_TORCH_CHECK(fp8_scale.device().is_cuda() &&
|
||||||
|
fp8_scale.scalar_type() == ScalarType::Float &&
|
||||||
|
fp8_scale.size(0) == 1,
|
||||||
|
"fp8_scale must be a scalar float32 CUDA tensor");
|
||||||
|
STD_TORCH_CHECK(q_fp8_scale_inv.device().is_cuda() &&
|
||||||
|
q_fp8_scale_inv.scalar_type() == ScalarType::Float &&
|
||||||
|
q_fp8_scale_inv.size(0) == 1,
|
||||||
|
"q_fp8_scale_inv must be a scalar float32 CUDA tensor");
|
||||||
|
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
|
||||||
|
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||||
|
STD_TORCH_CHECK(q.scalar_type() == kv.scalar_type(),
|
||||||
|
"q and kv dtype must match");
|
||||||
|
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
|
||||||
|
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
|
||||||
|
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
|
||||||
|
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::Float8_e4m3fn,
|
||||||
|
"k_cache must be float8_e4m3fn");
|
||||||
|
|
||||||
|
int const num_tokens_full = static_cast<int>(q.size(0));
|
||||||
|
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||||
|
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||||
|
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||||
|
"q/kv/position_ids row counts must match");
|
||||||
|
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||||
|
"slot_mapping must not exceed q row count");
|
||||||
|
int const num_heads_q = static_cast<int>(q.size(1));
|
||||||
|
|
||||||
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
|
q.get_device_index());
|
||||||
|
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
|
||||||
|
|
||||||
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
|
q.scalar_type(),
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert", [&] {
|
||||||
|
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, true,
|
||||||
|
true>(
|
||||||
|
// q is read-only in the fp8 path (the kernel writes q_fp8); the
|
||||||
|
// launcher signature is non-const, so cast away const on the ptr.
|
||||||
|
reinterpret_cast<scalar_t*>(
|
||||||
|
const_cast<void*>(q.const_data_ptr())),
|
||||||
|
reinterpret_cast<uint8_t*>(q_fp8.mutable_data_ptr()),
|
||||||
|
q_fp8.stride(0), q_fp8.stride(1),
|
||||||
|
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
|
||||||
|
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
|
||||||
|
slot_mapping.const_data_ptr<int64_t>(),
|
||||||
|
position_ids.const_data_ptr<int64_t>(),
|
||||||
|
cos_sin_cache.const_data_ptr<float>(),
|
||||||
|
fp8_scale.const_data_ptr<float>(),
|
||||||
|
q_fp8_scale_inv.const_data_ptr<float>(), static_cast<float>(eps),
|
||||||
|
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||||
|
static_cast<int>(cache_block_size),
|
||||||
|
// fp8 cache: 1 byte/element -> stride already in bytes.
|
||||||
|
k_cache.stride(0), k_cache.stride(1),
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
|
||||||
|
stream);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@@ -238,6 +238,23 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
|||||||
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
|
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
|
||||||
double eps, int64_t cache_block_size);
|
double eps, int64_t cache_block_size);
|
||||||
|
|
||||||
|
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
||||||
|
torch::stable::Tensor& q, torch::stable::Tensor const& kv,
|
||||||
|
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
|
||||||
|
torch::stable::Tensor const& position_ids,
|
||||||
|
torch::stable::Tensor const& cos_sin_cache, double eps,
|
||||||
|
int64_t cache_block_size);
|
||||||
|
|
||||||
|
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
||||||
|
torch::stable::Tensor const& q, torch::stable::Tensor const& kv,
|
||||||
|
torch::stable::Tensor& q_fp8, torch::stable::Tensor& k_cache,
|
||||||
|
torch::stable::Tensor const& slot_mapping,
|
||||||
|
torch::stable::Tensor const& position_ids,
|
||||||
|
torch::stable::Tensor const& cos_sin_cache,
|
||||||
|
torch::stable::Tensor const& fp8_scale,
|
||||||
|
torch::stable::Tensor const& q_fp8_scale_inv, double eps,
|
||||||
|
int64_t cache_block_size);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::stable::Tensor minimax_allreduce_rms(
|
torch::stable::Tensor minimax_allreduce_rms(
|
||||||
torch::stable::Tensor const& input,
|
torch::stable::Tensor const& input,
|
||||||
|
|||||||
@@ -343,6 +343,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||||
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
||||||
|
|
||||||
|
// FlashInfer V4 full-cache variants: write Q in place (bf16) or to a separate
|
||||||
|
// FP8 tensor, and KV into a contiguous 512-wide token-strided cache.
|
||||||
|
ops.def(
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert("
|
||||||
|
"Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, "
|
||||||
|
"Tensor position_ids, Tensor cos_sin_cache, float eps, "
|
||||||
|
"int cache_block_size) -> ()");
|
||||||
|
ops.def(
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert("
|
||||||
|
"Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, "
|
||||||
|
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||||
|
"Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, "
|
||||||
|
"int cache_block_size) -> ()");
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.def(
|
ops.def(
|
||||||
"minimax_allreduce_rms("
|
"minimax_allreduce_rms("
|
||||||
@@ -591,6 +605,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
||||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
|
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
|
||||||
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
|
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
|
||||||
|
ops.impl(
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
|
||||||
|
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert));
|
||||||
|
ops.impl(
|
||||||
|
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
|
||||||
|
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert));
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
|
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
|
||||||
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
|
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
|
||||||
|
|||||||
@@ -55,7 +55,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||||
// kernel launch. Registered in _C_stable_libtorch.
|
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
|
||||||
|
// full-cache bf16/fp8 variants).
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ RUN if [ "$USE_SCCACHE" = "1" ]; then \
|
|||||||
ARG USE_SCCACHE
|
ARG USE_SCCACHE
|
||||||
ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}}
|
ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}}
|
||||||
ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}}
|
ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}}
|
||||||
ENV SCCACHE_ENDPOINT=${USE_SCCACHE:+${SCCACHE_ENDPOINT}}
|
|
||||||
ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}}
|
ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}}
|
||||||
ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0}
|
ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0}
|
||||||
|
|
||||||
|
|||||||
@@ -228,3 +228,17 @@ MLA decode backends are selected using the standard
|
|||||||
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
||||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||||
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
|
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
|
||||||
|
|
||||||
|
### DeepSeek V4 Decode Backends
|
||||||
|
|
||||||
|
DeepSeek V4 sparse MLA uses its own decode backends, selected via
|
||||||
|
`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,
|
||||||
|
`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index
|
||||||
|
pipeline (compressor + SWA + indexer, 256-token blocks, head 512);
|
||||||
|
default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.
|
||||||
|
|
||||||
|
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Non-Causal | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
|
||||||
|
| ------- | ------ | --------- | ----------- | ---------- | ---- | ---------- | ------ | --------- | --- | --------------- | ------------ |
|
||||||
|
| `FLASHINFER_MLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
|
||||||
|
| `FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | 256 | 512 | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
|
||||||
|
| `ROCM_FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
|
|||||||
|
|
||||||
| Architecture | Models | CG for Image | CG for Video |
|
| Architecture | Models | CG for Image | CG for Video |
|
||||||
| ------------ | ------ | ------------ | ------------ |
|
| ------------ | ------ | ------------ | ------------ |
|
||||||
|
| `InternVLChatModel` | `InternVL3.5`, `InternVL3`, `InternVL2.5`, `InternVL2` | ✅︎ | ✅︎ |
|
||||||
| `Qwen2VLForConditionalGeneration` | `Qwen2-VL` | ✅︎ | ✅︎ |
|
| `Qwen2VLForConditionalGeneration` | `Qwen2-VL` | ✅︎ | ✅︎ |
|
||||||
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
|
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
|
||||||
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
|
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices.
|
Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices.
|
||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
To get started with quantization, see [LLM Compressor](llm_compressor.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
|
To get started with quantization, see [LLM Compressor](llm_compressor/README.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
|
||||||
|
|
||||||
The following are the supported quantization formats for vLLM:
|
The following are the supported quantization formats for vLLM:
|
||||||
|
|
||||||
@@ -12,9 +12,11 @@ The following are the supported quantization formats for vLLM:
|
|||||||
- [GGUF](gguf.md)
|
- [GGUF](gguf.md)
|
||||||
- [GPTQModel](gptqmodel.md)
|
- [GPTQModel](gptqmodel.md)
|
||||||
- [Intel Neural Compressor](inc.md)
|
- [Intel Neural Compressor](inc.md)
|
||||||
- [INT4 W4A16](int4.md)
|
- [LLM Compressor](llm_compressor/README.md)
|
||||||
- [INT8 W8A8](int8.md)
|
- [FP8 W8A8](llm_compressor/fp8.md)
|
||||||
- [FP8 W8A8](fp8.md)
|
- [INT4 W4A16](llm_compressor/int4.md)
|
||||||
|
- [INT8 W4A8](llm_compressor/int8_w4a8.md)
|
||||||
|
- [INT8 W8A8](llm_compressor/int8_w8a8.md)
|
||||||
- [NVIDIA Model Optimizer](modelopt.md)
|
- [NVIDIA Model Optimizer](modelopt.md)
|
||||||
- [Online Quantization](online.md)
|
- [Online Quantization](online.md)
|
||||||
- [AMD Quark](quark.md)
|
- [AMD Quark](quark.md)
|
||||||
@@ -46,16 +48,17 @@ th:not(:first-child) {
|
|||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|
||||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU |
|
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | Arm CPU |
|
||||||
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- |
|
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- | ------- |
|
||||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
|
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
|
||||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
|
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
|
||||||
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ |
|
| llm-compressor INT8 (W8A8)| ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ |
|
||||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
|
| llm-compressor INT8 (W4A8)| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ |
|
||||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
| llm-compressor FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
|
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||||
|
|
||||||
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
||||||
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
||||||
|
|||||||
+25
-25
@@ -21,9 +21,17 @@ The FP8 types typically supported in hardware have two distinct representations,
|
|||||||
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install llmcompressor
|
(venv-llm-compressor) pip install llmcompressor
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||||
|
|
||||||
## Quantization Process
|
## Quantization Process
|
||||||
|
|
||||||
The quantization process involves three main steps:
|
The quantization process involves three main steps:
|
||||||
@@ -57,36 +65,28 @@ For FP8 quantization, we can recover accuracy with simple RTN quantization. We r
|
|||||||
|
|
||||||
Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
|
Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
|
||||||
|
|
||||||
??? code
|
```python
|
||||||
|
from llmcompressor import oneshot
|
||||||
|
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||||
|
|
||||||
```python
|
# Configure the simple PTQ quantization
|
||||||
from llmcompressor import oneshot
|
recipe = QuantizationModifier(
|
||||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
targets="Linear",
|
||||||
|
scheme="FP8_DYNAMIC",
|
||||||
|
ignore=["lm_head"],
|
||||||
|
)
|
||||||
|
|
||||||
# Configure the simple PTQ quantization
|
# Apply the quantization algorithm.
|
||||||
recipe = QuantizationModifier(
|
oneshot(model=model, recipe=recipe)
|
||||||
targets="Linear",
|
|
||||||
scheme="FP8_DYNAMIC",
|
|
||||||
ignore=["lm_head"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply the quantization algorithm.
|
# Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic
|
||||||
oneshot(model=model, recipe=recipe)
|
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
|
||||||
|
model.save_pretrained(SAVE_DIR)
|
||||||
# Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic
|
tokenizer.save_pretrained(SAVE_DIR)
|
||||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
|
```
|
||||||
model.save_pretrained(SAVE_DIR)
|
|
||||||
tokenizer.save_pretrained(SAVE_DIR)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Evaluating Accuracy
|
### 3. Evaluating Accuracy
|
||||||
|
|
||||||
Install `vllm` and `lm-evaluation-harness` for evaluation:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install vllm "lm-eval[api]>=0.4.12"
|
|
||||||
```
|
|
||||||
|
|
||||||
Load and run the model in `vllm`:
|
Load and run the model in `vllm`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
+64
-68
@@ -12,15 +12,17 @@ Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs re
|
|||||||
To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install llmcompressor
|
(venv-llm-compressor) pip install llmcompressor
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install vllm "lm-eval[api]>=0.4.12"
|
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||||
|
|
||||||
## Quantization Process
|
## Quantization Process
|
||||||
|
|
||||||
The quantization process involves four main steps:
|
The quantization process involves four main steps:
|
||||||
@@ -52,55 +54,51 @@ When quantizing weights to INT4, you need sample data to estimate the weight upd
|
|||||||
It's best to use calibration data that closely matches your deployment data.
|
It's best to use calibration data that closely matches your deployment data.
|
||||||
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||||
|
|
||||||
??? code
|
```python
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
```python
|
NUM_CALIBRATION_SAMPLES = 512
|
||||||
from datasets import load_dataset
|
MAX_SEQUENCE_LENGTH = 2048
|
||||||
|
|
||||||
NUM_CALIBRATION_SAMPLES = 512
|
# Load and preprocess the dataset
|
||||||
MAX_SEQUENCE_LENGTH = 2048
|
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
||||||
|
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||||
|
|
||||||
# Load and preprocess the dataset
|
def preprocess(example):
|
||||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
ds = ds.map(preprocess)
|
||||||
|
|
||||||
def preprocess(example):
|
def tokenize(sample):
|
||||||
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
||||||
ds = ds.map(preprocess)
|
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
||||||
|
```
|
||||||
def tokenize(sample):
|
|
||||||
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
|
||||||
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Applying Quantization
|
### 3. Applying Quantization
|
||||||
|
|
||||||
Now, apply the quantization algorithms:
|
Now, apply the quantization algorithms:
|
||||||
|
|
||||||
??? code
|
```python
|
||||||
|
from llmcompressor import oneshot
|
||||||
|
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||||
|
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||||
|
|
||||||
```python
|
# Configure the quantization algorithms
|
||||||
from llmcompressor import oneshot
|
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
|
||||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
|
||||||
|
|
||||||
# Configure the quantization algorithms
|
# Apply quantization
|
||||||
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])
|
oneshot(
|
||||||
|
model=model,
|
||||||
|
dataset=ds,
|
||||||
|
recipe=recipe,
|
||||||
|
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||||
|
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply quantization
|
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128
|
||||||
oneshot(
|
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
|
||||||
model=model,
|
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||||
dataset=ds,
|
tokenizer.save_pretrained(SAVE_DIR)
|
||||||
recipe=recipe,
|
```
|
||||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
|
||||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128
|
|
||||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
|
|
||||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
|
||||||
tokenizer.save_pretrained(SAVE_DIR)
|
|
||||||
```
|
|
||||||
|
|
||||||
This process creates a W4A16 model with weights quantized to 4-bit integers.
|
This process creates a W4A16 model with weights quantized to 4-bit integers.
|
||||||
|
|
||||||
@@ -141,36 +139,34 @@ lm_eval --model vllm \
|
|||||||
|
|
||||||
The following is an example of an expanded quantization recipe you can tune to your own use case:
|
The following is an example of an expanded quantization recipe you can tune to your own use case:
|
||||||
|
|
||||||
??? code
|
```python
|
||||||
|
from compressed_tensors.quantization import (
|
||||||
```python
|
QuantizationArgs,
|
||||||
from compressed_tensors.quantization import (
|
QuantizationScheme,
|
||||||
QuantizationArgs,
|
QuantizationStrategy,
|
||||||
QuantizationScheme,
|
QuantizationType,
|
||||||
QuantizationStrategy,
|
)
|
||||||
QuantizationType,
|
recipe = GPTQModifier(
|
||||||
)
|
targets="Linear",
|
||||||
recipe = GPTQModifier(
|
config_groups={
|
||||||
targets="Linear",
|
"config_group": QuantizationScheme(
|
||||||
config_groups={
|
targets=["Linear"],
|
||||||
"config_group": QuantizationScheme(
|
weights=QuantizationArgs(
|
||||||
targets=["Linear"],
|
num_bits=4,
|
||||||
weights=QuantizationArgs(
|
type=QuantizationType.INT,
|
||||||
num_bits=4,
|
strategy=QuantizationStrategy.GROUP,
|
||||||
type=QuantizationType.INT,
|
group_size=128,
|
||||||
strategy=QuantizationStrategy.GROUP,
|
symmetric=True,
|
||||||
group_size=128,
|
dynamic=False,
|
||||||
symmetric=True,
|
actorder="weight",
|
||||||
dynamic=False,
|
|
||||||
actorder="weight",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
},
|
),
|
||||||
ignore=["lm_head"],
|
},
|
||||||
update_size=NUM_CALIBRATION_SAMPLES,
|
ignore=["lm_head"],
|
||||||
dampening_frac=0.01,
|
update_size=NUM_CALIBRATION_SAMPLES,
|
||||||
)
|
dampening_frac=0.01,
|
||||||
```
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Troubleshooting and Support
|
## Troubleshooting and Support
|
||||||
|
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
# INT8 W4A8
|
||||||
|
|
||||||
|
vLLM supports quantizing weights to INT4 and activations to INT8 for memory savings and inference acceleration.
|
||||||
|
This quantization method is particularly useful for reducing model size while maintaining good performance.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
To use INT8 W4A8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
(venv-llm-compressor) pip install llmcompressor
|
||||||
|
```
|
||||||
|
|
||||||
|
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||||
|
```
|
||||||
|
|
||||||
|
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||||
|
|
||||||
|
## Quantization Process
|
||||||
|
|
||||||
|
The quantization process involves four main steps:
|
||||||
|
|
||||||
|
1. Loading the model
|
||||||
|
2. Preparing calibration data
|
||||||
|
3. Applying quantization
|
||||||
|
4. Evaluating accuracy in vLLM
|
||||||
|
|
||||||
|
### 1. Loading the Model
|
||||||
|
|
||||||
|
Load your model and tokenizer using the standard `transformers` AutoModel classes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
MODEL_ID,
|
||||||
|
dtype="auto",
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Preparing Calibration Data
|
||||||
|
|
||||||
|
When quantizing activations to INT8 and weights to INT4, you need sample data to estimate the activation scales.
|
||||||
|
It's best to use calibration data that closely matches your deployment data.
|
||||||
|
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
NUM_CALIBRATION_SAMPLES = 512
|
||||||
|
MAX_SEQUENCE_LENGTH = 2048
|
||||||
|
|
||||||
|
# Load and preprocess the dataset
|
||||||
|
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
||||||
|
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||||
|
|
||||||
|
def preprocess(example):
|
||||||
|
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||||
|
ds = ds.map(preprocess)
|
||||||
|
|
||||||
|
def tokenize(sample):
|
||||||
|
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
||||||
|
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Applying Quantization
|
||||||
|
|
||||||
|
Now, apply the quantization algorithms.
|
||||||
|
|
||||||
|
The following recipes create W4A8 models (int4 weights, int8 activations). On Arm® CPUs, this is accelerated through [KleidiAI](https://github.com/ARM-software/kleidiai).
|
||||||
|
|
||||||
|
Use groupwise for best accuracy, and channelwise for best inference performance.
|
||||||
|
|
||||||
|
=== "Groupwise"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llmcompressor import oneshot
|
||||||
|
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||||
|
|
||||||
|
# Configure the quantization algorithms
|
||||||
|
recipe = [
|
||||||
|
GPTQModifier(
|
||||||
|
targets="Linear",
|
||||||
|
scheme="W4A8",
|
||||||
|
ignore=["lm_head"],
|
||||||
|
dampening_frac=0.01
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply quantization
|
||||||
|
oneshot(
|
||||||
|
model=model,
|
||||||
|
dataset=ds,
|
||||||
|
recipe=recipe,
|
||||||
|
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||||
|
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token
|
||||||
|
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-G128-Dynamic-Per-Token"
|
||||||
|
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||||
|
tokenizer.save_pretrained(SAVE_DIR)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Channelwise"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llmcompressor import oneshot
|
||||||
|
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||||
|
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
|
||||||
|
|
||||||
|
scheme = {
|
||||||
|
"targets": ["Linear"],
|
||||||
|
"weights": {
|
||||||
|
"num_bits": 4,
|
||||||
|
"type": QuantizationType.INT,
|
||||||
|
"strategy": QuantizationStrategy.CHANNEL,
|
||||||
|
"symmetric": True,
|
||||||
|
"dynamic": False,
|
||||||
|
"group_size": None,
|
||||||
|
},
|
||||||
|
"input_activations": {
|
||||||
|
"num_bits": 8,
|
||||||
|
"type": QuantizationType.INT,
|
||||||
|
"strategy": QuantizationStrategy.TOKEN,
|
||||||
|
"dynamic": True,
|
||||||
|
"symmetric": False,
|
||||||
|
"observer": None,
|
||||||
|
},
|
||||||
|
"output_activations": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
recipe = [
|
||||||
|
GPTQModifier(
|
||||||
|
targets="Linear",
|
||||||
|
config_groups={"group_0": scheme},
|
||||||
|
ignore=["lm_head"],
|
||||||
|
dampening_frac=0.01,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
oneshot(
|
||||||
|
model=model,
|
||||||
|
dataset=ds,
|
||||||
|
recipe=recipe,
|
||||||
|
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||||
|
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token
|
||||||
|
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-Channelwise-Dynamic-Per-Token"
|
||||||
|
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||||
|
tokenizer.save_pretrained(SAVE_DIR)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Evaluating Accuracy
|
||||||
|
|
||||||
|
=== "Groupwise"
|
||||||
|
|
||||||
|
After quantization, you can load and run the model in vLLM:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token")
|
||||||
|
```
|
||||||
|
|
||||||
|
To evaluate accuracy, you can use `lm_eval`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lm_eval --model vllm \
|
||||||
|
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token",add_bos_token=true \
|
||||||
|
--tasks gsm8k \
|
||||||
|
--num_fewshot 5 \
|
||||||
|
--limit 250 \
|
||||||
|
--batch_size 'auto'
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Channelwise"
|
||||||
|
|
||||||
|
After quantization, you can load and run the model in vLLM:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token")
|
||||||
|
```
|
||||||
|
|
||||||
|
To evaluate accuracy, you can use `lm_eval`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lm_eval --model vllm \
|
||||||
|
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token",add_bos_token=true \
|
||||||
|
--tasks gsm8k \
|
||||||
|
--num_fewshot 5 \
|
||||||
|
--limit 250 \
|
||||||
|
--batch_size 'auto'
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Start with 512 samples for calibration data (increase if accuracy drops)
|
||||||
|
- Use a sequence length of 2048 as a starting point
|
||||||
|
- Employ the chat template or instruction template that the model was trained with
|
||||||
|
- If you've fine-tuned a model, consider using a sample of your training data for calibration
|
||||||
|
|
||||||
|
## Troubleshooting and Support
|
||||||
|
|
||||||
|
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository.
|
||||||
+40
-42
@@ -17,15 +17,17 @@ Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs re
|
|||||||
To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install llmcompressor
|
(venv-llm-compressor) pip install llmcompressor
|
||||||
```
|
```
|
||||||
|
|
||||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install vllm "lm-eval[api]>=0.4.12"
|
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||||
|
|
||||||
## Quantization Process
|
## Quantization Process
|
||||||
|
|
||||||
The quantization process involves four main steps:
|
The quantization process involves four main steps:
|
||||||
@@ -57,26 +59,24 @@ When quantizing activations to INT8, you need sample data to estimate the activa
|
|||||||
It's best to use calibration data that closely matches your deployment data.
|
It's best to use calibration data that closely matches your deployment data.
|
||||||
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||||
|
|
||||||
??? code
|
```python
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
```python
|
NUM_CALIBRATION_SAMPLES = 512
|
||||||
from datasets import load_dataset
|
MAX_SEQUENCE_LENGTH = 2048
|
||||||
|
|
||||||
NUM_CALIBRATION_SAMPLES = 512
|
# Load and preprocess the dataset
|
||||||
MAX_SEQUENCE_LENGTH = 2048
|
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
||||||
|
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||||
|
|
||||||
# Load and preprocess the dataset
|
def preprocess(example):
|
||||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
ds = ds.map(preprocess)
|
||||||
|
|
||||||
def preprocess(example):
|
def tokenize(sample):
|
||||||
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
||||||
ds = ds.map(preprocess)
|
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
||||||
|
```
|
||||||
def tokenize(sample):
|
|
||||||
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
|
||||||
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -84,33 +84,31 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra
|
|||||||
|
|
||||||
Now, apply the quantization algorithms:
|
Now, apply the quantization algorithms:
|
||||||
|
|
||||||
??? code
|
```python
|
||||||
|
from llmcompressor import oneshot
|
||||||
|
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||||
|
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||||
|
|
||||||
```python
|
# Configure the quantization algorithms
|
||||||
from llmcompressor import oneshot
|
recipe = [
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
SmoothQuantModifier(smoothing_strength=0.8),
|
||||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
|
||||||
|
]
|
||||||
|
|
||||||
# Configure the quantization algorithms
|
# Apply quantization
|
||||||
recipe = [
|
oneshot(
|
||||||
SmoothQuantModifier(smoothing_strength=0.8),
|
model=model,
|
||||||
GPTQModifier(targets="Linear", scheme="W8A8", ignore=["lm_head"]),
|
dataset=ds,
|
||||||
]
|
recipe=recipe,
|
||||||
|
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||||
|
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply quantization
|
# Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token
|
||||||
oneshot(
|
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
|
||||||
model=model,
|
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||||
dataset=ds,
|
tokenizer.save_pretrained(SAVE_DIR)
|
||||||
recipe=recipe,
|
```
|
||||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
|
||||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token
|
|
||||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
|
|
||||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
|
||||||
tokenizer.save_pretrained(SAVE_DIR)
|
|
||||||
```
|
|
||||||
|
|
||||||
This process creates a W8A8 model with weights and activations quantized to 8-bit integers.
|
This process creates a W8A8 model with weights and activations quantized to 8-bit integers.
|
||||||
|
|
||||||
@@ -569,6 +569,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
|
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Granite4VisionForConditionalGeneration` | Granite 4 Vision | T + I<sup>E+</sup> | `ibm-granite/granite-4.1-3b-vision`, etc. | ✅︎ | ✅︎ |
|
| `Granite4VisionForConditionalGeneration` | Granite 4 Vision | T + I<sup>E+</sup> | `ibm-granite/granite-4.1-3b-vision`, etc. | ✅︎ | ✅︎ |
|
||||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
|
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
|
||||||
|
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | T + A | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
|
||||||
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
|
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
|
||||||
| `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | |
|
| `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | |
|
||||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ |
|
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ |
|
||||||
@@ -709,6 +710,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
|||||||
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||||
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
|
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
|
||||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
|
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
|
||||||
|
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
|
||||||
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
|
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ |
|
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ |
|
||||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
|
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@@ -24,8 +24,14 @@ echo "Checking pre-commit/pre-run-check status..."
|
|||||||
MAX_WAIT=300
|
MAX_WAIT=300
|
||||||
INTERVAL=60
|
INTERVAL=60
|
||||||
ELAPSED=0
|
ELAPSED=0
|
||||||
|
# Use a GitHub token if provided to raise the API rate limit (60 -> 5000
|
||||||
|
# requests/hour). Set GITHUB_TOKEN in the Read the Docs environment variables.
|
||||||
|
CURL_AUTH=()
|
||||||
|
if [ -n "$GITHUB_TOKEN" ]; then
|
||||||
|
CURL_AUTH=(-H "Authorization: Bearer $GITHUB_TOKEN")
|
||||||
|
fi
|
||||||
while :; do
|
while :; do
|
||||||
RAW=$(curl -sS -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
|
RAW=$(curl -sS "${CURL_AUTH[@]}" -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
|
||||||
HTTP_CODE=$(printf %s "$RAW" | tail -n1)
|
HTTP_CODE=$(printf %s "$RAW" | tail -n1)
|
||||||
BODY=$(printf %s "$RAW" | sed '$d')
|
BODY=$(printf %s "$RAW" | sed '$d')
|
||||||
if [ "$HTTP_CODE" != "200" ]; then
|
if [ "$HTTP_CODE" != "200" ]; then
|
||||||
|
|||||||
@@ -2554,6 +2554,7 @@ MODELS_NEED_VIDEO_METADATA = [
|
|||||||
|
|
||||||
|
|
||||||
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
|
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
|
||||||
|
"internvl_chat",
|
||||||
"qwen2_5_vl",
|
"qwen2_5_vl",
|
||||||
"qwen3_vl",
|
"qwen3_vl",
|
||||||
"qwen3_vl_moe",
|
"qwen3_vl_moe",
|
||||||
|
|||||||
@@ -110,6 +110,9 @@ plugins:
|
|||||||
redirect_maps:
|
redirect_maps:
|
||||||
features/spec_decode/README.md: features/speculative_decoding/README.md
|
features/spec_decode/README.md: features/speculative_decoding/README.md
|
||||||
features/spec_decode/speculators.md: features/speculative_decoding/speculators.md
|
features/spec_decode/speculators.md: features/speculative_decoding/speculators.md
|
||||||
|
features/quantization/fp8.md: features/quantization/llm_compressor/fp8.md
|
||||||
|
features/quantization/int4.md: features/quantization/llm_compressor/int4.md
|
||||||
|
features/quantization/int8.md: features/quantization/llm_compressor/int8_w8a8.md
|
||||||
serving/openai_compatible_server.md: serving/online_serving/README.md
|
serving/openai_compatible_server.md: serving/online_serving/README.md
|
||||||
|
|
||||||
markdown_extensions:
|
markdown_extensions:
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ pyyaml
|
|||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||||
einops # Required for Qwen2-VL.
|
einops # Required for Qwen2-VL.
|
||||||
compressed-tensors == 0.15.0.1 # required for compressed-tensors
|
compressed-tensors == 0.17.0 # required for compressed-tensors
|
||||||
depyf==0.20.0 # required for profiling and debugging with compilation config
|
depyf==0.20.0 # required for profiling and debugging with compilation config
|
||||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||||
watchfiles # required for http server to monitor the updates of TLS files
|
watchfiles # required for http server to monitor the updates of TLS files
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ tilelang==0.1.9
|
|||||||
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
||||||
|
|
||||||
# Required for faster safetensors model loading
|
# Required for faster safetensors model loading
|
||||||
fastsafetensors >= 0.2.2
|
fastsafetensors >= 0.3.2
|
||||||
|
|
||||||
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
||||||
nvidia-cutlass-dsl[cu13]==4.5.2
|
nvidia-cutlass-dsl[cu13]==4.5.2
|
||||||
@@ -28,4 +28,4 @@ quack-kernels>=0.3.3
|
|||||||
tokenspeed-mla==0.1.2
|
tokenspeed-mla==0.1.2
|
||||||
|
|
||||||
# Humming kernels for quantization gemm
|
# Humming kernels for quantization gemm
|
||||||
humming-kernels[cu13]==0.1.2
|
humming-kernels[cu13]==0.1.4
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ setuptools-rust>=1.9.0
|
|||||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||||
conch-triton-kernels==1.2.1
|
conch-triton-kernels==1.2.1
|
||||||
timm>=1.0.17
|
timm>=1.0.17
|
||||||
# amd-quark: required for Quark quantization on ROCm
|
# amd-quark: required for Quark quantization on ROCm
|
||||||
# To be consistent with test_quark.py
|
# To be consistent with test_quark.py
|
||||||
amd-quark>=0.8.99
|
amd-quark>=0.8.99
|
||||||
tilelang==0.1.10
|
tilelang==0.1.10
|
||||||
|
|
||||||
|
# Required for faster safetensors model loading
|
||||||
|
fastsafetensors >= 0.3.2
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ arctic-inference == 0.1.1; platform_machine == "x86_64" # Required for suffix de
|
|||||||
numba == 0.65.0 # Required for N-gram speculative decoding
|
numba == 0.65.0 # Required for N-gram speculative decoding
|
||||||
numpy
|
numpy
|
||||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||||
fastsafetensors>=0.2.2; platform_machine == "x86_64" # 0.2.2 contains important fixes for multi-GPU mem usage
|
fastsafetensors>=0.3.2
|
||||||
instanttensor>=0.1.5; platform_machine == "x86_64"
|
instanttensor>=0.1.5; platform_machine == "x86_64"
|
||||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||||
decord==0.6.0; platform_machine == "x86_64"
|
decord==0.6.0; platform_machine == "x86_64"
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ fastparquet==2024.11.0
|
|||||||
# via genai-perf
|
# via genai-perf
|
||||||
fastrlock==0.8.2
|
fastrlock==0.8.2
|
||||||
# via cupy-cuda12x
|
# via cupy-cuda12x
|
||||||
fastsafetensors==0.2.2
|
fastsafetensors==0.3.2
|
||||||
# via
|
# via
|
||||||
# -c requirements/cuda.txt
|
# -c requirements/cuda.txt
|
||||||
# -r requirements/test/cuda.in
|
# -r requirements/test/cuda.in
|
||||||
|
|||||||
@@ -43,6 +43,6 @@ tritonclient>=2.51.0
|
|||||||
numba == 0.65.0 # Required for N-gram speculative decoding
|
numba == 0.65.0 # Required for N-gram speculative decoding
|
||||||
numpy
|
numpy
|
||||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||||
fastsafetensors>=0.2.2
|
fastsafetensors>=0.3.2
|
||||||
instanttensor>=0.1.5
|
instanttensor>=0.1.5
|
||||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ arctic-inference==0.1.1 # Required for suffix decoding test
|
|||||||
numba==0.65.0 # Required for N-gram speculative decoding
|
numba==0.65.0 # Required for N-gram speculative decoding
|
||||||
numpy
|
numpy
|
||||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||||
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@0.2.2 # PyPI only ships CUDA wheels
|
fastsafetensors>=0.3.2
|
||||||
instanttensor>=0.1.5
|
instanttensor>=0.1.5
|
||||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||||
decord==0.6.0
|
decord==0.6.0
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ colorful==0.5.8
|
|||||||
# via ray
|
# via ray
|
||||||
colorlog==6.10.1
|
colorlog==6.10.1
|
||||||
# via optuna
|
# via optuna
|
||||||
compressed-tensors==0.15.0.1
|
compressed-tensors==0.17.0
|
||||||
# via
|
# via
|
||||||
# -c requirements/common.txt
|
# -c requirements/common.txt
|
||||||
# -r requirements/test/../common.txt
|
# -r requirements/test/../common.txt
|
||||||
@@ -240,8 +240,10 @@ fastar==0.10.0
|
|||||||
# via fastapi-cloud-cli
|
# via fastapi-cloud-cli
|
||||||
fastparquet==2026.3.0
|
fastparquet==2026.3.0
|
||||||
# via genai-perf
|
# via genai-perf
|
||||||
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@65d80088fca7a8f567fba30415fbcc80f7d2259c
|
fastsafetensors==0.3.2
|
||||||
# via -r requirements/test/rocm.in
|
# via
|
||||||
|
# -c requirements/rocm.txt
|
||||||
|
# -r requirements/test/rocm.in
|
||||||
filelock==3.25.2
|
filelock==3.25.2
|
||||||
# via
|
# via
|
||||||
# -c requirements/common.txt
|
# -c requirements/common.txt
|
||||||
|
|||||||
@@ -1168,7 +1168,7 @@ setup(
|
|||||||
"zen": ["zentorch==2.11.0.0"],
|
"zen": ["zentorch==2.11.0.0"],
|
||||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
||||||
"tensorizer": ["tensorizer==2.10.1"],
|
"tensorizer": ["tensorizer==2.10.1"],
|
||||||
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
|
"fastsafetensors": ["fastsafetensors >= 0.3.2"],
|
||||||
"instanttensor": ["instanttensor >= 0.1.5"],
|
"instanttensor": ["instanttensor >= 0.1.5"],
|
||||||
"runai": ["runai-model-streamer[s3,gcs,azure] >= 0.15.7"],
|
"runai": ["runai-model-streamer[s3,gcs,azure] >= 0.15.7"],
|
||||||
"audio": [
|
"audio": [
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
|
|||||||
AllReduceFusionPass,
|
AllReduceFusionPass,
|
||||||
RocmAiterAllReduceFusionPass,
|
RocmAiterAllReduceFusionPass,
|
||||||
)
|
)
|
||||||
|
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.passes.utility.fix_functionalization import (
|
from vllm.compilation.passes.utility.fix_functionalization import (
|
||||||
FixFunctionalizationPass,
|
FixFunctionalizationPass,
|
||||||
)
|
)
|
||||||
@@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
@@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
|||||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllReduceGemmaRMSNormModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=16,
|
||||||
|
token_num=16,
|
||||||
|
eps=1e-6,
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.eps = eps
|
||||||
|
self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)]
|
||||||
|
# Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path.
|
||||||
|
for n in self.norm:
|
||||||
|
n.weight.data.normal_(mean=0.0, std=0.1)
|
||||||
|
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# avoid having graph input be an arg to a pattern directly
|
||||||
|
z = torch.relu(x)
|
||||||
|
x = resid = tensor_model_parallel_all_reduce(z)
|
||||||
|
y = self.norm[0](x)
|
||||||
|
|
||||||
|
z2 = torch.mm(y, self.w[0])
|
||||||
|
x2 = tensor_model_parallel_all_reduce(z2)
|
||||||
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
|
z3 = torch.mm(y2, self.w[1])
|
||||||
|
x3 = tensor_model_parallel_all_reduce(z3)
|
||||||
|
y3, resid = self.norm[2](x3, resid)
|
||||||
|
|
||||||
|
z4 = torch.mm(y3, self.w[2])
|
||||||
|
x4 = tensor_model_parallel_all_reduce(z4)
|
||||||
|
y4, resid = self.norm[3](x4, resid)
|
||||||
|
return y4
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [torch.ops.vllm.all_reduce.default]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||||
|
|
||||||
|
|
||||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
quant_key = kFp8StaticTensorSym
|
quant_key = kFp8StaticTensorSym
|
||||||
|
|
||||||
@@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
|||||||
"test_model, enable_quant_fp8_custom_op, use_aiter",
|
"test_model, enable_quant_fp8_custom_op, use_aiter",
|
||||||
[
|
[
|
||||||
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
|
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
|
||||||
|
pytest.param(
|
||||||
|
TestAllReduceGemmaRMSNormModel,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
current_platform.is_rocm(),
|
||||||
|
reason="Not supported on ROCm platform",
|
||||||
|
),
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
TestAllReduceRMSNormStaticQuantFP8Model,
|
TestAllReduceRMSNormStaticQuantFP8Model,
|
||||||
True,
|
True,
|
||||||
@@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model(
|
|||||||
)
|
)
|
||||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||||
backend.check_after_ops(model.ops_in_model_after())
|
backend.check_after_ops(model.ops_in_model_after())
|
||||||
|
if test_model_cls is TestAllReduceGemmaRMSNormModel:
|
||||||
|
fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||||
|
fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass))
|
||||||
|
assert fused_nodes
|
||||||
|
assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes)
|
||||||
del all_reduce_fusion_pass
|
del all_reduce_fusion_pass
|
||||||
|
|||||||
@@ -0,0 +1,250 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for the Inductor FALLBACK_ALLOW_LIST patch in env_override.py.
|
||||||
|
|
||||||
|
The patch wraps ``torch._inductor.lowering.FALLBACK_ALLOW_LIST`` in a thin
|
||||||
|
proxy that auto-allows any custom op in the ``vllm::`` or ``vllm_aiter::``
|
||||||
|
namespaces. This routes those ops through Inductor's fast-path
|
||||||
|
``make_fallback(target, warn=False, override_decomp=True)`` and avoids the
|
||||||
|
expensive ``error.operator_str(target, args, kwargs)`` formatting that
|
||||||
|
recursively stringifies every input ``TensorBox``.
|
||||||
|
|
||||||
|
The slow path is what made ``torch.compile`` effectively hang on Kimi-K2.6
|
||||||
|
TP=8 (deep MoE/TP IR provenance trees). These tests cover both the proxy's
|
||||||
|
semantics in isolation and the membership-check fast-path that Inductor's
|
||||||
|
``GraphLowering.call_function`` actually performs, so we can validate the
|
||||||
|
optimization without needing a full GPU compile.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.env_override import (
|
||||||
|
_patch_inductor_fallback_allow_list,
|
||||||
|
_VllmFallbackAllowList,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVllmFallbackAllowListProxy:
|
||||||
|
"""Unit tests for the membership-proxy semantics."""
|
||||||
|
|
||||||
|
def test_vllm_namespace_auto_allowed(self):
|
||||||
|
proxy = _VllmFallbackAllowList(set())
|
||||||
|
assert "vllm::all_reduce" in proxy
|
||||||
|
assert "vllm::fused_add_rms_norm" in proxy
|
||||||
|
assert "vllm::all_reduce.default" in proxy
|
||||||
|
|
||||||
|
def test_vllm_aiter_namespace_auto_allowed(self):
|
||||||
|
proxy = _VllmFallbackAllowList(set())
|
||||||
|
assert "vllm_aiter::fused_add_rms_norm" in proxy
|
||||||
|
assert "vllm_aiter::rocm_aiter_fused_moe" in proxy
|
||||||
|
|
||||||
|
def test_unknown_namespace_falls_through(self):
|
||||||
|
proxy = _VllmFallbackAllowList({"torchvision::roi_align"})
|
||||||
|
assert "torchvision::roi_align" in proxy
|
||||||
|
assert "made_up_ns::nonexistent_op" not in proxy
|
||||||
|
|
||||||
|
def test_non_string_falls_through_to_inner(self):
|
||||||
|
sentinel = object()
|
||||||
|
inner = {sentinel}
|
||||||
|
proxy = _VllmFallbackAllowList(inner)
|
||||||
|
assert sentinel in proxy
|
||||||
|
assert object() not in proxy
|
||||||
|
|
||||||
|
def test_prefix_only_match_not_substring(self):
|
||||||
|
proxy = _VllmFallbackAllowList(set())
|
||||||
|
assert "not_vllm::something" not in proxy
|
||||||
|
assert " vllm::space_prefixed" not in proxy
|
||||||
|
|
||||||
|
def test_standard_entries_preserved(self):
|
||||||
|
base = {"torchvision::roi_align", "aten::index_add"}
|
||||||
|
proxy = _VllmFallbackAllowList(base)
|
||||||
|
assert "torchvision::roi_align" in proxy
|
||||||
|
assert "aten::index_add" in proxy
|
||||||
|
assert "aten::__not_present__" not in proxy
|
||||||
|
|
||||||
|
def test_add_and_discard_delegate_to_inner(self):
|
||||||
|
inner: set[str] = set()
|
||||||
|
proxy = _VllmFallbackAllowList(inner)
|
||||||
|
proxy.add("custom::op")
|
||||||
|
assert "custom::op" in inner
|
||||||
|
proxy.discard("custom::op")
|
||||||
|
assert "custom::op" not in inner
|
||||||
|
|
||||||
|
def test_iter_len_repr(self):
|
||||||
|
base = {"torchvision::roi_align", "aten::index_add"}
|
||||||
|
proxy = _VllmFallbackAllowList(base)
|
||||||
|
assert set(iter(proxy)) == base
|
||||||
|
assert len(proxy) == len(base)
|
||||||
|
assert "torchvision::roi_align" in repr(proxy)
|
||||||
|
|
||||||
|
def test_getattr_delegates_to_inner(self):
|
||||||
|
class _Inner:
|
||||||
|
sentinel = "i_am_inner"
|
||||||
|
|
||||||
|
def some_method(self):
|
||||||
|
return 42
|
||||||
|
|
||||||
|
inner = _Inner()
|
||||||
|
proxy = _VllmFallbackAllowList(inner)
|
||||||
|
assert proxy.sentinel == "i_am_inner"
|
||||||
|
assert proxy.some_method() == 42
|
||||||
|
|
||||||
|
def test_sentinel_attribute(self):
|
||||||
|
proxy = _VllmFallbackAllowList(set())
|
||||||
|
assert proxy._vllm_patched is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatchApplication:
|
||||||
|
"""Integration tests verifying the patch reaches ``torch._inductor``."""
|
||||||
|
|
||||||
|
def test_patch_applied_to_lowering(self):
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
|
||||||
|
assert getattr(_lowering.FALLBACK_ALLOW_LIST, "_vllm_patched", False), (
|
||||||
|
"env_override._patch_inductor_fallback_allow_list did not run"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_graph_module_local_binding_rebound(self):
|
||||||
|
# ``torch/_inductor/graph.py`` does:
|
||||||
|
# from torch._inductor.lowering import FALLBACK_ALLOW_LIST
|
||||||
|
# so the patch has to overwrite the graph module's local binding too,
|
||||||
|
# otherwise the fast-path check in GraphLowering.call_function still
|
||||||
|
# sees the original (unwrapped) OrderedSet.
|
||||||
|
import torch._inductor.graph as _graph
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
|
||||||
|
if not hasattr(_graph, "FALLBACK_ALLOW_LIST"):
|
||||||
|
pytest.skip(
|
||||||
|
"torch._inductor.graph no longer imports FALLBACK_ALLOW_LIST "
|
||||||
|
"as a module-level symbol; nothing to rebind."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _graph.FALLBACK_ALLOW_LIST is _lowering.FALLBACK_ALLOW_LIST
|
||||||
|
|
||||||
|
def test_patch_is_idempotent(self):
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
|
||||||
|
first = _lowering.FALLBACK_ALLOW_LIST
|
||||||
|
_patch_inductor_fallback_allow_list()
|
||||||
|
_patch_inductor_fallback_allow_list()
|
||||||
|
assert _lowering.FALLBACK_ALLOW_LIST is first
|
||||||
|
|
||||||
|
def test_real_vllm_ops_in_real_allow_list(self):
|
||||||
|
# End-to-end membership check using the live (already-patched) object.
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
|
||||||
|
allow_list = _lowering.FALLBACK_ALLOW_LIST
|
||||||
|
assert "vllm::all_reduce" in allow_list
|
||||||
|
assert "vllm::fused_add_rms_norm" in allow_list
|
||||||
|
assert "vllm_aiter::fused_add_rms_norm" in allow_list
|
||||||
|
|
||||||
|
|
||||||
|
class TestInductorFallbackFastPath:
|
||||||
|
"""Emulates ``GraphLowering.call_function``'s FALLBACK_ALLOW_LIST check.
|
||||||
|
|
||||||
|
The relevant snippet in ``torch/_inductor/graph.py`` is roughly::
|
||||||
|
|
||||||
|
base_name = target.name()
|
||||||
|
if base_name not in FALLBACK_ALLOW_LIST:
|
||||||
|
log.info(
|
||||||
|
"Creating implicit fallback for:\\n%s",
|
||||||
|
error.operator_str(target, args, kwargs),
|
||||||
|
)
|
||||||
|
out = make_fallback(target, ...)
|
||||||
|
|
||||||
|
On a deep MoE/TP graph (Kimi-K2.6 at TP=4/8) ``operator_str`` recurses
|
||||||
|
through every input ``TensorBox.__str__`` and ends up taking many minutes
|
||||||
|
of CPU per encountered op. The patch ensures the membership test
|
||||||
|
short-circuits for ``vllm::*``/``vllm_aiter::*`` ops so the slow path is
|
||||||
|
never entered. These tests pin that behaviour without needing a real
|
||||||
|
GPU compile.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _simulate_graph_lowering(self, target_names: list[str]):
|
||||||
|
"""Returns the set of target names that would have hit the slow
|
||||||
|
operator_str() path under the patched FALLBACK_ALLOW_LIST.
|
||||||
|
"""
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
|
||||||
|
allow_list = _lowering.FALLBACK_ALLOW_LIST
|
||||||
|
slow_path_hits: list[str] = []
|
||||||
|
for name in target_names:
|
||||||
|
if name not in allow_list:
|
||||||
|
slow_path_hits.append(name)
|
||||||
|
return slow_path_hits
|
||||||
|
|
||||||
|
def test_vllm_ops_skip_slow_path(self):
|
||||||
|
slow = self._simulate_graph_lowering(
|
||||||
|
[
|
||||||
|
"vllm::all_reduce",
|
||||||
|
"vllm::fused_add_rms_norm",
|
||||||
|
"vllm_aiter::rocm_aiter_fused_moe",
|
||||||
|
"vllm_aiter::asm_moe",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert slow == [], (
|
||||||
|
"Patched FALLBACK_ALLOW_LIST must short-circuit for all "
|
||||||
|
f"vllm::*/vllm_aiter::* ops; got slow-path hits: {slow}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_non_vllm_ops_still_hit_slow_path(self):
|
||||||
|
# Without the patch this is also what would happen; with the patch
|
||||||
|
# the behaviour for non-vllm namespaces must be unchanged.
|
||||||
|
slow = self._simulate_graph_lowering(
|
||||||
|
["my_user_ns::custom_op", "fancy_ns::something_else"]
|
||||||
|
)
|
||||||
|
assert "my_user_ns::custom_op" in slow
|
||||||
|
assert "fancy_ns::something_else" in slow
|
||||||
|
|
||||||
|
def test_kimi_k2_6_style_op_stream(self):
|
||||||
|
"""Emulates one decoder layer's worth of fallback hits.
|
||||||
|
|
||||||
|
Kimi-K2.6 at TP=4 lowers a stream of ``vllm::all_reduce`` +
|
||||||
|
``vllm_aiter::fused_add_rms_norm`` calls (one per residual block)
|
||||||
|
plus a handful of fused-MoE ops. Pre-patch every one of these would
|
||||||
|
invoke ``operator_str`` and stringify a hundreds-deep IR provenance
|
||||||
|
tree; post-patch they must all short-circuit.
|
||||||
|
"""
|
||||||
|
n_layers = 64 # Kimi-K2.6 has ~64 decoder layers per replica
|
||||||
|
op_stream: list[str] = []
|
||||||
|
for _ in range(n_layers):
|
||||||
|
op_stream.extend(
|
||||||
|
[
|
||||||
|
"vllm::all_reduce",
|
||||||
|
"vllm_aiter::fused_add_rms_norm",
|
||||||
|
"vllm_aiter::rocm_aiter_fused_moe",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
slow = self._simulate_graph_lowering(op_stream)
|
||||||
|
elapsed_s = time.perf_counter() - start
|
||||||
|
|
||||||
|
assert slow == [], (
|
||||||
|
f"Expected all {len(op_stream)} vllm/vllm_aiter ops to take "
|
||||||
|
f"the fast path; got {len(slow)} slow-path hits."
|
||||||
|
)
|
||||||
|
# ``__contains__`` is O(1) per call, so a Kimi-sized stream should
|
||||||
|
# complete in well under a second even on a slow runner. The
|
||||||
|
# pre-patch slow path took many minutes per op on Kimi-K2.6 TP=8.
|
||||||
|
assert elapsed_s < 1.0, (
|
||||||
|
f"FALLBACK_ALLOW_LIST membership check is unexpectedly slow: "
|
||||||
|
f"{elapsed_s:.3f}s for {len(op_stream)} ops"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_inner_set_membership_still_works_for_standard_ops(self):
|
||||||
|
"""The patch must not break Inductor's existing fallback decisions
|
||||||
|
for non-vllm ops such as ``torchvision::roi_align``."""
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
|
||||||
|
allow_list = _lowering.FALLBACK_ALLOW_LIST
|
||||||
|
# ``torchvision::roi_align`` has been a member of the upstream
|
||||||
|
# FALLBACK_ALLOW_LIST since the original Inductor implementation.
|
||||||
|
# If the proxy ever broke pass-through, this would regress.
|
||||||
|
if "torchvision::roi_align" not in allow_list:
|
||||||
|
pytest.skip(
|
||||||
|
"Upstream FALLBACK_ALLOW_LIST no longer ships "
|
||||||
|
"torchvision::roi_align; nothing to verify."
|
||||||
|
)
|
||||||
@@ -277,12 +277,15 @@ def assert_verification_synced(local_ok: bool, msg: str) -> None:
|
|||||||
assert bool(ok_tensor.item()), msg
|
assert bool(ok_tensor.item()), msg
|
||||||
|
|
||||||
|
|
||||||
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
|
def create_eplb_communicator_or_raise(
|
||||||
|
*, group_coordinator, backend, expert_weights, expert_buffer
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
return create_eplb_communicator(
|
return create_eplb_communicator(
|
||||||
group_coordinator=group_coordinator,
|
group_coordinator=group_coordinator,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
expert_weights=expert_weights,
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -355,7 +358,8 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
communicator = create_eplb_communicator_or_raise(
|
communicator = create_eplb_communicator_or_raise(
|
||||||
group_coordinator=ep_group_coordinator,
|
group_coordinator=ep_group_coordinator,
|
||||||
backend=eplb_communicator,
|
backend=eplb_communicator,
|
||||||
expert_weights=expert_weights[0],
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
communicator.set_stream(cuda_stream)
|
communicator.set_stream(cuda_stream)
|
||||||
|
|
||||||
@@ -368,6 +372,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
communicator=communicator,
|
communicator=communicator,
|
||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
cuda_stream.synchronize()
|
cuda_stream.synchronize()
|
||||||
move_from_buffer(
|
move_from_buffer(
|
||||||
@@ -460,10 +465,12 @@ def _test_rearrange_expert_weights_with_redundancy(
|
|||||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||||
communicator = create_eplb_communicator_or_raise(
|
communicator = create_eplb_communicator_or_raise(
|
||||||
group_coordinator=ep_group_coordinator,
|
group_coordinator=ep_group_coordinator,
|
||||||
backend=eplb_communicator,
|
backend=eplb_communicator,
|
||||||
expert_weights=expert_weights[0],
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute weight rearrangement
|
# Execute weight rearrangement
|
||||||
@@ -471,9 +478,9 @@ def _test_rearrange_expert_weights_with_redundancy(
|
|||||||
old_indices,
|
old_indices,
|
||||||
new_indices,
|
new_indices,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
|
expert_buffer,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=False,
|
communicator,
|
||||||
communicator=communicator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the rearrangement result
|
# Verify the rearrangement result
|
||||||
@@ -593,10 +600,12 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
|||||||
layer_copy.append(weight.clone())
|
layer_copy.append(weight.clone())
|
||||||
original_weights.append(layer_copy)
|
original_weights.append(layer_copy)
|
||||||
|
|
||||||
|
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||||
communicator = create_eplb_communicator_or_raise(
|
communicator = create_eplb_communicator_or_raise(
|
||||||
group_coordinator=ep_group_coordinator,
|
group_coordinator=ep_group_coordinator,
|
||||||
backend="torch_nccl",
|
backend="torch_nccl",
|
||||||
expert_weights=expert_weights[0],
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute rearrangement (should be no change)
|
# Execute rearrangement (should be no change)
|
||||||
@@ -604,9 +613,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
|||||||
indices,
|
indices,
|
||||||
indices, # Same indices
|
indices, # Same indices
|
||||||
expert_weights,
|
expert_weights,
|
||||||
|
expert_buffer,
|
||||||
ep_group,
|
ep_group,
|
||||||
communicator,
|
communicator,
|
||||||
is_profile=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the weights have not changed
|
# Verify that the weights have not changed
|
||||||
@@ -726,10 +735,12 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
|||||||
layer_copy.append(weight.clone())
|
layer_copy.append(weight.clone())
|
||||||
original_weights.append(layer_copy)
|
original_weights.append(layer_copy)
|
||||||
|
|
||||||
|
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||||
communicator = create_eplb_communicator_or_raise(
|
communicator = create_eplb_communicator_or_raise(
|
||||||
group_coordinator=ep_group_coordinator,
|
group_coordinator=ep_group_coordinator,
|
||||||
backend="torch_nccl",
|
backend="torch_nccl",
|
||||||
expert_weights=expert_weights[0],
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute profile mode rearrangement
|
# Execute profile mode rearrangement
|
||||||
@@ -737,9 +748,10 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
|||||||
old_indices,
|
old_indices,
|
||||||
new_indices,
|
new_indices,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
|
expert_buffer,
|
||||||
ep_group,
|
ep_group,
|
||||||
communicator,
|
communicator,
|
||||||
is_profile=True, # Profile mode
|
is_profile=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# In profile mode, the weights should remain unchanged
|
# In profile mode, the weights should remain unchanged
|
||||||
|
|||||||
@@ -9,9 +9,11 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
ensure_model_parallel_initialized,
|
ensure_model_parallel_initialized,
|
||||||
|
get_eplb_group,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
@@ -213,12 +215,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
|||||||
for lidx in range(test_config.num_layers):
|
for lidx in range(test_config.num_layers):
|
||||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||||
|
|
||||||
|
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
|
||||||
|
communicator = create_eplb_communicator(
|
||||||
|
group_coordinator=get_eplb_group(),
|
||||||
|
backend="torch_nccl",
|
||||||
|
expert_weights=rank_expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
|
)
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
indices,
|
indices,
|
||||||
shuffled_indices,
|
shuffled_indices,
|
||||||
rank_expert_weights,
|
rank_expert_weights,
|
||||||
|
expert_buffer,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=False,
|
communicator,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_local_experts = test_config.num_local_experts
|
num_local_experts = test_config.num_local_experts
|
||||||
|
|||||||
@@ -10,11 +10,13 @@ import torch
|
|||||||
|
|
||||||
from tests.kernels.moe.utils import make_test_quant_config
|
from tests.kernels.moe.utils import make_test_quant_config
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
ensure_model_parallel_initialized,
|
ensure_model_parallel_initialized,
|
||||||
get_dp_group,
|
get_dp_group,
|
||||||
|
get_eplb_group,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
@@ -171,12 +173,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
|||||||
for lidx in range(test_config.num_layers):
|
for lidx in range(test_config.num_layers):
|
||||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||||
|
|
||||||
|
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
|
||||||
|
communicator = create_eplb_communicator(
|
||||||
|
group_coordinator=get_eplb_group(),
|
||||||
|
backend="torch_nccl",
|
||||||
|
expert_weights=rank_expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
|
)
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
indices,
|
indices,
|
||||||
shuffled_indices,
|
shuffled_indices,
|
||||||
rank_expert_weights,
|
rank_expert_weights,
|
||||||
|
expert_buffer,
|
||||||
ep_group,
|
ep_group,
|
||||||
is_profile=False,
|
communicator,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_global_experts = test_config.num_experts
|
num_global_experts = test_config.num_experts
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm import PoolingParams
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.openai.engine.protocol import (
|
from vllm.entrypoints.openai.engine.protocol import (
|
||||||
@@ -13,10 +14,13 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
|
||||||
|
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||||
from vllm.entrypoints.serve.lora.protocol import (
|
from vllm.entrypoints.serve.lora.protocol import (
|
||||||
LoadLoRAAdapterRequest,
|
LoadLoRAAdapterRequest,
|
||||||
UnloadLoRAAdapterRequest,
|
UnloadLoRAAdapterRequest,
|
||||||
)
|
)
|
||||||
|
from vllm.exceptions import VLLMNotFoundError
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||||
@@ -130,3 +134,60 @@ async def test_unload_lora_adapter_not_found():
|
|||||||
assert isinstance(response, ErrorResponse)
|
assert isinstance(response, ErrorResponse)
|
||||||
assert response.error.type == "NotFoundError"
|
assert response.error.type == "NotFoundError"
|
||||||
assert response.error.code == HTTPStatus.NOT_FOUND
|
assert response.error.code == HTTPStatus.NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
class _ConcretePoolingServing(PoolingServingBase):
|
||||||
|
"""Minimal concrete subclass used only in these unit tests."""
|
||||||
|
|
||||||
|
request_id_prefix = "test"
|
||||||
|
|
||||||
|
def get_io_processor(self, request):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _build_response(self, ctx):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pooling_serving(lora_name: str) -> _ConcretePoolingServing:
|
||||||
|
lora_request = LoRARequest(
|
||||||
|
lora_name=lora_name, lora_int_id=1, lora_path="/path/to/lora"
|
||||||
|
)
|
||||||
|
mock_models = MagicMock()
|
||||||
|
mock_models.lora_requests = {lora_name: lora_request}
|
||||||
|
mock_models.is_base_model.side_effect = lambda name: name == MODEL_NAME
|
||||||
|
|
||||||
|
serving = object.__new__(_ConcretePoolingServing)
|
||||||
|
serving.models = mock_models
|
||||||
|
return serving
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pooling_ctx(model_name: str) -> PoolingServeContext:
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.model = model_name
|
||||||
|
return PoolingServeContext(
|
||||||
|
request=mock_request,
|
||||||
|
model_name=MODEL_NAME,
|
||||||
|
request_id="test-id",
|
||||||
|
pooling_params=PoolingParams(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pooling_maybe_get_adapters_lora_name_sets_lora_request():
|
||||||
|
"""LoRA adapter name must populate ctx.lora_request without raising."""
|
||||||
|
lora_name = "bot-embed-lora"
|
||||||
|
serving = _make_pooling_serving(lora_name)
|
||||||
|
ctx = _make_pooling_ctx(lora_name)
|
||||||
|
|
||||||
|
serving._maybe_get_adapters(ctx)
|
||||||
|
|
||||||
|
assert ctx.lora_request is not None
|
||||||
|
assert ctx.lora_request.lora_name == lora_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_pooling_maybe_get_adapters_unknown_model_raises():
|
||||||
|
"""An unrecognised model name must still raise VLLMNotFoundError."""
|
||||||
|
serving = _make_pooling_serving("some-lora")
|
||||||
|
ctx = _make_pooling_ctx("unknown-model")
|
||||||
|
|
||||||
|
with pytest.raises(VLLMNotFoundError):
|
||||||
|
serving._maybe_get_adapters(ctx)
|
||||||
|
|||||||
+1
-1
@@ -6,7 +6,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# Model name constants used across tests
|
# Model name constants used across tests
|
||||||
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
||||||
+2
-1
@@ -22,7 +22,8 @@ import tempfile
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
MODEL_NAME_SMOLLM,
|
MODEL_NAME_SMOLLM,
|
||||||
)
|
)
|
||||||
+2
-1
@@ -4,7 +4,8 @@ import openai # use the official async_client for correctness check
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
from .conftest import MODEL_NAME_SMOLLM
|
from .conftest import MODEL_NAME_SMOLLM
|
||||||
|
|
||||||
|
|
||||||
+2
-1
@@ -12,7 +12,8 @@ import tempfile
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
MODEL_NAME_SMOLLM,
|
MODEL_NAME_SMOLLM,
|
||||||
)
|
)
|
||||||
+2
-1
@@ -6,7 +6,8 @@ import openai # use the official client for correctness check
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
HEADER_SAGEMAKER_CLOSED_SESSION_ID,
|
HEADER_SAGEMAKER_CLOSED_SESSION_ID,
|
||||||
HEADER_SAGEMAKER_NEW_SESSION_ID,
|
HEADER_SAGEMAKER_NEW_SESSION_ID,
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||||
from vllm.entrypoints.utils import (
|
from vllm.entrypoints.serve.utils.api_utils import (
|
||||||
get_max_tokens,
|
get_max_tokens,
|
||||||
sanitize_message,
|
sanitize_message,
|
||||||
should_include_usage,
|
should_include_usage,
|
||||||
+1
-1
@@ -6,7 +6,7 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai import fingerprint as fp
|
from vllm.entrypoints.serve.utils import fingerprint as fp
|
||||||
|
|
||||||
|
|
||||||
def _cfg(tp=1, pp=1, dp=1, ep=False, digest="a3b21f94deadbeef"):
|
def _cfg(tp=1, pp=1, dp=1, ep=False, digest="a3b21f94deadbeef"):
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs():
|
||||||
|
"""Test the new log_outputs functionality."""
|
||||||
|
# Create a mock logger to capture log calls
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test basic output logging
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-123",
|
||||||
|
outputs="Hello, world!",
|
||||||
|
output_token_ids=[1, 2, 3, 4],
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-123"
|
||||||
|
assert call_args[3] == "Hello, world!"
|
||||||
|
assert call_args[4] == [1, 2, 3, 4]
|
||||||
|
assert call_args[5] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_streaming_delta():
|
||||||
|
"""Test log_outputs with streaming delta mode."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test streaming delta logging
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-456",
|
||||||
|
outputs="Hello",
|
||||||
|
output_token_ids=[1],
|
||||||
|
finish_reason=None,
|
||||||
|
is_streaming=True,
|
||||||
|
delta=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-456"
|
||||||
|
assert call_args[2] == " (streaming delta)"
|
||||||
|
assert call_args[3] == "Hello"
|
||||||
|
assert call_args[4] == [1]
|
||||||
|
assert call_args[5] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_streaming_complete():
|
||||||
|
"""Test log_outputs with streaming complete mode."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test streaming complete logging
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-789",
|
||||||
|
outputs="Complete response",
|
||||||
|
output_token_ids=[1, 2, 3],
|
||||||
|
finish_reason="length",
|
||||||
|
is_streaming=True,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-789"
|
||||||
|
assert call_args[2] == " (streaming complete)"
|
||||||
|
assert call_args[3] == "Complete response"
|
||||||
|
assert call_args[4] == [1, 2, 3]
|
||||||
|
assert call_args[5] == "length"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_with_truncation():
|
||||||
|
"""Test log_outputs respects max_log_len setting."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
# Set max_log_len to 10
|
||||||
|
request_logger = RequestLogger(max_log_len=10)
|
||||||
|
|
||||||
|
# Test output truncation
|
||||||
|
long_output = "This is a very long output that should be truncated"
|
||||||
|
long_token_ids = list(range(20)) # 20 tokens
|
||||||
|
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-truncate",
|
||||||
|
outputs=long_output,
|
||||||
|
output_token_ids=long_token_ids,
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args
|
||||||
|
|
||||||
|
# Check that output was truncated to first 10 characters
|
||||||
|
logged_output = call_args[0][3]
|
||||||
|
assert logged_output == "This is a "
|
||||||
|
assert len(logged_output) == 10
|
||||||
|
|
||||||
|
# Check that token IDs were truncated to first 10 tokens
|
||||||
|
logged_token_ids = call_args[0][4]
|
||||||
|
assert logged_token_ids == list(range(10))
|
||||||
|
assert len(logged_token_ids) == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_none_values():
|
||||||
|
"""Test log_outputs handles None values correctly."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test with None output_token_ids
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-none",
|
||||||
|
outputs="Test output",
|
||||||
|
output_token_ids=None,
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-none"
|
||||||
|
assert call_args[3] == "Test output"
|
||||||
|
assert call_args[4] is None
|
||||||
|
assert call_args[5] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_empty_output():
|
||||||
|
"""Test log_outputs handles empty output correctly."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=5)
|
||||||
|
|
||||||
|
# Test with empty output
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-empty",
|
||||||
|
outputs="",
|
||||||
|
output_token_ids=[],
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
assert "Generated response %s%s" in call_args[0]
|
||||||
|
assert call_args[1] == "test-empty"
|
||||||
|
assert call_args[3] == ""
|
||||||
|
assert call_args[4] == []
|
||||||
|
assert call_args[5] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_logger_log_outputs_integration():
|
||||||
|
"""Test that log_outputs can be called alongside log_inputs."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test that both methods can be called without interference
|
||||||
|
request_logger.log_inputs(
|
||||||
|
request_id="test-integration",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_embeds=None,
|
||||||
|
params=None,
|
||||||
|
lora_request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-integration",
|
||||||
|
outputs="Test output",
|
||||||
|
output_token_ids=[4, 5, 6],
|
||||||
|
finish_reason="stop",
|
||||||
|
is_streaming=False,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have been called twice - once for inputs, once for outputs
|
||||||
|
assert mock_logger.info.call_count == 2
|
||||||
|
|
||||||
|
# Check that the calls were made with correct patterns
|
||||||
|
input_call = mock_logger.info.call_args_list[0][0]
|
||||||
|
output_call = mock_logger.info.call_args_list[1][0]
|
||||||
|
|
||||||
|
assert "Received request %s" in input_call[0]
|
||||||
|
assert input_call[1] == "test-integration"
|
||||||
|
|
||||||
|
assert "Generated response %s%s" in output_call[0]
|
||||||
|
assert output_call[1] == "test-integration"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_complete_logs_full_text_content():
|
||||||
|
"""Test that streaming complete logging includes
|
||||||
|
full accumulated text, not just token count."""
|
||||||
|
mock_logger = MagicMock()
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||||
|
request_logger = RequestLogger(max_log_len=None)
|
||||||
|
|
||||||
|
# Test with actual content instead of token count format
|
||||||
|
full_response = "This is a complete response from streaming"
|
||||||
|
request_logger.log_outputs(
|
||||||
|
request_id="test-streaming-full-text",
|
||||||
|
outputs=full_response,
|
||||||
|
output_token_ids=None,
|
||||||
|
finish_reason="streaming_complete",
|
||||||
|
is_streaming=True,
|
||||||
|
delta=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called_once()
|
||||||
|
call_args = mock_logger.info.call_args.args
|
||||||
|
|
||||||
|
# Verify the logged output is the full text, not a token count format
|
||||||
|
logged_output = call_args[3]
|
||||||
|
assert logged_output == full_response
|
||||||
|
assert "tokens>" not in logged_output
|
||||||
|
assert "streaming_complete" not in logged_output
|
||||||
|
|
||||||
|
# Verify other parameters
|
||||||
|
assert call_args[1] == "test-streaming-full-text"
|
||||||
|
assert call_args[2] == " (streaming complete)"
|
||||||
|
assert call_args[5] == "streaming_complete"
|
||||||
+1
-1
@@ -7,7 +7,7 @@ from ssl import SSLContext
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
from vllm.entrypoints.serve.utils.ssl import SSLCertRefresher
|
||||||
|
|
||||||
|
|
||||||
class MockSSLContext(SSLContext):
|
class MockSSLContext(SSLContext):
|
||||||
@@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
|||||||
accuracy_threshold: 0.72
|
accuracy_threshold: 0.72
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
|
rocm_request_timeout_seconds: 1800
|
||||||
server_args: "--enforce-eager --max-model-len 4096"
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
|||||||
accuracy_threshold: 0.45
|
accuracy_threshold: 0.45
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
|
rocm_request_timeout_seconds: 1800
|
||||||
server_args: "--enforce-eager --max-model-len 4096"
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ async def call_vllm_api(
|
|||||||
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||||
return text, completion_tokens
|
return text, completion_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error calling vLLM API: {e}")
|
print(f"Error calling vLLM API ({type(e).__name__}): {e}")
|
||||||
return "", 0
|
return "", 0
|
||||||
|
|
||||||
|
|
||||||
@@ -177,6 +177,7 @@ def evaluate_gsm8k(
|
|||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
temperature: float = 0.0,
|
temperature: float = 0.0,
|
||||||
seed: int | None = 42,
|
seed: int | None = 42,
|
||||||
|
request_timeout_seconds: float = 600,
|
||||||
) -> dict[str, float | int]:
|
) -> dict[str, float | int]:
|
||||||
"""
|
"""
|
||||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||||
@@ -205,9 +206,8 @@ def evaluate_gsm8k(
|
|||||||
output_tokens[i] = tokens
|
output_tokens[i] = tokens
|
||||||
return answer, tokens
|
return answer, tokens
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
timeout = aiohttp.ClientTimeout(total=request_timeout_seconds)
|
||||||
timeout=aiohttp.ClientTimeout(total=600)
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
) as session:
|
|
||||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||||
await tqdm.gather(*tasks, desc="Evaluating")
|
await tqdm.gather(*tasks, desc="Evaluating")
|
||||||
|
|
||||||
|
|||||||
@@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
|
|||||||
host = f"http://{host}"
|
host = f"http://{host}"
|
||||||
|
|
||||||
# Run GSM8K evaluation
|
# Run GSM8K evaluation
|
||||||
|
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
request_timeout_seconds = eval_config.get(
|
||||||
|
"rocm_request_timeout_seconds", request_timeout_seconds
|
||||||
|
)
|
||||||
|
|
||||||
results = evaluate_gsm8k(
|
results = evaluate_gsm8k(
|
||||||
num_questions=eval_config["num_questions"],
|
num_questions=eval_config["num_questions"],
|
||||||
num_shots=eval_config["num_fewshot"],
|
num_shots=eval_config["num_fewshot"],
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
|
request_timeout_seconds=request_timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename):
|
|||||||
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
||||||
print(f"Number of questions: {eval_config['num_questions']}")
|
print(f"Number of questions: {eval_config['num_questions']}")
|
||||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||||
|
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
request_timeout_seconds = eval_config.get(
|
||||||
|
"rocm_request_timeout_seconds", request_timeout_seconds
|
||||||
|
)
|
||||||
|
print(f"Request timeout: {request_timeout_seconds}s")
|
||||||
print(f"Server args: {' '.join(server_args)}")
|
print(f"Server args: {' '.join(server_args)}")
|
||||||
print(f"Environment variables: {env_dict}")
|
print(f"Environment variables: {env_dict}")
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,339 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""ROCm kernel correctness tests for AITER unified attention.
|
||||||
|
|
||||||
|
Compares ``aiter.ops.triton.unified_attention`` against ``ref_paged_attn`` under
|
||||||
|
decode, prefill, and mixed batches with varied shapes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.attention.test_triton_unified_attention import ref_paged_attn
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.torch_utils import set_random_seed
|
||||||
|
|
||||||
|
_SKIP_NON_MI3XX = True
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from vllm.platforms.rocm import on_mi3xx
|
||||||
|
|
||||||
|
_SKIP_NON_MI3XX = not on_mi3xx()
|
||||||
|
|
||||||
|
pytestmark = [
|
||||||
|
pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-specific tests"),
|
||||||
|
pytest.mark.skipif(_SKIP_NON_MI3XX, reason="MI300/MI350 ROCm only"),
|
||||||
|
]
|
||||||
|
|
||||||
|
NUM_Q_HEADS = 8
|
||||||
|
NUM_KV_HEADS = 8
|
||||||
|
HEAD_SIZES = [128, 256]
|
||||||
|
BLOCK_SIZES = [16, 64]
|
||||||
|
DTYPES = [torch.bfloat16, torch.float16]
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
# (query_len, kv_len) per sequence
|
||||||
|
MIXED_SEQ_LENS = [
|
||||||
|
[(1, 128), (5, 18), (129, 463)],
|
||||||
|
[(10, 256), (5, 64), (32, 128)],
|
||||||
|
[(1, 1024), (5, 18), (129, 1328)],
|
||||||
|
]
|
||||||
|
DECODE_SEQ_LENS = [
|
||||||
|
[(1, 128), (1, 256), (1, 384), (1, 512)],
|
||||||
|
[(1, 1024), (1, 1536), (1, 2048)],
|
||||||
|
]
|
||||||
|
PREFILL_SEQ_LENS = [
|
||||||
|
[(256, 256), (128, 512)],
|
||||||
|
[(64, 128), (32, 256), (16, 512)],
|
||||||
|
[(256, 1024), (128, 2048)],
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_ATOL, DEFAULT_RTOL = 1.5e-2, 1e-2
|
||||||
|
FP8_ATOL, FP8_RTOL = 1.5e-1, 1.5e-1
|
||||||
|
# Non-unity scale so q_descale handling is exercised explicitly.
|
||||||
|
Q_SCALE = 0.75
|
||||||
|
K_SCALE, V_SCALE = 0.5, 0.25
|
||||||
|
|
||||||
|
Fp8Variant = Literal["fp8_kv", "fp8_query", "fp8_query_kv"]
|
||||||
|
|
||||||
|
FP8_VARIANTS = [
|
||||||
|
pytest.param("fp8_kv", id="fp8_kv"),
|
||||||
|
pytest.param("fp8_query", id="fp8_query"),
|
||||||
|
pytest.param("fp8_query_kv", id="fp8_query_kv"),
|
||||||
|
]
|
||||||
|
|
||||||
|
FP8_SEQ_LENS = [
|
||||||
|
MIXED_SEQ_LENS[0],
|
||||||
|
DECODE_SEQ_LENS[0],
|
||||||
|
DECODE_SEQ_LENS[1],
|
||||||
|
PREFILL_SEQ_LENS[0],
|
||||||
|
PREFILL_SEQ_LENS[2],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _require_aiter() -> None:
|
||||||
|
from vllm._aiter_ops import is_aiter_found_and_supported
|
||||||
|
|
||||||
|
if not is_aiter_found_and_supported():
|
||||||
|
pytest.skip("aiter is required on supported ROCm hardware for this test")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_case(
|
||||||
|
*,
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
num_blocks: int = 2048,
|
||||||
|
kv_cache_dtype: torch.dtype | None = None,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
|
q_dtype: torch.dtype | None = None,
|
||||||
|
q_scale: float = Q_SCALE,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
query_lens = [q for q, _ in seq_lens]
|
||||||
|
kv_lens = [k for _, k in seq_lens]
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
|
max_query_len = max(query_lens)
|
||||||
|
max_kv_len = max(kv_lens)
|
||||||
|
scale = head_size**-0.5
|
||||||
|
|
||||||
|
query = torch.randn(sum(query_lens), NUM_Q_HEADS, head_size, dtype=dtype)
|
||||||
|
if kv_cache_dtype is None:
|
||||||
|
key_cache = torch.randn(
|
||||||
|
num_blocks, block_size, NUM_KV_HEADS, head_size, dtype=dtype
|
||||||
|
)
|
||||||
|
value_cache = torch.randn_like(key_cache)
|
||||||
|
else:
|
||||||
|
key_cache = torch.clamp(
|
||||||
|
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
|
||||||
|
-1.0,
|
||||||
|
1.0,
|
||||||
|
).to(kv_cache_dtype)
|
||||||
|
value_cache = torch.clamp(
|
||||||
|
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
|
||||||
|
-1.0,
|
||||||
|
1.0,
|
||||||
|
).to(kv_cache_dtype)
|
||||||
|
|
||||||
|
cu_seqlens_q = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||||
|
dim=0, dtype=torch.int32
|
||||||
|
)
|
||||||
|
seq_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||||
|
|
||||||
|
max_num_blocks = (max_kv_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(
|
||||||
|
0, num_blocks, (num_seqs, max_num_blocks), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
descale_shape = (num_seqs, NUM_KV_HEADS)
|
||||||
|
k_descale = torch.full(descale_shape, k_scale, dtype=torch.float32, device="cuda")
|
||||||
|
v_descale = torch.full(descale_shape, v_scale, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
kernel_query = query
|
||||||
|
q_descale = None
|
||||||
|
if q_dtype is not None:
|
||||||
|
q_descale = torch.tensor(q_scale, dtype=torch.float32, device="cuda")
|
||||||
|
kernel_query = (query / q_scale).to(q_dtype)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query,
|
||||||
|
"kernel_query": kernel_query,
|
||||||
|
"key_cache": key_cache,
|
||||||
|
"value_cache": value_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"query_lens": query_lens,
|
||||||
|
"kv_lens": kv_lens,
|
||||||
|
"seq_lens_tensor": seq_lens_tensor,
|
||||||
|
"cu_seqlens_q": cu_seqlens_q,
|
||||||
|
"q_descale": q_descale,
|
||||||
|
"k_descale": k_descale,
|
||||||
|
"v_descale": v_descale,
|
||||||
|
"scale": scale,
|
||||||
|
"max_query_len": max_query_len,
|
||||||
|
"max_kv_len": max_kv_len,
|
||||||
|
"query_dtype": dtype,
|
||||||
|
"k_scale": k_scale,
|
||||||
|
"v_scale": v_scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fp8_case(
|
||||||
|
*,
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
variant: Fp8Variant,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
use_fp8_kv = variant in ("fp8_kv", "fp8_query_kv")
|
||||||
|
use_fp8_query = variant in ("fp8_query", "fp8_query_kv")
|
||||||
|
return _make_case(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
kv_cache_dtype=FP8_DTYPE if use_fp8_kv else None,
|
||||||
|
k_scale=K_SCALE if use_fp8_kv else 1.0,
|
||||||
|
v_scale=V_SCALE if use_fp8_kv else 1.0,
|
||||||
|
q_dtype=FP8_DTYPE if use_fp8_query else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_aiter_unified_attention(case: dict[str, Any]) -> torch.Tensor:
|
||||||
|
from aiter.ops.triton.unified_attention import unified_attention
|
||||||
|
|
||||||
|
kernel_query = case["kernel_query"]
|
||||||
|
# Kernel writes high-precision output even when Q is FP8 (matches vLLM usage).
|
||||||
|
output = torch.empty_like(case["query"])
|
||||||
|
unified_attention(
|
||||||
|
q=kernel_query,
|
||||||
|
k=case["key_cache"],
|
||||||
|
v=case["value_cache"],
|
||||||
|
out=output,
|
||||||
|
cu_seqlens_q=case["cu_seqlens_q"],
|
||||||
|
max_seqlen_q=case["max_query_len"],
|
||||||
|
seqused_k=case["seq_lens_tensor"],
|
||||||
|
max_seqlen_k=case["max_kv_len"],
|
||||||
|
softmax_scale=case["scale"],
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=None,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
block_table=case["block_tables"],
|
||||||
|
softcap=0,
|
||||||
|
q_descale=case["q_descale"],
|
||||||
|
k_descale=case["k_descale"],
|
||||||
|
v_descale=case["v_descale"],
|
||||||
|
sinks=None,
|
||||||
|
output_scale=None,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _ref_output(case: dict[str, Any]) -> torch.Tensor:
|
||||||
|
key_cache = case["key_cache"]
|
||||||
|
value_cache = case["value_cache"]
|
||||||
|
if key_cache.dtype != case["query_dtype"]:
|
||||||
|
key_cache = key_cache.to(case["query_dtype"]) * case["k_scale"]
|
||||||
|
value_cache = value_cache.to(case["query_dtype"]) * case["v_scale"]
|
||||||
|
|
||||||
|
return ref_paged_attn(
|
||||||
|
query=case["query"],
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
query_lens=case["query_lens"],
|
||||||
|
kv_lens=case["kv_lens"],
|
||||||
|
block_tables=case["block_tables"],
|
||||||
|
scale=case["scale"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_matches_reference(
|
||||||
|
case: dict[str, Any],
|
||||||
|
*,
|
||||||
|
atol: float = DEFAULT_ATOL,
|
||||||
|
rtol: float = DEFAULT_RTOL,
|
||||||
|
) -> None:
|
||||||
|
output = _run_aiter_unified_attention(case)
|
||||||
|
output_ref = _ref_output(case)
|
||||||
|
torch.testing.assert_close(output, output_ref, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seq_lens", MIXED_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_aiter_unified_attn_mixed_batch(
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
"""Decode + prefill sequences in one batch (native dtypes)."""
|
||||||
|
_require_aiter()
|
||||||
|
set_random_seed(0)
|
||||||
|
|
||||||
|
case = _make_case(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
_assert_matches_reference(case)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seq_lens", DECODE_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_aiter_unified_attn_decode(
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
"""Single-token decode (native dtypes)."""
|
||||||
|
_require_aiter()
|
||||||
|
set_random_seed(0)
|
||||||
|
|
||||||
|
case = _make_case(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
_assert_matches_reference(case)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seq_lens", PREFILL_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("head_size", [128])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_aiter_unified_attn_prefill(
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Prefill-only batches with query_len > 1 (native dtypes)."""
|
||||||
|
_require_aiter()
|
||||||
|
set_random_seed(0)
|
||||||
|
|
||||||
|
case = _make_case(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
_assert_matches_reference(case)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.supports_fp8(),
|
||||||
|
reason="FP8 not supported on this hardware",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("variant", FP8_VARIANTS)
|
||||||
|
@pytest.mark.parametrize("seq_lens", FP8_SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("head_size", [128])
|
||||||
|
@pytest.mark.parametrize("block_size", [16, 64])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_aiter_unified_attn_fp8(
|
||||||
|
variant: Fp8Variant,
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""FP8 KV cache, FP8 query, or both; compared at bf16 reference precision."""
|
||||||
|
_require_aiter()
|
||||||
|
set_random_seed(0)
|
||||||
|
|
||||||
|
case = _make_fp8_case(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
head_size=head_size,
|
||||||
|
block_size=block_size,
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
_assert_matches_reference(case, atol=FP8_ATOL, rtol=FP8_RTOL)
|
||||||
@@ -205,7 +205,10 @@ def run_with_expert_maps(
|
|||||||
w2 = kwargs["w2"]
|
w2 = kwargs["w2"]
|
||||||
a = kwargs["hidden_states"]
|
a = kwargs["hidden_states"]
|
||||||
moe_config = make_dummy_moe_config(
|
moe_config = make_dummy_moe_config(
|
||||||
num_experts=w2.shape[0],
|
max_num_tokens=kwargs.get("hidden_states").shape[0],
|
||||||
|
experts_per_token=kwargs.get("topk_ids").shape[1],
|
||||||
|
num_experts=num_experts,
|
||||||
|
num_local_experts=num_local_experts,
|
||||||
hidden_dim=w2.shape[1],
|
hidden_dim=w2.shape[1],
|
||||||
intermediate_size_per_partition=w2.shape[2],
|
intermediate_size_per_partition=w2.shape[2],
|
||||||
in_dtype=a.dtype,
|
in_dtype=a.dtype,
|
||||||
@@ -258,23 +261,27 @@ def run_8_bit(
|
|||||||
a1_scale=None,
|
a1_scale=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
||||||
|
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"hidden_states": moe_tensors.a,
|
"hidden_states": moe_tensors.a,
|
||||||
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
|
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||||
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
|
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||||
"topk_weights": topk_weights,
|
"topk_weights": topk_weights,
|
||||||
"topk_ids": topk_ids,
|
"topk_ids": topk_ids,
|
||||||
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
|
"global_num_experts": num_experts,
|
||||||
"activation": MoEActivation.SILU,
|
"activation": MoEActivation.SILU,
|
||||||
"expert_map": None,
|
"expert_map": None,
|
||||||
"apply_router_weight_on_input": False,
|
"apply_router_weight_on_input": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
|
||||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
|
||||||
if not with_ep:
|
if not with_ep:
|
||||||
moe_config = make_dummy_moe_config(
|
moe_config = make_dummy_moe_config(
|
||||||
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
|
max_num_tokens=moe_tensors.a.shape[0],
|
||||||
|
experts_per_token=topk_ids.shape[1],
|
||||||
|
num_experts=num_experts,
|
||||||
|
num_local_experts=num_local_experts,
|
||||||
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
||||||
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
||||||
in_dtype=moe_tensors.a.dtype,
|
in_dtype=moe_tensors.a.dtype,
|
||||||
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
|
|||||||
per_out_channel,
|
per_out_channel,
|
||||||
False,
|
False,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
workspace13.random_()
|
workspace13.random_()
|
||||||
|
|||||||
@@ -1287,10 +1287,12 @@ def _test_body_eplb(
|
|||||||
|
|
||||||
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
|
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
|
||||||
|
|
||||||
|
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||||
communicator = create_eplb_communicator(
|
communicator = create_eplb_communicator(
|
||||||
group_coordinator=get_eplb_group(),
|
group_coordinator=get_eplb_group(),
|
||||||
backend=vllm_config.parallel_config.eplb_config.communicator,
|
backend=vllm_config.parallel_config.eplb_config.communicator,
|
||||||
expert_weights=expert_weights[0],
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Rearrange expert weights across EP ranks
|
# Rearrange expert weights across EP ranks
|
||||||
@@ -1298,6 +1300,7 @@ def _test_body_eplb(
|
|||||||
old_global_expert_indices=initial_indices.unsqueeze(0),
|
old_global_expert_indices=initial_indices.unsqueeze(0),
|
||||||
new_global_expert_indices=shuffled_indices.unsqueeze(0),
|
new_global_expert_indices=shuffled_indices.unsqueeze(0),
|
||||||
expert_weights=expert_weights,
|
expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
ep_group=cpu_group,
|
ep_group=cpu_group,
|
||||||
communicator=communicator,
|
communicator=communicator,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
def make_dummy_moe_config(
|
def make_dummy_moe_config(
|
||||||
num_experts: int = 1,
|
num_experts: int = 1,
|
||||||
|
num_local_experts: int | None = None,
|
||||||
experts_per_token: int = 1,
|
experts_per_token: int = 1,
|
||||||
hidden_dim: int = 1,
|
hidden_dim: int = 1,
|
||||||
intermediate_size_per_partition: int = 1,
|
intermediate_size_per_partition: int = 1,
|
||||||
in_dtype: torch.dtype = torch.bfloat16,
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
max_num_tokens: int = 512,
|
||||||
) -> FusedMoEConfig:
|
) -> FusedMoEConfig:
|
||||||
"""
|
"""
|
||||||
This is a dummy config for the mk constructor interface
|
This is a dummy config for the mk constructor interface
|
||||||
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
|
|||||||
experts_per_token=experts_per_token,
|
experts_per_token=experts_per_token,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||||
num_local_experts=num_experts,
|
num_local_experts=num_local_experts
|
||||||
|
if num_local_experts is not None
|
||||||
|
else num_experts,
|
||||||
num_logical_experts=num_experts,
|
num_logical_experts=num_experts,
|
||||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||||
activation=MoEActivation.SILU,
|
activation=MoEActivation.SILU,
|
||||||
in_dtype=in_dtype,
|
in_dtype=in_dtype,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
routing_method=RoutingMethodType.TopK,
|
routing_method=RoutingMethodType.TopK,
|
||||||
max_num_tokens=512,
|
max_num_tokens=max_num_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for the Triton dequant-gather kernel used by
|
||||||
|
``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
|
||||||
|
_dequant_gather_triton,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def _dequant_gather_torch(
|
||||||
|
ids: torch.Tensor,
|
||||||
|
weight_packed: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
hidden: int,
|
||||||
|
num_bits: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Reference: gather packed rows by id, unpack int32-packed INT, dequant."""
|
||||||
|
n = ids.shape[0]
|
||||||
|
int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden]))
|
||||||
|
scale_rows = weight_scale[ids]
|
||||||
|
w = int8.to(scale_rows.dtype)
|
||||||
|
if scale_rows.shape[1] == 1:
|
||||||
|
return w * scale_rows
|
||||||
|
ng = scale_rows.shape[1]
|
||||||
|
return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("num_bits", [2, 4, 8])
|
||||||
|
@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("num_ids", [1, 17, 4096])
|
||||||
|
def test_dequant_gather(num_bits, group_size, dtype, num_ids):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
device = "cuda"
|
||||||
|
vocab, hidden = 1000, 2048
|
||||||
|
pack_factor = 32 // num_bits
|
||||||
|
|
||||||
|
# Random full-range int32 packed weights (covers the sign bit -> exercises the
|
||||||
|
# arithmetic-shift + mask unpack path).
|
||||||
|
weight_packed = torch.randint(
|
||||||
|
-(2**31),
|
||||||
|
2**31,
|
||||||
|
(vocab, hidden // pack_factor),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_groups = 1 if group_size == 0 else hidden // group_size
|
||||||
|
weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01
|
||||||
|
|
||||||
|
ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits)
|
||||||
|
ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits)
|
||||||
|
|
||||||
|
assert out.shape == (num_ids, hidden)
|
||||||
|
assert out.dtype == dtype
|
||||||
|
torch.testing.assert_close(out, ref)
|
||||||
@@ -468,6 +468,7 @@ def _reference_kv_compress_norm_rope(
|
|||||||
use_fp4: bool = False,
|
use_fp4: bool = False,
|
||||||
rms_eps: float = 1e-6,
|
rms_eps: float = 1e-6,
|
||||||
fp8_max: float = 448.0,
|
fp8_max: float = 448.0,
|
||||||
|
return_full_cache: bool = False,
|
||||||
):
|
):
|
||||||
"""Compress → RMSNorm → GPT-J RoPE → quantize.
|
"""Compress → RMSNorm → GPT-J RoPE → quantize.
|
||||||
|
|
||||||
@@ -521,6 +522,12 @@ def _reference_kv_compress_norm_rope(
|
|||||||
results.append(torch.cat([nope, rope]).to(state_cache.dtype))
|
results.append(torch.cat([nope, rope]).to(state_cache.dtype))
|
||||||
result = torch.stack(results)
|
result = torch.stack(results)
|
||||||
|
|
||||||
|
if return_full_cache:
|
||||||
|
# Contiguous 512-wide bf16 row (nope unrotated + rope rotated), matching
|
||||||
|
# the FlashInfer full-cache layout before any per-tensor fp8 quant. The
|
||||||
|
# kernel rounds the fp32 result to bf16 once at the store.
|
||||||
|
return result.to(torch.bfloat16)
|
||||||
|
|
||||||
if use_fp4:
|
if use_fp4:
|
||||||
return quantize_to_mxfp4(result)
|
return quantize_to_mxfp4(result)
|
||||||
else:
|
else:
|
||||||
@@ -667,3 +674,145 @@ def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: b
|
|||||||
assert torch.equal(actual_scale, scale[i : i + 1]), (
|
assert torch.equal(actual_scale, scale[i : i + 1]), (
|
||||||
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
|
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("compress_ratio", [4, 128])
|
||||||
|
@pytest.mark.parametrize("store_fp8", [False, True])
|
||||||
|
def test_cutedsl_full_cache_store(compress_ratio: int, store_fp8: bool):
|
||||||
|
"""CuTeDSL compressor full-cache (FlashInfer) store parity for head=512.
|
||||||
|
|
||||||
|
Exercises the contiguous bf16 / per-tensor fp8 store branch of both the C4
|
||||||
|
fused kernel and the C128 split kernel against the PyTorch reference.
|
||||||
|
"""
|
||||||
|
cutedsl = pytest.importorskip("cutlass") # noqa: F841
|
||||||
|
from vllm.models.deepseek_v4.nvidia.ops.sparse_attn_compress_cutedsl import (
|
||||||
|
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||||
|
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||||
|
)
|
||||||
|
|
||||||
|
HEAD_DIM = 512
|
||||||
|
ROPE_DIM = 64
|
||||||
|
RMS_EPS = 1e-6
|
||||||
|
FP8_MAX = 448.0
|
||||||
|
# C128 compress (Block8 kernel) requires state-cache block_size=8; C4 uses 16.
|
||||||
|
BLOCK_SIZE = 8 if compress_ratio == 128 else 16
|
||||||
|
KV_BLOCK_SIZE = 64
|
||||||
|
device = "cuda"
|
||||||
|
torch.manual_seed(7)
|
||||||
|
|
||||||
|
overlap = 1 if compress_ratio == 4 else 0
|
||||||
|
coff = 1 + overlap
|
||||||
|
num_tokens = 8
|
||||||
|
|
||||||
|
num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2
|
||||||
|
# The production CompressorStateCache is fp32.
|
||||||
|
state_cache = torch.randn(
|
||||||
|
num_pages, BLOCK_SIZE, 2 * coff * HEAD_DIM, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)
|
||||||
|
token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device)
|
||||||
|
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||||
|
positions = torch.arange(
|
||||||
|
compress_ratio - 1,
|
||||||
|
compress_ratio * num_tokens,
|
||||||
|
compress_ratio,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
|
||||||
|
cos_sin_cache = torch.randn(
|
||||||
|
compress_ratio * num_tokens, ROPE_DIM, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = torch.float8_e4m3fn if store_fp8 else torch.bfloat16
|
||||||
|
kv_n_blocks = (num_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + 1
|
||||||
|
k_cache = torch.zeros(
|
||||||
|
kv_n_blocks, KV_BLOCK_SIZE, HEAD_DIM, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
fp8_scale = torch.tensor(
|
||||||
|
[0.5 if store_fp8 else 1.0], dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
if compress_ratio == 4:
|
||||||
|
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
|
||||||
|
state_cache,
|
||||||
|
token_to_req,
|
||||||
|
positions,
|
||||||
|
slot_mapping,
|
||||||
|
block_table,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
rms_weight,
|
||||||
|
RMS_EPS,
|
||||||
|
cos_sin_cache,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
KV_BLOCK_SIZE,
|
||||||
|
k_cache.stride(0),
|
||||||
|
head_size=HEAD_DIM,
|
||||||
|
state_width=coff * HEAD_DIM,
|
||||||
|
rope_head_dim=ROPE_DIM,
|
||||||
|
fp8_max=FP8_MAX,
|
||||||
|
quant_block=64,
|
||||||
|
token_stride=576,
|
||||||
|
scale_dim=8,
|
||||||
|
compress_ratio=compress_ratio,
|
||||||
|
overlap=True,
|
||||||
|
store_full_kv=True,
|
||||||
|
store_full_fp8=store_fp8,
|
||||||
|
fp8_scale=fp8_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
compressed_kv = torch.empty(
|
||||||
|
(num_tokens, HEAD_DIM), dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
|
||||||
|
state_cache,
|
||||||
|
token_to_req,
|
||||||
|
positions,
|
||||||
|
slot_mapping,
|
||||||
|
block_table,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
compressed_kv,
|
||||||
|
rms_weight,
|
||||||
|
RMS_EPS,
|
||||||
|
cos_sin_cache,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
KV_BLOCK_SIZE,
|
||||||
|
k_cache.stride(0),
|
||||||
|
head_size=HEAD_DIM,
|
||||||
|
state_width=coff * HEAD_DIM,
|
||||||
|
rope_head_dim=ROPE_DIM,
|
||||||
|
fp8_max=FP8_MAX,
|
||||||
|
quant_block=64,
|
||||||
|
token_stride=576,
|
||||||
|
scale_dim=8,
|
||||||
|
compress_ratio=compress_ratio,
|
||||||
|
overlap=bool(overlap),
|
||||||
|
store_full_kv=True,
|
||||||
|
store_full_fp8=store_fp8,
|
||||||
|
fp8_scale=fp8_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref = _reference_kv_compress_norm_rope(
|
||||||
|
state_cache,
|
||||||
|
block_table,
|
||||||
|
positions,
|
||||||
|
rms_weight,
|
||||||
|
cos_sin_cache,
|
||||||
|
compress_ratio,
|
||||||
|
overlap,
|
||||||
|
rms_eps=RMS_EPS,
|
||||||
|
return_full_cache=True,
|
||||||
|
) # [num_tokens, HEAD_DIM] bf16
|
||||||
|
|
||||||
|
actual = torch.stack(
|
||||||
|
[k_cache[i // KV_BLOCK_SIZE, i % KV_BLOCK_SIZE] for i in range(num_tokens)]
|
||||||
|
)
|
||||||
|
if store_fp8:
|
||||||
|
ref_fp8 = torch.clamp(ref.float() / fp8_scale, -FP8_MAX, FP8_MAX).to(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(actual.float(), ref_fp8.float(), rtol=0.0, atol=0.3)
|
||||||
|
else:
|
||||||
|
torch.testing.assert_close(actual.float(), ref.float(), rtol=3e-2, atol=3e-2)
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def apply_rope_gptj_last_k(
|
|||||||
head_dim = x.shape[-1]
|
head_dim = x.shape[-1]
|
||||||
nope_dim = head_dim - rope_dim
|
nope_dim = head_dim - rope_dim
|
||||||
|
|
||||||
cs = cos_sin_cache[positions].to(torch.float32)
|
cs = cos_sin_cache[positions.long()].to(torch.float32)
|
||||||
cos = cs[..., :half]
|
cos = cs[..., :half]
|
||||||
sin = cs[..., half:]
|
sin = cs[..., half:]
|
||||||
|
|
||||||
@@ -114,6 +114,18 @@ def _op_available() -> bool:
|
|||||||
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
|
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
|
||||||
|
|
||||||
|
|
||||||
|
def _full_cache_fp8_op_available() -> bool:
|
||||||
|
return hasattr(
|
||||||
|
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _full_cache_bf16_op_available() -> bool:
|
||||||
|
return hasattr(
|
||||||
|
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
not torch.cuda.is_available() or not _op_available(),
|
not torch.cuda.is_available() or not _op_available(),
|
||||||
reason="CUDA not available or fused DeepseekV4 op not built in",
|
reason="CUDA not available or fused DeepseekV4 op not built in",
|
||||||
@@ -415,3 +427,238 @@ def test_combined_q_and_kv(
|
|||||||
"padded head slots must be exact zero"
|
"padded head slots must be exact zero"
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Full-cache (FlashInfer) path parity ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _call_full_cache_fp8_fused(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
q_fp8,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
fp8_scale,
|
||||||
|
q_fp8_scale_inv,
|
||||||
|
eps,
|
||||||
|
bs,
|
||||||
|
):
|
||||||
|
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
q_fp8,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
positions.long(),
|
||||||
|
cos_sin_cache,
|
||||||
|
fp8_scale,
|
||||||
|
q_fp8_scale_inv,
|
||||||
|
eps,
|
||||||
|
bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_full_cache_bf16_fused(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
bs,
|
||||||
|
):
|
||||||
|
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
positions.long(),
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
bs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_full_cache_reference(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
k_cache,
|
||||||
|
q_fp8,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
block_size,
|
||||||
|
fp8_scale,
|
||||||
|
q_fp8_scale_inv,
|
||||||
|
):
|
||||||
|
q_ref = rmsnorm_no_weight(q, eps)
|
||||||
|
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
|
||||||
|
q_fp8.copy_(
|
||||||
|
torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to(
|
||||||
|
torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
|
||||||
|
valid = slot_mapping >= 0
|
||||||
|
slots = slot_mapping[valid]
|
||||||
|
block_idx = slots // block_size
|
||||||
|
pos_in_block = slots % block_size
|
||||||
|
k_cache[block_idx, pos_in_block] = torch.clamp(
|
||||||
|
kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX
|
||||||
|
).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _bf16_full_cache_reference(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
k_cache,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
block_size,
|
||||||
|
):
|
||||||
|
q_ref = rmsnorm_no_weight(q, eps)
|
||||||
|
# Kernel keeps RMSNorm+RoPE in fp32 and rounds to bf16 once at the store.
|
||||||
|
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache).to(q.dtype)
|
||||||
|
|
||||||
|
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
|
||||||
|
valid = slot_mapping >= 0
|
||||||
|
slots = slot_mapping[valid]
|
||||||
|
block_idx = slots // block_size
|
||||||
|
pos_in_block = slots % block_size
|
||||||
|
k_cache[block_idx, pos_in_block] = kv_ref[valid]
|
||||||
|
return q_ref
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _full_cache_fp8_op_available(),
|
||||||
|
reason="full-cache per-tensor FP8 DeepseekV4 op not built in",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("num_tokens", [4, 17])
|
||||||
|
@pytest.mark.parametrize("n_heads", [8, 17])
|
||||||
|
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
|
||||||
|
def test_full_cache_per_tensor_fp8_matches_reference(
|
||||||
|
num_tokens: int,
|
||||||
|
n_heads: int,
|
||||||
|
positions_dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
torch.manual_seed(4)
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
eps = 1e-6
|
||||||
|
block_size = 16
|
||||||
|
max_pos = 4096
|
||||||
|
|
||||||
|
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
|
||||||
|
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||||
|
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
|
||||||
|
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||||
|
|
||||||
|
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||||
|
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||||
|
fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
|
||||||
|
q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn)
|
||||||
|
q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn)
|
||||||
|
k_cache_ref = torch.zeros(
|
||||||
|
num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device
|
||||||
|
)
|
||||||
|
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||||
|
|
||||||
|
_fp8_full_cache_reference(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
k_cache_ref,
|
||||||
|
q_fp8_ref,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
block_size,
|
||||||
|
fp8_scale,
|
||||||
|
q_fp8_scale_inv,
|
||||||
|
)
|
||||||
|
_call_full_cache_fp8_fused(
|
||||||
|
q.clone(),
|
||||||
|
kv,
|
||||||
|
q_fp8_fused,
|
||||||
|
k_cache_fused,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
fp8_scale,
|
||||||
|
q_fp8_scale_inv,
|
||||||
|
eps,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _full_cache_bf16_op_available(),
|
||||||
|
reason="full-cache BF16 DeepseekV4 op not built in",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("num_tokens", [4, 17])
|
||||||
|
@pytest.mark.parametrize("n_heads", [8, 17])
|
||||||
|
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
|
||||||
|
def test_full_cache_bf16_matches_reference(
|
||||||
|
num_tokens: int,
|
||||||
|
n_heads: int,
|
||||||
|
positions_dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
torch.manual_seed(5)
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
eps = 1e-6
|
||||||
|
block_size = 16
|
||||||
|
max_pos = 4096
|
||||||
|
|
||||||
|
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
|
||||||
|
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||||
|
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
|
||||||
|
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||||
|
|
||||||
|
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||||
|
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
q_fused = q.clone()
|
||||||
|
k_cache_ref = torch.zeros(
|
||||||
|
num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device
|
||||||
|
)
|
||||||
|
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||||
|
q_ref = _bf16_full_cache_reference(
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
k_cache_ref,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
_call_full_cache_bf16_fused(
|
||||||
|
q_fused,
|
||||||
|
kv,
|
||||||
|
k_cache_fused,
|
||||||
|
slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos_sin_cache,
|
||||||
|
eps,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
|
||||||
|
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||||
|
|||||||
@@ -0,0 +1,481 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Unit tests for sequence and token pooler head classes."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
|
||||||
|
from vllm.model_executor.layers.pooler.seqwise.heads import (
|
||||||
|
ClassifierPoolerHead,
|
||||||
|
EmbeddingPoolerHead,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.pooler.tokwise.heads import (
|
||||||
|
TokenClassifierPoolerHead,
|
||||||
|
TokenEmbeddingPoolerHead,
|
||||||
|
)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||||
|
|
||||||
|
_HIDDEN = 16
|
||||||
|
_BATCH = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _make_params(
|
||||||
|
n: int,
|
||||||
|
*,
|
||||||
|
task: str = "embed",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
use_activation: bool | None = None,
|
||||||
|
) -> list[PoolingParams]:
|
||||||
|
return [
|
||||||
|
PoolingParams(task=task, dimensions=dimensions, use_activation=use_activation)
|
||||||
|
for _ in range(n)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_metadata(pooling_params: list[PoolingParams]) -> PoolingMetadata:
|
||||||
|
n = len(pooling_params)
|
||||||
|
return PoolingMetadata(
|
||||||
|
prompt_lens=torch.ones(n, dtype=torch.long),
|
||||||
|
prompt_token_ids=None,
|
||||||
|
prompt_token_ids_cpu=None,
|
||||||
|
pooling_params=pooling_params,
|
||||||
|
pooling_states=[PoolingStates() for _ in range(n)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _linear(in_f: int, out_f: int) -> nn.Linear:
|
||||||
|
torch.manual_seed(42)
|
||||||
|
return nn.Linear(in_f, out_f, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# EmbeddingPoolerHead
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestEmbeddingPoolerHead:
|
||||||
|
def test_supported_tasks(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
assert head.get_supported_tasks() == {"embed"}
|
||||||
|
|
||||||
|
def test_passthrough(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_head_dtype(self):
|
||||||
|
head = EmbeddingPoolerHead(head_dtype=torch.float16)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_projector(self):
|
||||||
|
proj = _linear(_HIDDEN, 8)
|
||||||
|
head = EmbeddingPoolerHead(projector=proj)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (_BATCH, 8)
|
||||||
|
assert torch.allclose(out, proj(x))
|
||||||
|
|
||||||
|
def test_matryoshka_uniform(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, dimensions=4)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (_BATCH, 4)
|
||||||
|
assert torch.equal(out, x[..., :4])
|
||||||
|
|
||||||
|
def test_matryoshka_mixed(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
x = torch.randn(2, _HIDDEN)
|
||||||
|
params = [
|
||||||
|
PoolingParams(task="embed", dimensions=4),
|
||||||
|
PoolingParams(task="embed", dimensions=8),
|
||||||
|
]
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert isinstance(out, list)
|
||||||
|
assert len(out) == 2
|
||||||
|
assert out[0].shape[-1] == 4
|
||||||
|
assert out[1].shape[-1] == 8
|
||||||
|
|
||||||
|
def test_matryoshka_mixed_with_none(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
x = torch.randn(2, _HIDDEN)
|
||||||
|
params = [
|
||||||
|
PoolingParams(task="embed", dimensions=4),
|
||||||
|
PoolingParams(task="embed", dimensions=None),
|
||||||
|
]
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert isinstance(out, list)
|
||||||
|
assert out[0].shape[-1] == 4
|
||||||
|
assert torch.equal(out[1], x[1])
|
||||||
|
|
||||||
|
def test_activation_uniform_true(self):
|
||||||
|
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, use_activation=True)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
norms = torch.linalg.norm(out, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
|
||||||
|
|
||||||
|
def test_activation_uniform_false(self):
|
||||||
|
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, use_activation=False)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_activation_mixed_flags(self):
|
||||||
|
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(2, _HIDDEN)
|
||||||
|
params = [
|
||||||
|
PoolingParams(task="embed", use_activation=True),
|
||||||
|
PoolingParams(task="embed", use_activation=False),
|
||||||
|
]
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert isinstance(out, list)
|
||||||
|
norm_0 = torch.linalg.norm(out[0], dim=-1)
|
||||||
|
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
|
||||||
|
assert torch.equal(out[1], x[1])
|
||||||
|
|
||||||
|
def test_list_input_gets_stacked(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
|
||||||
|
meta = _make_metadata(_make_params(_BATCH))
|
||||||
|
out = head(tensors, meta)
|
||||||
|
assert out.shape == (_BATCH, _HIDDEN)
|
||||||
|
expected = torch.stack(tensors)
|
||||||
|
assert torch.equal(out, expected)
|
||||||
|
|
||||||
|
def test_projector_then_matryoshka(self):
|
||||||
|
proj = _linear(_HIDDEN, 8)
|
||||||
|
head = EmbeddingPoolerHead(projector=proj)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, dimensions=4)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (_BATCH, 4)
|
||||||
|
assert torch.equal(out, proj(x)[..., :4])
|
||||||
|
|
||||||
|
def test_matryoshka_then_activation(self):
|
||||||
|
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, dimensions=4, use_activation=True)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (_BATCH, 4)
|
||||||
|
norms = torch.linalg.norm(out, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
|
||||||
|
|
||||||
|
def test_empty_batch(self):
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
x = torch.randn(0, _HIDDEN)
|
||||||
|
meta = _make_metadata([])
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (0, _HIDDEN)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ClassifierPoolerHead
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestClassifierPoolerHead:
|
||||||
|
def test_supported_tasks(self):
|
||||||
|
head = ClassifierPoolerHead()
|
||||||
|
assert head.get_supported_tasks() == {"classify"}
|
||||||
|
|
||||||
|
def test_passthrough(self):
|
||||||
|
head = ClassifierPoolerHead()
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_head_dtype(self):
|
||||||
|
head = ClassifierPoolerHead(head_dtype=torch.float16)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_classifier(self):
|
||||||
|
clf = _linear(_HIDDEN, 3)
|
||||||
|
head = ClassifierPoolerHead(classifier=clf)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (_BATCH, 3)
|
||||||
|
assert torch.allclose(out, clf(x))
|
||||||
|
|
||||||
|
def test_logit_mean(self):
|
||||||
|
head = ClassifierPoolerHead(logit_mean=2.0)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.allclose(out, x - 2.0)
|
||||||
|
|
||||||
|
def test_logit_sigma(self):
|
||||||
|
head = ClassifierPoolerHead(logit_sigma=0.5)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.allclose(out, x / 0.5)
|
||||||
|
|
||||||
|
def test_platt_scaling_combined(self):
|
||||||
|
head = ClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.allclose(out, (x - 1.0) / 2.0)
|
||||||
|
|
||||||
|
def test_activation_uniform_true(self):
|
||||||
|
head = ClassifierPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, task="classify", use_activation=True)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
norms = torch.linalg.norm(out, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
|
||||||
|
|
||||||
|
def test_activation_uniform_false(self):
|
||||||
|
head = ClassifierPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
params = _make_params(_BATCH, task="classify", use_activation=False)
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_activation_mixed_flags(self):
|
||||||
|
head = ClassifierPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(2, _HIDDEN)
|
||||||
|
params = [
|
||||||
|
PoolingParams(task="classify", use_activation=True),
|
||||||
|
PoolingParams(task="classify", use_activation=False),
|
||||||
|
]
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(x, meta)
|
||||||
|
assert isinstance(out, list)
|
||||||
|
norm_0 = torch.linalg.norm(out[0], dim=-1)
|
||||||
|
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
|
||||||
|
assert torch.equal(out[1], x[1])
|
||||||
|
|
||||||
|
def test_list_input_gets_stacked(self):
|
||||||
|
head = ClassifierPoolerHead()
|
||||||
|
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(tensors, meta)
|
||||||
|
assert out.shape == (_BATCH, _HIDDEN)
|
||||||
|
expected = torch.stack(tensors)
|
||||||
|
assert torch.equal(out, expected)
|
||||||
|
|
||||||
|
def test_classifier_then_platt_scaling(self):
|
||||||
|
clf = _linear(_HIDDEN, 3)
|
||||||
|
head = ClassifierPoolerHead(classifier=clf, logit_mean=1.0, logit_sigma=2.0)
|
||||||
|
x = torch.randn(_BATCH, _HIDDEN)
|
||||||
|
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||||
|
out = head(x, meta)
|
||||||
|
expected = (clf(x) - 1.0) / 2.0
|
||||||
|
assert torch.allclose(out, expected)
|
||||||
|
|
||||||
|
def test_empty_batch(self):
|
||||||
|
head = ClassifierPoolerHead()
|
||||||
|
x = torch.randn(0, _HIDDEN)
|
||||||
|
meta = _make_metadata([])
|
||||||
|
out = head(x, meta)
|
||||||
|
assert out.shape == (0, _HIDDEN)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TokenEmbeddingPoolerHead
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestTokenEmbeddingPoolerHead:
|
||||||
|
def test_supported_tasks(self):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
assert head.get_supported_tasks() == {"token_embed"}
|
||||||
|
|
||||||
|
def test_passthrough(self):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_none_chunked_prefill(self):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
param = PoolingParams(task="token_embed")
|
||||||
|
out = head.forward_chunk(None, param)
|
||||||
|
assert out is None
|
||||||
|
|
||||||
|
def test_head_dtype(self):
|
||||||
|
head = TokenEmbeddingPoolerHead(head_dtype=torch.float16)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_projector(self):
|
||||||
|
proj = _linear(_HIDDEN, 8)
|
||||||
|
head = TokenEmbeddingPoolerHead(projector=proj)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.shape == (5, 8)
|
||||||
|
assert torch.allclose(out, proj(x))
|
||||||
|
|
||||||
|
def test_matryoshka_truncation(self):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed", dimensions=4)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.shape == (5, 4)
|
||||||
|
assert torch.equal(out, x[..., :4])
|
||||||
|
|
||||||
|
def test_activation_true(self):
|
||||||
|
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed", use_activation=True)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
norms = torch.linalg.norm(out, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
|
||||||
|
|
||||||
|
def test_activation_false(self):
|
||||||
|
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed", use_activation=False)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_projector_then_matryoshka(self):
|
||||||
|
proj = _linear(_HIDDEN, 8)
|
||||||
|
head = TokenEmbeddingPoolerHead(projector=proj)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed", dimensions=4)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.shape == (5, 4)
|
||||||
|
assert torch.equal(out, proj(x)[..., :4])
|
||||||
|
|
||||||
|
def test_matryoshka_then_activation(self):
|
||||||
|
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_embed", dimensions=4, use_activation=True)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.shape == (5, 4)
|
||||||
|
norms = torch.linalg.norm(out, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
|
||||||
|
|
||||||
|
def test_forward_mixed_batch_chunked_prefill(self):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
|
||||||
|
params = _make_params(3, task="token_embed")
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(pooled_data, meta)
|
||||||
|
assert len(out) == 3
|
||||||
|
assert torch.equal(out[0], pooled_data[0])
|
||||||
|
assert out[1] is None
|
||||||
|
assert torch.equal(out[2], pooled_data[2])
|
||||||
|
|
||||||
|
def test_forward_empty_batch(self):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
meta = _make_metadata([])
|
||||||
|
out = head([], meta)
|
||||||
|
assert out == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TokenClassifierPoolerHead
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class TestTokenClassifierPoolerHead:
|
||||||
|
def test_supported_tasks(self):
|
||||||
|
head = TokenClassifierPoolerHead()
|
||||||
|
assert head.get_supported_tasks() == {"token_classify"}
|
||||||
|
|
||||||
|
def test_passthrough(self):
|
||||||
|
head = TokenClassifierPoolerHead()
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_none_chunked_prefill(self):
|
||||||
|
head = TokenClassifierPoolerHead()
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(None, param)
|
||||||
|
assert out is None
|
||||||
|
|
||||||
|
def test_head_dtype(self):
|
||||||
|
head = TokenClassifierPoolerHead(head_dtype=torch.float16)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_classifier(self):
|
||||||
|
clf = _linear(_HIDDEN, 3)
|
||||||
|
head = TokenClassifierPoolerHead(classifier=clf)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert out.shape == (5, 3)
|
||||||
|
assert torch.allclose(out, clf(x))
|
||||||
|
|
||||||
|
def test_logit_mean(self):
|
||||||
|
head = TokenClassifierPoolerHead(logit_mean=2.0)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.allclose(out, x - 2.0)
|
||||||
|
|
||||||
|
def test_logit_sigma(self):
|
||||||
|
head = TokenClassifierPoolerHead(logit_sigma=0.5)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.allclose(out, x / 0.5)
|
||||||
|
|
||||||
|
def test_platt_scaling_combined(self):
|
||||||
|
head = TokenClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify")
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.allclose(out, (x - 1.0) / 2.0)
|
||||||
|
|
||||||
|
def test_activation_true(self):
|
||||||
|
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify", use_activation=True)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
norms = torch.linalg.norm(out, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
|
||||||
|
|
||||||
|
def test_activation_false(self):
|
||||||
|
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
|
||||||
|
x = torch.randn(5, _HIDDEN)
|
||||||
|
param = PoolingParams(task="token_classify", use_activation=False)
|
||||||
|
out = head.forward_chunk(x, param)
|
||||||
|
assert torch.equal(out, x)
|
||||||
|
|
||||||
|
def test_forward_mixed_batch_chunked_prefill(self):
|
||||||
|
head = TokenClassifierPoolerHead()
|
||||||
|
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
|
||||||
|
params = _make_params(3, task="token_classify")
|
||||||
|
meta = _make_metadata(params)
|
||||||
|
out = head(pooled_data, meta)
|
||||||
|
assert len(out) == 3
|
||||||
|
assert torch.equal(out[0], pooled_data[0])
|
||||||
|
assert out[1] is None
|
||||||
|
assert torch.equal(out[2], pooled_data[2])
|
||||||
|
|
||||||
|
def test_forward_empty_batch(self):
|
||||||
|
head = TokenClassifierPoolerHead()
|
||||||
|
meta = _make_metadata([])
|
||||||
|
out = head([], meta)
|
||||||
|
assert out == []
|
||||||
@@ -2,11 +2,12 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||||
from tests.utils import multi_gpu_test
|
from tests.utils import multi_gpu_test, wait_for_gpu_memory_to_clear
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -404,6 +405,30 @@ def _get_vllm_runner_params(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_rocm_memory_to_settle() -> None:
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return
|
||||||
|
|
||||||
|
num_gpus = current_platform.device_count()
|
||||||
|
if num_gpus == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(num_gpus)),
|
||||||
|
threshold_ratio=0.01,
|
||||||
|
timeout_s=120,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _owned_vLLM_runner(vllm_runner, kwargs):
|
||||||
|
try:
|
||||||
|
with vllm_runner(**kwargs) as runner:
|
||||||
|
yield runner
|
||||||
|
finally:
|
||||||
|
_wait_for_rocm_memory_to_settle()
|
||||||
|
|
||||||
|
|
||||||
def _get_vLLM_output(
|
def _get_vLLM_output(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
kwargs,
|
kwargs,
|
||||||
@@ -413,17 +438,21 @@ def _get_vLLM_output(
|
|||||||
num_repetitions=1,
|
num_repetitions=1,
|
||||||
vllm_model=None,
|
vllm_model=None,
|
||||||
):
|
):
|
||||||
outs = []
|
runner_context = (
|
||||||
if vllm_model is None:
|
_owned_vLLM_runner(vllm_runner, kwargs)
|
||||||
vllm_model = vllm_runner(**kwargs)
|
if vllm_model is None
|
||||||
for _ in range(num_repetitions):
|
else nullcontext(vllm_model)
|
||||||
if num_logprobs < 0:
|
)
|
||||||
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
|
with runner_context as runner:
|
||||||
else:
|
outs = []
|
||||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
for _ in range(num_repetitions):
|
||||||
prompts, max_tokens, num_logprobs
|
if num_logprobs < 0:
|
||||||
)
|
vllm_output = runner.generate_greedy(prompts, max_tokens)
|
||||||
outs.append(vllm_output)
|
else:
|
||||||
|
vllm_output = runner.generate_greedy_logprobs(
|
||||||
|
prompts, max_tokens, num_logprobs
|
||||||
|
)
|
||||||
|
outs.append(vllm_output)
|
||||||
|
|
||||||
return outs, vllm_model
|
return outs, vllm_model
|
||||||
|
|
||||||
@@ -772,38 +801,44 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
|||||||
|
|
||||||
# Cache only part of all the prompts
|
# Cache only part of all the prompts
|
||||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||||
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
|
with _owned_vLLM_runner(vllm_runner, vllm_runner_kwargs) as vllm_model:
|
||||||
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
|
vllm_outputs_partial_cache, _ = _get_vLLM_output(
|
||||||
)
|
vllm_runner,
|
||||||
|
vllm_runner_kwargs,
|
||||||
compare_operator(
|
generated_prompts[:3],
|
||||||
outputs_0_lst=vllm_outputs_no_cache[0][:3],
|
max_tokens,
|
||||||
outputs_1_lst=vllm_outputs_partial_cache[0],
|
num_logprobs,
|
||||||
name_0="vllm_no_cache",
|
vllm_model=vllm_model,
|
||||||
name_1="vllm_partial_cache",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
|
||||||
vllm_runner,
|
|
||||||
vllm_runner_kwargs,
|
|
||||||
generated_prompts,
|
|
||||||
max_tokens,
|
|
||||||
num_logprobs,
|
|
||||||
n_repetitions,
|
|
||||||
vllm_model=vllm_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
|
||||||
# In the first repetition, the caches are filled
|
|
||||||
# In the second repetition, these caches are reused
|
|
||||||
|
|
||||||
compare_operator(
|
compare_operator(
|
||||||
outputs_0_lst=vllm_outputs_no_cache[0],
|
outputs_0_lst=vllm_outputs_no_cache[0][:3],
|
||||||
outputs_1_lst=vllm_outputs_cache_itn,
|
outputs_1_lst=vllm_outputs_partial_cache[0],
|
||||||
name_0="vllm_no_cache",
|
name_0="vllm_no_cache",
|
||||||
name_1=f"vllm_cache_it_{r_idx + 1}",
|
name_1="vllm_partial_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
||||||
|
vllm_runner,
|
||||||
|
vllm_runner_kwargs,
|
||||||
|
generated_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
n_repetitions,
|
||||||
|
vllm_model=vllm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||||
|
# In the first repetition, the caches are filled
|
||||||
|
# In the second repetition, these caches are reused
|
||||||
|
|
||||||
|
compare_operator(
|
||||||
|
outputs_0_lst=vllm_outputs_no_cache[0],
|
||||||
|
outputs_1_lst=vllm_outputs_cache_itn,
|
||||||
|
name_0="vllm_no_cache",
|
||||||
|
name_1=f"vllm_cache_it_{r_idx + 1}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Test that outputs match whether prefix caching is enabled or not for mamba.
|
# Test that outputs match whether prefix caching is enabled or not for mamba.
|
||||||
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
|
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
|
||||||
@@ -826,7 +861,7 @@ def test_same_mamba_output_apc_on_vs_off(
|
|||||||
|
|
||||||
# No prefix caching
|
# No prefix caching
|
||||||
kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
|
kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
|
||||||
with vllm_runner(**kwargs_no_apc) as vllm_model:
|
with _owned_vLLM_runner(vllm_runner, kwargs_no_apc) as vllm_model:
|
||||||
outputs_no_apc, _ = _get_vLLM_output(
|
outputs_no_apc, _ = _get_vLLM_output(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
kwargs_no_apc,
|
kwargs_no_apc,
|
||||||
@@ -841,7 +876,7 @@ def test_same_mamba_output_apc_on_vs_off(
|
|||||||
"enable_prefix_caching": True,
|
"enable_prefix_caching": True,
|
||||||
"mamba_block_size": 16,
|
"mamba_block_size": 16,
|
||||||
}
|
}
|
||||||
with vllm_runner(**kwargs_with_apc) as vllm_model:
|
with _owned_vLLM_runner(vllm_runner, kwargs_with_apc) as vllm_model:
|
||||||
outputs_with_apc, _ = _get_vLLM_output(
|
outputs_with_apc, _ = _get_vLLM_output(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
kwargs_with_apc,
|
kwargs_with_apc,
|
||||||
|
|||||||
@@ -30,11 +30,14 @@ def vllm_to_hf_output(
|
|||||||
|
|
||||||
MODEL_NAME = "ibm-granite/granite-speech-3.3-2b"
|
MODEL_NAME = "ibm-granite/granite-speech-3.3-2b"
|
||||||
MODEL_NAME_4_0 = "ibm-granite/granite-4.0-1b-speech"
|
MODEL_NAME_4_0 = "ibm-granite/granite-4.0-1b-speech"
|
||||||
|
# "plus" variant of granite speech (uses GraniteSpeechPlusForConditionalGeneration).
|
||||||
|
MODEL_NAME_4_1_PLUS = "ibm-granite/granite-speech-4.1-2b-plus"
|
||||||
# Audio lora co-exists directly in the 3.3 model directory,
|
# Audio lora co-exists directly in the 3.3 model directory,
|
||||||
# the 4.0 model has adapters merged into the weights.
|
# the 4.0 and 4.1-plus models have adapters merged into the weights.
|
||||||
models: dict[str, str | None] = {
|
models: dict[str, str | None] = {
|
||||||
MODEL_NAME: MODEL_NAME,
|
MODEL_NAME: MODEL_NAME,
|
||||||
MODEL_NAME_4_0: None,
|
MODEL_NAME_4_0: None,
|
||||||
|
MODEL_NAME_4_1_PLUS: None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,10 @@ def qwen_vl_chat_template(content: str) -> str:
|
|||||||
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
|
||||||
|
def internvl_chat_template(content: str) -> str:
|
||||||
|
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
|
||||||
def step3_vl_chat_template(content: str) -> str:
|
def step3_vl_chat_template(content: str) -> str:
|
||||||
return (
|
return (
|
||||||
"<|begin▁of▁sentence|> You are a helpful assistant.<|BOT|>user\n "
|
"<|begin▁of▁sentence|> You are a helpful assistant.<|BOT|>user\n "
|
||||||
@@ -51,6 +55,17 @@ def step3_vl_chat_template(content: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
|
MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
|
||||||
|
"internvl": VitCudagraphTestConfig(
|
||||||
|
model="OpenGVLab/InternVL3-1B",
|
||||||
|
num_video_frames=8,
|
||||||
|
image_prompt=internvl_chat_template("<image>\nWhat is in this image?"),
|
||||||
|
video_prompt=internvl_chat_template(
|
||||||
|
"<video>\nDescribe this video in one sentence."
|
||||||
|
),
|
||||||
|
needs_video_metadata=False,
|
||||||
|
vllm_runner_kwargs={"trust_remote_code": True},
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
),
|
||||||
"qwen2_5_vl": VitCudagraphTestConfig(
|
"qwen2_5_vl": VitCudagraphTestConfig(
|
||||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
image_prompt=qwen_vl_chat_template(
|
image_prompt=qwen_vl_chat_template(
|
||||||
|
|||||||
@@ -938,6 +938,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"ibm-granite/granite-speech-3.3-2b",
|
"ibm-granite/granite-speech-3.3-2b",
|
||||||
extras={"4.0-1b": "ibm-granite/granite-4.0-1b-speech"},
|
extras={"4.0-1b": "ibm-granite/granite-4.0-1b-speech"},
|
||||||
),
|
),
|
||||||
|
"GraniteSpeechPlusForConditionalGeneration": _HfExamplesInfo(
|
||||||
|
"ibm-granite/granite-speech-4.1-2b-plus",
|
||||||
|
min_transformers_version="5.8.0",
|
||||||
|
),
|
||||||
"GLM4VForCausalLM": _HfExamplesInfo(
|
"GLM4VForCausalLM": _HfExamplesInfo(
|
||||||
"zai-org/glm-4v-9b",
|
"zai-org/glm-4v-9b",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
@@ -36,11 +36,24 @@ def tokenizer():
|
|||||||
return get_tokenizer("Qwen/Qwen3-32B")
|
return get_tokenizer("Qwen/Qwen3-32B")
|
||||||
|
|
||||||
|
|
||||||
|
TOOLS = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def request_obj():
|
def request_obj():
|
||||||
return ChatCompletionRequest(
|
return ChatCompletionRequest(
|
||||||
model="test-model",
|
model="test-model",
|
||||||
messages=[{"role": "user", "content": "hi"}],
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=TOOLS,
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -328,3 +341,27 @@ def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
|
|||||||
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
||||||
)
|
)
|
||||||
assert tool_args.endswith(remainder)
|
assert tool_args.endswith(remainder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_delta_tool_choice_none(tokenizer, request_obj):
|
||||||
|
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||||
|
request = request_obj.model_copy(update={"tool_choice": "none"})
|
||||||
|
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
|
||||||
|
reasoning, content, tool_calls = collect_fields(results)
|
||||||
|
|
||||||
|
assert reasoning == ""
|
||||||
|
assert len(tool_calls) == 0
|
||||||
|
assert "<tool_call>" in content
|
||||||
|
assert "get_weather" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_delta_tool_choice_none_with_reasoning(tokenizer, request_obj):
|
||||||
|
parser = make_parser(tokenizer, reasoning=True, tool=True)
|
||||||
|
request = request_obj.model_copy(update={"tool_choice": "none"})
|
||||||
|
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
|
||||||
|
reasoning, content, tool_calls = collect_fields(results)
|
||||||
|
|
||||||
|
assert "let me think about this" in reasoning
|
||||||
|
assert len(tool_calls) == 0
|
||||||
|
assert "<tool_call>" in content
|
||||||
|
assert "get_weather" in content
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsW4A4Fp4,
|
CompressedTensorsW4A4Fp4,
|
||||||
CompressedTensorsW4A4Mxfp4,
|
CompressedTensorsW4A4Mxfp4,
|
||||||
CompressedTensorsW4A8Fp8,
|
CompressedTensorsW4A8Fp8,
|
||||||
CompressedTensorsW4A16Fp4,
|
|
||||||
CompressedTensorsW8A8Fp8,
|
CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A8Int8,
|
CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsW8A8Mxfp8,
|
CompressedTensorsW8A8Mxfp8,
|
||||||
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|||||||
find_matched_target,
|
find_matched_target,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
|
||||||
cutlass_fp4_supported,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||||
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args",
|
"args",
|
||||||
[
|
[
|
||||||
# TODO: Enable once model is available again
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
|
||||||
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
|
||||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_compressed_tensors_nvfp4(vllm_runner, args):
|
def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||||
model, scheme = args
|
model, use_a16 = args
|
||||||
with vllm_runner(model, enforce_eager=True) as llm:
|
with vllm_runner(model, enforce_eager=True) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
|||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
if (
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
|
||||||
isinstance(qkv_proj.scheme, scheme)
|
assert qkv_proj.scheme.use_a16 == use_a16
|
||||||
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
|
|
||||||
and not cutlass_fp4_supported()
|
|
||||||
):
|
|
||||||
assert True
|
|
||||||
else:
|
|
||||||
raise AssertionError("FP4 Scheme Mismatch")
|
|
||||||
|
|
||||||
assert qkv_proj.scheme.group_size == 16
|
assert qkv_proj.scheme.group_size == 16
|
||||||
|
|
||||||
llm.apply_model(check_model)
|
llm.apply_model(check_model)
|
||||||
|
|||||||
+1
-244
@@ -10,12 +10,11 @@ from dataclasses import dataclass
|
|||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import patch
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
|
||||||
from vllm.logger import (
|
from vllm.logger import (
|
||||||
_DATE_FORMAT,
|
_DATE_FORMAT,
|
||||||
_FORMAT,
|
_FORMAT,
|
||||||
@@ -269,248 +268,6 @@ def test_prepare_object_to_dump():
|
|||||||
assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')"
|
assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')"
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs():
|
|
||||||
"""Test the new log_outputs functionality."""
|
|
||||||
# Create a mock logger to capture log calls
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=None)
|
|
||||||
|
|
||||||
# Test basic output logging
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-123",
|
|
||||||
outputs="Hello, world!",
|
|
||||||
output_token_ids=[1, 2, 3, 4],
|
|
||||||
finish_reason="stop",
|
|
||||||
is_streaming=False,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args.args
|
|
||||||
assert "Generated response %s%s" in call_args[0]
|
|
||||||
assert call_args[1] == "test-123"
|
|
||||||
assert call_args[3] == "Hello, world!"
|
|
||||||
assert call_args[4] == [1, 2, 3, 4]
|
|
||||||
assert call_args[5] == "stop"
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs_streaming_delta():
|
|
||||||
"""Test log_outputs with streaming delta mode."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=None)
|
|
||||||
|
|
||||||
# Test streaming delta logging
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-456",
|
|
||||||
outputs="Hello",
|
|
||||||
output_token_ids=[1],
|
|
||||||
finish_reason=None,
|
|
||||||
is_streaming=True,
|
|
||||||
delta=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args.args
|
|
||||||
assert "Generated response %s%s" in call_args[0]
|
|
||||||
assert call_args[1] == "test-456"
|
|
||||||
assert call_args[2] == " (streaming delta)"
|
|
||||||
assert call_args[3] == "Hello"
|
|
||||||
assert call_args[4] == [1]
|
|
||||||
assert call_args[5] is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs_streaming_complete():
|
|
||||||
"""Test log_outputs with streaming complete mode."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=None)
|
|
||||||
|
|
||||||
# Test streaming complete logging
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-789",
|
|
||||||
outputs="Complete response",
|
|
||||||
output_token_ids=[1, 2, 3],
|
|
||||||
finish_reason="length",
|
|
||||||
is_streaming=True,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args.args
|
|
||||||
assert "Generated response %s%s" in call_args[0]
|
|
||||||
assert call_args[1] == "test-789"
|
|
||||||
assert call_args[2] == " (streaming complete)"
|
|
||||||
assert call_args[3] == "Complete response"
|
|
||||||
assert call_args[4] == [1, 2, 3]
|
|
||||||
assert call_args[5] == "length"
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs_with_truncation():
|
|
||||||
"""Test log_outputs respects max_log_len setting."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
# Set max_log_len to 10
|
|
||||||
request_logger = RequestLogger(max_log_len=10)
|
|
||||||
|
|
||||||
# Test output truncation
|
|
||||||
long_output = "This is a very long output that should be truncated"
|
|
||||||
long_token_ids = list(range(20)) # 20 tokens
|
|
||||||
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-truncate",
|
|
||||||
outputs=long_output,
|
|
||||||
output_token_ids=long_token_ids,
|
|
||||||
finish_reason="stop",
|
|
||||||
is_streaming=False,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args
|
|
||||||
|
|
||||||
# Check that output was truncated to first 10 characters
|
|
||||||
logged_output = call_args[0][3]
|
|
||||||
assert logged_output == "This is a "
|
|
||||||
assert len(logged_output) == 10
|
|
||||||
|
|
||||||
# Check that token IDs were truncated to first 10 tokens
|
|
||||||
logged_token_ids = call_args[0][4]
|
|
||||||
assert logged_token_ids == list(range(10))
|
|
||||||
assert len(logged_token_ids) == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs_none_values():
|
|
||||||
"""Test log_outputs handles None values correctly."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=None)
|
|
||||||
|
|
||||||
# Test with None output_token_ids
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-none",
|
|
||||||
outputs="Test output",
|
|
||||||
output_token_ids=None,
|
|
||||||
finish_reason="stop",
|
|
||||||
is_streaming=False,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args.args
|
|
||||||
assert "Generated response %s%s" in call_args[0]
|
|
||||||
assert call_args[1] == "test-none"
|
|
||||||
assert call_args[3] == "Test output"
|
|
||||||
assert call_args[4] is None
|
|
||||||
assert call_args[5] == "stop"
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs_empty_output():
|
|
||||||
"""Test log_outputs handles empty output correctly."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=5)
|
|
||||||
|
|
||||||
# Test with empty output
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-empty",
|
|
||||||
outputs="",
|
|
||||||
output_token_ids=[],
|
|
||||||
finish_reason="stop",
|
|
||||||
is_streaming=False,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args.args
|
|
||||||
assert "Generated response %s%s" in call_args[0]
|
|
||||||
assert call_args[1] == "test-empty"
|
|
||||||
assert call_args[3] == ""
|
|
||||||
assert call_args[4] == []
|
|
||||||
assert call_args[5] == "stop"
|
|
||||||
|
|
||||||
|
|
||||||
def test_request_logger_log_outputs_integration():
|
|
||||||
"""Test that log_outputs can be called alongside log_inputs."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=None)
|
|
||||||
|
|
||||||
# Test that both methods can be called without interference
|
|
||||||
request_logger.log_inputs(
|
|
||||||
request_id="test-integration",
|
|
||||||
prompt="Test prompt",
|
|
||||||
prompt_token_ids=[1, 2, 3],
|
|
||||||
prompt_embeds=None,
|
|
||||||
params=None,
|
|
||||||
lora_request=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-integration",
|
|
||||||
outputs="Test output",
|
|
||||||
output_token_ids=[4, 5, 6],
|
|
||||||
finish_reason="stop",
|
|
||||||
is_streaming=False,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should have been called twice - once for inputs, once for outputs
|
|
||||||
assert mock_logger.info.call_count == 2
|
|
||||||
|
|
||||||
# Check that the calls were made with correct patterns
|
|
||||||
input_call = mock_logger.info.call_args_list[0][0]
|
|
||||||
output_call = mock_logger.info.call_args_list[1][0]
|
|
||||||
|
|
||||||
assert "Received request %s" in input_call[0]
|
|
||||||
assert input_call[1] == "test-integration"
|
|
||||||
|
|
||||||
assert "Generated response %s%s" in output_call[0]
|
|
||||||
assert output_call[1] == "test-integration"
|
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_complete_logs_full_text_content():
|
|
||||||
"""Test that streaming complete logging includes
|
|
||||||
full accumulated text, not just token count."""
|
|
||||||
mock_logger = MagicMock()
|
|
||||||
|
|
||||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
|
||||||
request_logger = RequestLogger(max_log_len=None)
|
|
||||||
|
|
||||||
# Test with actual content instead of token count format
|
|
||||||
full_response = "This is a complete response from streaming"
|
|
||||||
request_logger.log_outputs(
|
|
||||||
request_id="test-streaming-full-text",
|
|
||||||
outputs=full_response,
|
|
||||||
output_token_ids=None,
|
|
||||||
finish_reason="streaming_complete",
|
|
||||||
is_streaming=True,
|
|
||||||
delta=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_logger.info.assert_called_once()
|
|
||||||
call_args = mock_logger.info.call_args.args
|
|
||||||
|
|
||||||
# Verify the logged output is the full text, not a token count format
|
|
||||||
logged_output = call_args[3]
|
|
||||||
assert logged_output == full_response
|
|
||||||
assert "tokens>" not in logged_output
|
|
||||||
assert "streaming_complete" not in logged_output
|
|
||||||
|
|
||||||
# Verify other parameters
|
|
||||||
assert call_args[1] == "test-streaming-full-text"
|
|
||||||
assert call_args[2] == " (streaming complete)"
|
|
||||||
assert call_args[5] == "streaming_complete"
|
|
||||||
|
|
||||||
|
|
||||||
# Add vllm prefix to make sure logs go through the vllm logger
|
# Add vllm prefix to make sure logs go through the vllm logger
|
||||||
test_logger = init_logger("vllm.test_logger")
|
test_logger = init_logger("vllm.test_logger")
|
||||||
|
|
||||||
|
|||||||
@@ -54,15 +54,14 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
|||||||
(
|
(
|
||||||
MiniMaxText01LinearAttention,
|
MiniMaxText01LinearAttention,
|
||||||
dict(
|
dict(
|
||||||
hidden_size=128,
|
config=SimpleNamespace(
|
||||||
hidden_inner_size=256,
|
hidden_size=256,
|
||||||
num_heads=8,
|
num_attention_heads=8,
|
||||||
head_dim=32,
|
head_dim=32,
|
||||||
max_position=2048,
|
num_hidden_layers=12,
|
||||||
block_size=64,
|
block=64,
|
||||||
num_hidden_layer=12,
|
),
|
||||||
layer_idx=0,
|
prefix="layers.0.self_attn",
|
||||||
linear_layer_idx=0,
|
|
||||||
),
|
),
|
||||||
LinearAttentionBackend,
|
LinearAttentionBackend,
|
||||||
MambaAttentionBackendEnum.LINEAR,
|
MambaAttentionBackendEnum.LINEAR,
|
||||||
@@ -88,6 +87,8 @@ def test_mamba_layers_get_attn_backend(
|
|||||||
expected_mamba_type,
|
expected_mamba_type,
|
||||||
):
|
):
|
||||||
"""Test that Mamba-like layers return the correct attention backend."""
|
"""Test that Mamba-like layers return the correct attention backend."""
|
||||||
|
if layer_class is MiniMaxText01LinearAttention:
|
||||||
|
init_kwargs["vllm_config"] = default_vllm_config
|
||||||
layer = layer_class(**init_kwargs)
|
layer = layer_class(**init_kwargs)
|
||||||
|
|
||||||
backend_class = layer.get_attn_backend()
|
backend_class = layer.get_attn_backend()
|
||||||
|
|||||||
@@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_kv_cache_block_queue_prepend_n():
|
||||||
|
# Seed the queue with one block so prepend has an existing head to splice
|
||||||
|
# in front of (fake_head->b0->fake_tail).
|
||||||
|
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||||
|
queue = FreeKVCacheBlockQueue(blocks[0:1])
|
||||||
|
|
||||||
|
# Prepend 0 blocks is a no-op.
|
||||||
|
queue.prepend_n([])
|
||||||
|
assert queue.num_free_blocks == 1
|
||||||
|
assert queue.fake_free_list_head.next_free_block is blocks[0]
|
||||||
|
|
||||||
|
# Prepend 2 blocks; they land in front of the existing head, in order.
|
||||||
|
# fake_head->b4->b5->b0->fake_tail
|
||||||
|
queue.prepend_n(blocks[4:6])
|
||||||
|
assert queue.num_free_blocks == 3
|
||||||
|
assert queue.fake_free_list_head.next_free_block is blocks[4]
|
||||||
|
assert blocks[4].prev_free_block is queue.fake_free_list_head
|
||||||
|
assert blocks[4].next_free_block is blocks[5]
|
||||||
|
assert blocks[5].prev_free_block is blocks[4]
|
||||||
|
assert blocks[5].next_free_block is blocks[0]
|
||||||
|
assert blocks[0].prev_free_block is blocks[5]
|
||||||
|
assert blocks[0].next_free_block is queue.fake_free_list_tail
|
||||||
|
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
|
||||||
|
|
||||||
|
# A second prepend goes ahead of everything previously prepended.
|
||||||
|
# fake_head->b1->b2->b4->b5->b0->fake_tail
|
||||||
|
queue.prepend_n(blocks[1:3])
|
||||||
|
assert queue.num_free_blocks == 5
|
||||||
|
assert queue.fake_free_list_head.next_free_block is blocks[1]
|
||||||
|
assert blocks[1].next_free_block is blocks[2]
|
||||||
|
assert blocks[2].next_free_block is blocks[4]
|
||||||
|
|
||||||
|
# The popleft order reflects the front-to-back queue order.
|
||||||
|
assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0]
|
||||||
|
assert queue.num_free_blocks == 0
|
||||||
|
|
||||||
|
|
||||||
def test_free_kv_cache_block_queue_popleft_n():
|
def test_free_kv_cache_block_queue_popleft_n():
|
||||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||||
# Create an empty FreeKVCacheBlockQueue with these blocks
|
# Create an empty FreeKVCacheBlockQueue with these blocks
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import (
|
|||||||
KVCacheGroupSpec,
|
KVCacheGroupSpec,
|
||||||
KVCacheSpecKind,
|
KVCacheSpecKind,
|
||||||
MambaSpec,
|
MambaSpec,
|
||||||
|
MLAAttentionSpec,
|
||||||
SlidingWindowSpec,
|
SlidingWindowSpec,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
|
||||||
|
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
||||||
|
block_size = 8
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer1"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=4 * block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float16,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer2"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=block_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The SWA manager uses the configured 64-token interval (a multiple of the
|
||||||
|
# 32-token lcm_block_size) as its retention segment. For this 128-token
|
||||||
|
# prompt, the retained SWA tails are the 64-token interval boundary, the
|
||||||
|
# 96-token replay boundary, and the 128-token interval boundary.
|
||||||
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||||
|
req = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req,
|
||||||
|
len(token_ids),
|
||||||
|
len(computed_blocks.blocks[0]) * block_size,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
pool = manager.block_pool
|
||||||
|
expected_swa_cached = {7, 11, 15}
|
||||||
|
for i in range(16):
|
||||||
|
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
|
||||||
|
if i in expected_swa_cached:
|
||||||
|
assert cached is not None, f"SWA hash {i} should be cached"
|
||||||
|
else:
|
||||||
|
assert cached is None, f"SWA hash {i} should not be cached"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"interval, expected_match",
|
||||||
|
[
|
||||||
|
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
|
||||||
|
("33", "multiple of scheduler_block_size"),
|
||||||
|
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
|
||||||
|
# otherwise it would pass the modulo check and silently degrade to dense.
|
||||||
|
("-32", "non-negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_hybrid_local_kv_retention_interval_rejects_invalid(
|
||||||
|
monkeypatch, interval, expected_match
|
||||||
|
):
|
||||||
|
"""A retention interval that is negative or not a multiple of
|
||||||
|
scheduler_block_size errors out at construction time."""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
|
||||||
|
block_size = 8
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer1"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=4 * block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float16,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer2"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=block_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match=expected_match):
|
||||||
|
make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
|
||||||
|
"""Verify retained local checkpoints are reused after block recycling."""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
|
||||||
|
hash_block_size = 4
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=800,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["full"],
|
||||||
|
MLAAttentionSpec(
|
||||||
|
block_size=64 * hash_block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
compress_ratio=4,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["swa"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=16 * hash_block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=512,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["c128"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=2 * hash_block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=128,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["c4"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=hash_block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=8,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=4096,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=hash_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fill_request(request_id: str, token_offset: int) -> list[int]:
|
||||||
|
token_ids = [
|
||||||
|
token_offset + i for i in range(1024) for _ in range(hash_block_size)
|
||||||
|
]
|
||||||
|
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
|
||||||
|
while fill_req.num_computed_tokens < len(token_ids):
|
||||||
|
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
|
||||||
|
blocks = manager.allocate_slots(fill_req, num_new_tokens)
|
||||||
|
assert blocks is not None
|
||||||
|
fill_req.num_computed_tokens += num_new_tokens
|
||||||
|
manager.free(fill_req)
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
token_ids = fill_request("fill_0", 0)
|
||||||
|
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
||||||
|
assert num_computed_tokens == 1024
|
||||||
|
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
||||||
|
|
||||||
|
fill_request("fill_1", 100_000)
|
||||||
|
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
||||||
|
assert num_computed_tokens == 1024
|
||||||
|
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
|
||||||
|
"""Verify latest-only retention reuses only the replayable prompt boundary."""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||||
|
block_size = 8
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer1"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=4 * block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float16,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer2"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=block_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||||
|
req0 = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req0)
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req0,
|
||||||
|
len(token_ids),
|
||||||
|
len(computed_blocks.blocks[0]) * block_size,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
pool = manager.block_pool
|
||||||
|
expected_swa_cached = {11}
|
||||||
|
for i in range(16):
|
||||||
|
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
||||||
|
if i in expected_swa_cached:
|
||||||
|
assert cached is not None, f"SWA hash {i} should be cached"
|
||||||
|
else:
|
||||||
|
assert cached is None, f"SWA hash {i} should not be cached"
|
||||||
|
|
||||||
|
manager.free(req0)
|
||||||
|
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
|
||||||
|
assert retained_swa_block is not None
|
||||||
|
assert retained_swa_block[0].ref_cnt == 0
|
||||||
|
|
||||||
|
req1 = make_request("1", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
# Full prompt hits intentionally recompute the final block for logits, so
|
||||||
|
# the longest usable hit is the previous LCM boundary: 96 tokens.
|
||||||
|
assert num_computed_tokens == 12 * block_size
|
||||||
|
assert len(computed_blocks.blocks[1]) == 12
|
||||||
|
|
||||||
|
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
|
||||||
|
assert num_computed_tokens == 0
|
||||||
|
assert len(computed_blocks.blocks[1]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
|
||||||
|
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
|
||||||
|
|
||||||
|
EAGLE/MTP lookup matches one additional local block after the returned
|
||||||
|
prefix and then drops it. Sparse retention must therefore cache the normal
|
||||||
|
local tail at the latest replay boundary plus one extra SWA block.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||||
|
block_size = 8
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["full"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=4 * block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float16,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["swa_mtp"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=block_size,
|
||||||
|
),
|
||||||
|
is_eagle_group=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
use_eagle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
|
||||||
|
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
|
||||||
|
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
|
||||||
|
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
|
||||||
|
req0 = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
assert num_computed_tokens == 0
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req0,
|
||||||
|
len(token_ids),
|
||||||
|
num_computed_tokens,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
pool = manager.block_pool
|
||||||
|
expected_swa_cached = {11, 12}
|
||||||
|
for i in range(15):
|
||||||
|
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
||||||
|
if i in expected_swa_cached:
|
||||||
|
assert cached is not None, f"SWA hash {i} should be cached"
|
||||||
|
else:
|
||||||
|
assert cached is None, f"SWA hash {i} should not be cached"
|
||||||
|
|
||||||
|
manager.free(req0)
|
||||||
|
|
||||||
|
req1 = make_request("1", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
assert num_computed_tokens == 12 * block_size
|
||||||
|
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
|
||||||
|
|
||||||
|
|
||||||
def test_block_lookup_cache_single_block_per_key():
|
def test_block_lookup_cache_single_block_per_key():
|
||||||
cache = BlockHashToBlockMap()
|
cache = BlockHashToBlockMap()
|
||||||
key0 = BlockHashWithGroupId(b"hash0")
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
@@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
|
|||||||
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
|
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
|
||||||
|
|
||||||
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
|
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
|
||||||
|
"""Default path (no retention): freeing an SWA request must place its
|
||||||
|
uncached scratch blocks at the front of the free queue (recycled first)
|
||||||
|
and keep its cached checkpoint blocks at the back (retained for prefix
|
||||||
|
hits). This split is always-on, independent of the retention interval."""
|
||||||
|
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
||||||
|
block_size = 8
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer1"],
|
||||||
|
FullAttentionSpec(
|
||||||
|
block_size=4 * block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float16,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer2"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=block_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||||
|
req = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req,
|
||||||
|
len(token_ids),
|
||||||
|
len(computed_blocks.blocks[0]) * block_size,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
swa_manager = manager.coordinator.single_type_managers[1]
|
||||||
|
null_block = manager.block_pool.null_block
|
||||||
|
cached_ids: set[int] = set()
|
||||||
|
uncached_ids: set[int] = set()
|
||||||
|
cached_hash_indices: list[int] = []
|
||||||
|
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
|
||||||
|
if block is null_block:
|
||||||
|
continue
|
||||||
|
if block.block_hash is None:
|
||||||
|
uncached_ids.add(block.block_id)
|
||||||
|
else:
|
||||||
|
cached_ids.add(block.block_id)
|
||||||
|
cached_hash_indices.append(i)
|
||||||
|
# The dense default mask caches only the per-segment tails, so a 16-block
|
||||||
|
# SWA prompt must produce a mix of retained and scratch blocks.
|
||||||
|
assert cached_ids, "expected some retained (cached) SWA tail blocks"
|
||||||
|
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
|
||||||
|
|
||||||
|
manager.free(req)
|
||||||
|
|
||||||
|
order = [
|
||||||
|
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||||
|
]
|
||||||
|
pos = {bid: i for i, bid in enumerate(order)}
|
||||||
|
# Every scratch block is recycled before every retained block.
|
||||||
|
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
|
||||||
|
# The retained tails survive the free and still serve a prefix-cache hit.
|
||||||
|
for i in cached_hash_indices:
|
||||||
|
assert (
|
||||||
|
manager.block_pool.get_cached_block(
|
||||||
|
req.block_hashes[i], kv_cache_group_ids=[1]
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
|
||||||
|
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=1,
|
||||||
|
head_size=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return make_kv_cache_manager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
|
||||||
|
"""Sparse retention must work for a pure-SWA single-group model, not just
|
||||||
|
hybrid models: only the per-interval tails plus the latest replay tail are
|
||||||
|
cached, and a replay still hits the latest replayable boundary."""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
||||||
|
block_size = 16
|
||||||
|
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||||
|
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
|
||||||
|
|
||||||
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||||
|
req = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req,
|
||||||
|
len(token_ids),
|
||||||
|
len(computed_blocks.blocks[0]) * block_size,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
pool = manager.block_pool
|
||||||
|
cached = {
|
||||||
|
i
|
||||||
|
for i in range(16)
|
||||||
|
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||||
|
is not None
|
||||||
|
}
|
||||||
|
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
|
||||||
|
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
|
||||||
|
# block 14. Crucially this is a strict subset of all 16 blocks: retention
|
||||||
|
# is actually sparse for pure SWA (not silently dense).
|
||||||
|
assert cached == {3, 7, 11, 14, 15}
|
||||||
|
|
||||||
|
# A replay of the same prompt hits the latest replayable boundary (240).
|
||||||
|
replay = make_request("1", token_ids, block_size, sha256)
|
||||||
|
_, num_computed = manager.get_computed_blocks(replay)
|
||||||
|
assert num_computed == 240
|
||||||
|
|
||||||
|
|
||||||
|
def test_pure_swa_retention_latest_only(monkeypatch):
|
||||||
|
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
|
||||||
|
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||||
|
block_size = 16
|
||||||
|
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||||
|
|
||||||
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||||
|
req = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req,
|
||||||
|
len(token_ids),
|
||||||
|
len(computed_blocks.blocks[0]) * block_size,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
pool = manager.block_pool
|
||||||
|
cached = {
|
||||||
|
i
|
||||||
|
for i in range(16)
|
||||||
|
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||||
|
is not None
|
||||||
|
}
|
||||||
|
# No segment tails (interval 0); only the latest replay tail (block 14).
|
||||||
|
assert cached == {14}
|
||||||
|
|
||||||
|
replay = make_request("1", token_ids, block_size, sha256)
|
||||||
|
_, num_computed = manager.get_computed_blocks(replay)
|
||||||
|
assert num_computed == 240
|
||||||
|
|
||||||
|
|
||||||
|
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
|
||||||
|
"""With retention unset, a pure-SWA model must keep the dense behavior:
|
||||||
|
every block boundary is a potential hit, so all blocks are cached."""
|
||||||
|
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
||||||
|
block_size = 16
|
||||||
|
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||||
|
|
||||||
|
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||||
|
req = make_request("0", token_ids, block_size, sha256)
|
||||||
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req,
|
||||||
|
len(token_ids),
|
||||||
|
len(computed_blocks.blocks[0]) * block_size,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
|
||||||
|
pool = manager.block_pool
|
||||||
|
cached = {
|
||||||
|
i
|
||||||
|
for i in range(16)
|
||||||
|
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||||
|
is not None
|
||||||
|
}
|
||||||
|
assert cached == set(range(16))
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from utils import skip_unsupported
|
from utils import skip_unsupported
|
||||||
|
|
||||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
rms_norm_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -51,7 +53,7 @@ def test_rms_norm_batch_invariant_vs_standard(
|
|||||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
# Batch-invariant implementation (Triton)
|
# Batch-invariant implementation (Triton)
|
||||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
# Compare outputs
|
# Compare outputs
|
||||||
# Use looser tolerance for bfloat16 due to its lower precision
|
# Use looser tolerance for bfloat16 due to its lower precision
|
||||||
@@ -125,7 +127,7 @@ def test_fused_add_rms_norm_batch_invariant_residual_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
merged_single = x_single + residual_single
|
merged_single = x_single + residual_single
|
||||||
ref_out = triton_rms_norm(merged_single, weight, eps=eps)
|
ref_out = rms_norm_batch_invariant(merged_single, weight, eps=eps)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
residual_out_single,
|
residual_out_single,
|
||||||
@@ -193,7 +195,7 @@ def test_rms_norm_3d_input(
|
|||||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
# Batch-invariant implementation
|
# Batch-invariant implementation
|
||||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
# Use looser tolerance for bfloat16
|
# Use looser tolerance for bfloat16
|
||||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||||
@@ -242,7 +244,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
|
|||||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
# Batch-invariant implementation
|
# Batch-invariant implementation
|
||||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
# Check for NaN or Inf
|
# Check for NaN or Inf
|
||||||
assert not torch.isnan(standard_output).any(), (
|
assert not torch.isnan(standard_output).any(), (
|
||||||
@@ -289,7 +291,7 @@ def test_rms_norm_formula(default_vllm_config):
|
|||||||
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
|
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
|
||||||
|
|
||||||
# Batch-invariant implementation
|
# Batch-invariant implementation
|
||||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
# Compare against formula
|
# Compare against formula
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
@@ -325,7 +327,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
|
|||||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
# Batch-invariant implementation
|
# Batch-invariant implementation
|
||||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
# Use looser tolerance for bfloat16
|
# Use looser tolerance for bfloat16
|
||||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||||
@@ -360,7 +362,7 @@ def test_rms_norm_determinism(default_vllm_config):
|
|||||||
# Run multiple times
|
# Run multiple times
|
||||||
outputs = []
|
outputs = []
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
|
output = rms_norm_batch_invariant(input_tensor.clone(), weight, eps=eps)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
# All outputs should be identical
|
# All outputs should be identical
|
||||||
@@ -395,7 +397,7 @@ if __name__ == "__main__":
|
|||||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
# Batch-invariant implementation
|
# Batch-invariant implementation
|
||||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
# Compare
|
# Compare
|
||||||
max_diff = (triton_output - standard_output).abs().max().item()
|
max_diff = (triton_output - standard_output).abs().max().item()
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ def _make_connector_with_fake_worker(
|
|||||||
)
|
)
|
||||||
worker = connector.connector_worker
|
worker = connector.connector_worker
|
||||||
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
|
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
|
||||||
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
|
|
||||||
worker.kv_cache_layout = "HND"
|
worker.kv_cache_layout = "HND"
|
||||||
if do_handshake:
|
if do_handshake:
|
||||||
remote_agents = worker._nixl_handshake(
|
remote_agents = worker._nixl_handshake(
|
||||||
|
|||||||
@@ -6,25 +6,56 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
class _StubLMCacheMPConnector:
|
||||||
|
"""Stand-in for LMCacheMPConnector used in config-translation tests.
|
||||||
|
|
||||||
|
The real connector module hard-imports the optional ``lmcache`` package
|
||||||
|
at module load time, which is not installed in the cpu_test image. This
|
||||||
|
test only asserts on the connector *name* and the ``extra_config`` dict
|
||||||
|
produced by ``VllmConfig``, never instantiates the connector, so a bare
|
||||||
|
placeholder class is sufficient. Not subclassing ``SupportsHMA`` mirrors
|
||||||
|
the real connector's HMA support (it does not support HMA either)."""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def stub_lmcache_mp_connector(monkeypatch):
|
||||||
|
"""Replace the lazy loader so VllmConfig.__post_init__ does not import
|
||||||
|
``lmcache_mp_connector`` (and thus ``lmcache``) during config tests."""
|
||||||
|
monkeypatch.setitem(
|
||||||
|
KVConnectorFactory._registry,
|
||||||
|
"LMCacheMPConnector",
|
||||||
|
lambda: _StubLMCacheMPConnector,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||||
[
|
[
|
||||||
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||||
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||||
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30)),
|
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30)),
|
||||||
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
# ``lmcache`` backend now defaults to LMCacheMPConnector. The KV
|
||||||
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
# storage capacity is owned by the standalone LMCache server, so
|
||||||
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
# ``kv_offloading_size`` is intentionally not propagated.
|
||||||
|
("lmcache", 4.0, 1, 1, "LMCacheMPConnector", None),
|
||||||
|
("lmcache", 8.0, 2, 2, "LMCacheMPConnector", None),
|
||||||
# When kv_offloading_size is None, offloading is disabled (backend is ignored)
|
# When kv_offloading_size is None, offloading is disabled (backend is ignored)
|
||||||
("native", None, 1, 1, None, None),
|
("native", None, 1, 1, None, None),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_kv_connector(
|
def test_kv_connector(
|
||||||
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
stub_lmcache_mp_connector,
|
||||||
|
kv_offloading_backend,
|
||||||
|
kv_offloading_size,
|
||||||
|
tp,
|
||||||
|
pp,
|
||||||
|
expected_backend,
|
||||||
|
expected_bytes,
|
||||||
):
|
):
|
||||||
kv_transfer_config = (
|
kv_transfer_config = (
|
||||||
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||||
@@ -59,10 +90,12 @@ def test_kv_connector(
|
|||||||
# Existing config should be preserved
|
# Existing config should be preserved
|
||||||
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||||
elif kv_offloading_backend == "lmcache":
|
elif kv_offloading_backend == "lmcache":
|
||||||
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
# MP mode does not push lmcache.local_cpu / max_local_cpu_size into
|
||||||
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
# extra config (the LMCache server owns capacity). Pre-existing
|
||||||
# Existing config should be replaced
|
# extra config entries are preserved as-is.
|
||||||
assert "existing_key" not in kv_connector_extra_config
|
assert "lmcache.local_cpu" not in kv_connector_extra_config
|
||||||
|
assert "lmcache.max_local_cpu_size" not in kv_connector_extra_config
|
||||||
|
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||||
|
|
||||||
|
|
||||||
def _build_config(
|
def _build_config(
|
||||||
|
|||||||
@@ -197,13 +197,6 @@ class FakeNixlWrapper:
|
|||||||
def get_xfer_telemetry(self, handle: int) -> dict:
|
def get_xfer_telemetry(self, handle: int) -> dict:
|
||||||
return get_default_xfer_telemetry()
|
return get_default_xfer_telemetry()
|
||||||
|
|
||||||
############################################################
|
|
||||||
# Follow are for changing the behavior during testing.
|
|
||||||
############################################################
|
|
||||||
|
|
||||||
def set_cycles_before_xfer_done(self, cycles: int):
|
|
||||||
"""Set the number of cycles before a transfer is considered done."""
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _make_fake_nixl_pkg():
|
def _make_fake_nixl_pkg():
|
||||||
@@ -578,10 +571,7 @@ class TestNixlHandshake:
|
|||||||
"""Test case where multiple xfers are initiated to the same engine.
|
"""Test case where multiple xfers are initiated to the same engine.
|
||||||
|
|
||||||
This test triggers the connector to load remote KV for the same
|
This test triggers the connector to load remote KV for the same
|
||||||
`request_id`. The transfer is not done immediately due to
|
`request_id`.
|
||||||
`set_cycles_before_xfer_done`, so there is a state where there are
|
|
||||||
multiple transfer states for the same `request_id`, and `get_finished`
|
|
||||||
should handle it correctly (wait for all transfers to be done).
|
|
||||||
"""
|
"""
|
||||||
vllm_config = create_vllm_config()
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
@@ -598,7 +588,6 @@ class TestNixlHandshake:
|
|||||||
)
|
)
|
||||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||||
worker = connector.connector_worker
|
worker = connector.connector_worker
|
||||||
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
|
||||||
# simulate handshake
|
# simulate handshake
|
||||||
worker.dst_xfer_side_handles = {
|
worker.dst_xfer_side_handles = {
|
||||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
|
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
|
||||||
@@ -1304,7 +1293,6 @@ def test_scheduler_kv_connector_stats_aggregation():
|
|||||||
# Worker stats with transfer metrics
|
# Worker stats with transfer metrics
|
||||||
worker_stats = NixlKVConnectorStats()
|
worker_stats = NixlKVConnectorStats()
|
||||||
worker_stats.record_transfer(get_default_xfer_telemetry())
|
worker_stats.record_transfer(get_default_xfer_telemetry())
|
||||||
worker_stats.data["remote_tokens"] = []
|
|
||||||
|
|
||||||
# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
|
# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
|
||||||
scheduler_stats = NixlKVConnectorStats()
|
scheduler_stats = NixlKVConnectorStats()
|
||||||
@@ -1314,7 +1302,6 @@ def test_scheduler_kv_connector_stats_aggregation():
|
|||||||
"post_duration": [0],
|
"post_duration": [0],
|
||||||
"bytes_transferred": [0],
|
"bytes_transferred": [0],
|
||||||
"num_descriptors": [0],
|
"num_descriptors": [0],
|
||||||
"remote_tokens": [128],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1355,7 +1342,6 @@ def test_scheduler_kv_connector_stats_aggregation():
|
|||||||
).scheduler_stats.kv_connector_stats
|
).scheduler_stats.kv_connector_stats
|
||||||
nixl_stats = final_stats["NixlConnector"]
|
nixl_stats = final_stats["NixlConnector"]
|
||||||
assert nixl_stats.num_successful_transfers == 2
|
assert nixl_stats.num_successful_transfers == 2
|
||||||
assert nixl_stats.data["remote_tokens"] == [128]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||||
|
|||||||
@@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cpu_test
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"local_physical_per_logical,remote_physical_per_logical,"
|
||||||
|
"local_block_ids,remote_block_ids,"
|
||||||
|
"expected_local,expected_remote",
|
||||||
|
[
|
||||||
|
# SSM prefix caching: remote has 3 placeholder + 1 real block,
|
||||||
|
# local has only the 1 real block. FA blocks are equal (no trim).
|
||||||
|
pytest.param(
|
||||||
|
10,
|
||||||
|
10,
|
||||||
|
[list(range(10)), [42]],
|
||||||
|
[list(range(10)), [40, 41, 42, 43]],
|
||||||
|
[list(range(10)), [42]],
|
||||||
|
[list(range(10)), [43]],
|
||||||
|
id="ssm_prefix_trim_only",
|
||||||
|
),
|
||||||
|
# FA partial prefix cache hit with homogeneous TP: local has 4 FA
|
||||||
|
# blocks (prefix cached), remote has full 10. SSM equal (no trim).
|
||||||
|
pytest.param(
|
||||||
|
10,
|
||||||
|
10,
|
||||||
|
[list(range(6, 10)), [42]],
|
||||||
|
[list(range(10)), [42]],
|
||||||
|
[list(range(6, 10)), [42]],
|
||||||
|
[list(range(6, 10)), [42]],
|
||||||
|
id="fa_prefix_hit_homo_tp",
|
||||||
|
),
|
||||||
|
# Both: FA partial prefix hit + SSM placeholder trim.
|
||||||
|
# local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9]
|
||||||
|
# local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real)
|
||||||
|
pytest.param(
|
||||||
|
10,
|
||||||
|
10,
|
||||||
|
[[6, 7, 8, 9], [99]],
|
||||||
|
[list(range(10)), [10, 20, 99]],
|
||||||
|
[[6, 7, 8, 9], [99]],
|
||||||
|
[[6, 7, 8, 9], [99]],
|
||||||
|
id="fa_prefix_hit_and_ssm_trim",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_apply_prefix_caching_ssm_prefix_cache_hit(
|
||||||
|
local_physical_per_logical,
|
||||||
|
remote_physical_per_logical,
|
||||||
|
local_block_ids,
|
||||||
|
remote_block_ids,
|
||||||
|
expected_local,
|
||||||
|
expected_remote,
|
||||||
|
):
|
||||||
|
"""_apply_prefix_caching end-trims SSM remote blocks to match the single
|
||||||
|
local block (placeholders dropped) and end-trims FA remote blocks on
|
||||||
|
partial prefix cache hits when physical_per_logical matches.
|
||||||
|
"""
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
|
||||||
|
NixlConnectorWorker,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
worker = object.__new__(NixlConnectorWorker)
|
||||||
|
worker._has_mamba = True
|
||||||
|
worker._physical_blocks_per_logical_kv_block = local_physical_per_logical
|
||||||
|
worker._group_spec_types = (FullAttentionSpec, MambaSpec)
|
||||||
|
worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True)
|
||||||
|
|
||||||
|
aligned_local, aligned_remote = worker._apply_prefix_caching(
|
||||||
|
local_block_ids, remote_block_ids, remote_physical_per_logical
|
||||||
|
)
|
||||||
|
|
||||||
|
assert aligned_local == expected_local, (
|
||||||
|
f"Expected local {expected_local}, got {aligned_local}"
|
||||||
|
)
|
||||||
|
assert aligned_remote == expected_remote, (
|
||||||
|
f"Expected remote {expected_remote}, got {aligned_remote}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.cpu_test
|
@pytest.mark.cpu_test
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"local_physical_per_logical,remote_physical_per_logical,"
|
"local_physical_per_logical,remote_physical_per_logical,"
|
||||||
|
|||||||
@@ -1072,8 +1072,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
current_platform.is_rocm(),
|
not current_platform.is_cuda(),
|
||||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
reason="Attention backend FLASHINFER is only supported on CUDA.",
|
||||||
)
|
)
|
||||||
def test_hybrid_attention_mamba_tensor_shapes():
|
def test_hybrid_attention_mamba_tensor_shapes():
|
||||||
"""
|
"""
|
||||||
@@ -1508,8 +1508,8 @@ def test_is_uniform_decode() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
current_platform.is_rocm(),
|
not current_platform.is_cuda(),
|
||||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
reason="Attention backend FLASHINFER is only supported on CUDA.",
|
||||||
)
|
)
|
||||||
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
||||||
"""Test that a ValueError is raised when max_num_seqs exceeds the
|
"""Test that a ValueError is raised when max_num_seqs exceeds the
|
||||||
|
|||||||
@@ -1562,7 +1562,9 @@ def generate_legend() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def generate_mla_section(
|
def generate_mla_section(
|
||||||
prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]]
|
prefill_backends: list[dict[str, Any]],
|
||||||
|
decode_backends: list[dict[str, Any]],
|
||||||
|
v4_decode_backends: list[dict[str, Any]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate the complete MLA section with prefill and decode tables."""
|
"""Generate the complete MLA section with prefill and decode tables."""
|
||||||
lines = [
|
lines = [
|
||||||
@@ -1611,6 +1613,22 @@ def generate_mla_section(
|
|||||||
columns = _build_columns(is_mla=True, has_versions=False)
|
columns = _build_columns(is_mla=True, has_versions=False)
|
||||||
lines.extend(_render_table(columns, decode_backends))
|
lines.extend(_render_table(columns, decode_backends))
|
||||||
|
|
||||||
|
if v4_decode_backends:
|
||||||
|
lines.extend(
|
||||||
|
[
|
||||||
|
"",
|
||||||
|
"### DeepSeek V4 Decode Backends",
|
||||||
|
"",
|
||||||
|
"DeepSeek V4 sparse MLA uses its own decode backends, selected via",
|
||||||
|
"`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,",
|
||||||
|
"`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index",
|
||||||
|
"pipeline (compressor + SWA + indexer, 256-token blocks, head 512);",
|
||||||
|
"default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lines.extend(_render_table(columns, v4_decode_backends))
|
||||||
|
|
||||||
lines.append("")
|
lines.append("")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
@@ -1651,9 +1669,15 @@ def generate_docs() -> str:
|
|||||||
if fi_features:
|
if fi_features:
|
||||||
all_backends = _expand_flashinfer_variants(all_backends, fi_features)
|
all_backends = _expand_flashinfer_variants(all_backends, fi_features)
|
||||||
|
|
||||||
# Split into MLA and non-MLA
|
# DeepSeek V4 (*_DSV4) decode backends get their own subsection rather than
|
||||||
mla_backends = [b for b in all_backends if b["is_mla"]]
|
# mixing into the main MLA / standard tables (the ROCm V4 backend isn't
|
||||||
non_mla_backends = [b for b in all_backends if not b["is_mla"]]
|
# flagged is_mla by the AST heuristic, so filter purely on the name).
|
||||||
|
def _is_v4(b: dict[str, Any]) -> bool:
|
||||||
|
return b["name"].endswith("_DSV4")
|
||||||
|
|
||||||
|
v4_decode_backends = [b for b in all_backends if _is_v4(b)]
|
||||||
|
mla_backends = [b for b in all_backends if b["is_mla"] and not _is_v4(b)]
|
||||||
|
non_mla_backends = [b for b in all_backends if not b["is_mla"] and not _is_v4(b)]
|
||||||
|
|
||||||
# Generate documentation
|
# Generate documentation
|
||||||
script_path = "tools/pre_commit/generate_attention_backend_docs.py"
|
script_path = "tools/pre_commit/generate_attention_backend_docs.py"
|
||||||
@@ -1703,7 +1727,9 @@ def generate_docs() -> str:
|
|||||||
doc_lines.append("\n>\n".join(footnotes) + "\n")
|
doc_lines.append("\n>\n".join(footnotes) + "\n")
|
||||||
|
|
||||||
# Add MLA section with prefill and decode backends
|
# Add MLA section with prefill and decode backends
|
||||||
doc_lines.append(generate_mla_section(mla_prefill_backends, mla_backends))
|
doc_lines.append(
|
||||||
|
generate_mla_section(mla_prefill_backends, mla_backends, v4_decode_backends)
|
||||||
|
)
|
||||||
|
|
||||||
return "\n".join(doc_lines)
|
return "\n".join(doc_lines)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||||
|
|
||||||
from .plot import SweepPlotArgs
|
from .plot import SweepPlotArgs
|
||||||
from .plot import main as plot_main
|
from .plot import main as plot_main
|
||||||
|
|||||||
@@ -44,6 +44,24 @@ from .matcher_utils import MatcherQuantFP8
|
|||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
_IR_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default
|
||||||
|
_IR_FUSED_ADD_RMS_NORM_OP = torch.ops.vllm_ir.fused_add_rms_norm.default
|
||||||
|
|
||||||
|
|
||||||
|
def _norm_input_weight_dtype_match(match: pm.Match) -> bool:
|
||||||
|
"""Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma
|
||||||
|
fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm."""
|
||||||
|
for node in match.nodes:
|
||||||
|
if node.target == _IR_RMS_NORM_OP:
|
||||||
|
x, weight = node.args[0], node.args[1]
|
||||||
|
elif node.target == _IR_FUSED_ADD_RMS_NORM_OP:
|
||||||
|
x, weight = node.args[0], node.args[2]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
|
||||||
|
return x.meta["val"].dtype == weight.meta["val"].dtype
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# The empirical value for small batch
|
# The empirical value for small batch
|
||||||
PDL_ADVANCE_LAUNCH_TOKENS = 16
|
PDL_ADVANCE_LAUNCH_TOKENS = 16
|
||||||
@@ -132,6 +150,7 @@ if flashinfer_comm is not None:
|
|||||||
quant_out: torch.Tensor | None = None,
|
quant_out: torch.Tensor | None = None,
|
||||||
scale_out: torch.Tensor | None = None,
|
scale_out: torch.Tensor | None = None,
|
||||||
scale_factor: torch.Tensor | None = None,
|
scale_factor: torch.Tensor | None = None,
|
||||||
|
weight_bias: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_tokens, hidden_size = allreduce_in.shape
|
num_tokens, hidden_size = allreduce_in.shape
|
||||||
element_size = allreduce_in.element_size()
|
element_size = allreduce_in.element_size()
|
||||||
@@ -208,6 +227,7 @@ if flashinfer_comm is not None:
|
|||||||
layout_code=layout_code,
|
layout_code=layout_code,
|
||||||
use_oneshot=use_oneshot,
|
use_oneshot=use_oneshot,
|
||||||
fp32_acc=fp32_acc,
|
fp32_acc=fp32_acc,
|
||||||
|
weight_bias=weight_bias,
|
||||||
trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS,
|
trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,6 +245,7 @@ if flashinfer_comm is not None:
|
|||||||
quant_out: torch.Tensor | None = None,
|
quant_out: torch.Tensor | None = None,
|
||||||
scale_out: torch.Tensor | None = None,
|
scale_out: torch.Tensor | None = None,
|
||||||
scale_factor: torch.Tensor | None = None,
|
scale_factor: torch.Tensor | None = None,
|
||||||
|
weight_bias: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -399,14 +420,142 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
|||||||
# allreduce_in, residual
|
# allreduce_in, residual
|
||||||
return allreduce[1], allreduce[2]
|
return allreduce[1], allreduce[2]
|
||||||
|
|
||||||
|
# extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern.
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
pattern,
|
||||||
|
replacement,
|
||||||
|
self.get_inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
extra_check=_norm_input_weight_dtype_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Same pattern, but only return the output and not residual
|
# Same pattern, but only return the output and not residual
|
||||||
# (helpful for end of graph where residual is not used again)
|
# (helpful for end of graph where residual is not used again)
|
||||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||||
|
|
||||||
|
pm.register_replacement(
|
||||||
|
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||||
|
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||||
|
self.get_inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
extra_check=_norm_input_weight_dtype_match,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceGemmaRMSNormPattern(BasePattern):
|
||||||
|
"""Gemma-style variant of AllReduceRMSNormPattern (no residual)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str | None,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
|
||||||
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
|
return [self.empty(5, 16), self.empty(16)]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
|
rms = vllm.ir.ops.rms_norm(
|
||||||
|
allreduce_output, weight.float() + 1.0, self.epsilon
|
||||||
|
)
|
||||||
|
return rms, allreduce_output
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
residual = torch.zeros_like(input)
|
||||||
|
rms_result = torch.empty_like(input)
|
||||||
|
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=rms_result,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||||
|
weight_bias=1.0,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
return allreduce[3], allreduce[1]
|
||||||
|
|
||||||
|
pm.register_replacement(
|
||||||
|
pattern,
|
||||||
|
replacement,
|
||||||
|
self.get_inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AllReduceFusedAddGemmaRMSNormPattern(BasePattern):
|
||||||
|
"""Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str | None,
|
||||||
|
allreduce_params: FlashInferFusedAllReduceParams,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(dtype, device)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.allreduce_params = allreduce_params
|
||||||
|
|
||||||
|
def get_inputs(self) -> list[torch.Tensor]:
|
||||||
|
input = self.empty(5, 16)
|
||||||
|
residual = self.empty(5, 16)
|
||||||
|
weight = self.empty(16)
|
||||||
|
return [residual, input.to(self.dtype), weight]
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||||
|
def pattern(
|
||||||
|
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||||
|
rms, residual = vllm.ir.ops.fused_add_rms_norm(
|
||||||
|
allreduce_output, residual, weight.float() + 1.0, self.epsilon
|
||||||
|
)
|
||||||
|
return rms, residual
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||||
|
allreduce = auto_functionalized(
|
||||||
|
flashinfer_trtllm_fused_allreduce_norm,
|
||||||
|
allreduce_in=input,
|
||||||
|
residual=residual,
|
||||||
|
norm_out=None,
|
||||||
|
quant_out=None,
|
||||||
|
scale_out=None,
|
||||||
|
rms_gamma=weight,
|
||||||
|
rms_eps=self.epsilon,
|
||||||
|
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||||
|
weight_bias=1.0,
|
||||||
|
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||||
|
)
|
||||||
|
return allreduce[1], allreduce[2]
|
||||||
|
|
||||||
|
pm.register_replacement(
|
||||||
|
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||||
|
)
|
||||||
|
|
||||||
|
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||||
@@ -881,6 +1030,18 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
|||||||
self.device,
|
self.device,
|
||||||
self.allreduce_params,
|
self.allreduce_params,
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
|
AllReduceGemmaRMSNormPattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
AllReduceFusedAddGemmaRMSNormPattern(
|
||||||
|
epsilon,
|
||||||
|
self.model_dtype,
|
||||||
|
self.device,
|
||||||
|
self.allreduce_params,
|
||||||
|
).register(self.patterns)
|
||||||
|
|
||||||
# WARNING: This is a hack to clear the pattern matcher cache
|
# WARNING: This is a hack to clear the pattern matcher cache
|
||||||
# and allow multiple values of epsilon.
|
# and allow multiple values of epsilon.
|
||||||
|
|||||||
@@ -36,10 +36,12 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|||||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
|
||||||
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
|
||||||
|
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
|
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
|
|
||||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||||
|
|
||||||
|
|||||||
@@ -84,9 +84,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|||||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||||
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
|
||||||
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
|
||||||
}
|
}
|
||||||
|
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
|
||||||
|
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
|
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||||
|
|
||||||
|
|||||||
+20
-18
@@ -109,22 +109,24 @@ class EPLBConfig:
|
|||||||
class ParallelConfig:
|
class ParallelConfig:
|
||||||
"""Configuration for the distributed execution."""
|
"""Configuration for the distributed execution."""
|
||||||
|
|
||||||
pipeline_parallel_size: int = 1
|
pipeline_parallel_size: int = Field(default=1, ge=1)
|
||||||
"""Number of pipeline parallel groups."""
|
"""Number of pipeline parallel groups."""
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = Field(default=1, ge=1)
|
||||||
"""Number of tensor parallel groups."""
|
"""Number of tensor parallel groups."""
|
||||||
prefill_context_parallel_size: int = 1
|
prefill_context_parallel_size: int = Field(default=1, ge=1)
|
||||||
"""Number of prefill context parallel groups."""
|
"""Number of prefill context parallel groups."""
|
||||||
data_parallel_size: int = 1
|
data_parallel_size: int = Field(default=1, ge=1)
|
||||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||||
the product of the tensor parallel size and data parallel size."""
|
the product of the tensor parallel size and data parallel size."""
|
||||||
data_parallel_size_local: int = 1
|
data_parallel_size_local: int = Field(default=1, ge=0)
|
||||||
"""Number of local data parallel groups."""
|
"""Number of local data parallel groups. A value of 0 is a sentinel used by
|
||||||
data_parallel_rank: int = 0
|
the engine-args layer to signal that data parallelism was specified
|
||||||
"""Rank of the data parallel group."""
|
externally (see `ParallelConfig.__post_init__`)."""
|
||||||
|
data_parallel_rank: int = Field(default=0, ge=0)
|
||||||
|
"""Rank of the data parallel group. The runtime check at
|
||||||
|
``__post_init__`` further bounds this by ``data_parallel_size``."""
|
||||||
data_parallel_rank_local: int | None = None
|
data_parallel_rank_local: int | None = None
|
||||||
"""Local rank of the data parallel group,
|
"""Local rank of the data parallel group, set only in SPMD mode."""
|
||||||
set only in SPMD mode."""
|
|
||||||
data_parallel_master_ip: str = "127.0.0.1"
|
data_parallel_master_ip: str = "127.0.0.1"
|
||||||
"""IP of the data parallel master."""
|
"""IP of the data parallel master."""
|
||||||
data_parallel_rpc_port: int = 29550
|
data_parallel_rpc_port: int = 29550
|
||||||
@@ -184,7 +186,7 @@ class ParallelConfig:
|
|||||||
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
|
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
|
||||||
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
|
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
|
||||||
|
|
||||||
max_parallel_loading_workers: int | None = None
|
max_parallel_loading_workers: int | None = Field(default=None, ge=1)
|
||||||
"""Maximum number of parallel loading workers when loading model
|
"""Maximum number of parallel loading workers when loading model
|
||||||
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||||
parallel and large models."""
|
parallel and large models."""
|
||||||
@@ -197,15 +199,15 @@ class ParallelConfig:
|
|||||||
|
|
||||||
enable_dbo: bool = False
|
enable_dbo: bool = False
|
||||||
"""Enable dual batch overlap for the model executor."""
|
"""Enable dual batch overlap for the model executor."""
|
||||||
ubatch_size: int = 0
|
ubatch_size: int = Field(default=0, ge=0)
|
||||||
"""Number of ubatch size."""
|
"""Number of ubatch size."""
|
||||||
|
|
||||||
dbo_decode_token_threshold: int = 32
|
dbo_decode_token_threshold: int = Field(default=32, ge=0)
|
||||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||||
If the number of tokens in the request is greater than this threshold,
|
If the number of tokens in the request is greater than this threshold,
|
||||||
microbatching will be used. Otherwise, the request will be processed in a
|
microbatching will be used. Otherwise, the request will be processed in a
|
||||||
single batch."""
|
single batch."""
|
||||||
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
dbo_prefill_token_threshold: int = Field(default=512, ge=0) # TODO(lucas): tune
|
||||||
"""The threshold for dual batch overlap for batches that contain one or more
|
"""The threshold for dual batch overlap for batches that contain one or more
|
||||||
prefills. If the number of tokens in the request is greater than this
|
prefills. If the number of tokens in the request is greater than this
|
||||||
threshold, microbatching will be used. Otherwise, the request will be
|
threshold, microbatching will be used. Otherwise, the request will be
|
||||||
@@ -260,10 +262,10 @@ class ParallelConfig:
|
|||||||
master_port: int = 29501
|
master_port: int = 29501
|
||||||
"""distributed master port for multi-node distributed
|
"""distributed master port for multi-node distributed
|
||||||
inference when distributed_executor_backend is mp."""
|
inference when distributed_executor_backend is mp."""
|
||||||
node_rank: int = 0
|
node_rank: int = Field(default=0, ge=0)
|
||||||
"""distributed node rank for multi-node distributed
|
"""distributed node rank for multi-node distributed
|
||||||
inference when distributed_executor_backend is mp."""
|
inference when distributed_executor_backend is mp."""
|
||||||
nnodes: int = 1
|
nnodes: int = Field(default=1, ge=1)
|
||||||
"""num of nodes for multi-node distributed
|
"""num of nodes for multi-node distributed
|
||||||
inference when distributed_executor_backend is mp."""
|
inference when distributed_executor_backend is mp."""
|
||||||
numa_bind: bool = False
|
numa_bind: bool = False
|
||||||
@@ -318,7 +320,7 @@ class ParallelConfig:
|
|||||||
"""Port of the coordination TCPStore. Can be set by the API server; workers
|
"""Port of the coordination TCPStore. Can be set by the API server; workers
|
||||||
connect as clients to exchange self-picked group ports at runtime."""
|
connect as clients to exchange self-picked group ports at runtime."""
|
||||||
|
|
||||||
decode_context_parallel_size: int = 1
|
decode_context_parallel_size: int = Field(default=1, ge=1)
|
||||||
"""Number of decode context parallel groups, because the world size does
|
"""Number of decode context parallel groups, because the world size does
|
||||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||||
needs to be divisible by dcp_size."""
|
needs to be divisible by dcp_size."""
|
||||||
|
|||||||
+6
-10
@@ -771,10 +771,6 @@ class VllmConfig:
|
|||||||
# If no KVTransferConfig is provided, create a default one.
|
# If no KVTransferConfig is provided, create a default one.
|
||||||
if self.kv_transfer_config is None:
|
if self.kv_transfer_config is None:
|
||||||
self.kv_transfer_config = KVTransferConfig()
|
self.kv_transfer_config = KVTransferConfig()
|
||||||
num_kv_ranks = (
|
|
||||||
self.parallel_config.tensor_parallel_size
|
|
||||||
* self.parallel_config.pipeline_parallel_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if kv_offloading_backend == "native":
|
if kv_offloading_backend == "native":
|
||||||
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
|
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
|
||||||
@@ -786,12 +782,12 @@ class VllmConfig:
|
|||||||
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
|
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
|
||||||
)
|
)
|
||||||
elif kv_offloading_backend == "lmcache":
|
elif kv_offloading_backend == "lmcache":
|
||||||
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
|
# Default to LMCache multi-process (MP) mode. The actual KV
|
||||||
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
|
# storage capacity is managed by the standalone LMCache server
|
||||||
self.kv_transfer_config.kv_connector_extra_config = {
|
# process, so ``kv_offloading_size`` is not propagated here.
|
||||||
"lmcache.local_cpu": True,
|
# ``LMCacheMPConnector`` falls back to ``tcp://localhost:5555``
|
||||||
"lmcache.max_local_cpu_size": kv_gb_per_rank,
|
# when host/port are not provided via extra_config.
|
||||||
}
|
self.kv_transfer_config.kv_connector = "LMCacheMPConnector"
|
||||||
|
|
||||||
# This is the same for all backends
|
# This is the same for all backends
|
||||||
self.kv_transfer_config.kv_role = "kv_both"
|
self.kv_transfer_config.kv_role = "kv_both"
|
||||||
|
|||||||
@@ -40,11 +40,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_weak_contiguous(inp: torch.Tensor):
|
from vllm.distributed.utils import is_weak_contiguous # noqa: E402
|
||||||
return inp.is_contiguous() or (
|
|
||||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
|
||||||
== inp.numel() * inp.element_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAllreduce:
|
class CustomAllreduce:
|
||||||
|
|||||||
@@ -24,11 +24,7 @@ except Exception:
|
|||||||
quick_ar = False
|
quick_ar = False
|
||||||
|
|
||||||
|
|
||||||
def is_weak_contiguous(inp: torch.Tensor):
|
from vllm.distributed.utils import is_weak_contiguous # noqa: E402, F401
|
||||||
return inp.is_contiguous() or (
|
|
||||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
|
||||||
== inp.numel() * inp.element_size()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QuickReduceRegime(Enum):
|
class QuickReduceRegime(Enum):
|
||||||
|
|||||||
@@ -470,10 +470,14 @@ class ElasticEPScalingExecutor:
|
|||||||
module._replace_quant_method(module.quant_method.old_quant_method)
|
module._replace_quant_method(module.quant_method.old_quant_method)
|
||||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||||
|
|
||||||
|
eplb_model_state.expert_buffer = [
|
||||||
|
torch.empty_like(w) for w in model.expert_weights[0]
|
||||||
|
]
|
||||||
eplb_model_state.communicator = create_eplb_communicator(
|
eplb_model_state.communicator = create_eplb_communicator(
|
||||||
group_coordinator=get_eplb_group(),
|
group_coordinator=get_eplb_group(),
|
||||||
backend=parallel_config.eplb_config.communicator,
|
backend=parallel_config.eplb_config.communicator,
|
||||||
expert_weights=model.expert_weights[0],
|
expert_weights=model.expert_weights,
|
||||||
|
expert_buffer=eplb_model_state.expert_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ def transfer_run_periodically(
|
|||||||
ep_group=eplb_group,
|
ep_group=eplb_group,
|
||||||
is_profile=is_profile,
|
is_profile=is_profile,
|
||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait until all writes to expert_buffer have finished before making the
|
# Wait until all writes to expert_buffer have finished before making the
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
is_local_first_rank,
|
is_local_first_rank,
|
||||||
)
|
)
|
||||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||||
|
from vllm.distributed.utils import is_weak_contiguous
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -63,8 +64,22 @@ class EplbCommunicator(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
def execute(self) -> None:
|
||||||
pass
|
"""Complete all enqueued transfers.
|
||||||
|
|
||||||
|
Some backends perform communication here; others (e.g. NIXL)
|
||||||
|
issue transfers eagerly in add_recv and only wait here.
|
||||||
|
On return, all data is available in the destination buffers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def set_transfer_context( # noqa: B027
|
||||||
|
self, old_indices: np.ndarray, layer_idx: int
|
||||||
|
) -> None:
|
||||||
|
"""Pre-set layer context before add_recv calls.
|
||||||
|
|
||||||
|
Default is a no-op; overridden by backends (e.g. NIXL) that need
|
||||||
|
layer-level context to issue transfers inside add_recv.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def needs_profile_buffer_reservation(self) -> bool:
|
def needs_profile_buffer_reservation(self) -> bool:
|
||||||
@@ -125,7 +140,7 @@ class TorchDistNcclEplbCommunicator(EplbCommunicator):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
def execute(self) -> None:
|
||||||
if not self._p2p_ops:
|
if not self._p2p_ops:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -168,7 +183,7 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
|
|||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
self._ops.append(("recv", tensor, src_rank))
|
self._ops.append(("recv", tensor, src_rank))
|
||||||
|
|
||||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
def execute(self) -> None:
|
||||||
if not self._ops:
|
if not self._ops:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -229,29 +244,47 @@ class NixlEplbCommunicator(EplbCommunicator):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cpu_group: ProcessGroup,
|
cpu_group: ProcessGroup,
|
||||||
expert_weights: Sequence[torch.Tensor],
|
all_expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||||
cuda_stream: torch.cuda.Stream | None = None,
|
expert_buffer: Sequence[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
|
assert all_expert_weights, (
|
||||||
|
"NixlEplbCommunicator requires non-empty all_expert_weights."
|
||||||
|
)
|
||||||
|
assert expert_buffer, "NixlEplbCommunicator requires non-empty expert_buffer."
|
||||||
nixl_wrapper_cls = nixl_utils.NixlWrapper
|
nixl_wrapper_cls = nixl_utils.NixlWrapper
|
||||||
if nixl_wrapper_cls is None:
|
if nixl_wrapper_cls is None:
|
||||||
raise RuntimeError("NIXL/ RIXL is unavailable.")
|
raise RuntimeError("NIXL/ RIXL is unavailable.")
|
||||||
|
|
||||||
self._cpu_group = cpu_group
|
self._cpu_group = cpu_group
|
||||||
self._cuda_stream = cuda_stream
|
|
||||||
self._world_size = cpu_group.size()
|
self._world_size = cpu_group.size()
|
||||||
self._rank = cpu_group.rank()
|
self._rank = cpu_group.rank()
|
||||||
# expert_id -> weight tensors to pack into the send buffer.
|
|
||||||
self._expert_send_map: dict[int, list[torch.Tensor]] = {}
|
self._all_expert_weights = all_expert_weights
|
||||||
# src_rank -> expert_id -> weight tensors to unpack after transfer.
|
self._expert_buffer = expert_buffer
|
||||||
self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {}
|
self._num_local_experts: int = all_expert_weights[0][0].shape[0]
|
||||||
self._num_local_experts: int = expert_weights[0].shape[0]
|
self._device = all_expert_weights[0][0].device
|
||||||
self._device = expert_weights[0].device
|
|
||||||
for tensor in expert_weights:
|
for layer_tensors in all_expert_weights:
|
||||||
assert tensor.device == self._device, (
|
for tensor in layer_tensors:
|
||||||
"All local EPLB tensors are expected to be on the same device: "
|
assert is_weak_contiguous(tensor), (
|
||||||
f"expected={self._device}, got={tensor.device}"
|
"Expert weight tensors must be contiguous in memory"
|
||||||
|
)
|
||||||
|
assert tensor.device == self._device, (
|
||||||
|
"All local EPLB tensors are expected to be on the same "
|
||||||
|
f"device: expected={self._device}, got={tensor.device}"
|
||||||
|
)
|
||||||
|
for tensor in expert_buffer:
|
||||||
|
assert is_weak_contiguous(tensor), (
|
||||||
|
"expert_buffer tensors must be contiguous in memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# (local_dlist, remote_dlist, xfer_handle) for in-flight READs;
|
||||||
|
# accumulated by add_recv, drained by execute.
|
||||||
|
self._xfer_entries: list[tuple[int, int, int]] = []
|
||||||
|
# Per-rank expert_id -> physical row; set by set_transfer_context.
|
||||||
|
self._expert_to_src_row: list[dict[int, int]] | None = None
|
||||||
|
self._layer_idx: int | None = None
|
||||||
|
|
||||||
nixl_agent_config = nixl_utils.nixl_agent_config
|
nixl_agent_config = nixl_utils.nixl_agent_config
|
||||||
config = (
|
config = (
|
||||||
nixl_agent_config(capture_telemetry=False)
|
nixl_agent_config(capture_telemetry=False)
|
||||||
@@ -260,15 +293,16 @@ class NixlEplbCommunicator(EplbCommunicator):
|
|||||||
)
|
)
|
||||||
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
|
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
|
||||||
self._nixl_memory_type = "VRAM"
|
self._nixl_memory_type = "VRAM"
|
||||||
self._registered_desc: object | None = None
|
# NIXL registration handles; deregistered in __del__.
|
||||||
|
self._registered_descs: list[object] = []
|
||||||
self._remote_agents: dict[int, str] = {}
|
self._remote_agents: dict[int, str] = {}
|
||||||
self._remote_send_meta: dict[int, tuple[int, int]] = {}
|
# peer -> (layer, tensor) -> (base_ptr, bytes_per_expert, dev_id).
|
||||||
self._send_buffer: torch.Tensor = torch.empty(0)
|
self._remote_send_meta: dict[
|
||||||
self._recv_buffer: torch.Tensor = torch.empty(0)
|
int, dict[tuple[int, int], tuple[int, int, int]]
|
||||||
self._expert_bytes: int = 0
|
] = {}
|
||||||
|
|
||||||
self._cuda_device_id = int(self._device.index or 0)
|
self._cuda_device_id = int(self._device.index or 0)
|
||||||
self._init_step("buffers", self._init_registered_buffers, expert_weights)
|
self._init_step("buffers", self._init_registered_buffers)
|
||||||
self._init_step("agents", self._init_remote_agents)
|
self._init_step("agents", self._init_remote_agents)
|
||||||
self._init_step("send meta", self._exchange_remote_send_meta)
|
self._init_step("send meta", self._exchange_remote_send_meta)
|
||||||
self._log_initialized()
|
self._log_initialized()
|
||||||
@@ -291,19 +325,34 @@ class NixlEplbCommunicator(EplbCommunicator):
|
|||||||
uid = uuid.uuid4().hex[:8]
|
uid = uuid.uuid4().hex[:8]
|
||||||
return f"eplb-{self._rank}{pp_suffix}-{uid}"
|
return f"eplb-{self._rank}{pp_suffix}-{uid}"
|
||||||
|
|
||||||
|
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def add_send(
|
def add_send(
|
||||||
self,
|
self,
|
||||||
tensors: list[torch.Tensor],
|
tensors: list[torch.Tensor],
|
||||||
dst_rank: int,
|
dst_rank: int,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert dst_rank != self._rank, (
|
# No-op: NIXL READ is receiver-initiated. The sender's expert
|
||||||
"EPLB communicator should not enqueue same-rank sends: "
|
# weights are pre-registered and always readable in-place.
|
||||||
f"rank={self._rank}, dst_rank={dst_rank}"
|
pass
|
||||||
|
|
||||||
|
def set_transfer_context(self, old_indices: np.ndarray, layer_idx: int) -> None:
|
||||||
|
# Pre-compute expert_id -> src_row mapping for every rank so that
|
||||||
|
# add_recv can immediately issue NIXL READs.
|
||||||
|
assert not self._xfer_entries, (
|
||||||
|
f"set_transfer_context() called with {len(self._xfer_entries)} "
|
||||||
|
f"pending transfers from layer {self._layer_idx}; "
|
||||||
|
f"execute() was not called after previous add_recv() calls"
|
||||||
)
|
)
|
||||||
# An expert sent to multiple peers is packed only once; skip duplicates.
|
self._layer_idx = layer_idx
|
||||||
if expert_id not in self._expert_send_map:
|
n = self._num_local_experts
|
||||||
self._expert_send_map[expert_id] = tensors
|
rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n)
|
||||||
|
self._expert_to_src_row = [
|
||||||
|
{int(eid): i for i, eid in enumerate(row) if eid != -1}
|
||||||
|
for row in rank_experts
|
||||||
|
]
|
||||||
|
|
||||||
def add_recv(
|
def add_recv(
|
||||||
self,
|
self,
|
||||||
@@ -311,13 +360,44 @@ class NixlEplbCommunicator(EplbCommunicator):
|
|||||||
src_rank: int,
|
src_rank: int,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert src_rank != self._rank, (
|
# Build NIXL descriptors and issue the RDMA READ immediately,
|
||||||
"EPLB communicator should not enqueue same-rank recvs: "
|
# overlapping the transfer with the remaining Python loop in
|
||||||
f"rank={self._rank}, src_rank={src_rank}"
|
# move_to_buffer.
|
||||||
|
assert self._expert_to_src_row is not None and self._layer_idx is not None, (
|
||||||
|
"set_transfer_context() must be called before add_recv()"
|
||||||
)
|
)
|
||||||
recv_experts = self._recv_map.setdefault(src_rank, {})
|
src_row = self._expert_to_src_row[src_rank][expert_id]
|
||||||
if expert_id not in recv_experts:
|
layer_idx = self._layer_idx
|
||||||
recv_experts[expert_id] = tensors
|
|
||||||
|
local_descs: list[tuple[int, int, int]] = []
|
||||||
|
remote_descs: list[tuple[int, int, int]] = []
|
||||||
|
for t_idx, t in enumerate(tensors):
|
||||||
|
send_base, send_stride, remote_dev = self._remote_send_meta[src_rank][
|
||||||
|
(layer_idx, t_idx)
|
||||||
|
]
|
||||||
|
assert t.nbytes == send_stride, (
|
||||||
|
f"tensor {t_idx} size {t.nbytes} != remote stride {send_stride}"
|
||||||
|
)
|
||||||
|
local_descs.append(
|
||||||
|
(
|
||||||
|
t.data_ptr(),
|
||||||
|
t.nbytes,
|
||||||
|
self._cuda_device_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
remote_descs.append(
|
||||||
|
(
|
||||||
|
send_base + src_row * send_stride,
|
||||||
|
send_stride,
|
||||||
|
remote_dev,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
local_h, remote_h, xfer_h = self._create_peer_xfer(
|
||||||
|
src_rank, local_descs, remote_descs
|
||||||
|
)
|
||||||
|
self._nixl_wrapper.transfer(xfer_h)
|
||||||
|
self._xfer_entries.append((local_h, remote_h, xfer_h))
|
||||||
|
|
||||||
def _init_remote_agents(self) -> None:
|
def _init_remote_agents(self) -> None:
|
||||||
local_metadata = self._nixl_wrapper.get_agent_metadata()
|
local_metadata = self._nixl_wrapper.get_agent_metadata()
|
||||||
@@ -334,73 +414,60 @@ class NixlEplbCommunicator(EplbCommunicator):
|
|||||||
peer_metadata
|
peer_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
|
def _init_registered_buffers(self) -> None:
|
||||||
total_bytes = max(sum(t.nbytes for t in expert_weights), 1)
|
all_tensors: list[torch.Tensor] = []
|
||||||
assert total_bytes % self._num_local_experts == 0, (
|
for layer_tensors in self._all_expert_weights:
|
||||||
f"Number of bytes in moe layer {total_bytes} is not divisible "
|
all_tensors.extend(layer_tensors)
|
||||||
f"by number of local experts {self._num_local_experts}"
|
all_tensors.extend(self._expert_buffer)
|
||||||
)
|
|
||||||
self._expert_bytes = total_bytes // self._num_local_experts
|
|
||||||
|
|
||||||
self._send_buffer = torch.empty(
|
descs = self._nixl_wrapper.get_reg_descs(all_tensors)
|
||||||
total_bytes, device=self._device, dtype=torch.uint8
|
|
||||||
)
|
|
||||||
self._recv_buffer = torch.empty(
|
|
||||||
total_bytes, device=self._device, dtype=torch.uint8
|
|
||||||
)
|
|
||||||
|
|
||||||
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
|
|
||||||
self._nixl_wrapper.register_memory(descs)
|
self._nixl_wrapper.register_memory(descs)
|
||||||
self._registered_desc = descs
|
self._registered_descs.append(descs)
|
||||||
|
|
||||||
def _exchange_remote_send_meta(self) -> None:
|
def _exchange_remote_send_meta(self) -> None:
|
||||||
"""Exchange send-buffer metadata so each rank can build dynamic
|
"""Exchange per-layer per-tensor metadata so receivers can compute
|
||||||
descriptors at execute time."""
|
remote RDMA addresses at transfer time."""
|
||||||
local_meta: tuple[int, int] = (
|
local_meta: dict[tuple[int, int], tuple[int, int, int]] = {}
|
||||||
self._send_buffer.data_ptr(),
|
for layer_idx, layer_tensors in enumerate(self._all_expert_weights):
|
||||||
self._cuda_device_id,
|
for t_idx, t in enumerate(layer_tensors):
|
||||||
)
|
nbytes_per_expert = t.nbytes // self._num_local_experts
|
||||||
gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size
|
local_meta[(layer_idx, t_idx)] = (
|
||||||
|
t.data_ptr(),
|
||||||
|
nbytes_per_expert,
|
||||||
|
self._cuda_device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-rank map: (layer_idx, tensor_idx) -> (base_ptr, bytes_per_expert, dev_id).
|
||||||
|
# add_recv uses base_ptr + src_row * bytes_per_expert to compute
|
||||||
|
# the remote RDMA address for each expert.
|
||||||
|
gathered_meta: list[dict[tuple[int, int], tuple[int, int, int]] | None] = [
|
||||||
|
None
|
||||||
|
] * self._world_size
|
||||||
torch.distributed.all_gather_object(
|
torch.distributed.all_gather_object(
|
||||||
gathered_meta, local_meta, group=self._cpu_group
|
gathered_meta, local_meta, group=self._cpu_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_keys = set(local_meta.keys())
|
||||||
for peer in self._remote_agents:
|
for peer in self._remote_agents:
|
||||||
peer_meta = gathered_meta[peer]
|
peer_meta = gathered_meta[peer]
|
||||||
assert peer_meta is not None
|
assert peer_meta is not None
|
||||||
|
peer_keys = set(peer_meta.keys())
|
||||||
|
if peer_keys != local_keys:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"NIXL EPLB metadata key mismatch with rank {peer}: "
|
||||||
|
f"local={sorted(local_keys)}, peer={sorted(peer_keys)}"
|
||||||
|
)
|
||||||
|
for key in local_keys:
|
||||||
|
_, local_stride, _ = local_meta[key]
|
||||||
|
_, peer_stride, _ = peer_meta[key]
|
||||||
|
if local_stride != peer_stride:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"NIXL EPLB nbytes_per_expert mismatch for {key} "
|
||||||
|
f"with rank {peer}: "
|
||||||
|
f"local={local_stride}, peer={peer_stride}"
|
||||||
|
)
|
||||||
self._remote_send_meta[peer] = peer_meta
|
self._remote_send_meta[peer] = peer_meta
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _pack_send_buffer(
|
|
||||||
in_tensors: list[torch.Tensor],
|
|
||||||
send_buffer: torch.Tensor,
|
|
||||||
byte_offset: int,
|
|
||||||
) -> None:
|
|
||||||
for tensor in in_tensors:
|
|
||||||
raw = tensor.reshape(-1).view(torch.uint8)
|
|
||||||
if raw.numel() == 0:
|
|
||||||
continue
|
|
||||||
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
|
|
||||||
raw, non_blocking=True
|
|
||||||
)
|
|
||||||
byte_offset += raw.numel()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unpack_recv_buffer(
|
|
||||||
recv_buffer: torch.Tensor,
|
|
||||||
out_tensors: list[torch.Tensor],
|
|
||||||
byte_offset: int,
|
|
||||||
) -> None:
|
|
||||||
for tensor in out_tensors:
|
|
||||||
num_bytes = tensor.numel() * tensor.element_size()
|
|
||||||
if num_bytes == 0:
|
|
||||||
continue
|
|
||||||
tensor.reshape(-1).view(torch.uint8).copy_(
|
|
||||||
recv_buffer[byte_offset : byte_offset + num_bytes],
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
byte_offset += num_bytes
|
|
||||||
|
|
||||||
def _wait_for_all_transfers(self, handles: list[int]) -> None:
|
def _wait_for_all_transfers(self, handles: list[int]) -> None:
|
||||||
pending = set(handles)
|
pending = set(handles)
|
||||||
while pending:
|
while pending:
|
||||||
@@ -456,110 +523,52 @@ class NixlEplbCommunicator(EplbCommunicator):
|
|||||||
)
|
)
|
||||||
return (local_handle, remote_handle, xfer_handle)
|
return (local_handle, remote_handle, xfer_handle)
|
||||||
|
|
||||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
def execute(self) -> None:
|
||||||
assert old_indices is not None, (
|
assert self._layer_idx is not None or not self._xfer_entries, (
|
||||||
"NixlEplbCommunicator.execute requires old_indices"
|
"set_transfer_context() must be called before execute() "
|
||||||
|
"if any add_recv() calls were made"
|
||||||
)
|
)
|
||||||
|
|
||||||
xfer_entries: list[tuple[int, int, int]] = []
|
|
||||||
try:
|
try:
|
||||||
n = self._num_local_experts
|
self._wait_for_all_transfers([x[2] for x in self._xfer_entries])
|
||||||
rank_experts = old_indices[: self._world_size * n].reshape(
|
|
||||||
self._world_size, n
|
|
||||||
)
|
|
||||||
# Build expert_id -> send slot mapping per rank.
|
|
||||||
expert_to_send_slot: list[dict[int, int]] = [
|
|
||||||
{int(eid): i for i, eid in enumerate(row) if eid != -1}
|
|
||||||
for row in rank_experts
|
|
||||||
]
|
|
||||||
|
|
||||||
# Phase 1: pack each expert at its slot offset in the send buffer.
|
# Post-READ barrier.
|
||||||
with torch.cuda.stream(self._cuda_stream):
|
# Correctness fence for zero-copy: prevents overwrite-while-
|
||||||
for expert_id, tensors in self._expert_send_map.items():
|
# remote-read race.
|
||||||
slot = expert_to_send_slot[self._rank][expert_id]
|
|
||||||
byte_offset = slot * self._expert_bytes
|
|
||||||
self._pack_send_buffer(tensors, self._send_buffer, byte_offset)
|
|
||||||
|
|
||||||
# Ensure all packed data is visible in device memory before pulls.
|
|
||||||
if self._cuda_stream is not None:
|
|
||||||
self._cuda_stream.synchronize()
|
|
||||||
else:
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
# READ is receiver-initiated; synchronize all ranks before transfer.
|
|
||||||
# We use monitored_barrier so a rank that crashes or exits early
|
|
||||||
# produces a diagnostic timeout instead of a silent hang.
|
|
||||||
torch.distributed.monitored_barrier(
|
torch.distributed.monitored_barrier(
|
||||||
group=self._cpu_group,
|
group=self._cpu_group,
|
||||||
timeout=timedelta(minutes=5),
|
timeout=timedelta(minutes=5),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phase 2: issue one batched READ per peer.
|
|
||||||
recv_offsets: dict[tuple[int, int], int] = {}
|
|
||||||
recv_offset = 0
|
|
||||||
recv_base = self._recv_buffer.data_ptr()
|
|
||||||
for src in range(self._world_size):
|
|
||||||
if src == self._rank:
|
|
||||||
continue
|
|
||||||
recv_experts = self._recv_map.get(src)
|
|
||||||
if not recv_experts:
|
|
||||||
continue
|
|
||||||
expert_ids = list(recv_experts.keys())
|
|
||||||
remote_base, remote_dev = self._remote_send_meta[src]
|
|
||||||
local_descs: list[tuple[int, int, int]] = []
|
|
||||||
remote_descs: list[tuple[int, int, int]] = []
|
|
||||||
for expert_id in expert_ids:
|
|
||||||
slot = expert_to_send_slot[src][expert_id]
|
|
||||||
remote_off = slot * self._expert_bytes
|
|
||||||
recv_offsets[(src, expert_id)] = recv_offset
|
|
||||||
local_descs.append(
|
|
||||||
(
|
|
||||||
recv_base + recv_offset,
|
|
||||||
self._expert_bytes,
|
|
||||||
self._cuda_device_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
remote_descs.append(
|
|
||||||
(remote_base + remote_off, self._expert_bytes, remote_dev)
|
|
||||||
)
|
|
||||||
recv_offset += self._expert_bytes
|
|
||||||
assert recv_offset <= self._recv_buffer.nbytes
|
|
||||||
local_h, remote_h, xfer_h = self._create_peer_xfer(
|
|
||||||
src, local_descs, remote_descs
|
|
||||||
)
|
|
||||||
self._nixl_wrapper.transfer(xfer_h)
|
|
||||||
xfer_entries.append((local_h, remote_h, xfer_h))
|
|
||||||
|
|
||||||
# Phase 3: wait for all in-flight transfers, then unpack.
|
|
||||||
self._wait_for_all_transfers([x[2] for x in xfer_entries])
|
|
||||||
|
|
||||||
with torch.cuda.stream(self._cuda_stream):
|
|
||||||
for (src, expert_id), offset in recv_offsets.items():
|
|
||||||
self._unpack_recv_buffer(
|
|
||||||
self._recv_buffer,
|
|
||||||
self._recv_map[src][expert_id],
|
|
||||||
offset,
|
|
||||||
)
|
|
||||||
finally:
|
finally:
|
||||||
for local_h, remote_h, xfer_h in xfer_entries:
|
for local_h, remote_h, xfer_h in self._xfer_entries:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self._nixl_wrapper.release_dlist_handle(local_h)
|
self._nixl_wrapper.release_dlist_handle(local_h)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
self._nixl_wrapper.release_dlist_handle(remote_h)
|
self._nixl_wrapper.release_dlist_handle(remote_h)
|
||||||
self._expert_send_map.clear()
|
self._xfer_entries.clear()
|
||||||
self._recv_map.clear()
|
self._expert_to_src_row = None
|
||||||
|
self._layer_idx = None
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
if self._registered_desc is not None:
|
for local_h, remote_h, xfer_h in self._xfer_entries:
|
||||||
self._nixl_wrapper.deregister_memory(self._registered_desc)
|
with contextlib.suppress(Exception):
|
||||||
self._registered_desc = None
|
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self._nixl_wrapper.release_dlist_handle(local_h)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self._nixl_wrapper.release_dlist_handle(remote_h)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
for descs in self._registered_descs:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
self._nixl_wrapper.deregister_memory(descs)
|
||||||
|
self._registered_descs.clear()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
for agent_name in self._remote_agents.values():
|
for agent_name in self._remote_agents.values():
|
||||||
self._nixl_wrapper.remove_remote_agent(agent_name)
|
with contextlib.suppress(Exception):
|
||||||
|
self._nixl_wrapper.remove_remote_agent(agent_name)
|
||||||
self._remote_agents.clear()
|
self._remote_agents.clear()
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
|
|
||||||
|
|
||||||
|
|
||||||
class PyNcclEplbCommunicator(EplbCommunicator):
|
class PyNcclEplbCommunicator(EplbCommunicator):
|
||||||
@@ -600,7 +609,7 @@ class PyNcclEplbCommunicator(EplbCommunicator):
|
|||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
|
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
|
||||||
|
|
||||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
def execute(self) -> None:
|
||||||
if self._group_started:
|
if self._group_started:
|
||||||
self._pynccl_comm.group_end()
|
self._pynccl_comm.group_end()
|
||||||
self._group_started = False
|
self._group_started = False
|
||||||
@@ -609,7 +618,8 @@ class PyNcclEplbCommunicator(EplbCommunicator):
|
|||||||
def create_eplb_communicator(
|
def create_eplb_communicator(
|
||||||
group_coordinator: GroupCoordinator,
|
group_coordinator: GroupCoordinator,
|
||||||
backend: str | None,
|
backend: str | None,
|
||||||
expert_weights: Sequence[torch.Tensor],
|
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||||
|
expert_buffer: Sequence[torch.Tensor],
|
||||||
) -> EplbCommunicator:
|
) -> EplbCommunicator:
|
||||||
"""Create an EPLB communicator for the given backend.
|
"""Create an EPLB communicator for the given backend.
|
||||||
|
|
||||||
@@ -624,16 +634,18 @@ def create_eplb_communicator(
|
|||||||
``"pynccl"`` in that case. When tensors reside on CPU,
|
``"pynccl"`` in that case. When tensors reside on CPU,
|
||||||
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
|
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
|
||||||
process group.
|
process group.
|
||||||
expert_weights: Expert weight tensors from *one* MoE layer.
|
expert_weights: Expert weight tensors for *all* MoE layers.
|
||||||
NixlEplbCommunicator pre-allocates send/recv buffers sized
|
Shape ``(num_layers)(num_tensors_per_layer)``.
|
||||||
to this layer, so all other MoE layers must have the same
|
NixlEplbCommunicator registers all layers with NIXL for
|
||||||
tensor count, shapes, and dtypes.
|
zero-copy RDMA reads.
|
||||||
|
expert_buffer: Pre-allocated receive buffer tensors (one per
|
||||||
|
weight tensor in a single layer).
|
||||||
"""
|
"""
|
||||||
# Keep a safe default for callers that have not resolved communicator yet.
|
|
||||||
if backend is None:
|
if backend is None:
|
||||||
backend = "torch_nccl"
|
backend = "torch_nccl"
|
||||||
|
|
||||||
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
|
first_layer = expert_weights[0] if expert_weights else []
|
||||||
|
tensor_device_type = first_layer[0].device.type if first_layer else "cpu"
|
||||||
torch_group = (
|
torch_group = (
|
||||||
group_coordinator.cpu_group
|
group_coordinator.cpu_group
|
||||||
if tensor_device_type == "cpu"
|
if tensor_device_type == "cpu"
|
||||||
@@ -649,7 +661,7 @@ def create_eplb_communicator(
|
|||||||
unsupported_dtypes = sorted(
|
unsupported_dtypes = sorted(
|
||||||
{
|
{
|
||||||
tensor.dtype
|
tensor.dtype
|
||||||
for tensor in expert_weights
|
for tensor in first_layer
|
||||||
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
|
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
|
||||||
},
|
},
|
||||||
key=str,
|
key=str,
|
||||||
@@ -704,7 +716,8 @@ def create_eplb_communicator(
|
|||||||
try:
|
try:
|
||||||
return NixlEplbCommunicator(
|
return NixlEplbCommunicator(
|
||||||
cpu_group=group_coordinator.cpu_group,
|
cpu_group=group_coordinator.cpu_group,
|
||||||
expert_weights=expert_weights,
|
all_expert_weights=expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -450,7 +450,8 @@ class EplbState:
|
|||||||
communicator = create_eplb_communicator(
|
communicator = create_eplb_communicator(
|
||||||
group_coordinator=get_eplb_group(),
|
group_coordinator=get_eplb_group(),
|
||||||
backend=self.parallel_config.eplb_config.communicator,
|
backend=self.parallel_config.eplb_config.communicator,
|
||||||
expert_weights=model.expert_weights[0],
|
expert_weights=model.expert_weights,
|
||||||
|
expert_buffer=expert_buffer,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_state = EplbModelState(
|
model_state = EplbModelState(
|
||||||
@@ -766,6 +767,7 @@ class EplbState:
|
|||||||
eplb_model_state.physical_to_logical_map,
|
eplb_model_state.physical_to_logical_map,
|
||||||
new_physical_to_logical_map,
|
new_physical_to_logical_map,
|
||||||
eplb_model_state.model.expert_weights,
|
eplb_model_state.model.expert_weights,
|
||||||
|
eplb_model_state.expert_buffer,
|
||||||
ep_group,
|
ep_group,
|
||||||
eplb_model_state.communicator,
|
eplb_model_state.communicator,
|
||||||
is_profile,
|
is_profile,
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ def move_to_buffer(
|
|||||||
cuda_stream: torch.cuda.Stream | None,
|
cuda_stream: torch.cuda.Stream | None,
|
||||||
ep_rank: int,
|
ep_rank: int,
|
||||||
communicator: EplbCommunicator,
|
communicator: EplbCommunicator,
|
||||||
|
layer_idx: int = 0,
|
||||||
) -> TransferMetadata:
|
) -> TransferMetadata:
|
||||||
"""
|
"""
|
||||||
Rearranges expert weights during EPLB rebalancing.
|
Rearranges expert weights during EPLB rebalancing.
|
||||||
@@ -193,6 +194,7 @@ def move_to_buffer(
|
|||||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||||
ep_rank: Rank of this process in expert parallel group.
|
ep_rank: Rank of this process in expert parallel group.
|
||||||
communicator: EplbCommunicator instance for P2P communication.
|
communicator: EplbCommunicator instance for P2P communication.
|
||||||
|
layer_idx: Index of the MoE layer being transferred.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TransferMetadata: Metadata needed for completing remote weight transfers.
|
TransferMetadata: Metadata needed for completing remote weight transfers.
|
||||||
@@ -265,6 +267,8 @@ def move_to_buffer(
|
|||||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||||
b[dst].copy_(w[src_local], non_blocking=True)
|
b[dst].copy_(w[src_local], non_blocking=True)
|
||||||
|
|
||||||
|
communicator.set_transfer_context(old_indices, layer_idx)
|
||||||
|
|
||||||
# 2. Post sends
|
# 2. Post sends
|
||||||
if send_count > 0:
|
if send_count > 0:
|
||||||
experts = send_expert_ids[:send_count]
|
experts = send_expert_ids[:send_count]
|
||||||
@@ -331,9 +335,8 @@ def move_to_buffer(
|
|||||||
expert_id=int(expert),
|
expert_id=int(expert),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Execute the P2P operations. The real communication happens here.
|
# 4. Execute transfers and wait for completion.
|
||||||
communicator.execute(old_indices=old_indices)
|
communicator.execute()
|
||||||
# wait for the communication to finish
|
|
||||||
return TransferMetadata(
|
return TransferMetadata(
|
||||||
is_unchanged=is_unchanged,
|
is_unchanged=is_unchanged,
|
||||||
is_received_locally=is_received_locally,
|
is_received_locally=is_received_locally,
|
||||||
@@ -431,6 +434,7 @@ def transfer_layer(
|
|||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
cuda_stream: torch.cuda.Stream | None = None,
|
cuda_stream: torch.cuda.Stream | None = None,
|
||||||
rank_mapping: dict[int, int] | None = None,
|
rank_mapping: dict[int, int] | None = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
) -> TransferMetadata:
|
) -> TransferMetadata:
|
||||||
"""
|
"""
|
||||||
Rearranges the expert weights in place according to the new expert indices.
|
Rearranges the expert weights in place according to the new expert indices.
|
||||||
@@ -452,6 +456,7 @@ def transfer_layer(
|
|||||||
communications to reserve enough memory for the buffers.
|
communications to reserve enough memory for the buffers.
|
||||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||||
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
||||||
|
layer_idx: Index of the MoE layer being transferred.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TransferMetadata: Metadata needed for completing remote weight transfers,
|
TransferMetadata: Metadata needed for completing remote weight transfers,
|
||||||
@@ -499,6 +504,7 @@ def transfer_layer(
|
|||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
ep_rank=ep_group.rank(),
|
ep_rank=ep_group.rank(),
|
||||||
communicator=communicator,
|
communicator=communicator,
|
||||||
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -506,6 +512,7 @@ def rearrange_expert_weights_inplace(
|
|||||||
old_global_expert_indices: torch.Tensor,
|
old_global_expert_indices: torch.Tensor,
|
||||||
new_global_expert_indices: torch.Tensor,
|
new_global_expert_indices: torch.Tensor,
|
||||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||||
|
expert_buffer: Sequence[torch.Tensor],
|
||||||
ep_group: ProcessGroup,
|
ep_group: ProcessGroup,
|
||||||
communicator: EplbCommunicator,
|
communicator: EplbCommunicator,
|
||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
@@ -524,6 +531,8 @@ def rearrange_expert_weights_inplace(
|
|||||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||||
For example, a linear layer may have up and down projection,
|
For example, a linear layer may have up and down projection,
|
||||||
so weight_count = 2. Each weight's hidden size can be different.
|
so weight_count = 2. Each weight's hidden size can be different.
|
||||||
|
expert_buffer: Pre-allocated receive buffer tensors (one per
|
||||||
|
weight tensor in a single layer).
|
||||||
ep_group: The device process group for expert parallelism.
|
ep_group: The device process group for expert parallelism.
|
||||||
communicator: EplbCommunicator instance for P2P communication.
|
communicator: EplbCommunicator instance for P2P communication.
|
||||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||||
@@ -566,10 +575,10 @@ def rearrange_expert_weights_inplace(
|
|||||||
# Reserve NCCL communication buffers via a dummy all_gather.
|
# Reserve NCCL communication buffers via a dummy all_gather.
|
||||||
# Backends that pre-allocate their own transfer buffers
|
# Backends that pre-allocate their own transfer buffers
|
||||||
# skip this to avoid the extra memory spike during profiling.
|
# skip this to avoid the extra memory spike during profiling.
|
||||||
weights_buffer: list[torch.Tensor] = [
|
profile_buffer: list[torch.Tensor] = [
|
||||||
torch.empty_like(w) for w in first_layer_weights
|
torch.empty_like(w) for w in first_layer_weights
|
||||||
]
|
]
|
||||||
for weight, buffer in zip(expert_weights[0], weights_buffer):
|
for weight, buffer in zip(expert_weights[0], profile_buffer):
|
||||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
all_gather(
|
all_gather(
|
||||||
@@ -579,10 +588,7 @@ def rearrange_expert_weights_inplace(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Buffers to hold the expert weights during the exchange.
|
weights_buffer = list(expert_buffer)
|
||||||
# NOTE: Currently we assume the same weights across different layers
|
|
||||||
# have the same shape.
|
|
||||||
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
|
|
||||||
|
|
||||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
|
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
|
||||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
|
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
|
||||||
@@ -597,6 +603,7 @@ def rearrange_expert_weights_inplace(
|
|||||||
cuda_stream=None,
|
cuda_stream=None,
|
||||||
ep_rank=ep_rank,
|
ep_rank=ep_rank,
|
||||||
communicator=communicator,
|
communicator=communicator,
|
||||||
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
move_from_buffer(
|
move_from_buffer(
|
||||||
|
|||||||
@@ -120,15 +120,12 @@ class KVOutputAggregator:
|
|||||||
# Use the first worker's kv_connector_stats as accumulator.
|
# Use the first worker's kv_connector_stats as accumulator.
|
||||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||||
if aggregated_kv_connector_stats is None:
|
assert isinstance(
|
||||||
aggregated_kv_connector_stats = kv_connector_stats
|
aggregated_kv_connector_stats, type(kv_connector_stats)
|
||||||
else:
|
)
|
||||||
assert isinstance(
|
aggregated_kv_connector_stats = aggregated_kv_connector_stats.aggregate(
|
||||||
aggregated_kv_connector_stats, type(kv_connector_stats)
|
kv_connector_stats
|
||||||
)
|
)
|
||||||
aggregated_kv_connector_stats = (
|
|
||||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Aggregate kv_connector_worker_meta from all workers.
|
# Aggregate kv_connector_worker_meta from all workers.
|
||||||
if aggregated_kv_connector_worker_meta is None:
|
if aggregated_kv_connector_worker_meta is None:
|
||||||
|
|||||||
@@ -2333,9 +2333,25 @@ class NixlConnectorWorker:
|
|||||||
for i, remote_group in enumerate(remote_block_ids):
|
for i, remote_group in enumerate(remote_block_ids):
|
||||||
num_local_blocks = len(local_block_ids[i])
|
num_local_blocks = len(local_block_ids[i])
|
||||||
num_remote_blocks = len(remote_group)
|
num_remote_blocks = len(remote_group)
|
||||||
if _is_ssm_spec(self._group_spec_types[i]):
|
if (
|
||||||
assert num_local_blocks == num_remote_blocks
|
_is_ssm_spec(self._group_spec_types[i])
|
||||||
|
and num_local_blocks < num_remote_blocks
|
||||||
|
):
|
||||||
|
# NOTE (NickLucche): With prefix caching on SSM, (remote) blocks
|
||||||
|
# prior to the last one are placeholders (null blocks). Mind that
|
||||||
|
# this doesn't really impact transfer, as we only still care about
|
||||||
|
# the last "block", the full in-place state.
|
||||||
|
assert num_local_blocks == 1, "SSM can only have one local block"
|
||||||
|
remote_block_ids[i] = remote_group[-num_local_blocks:]
|
||||||
|
elif (
|
||||||
|
self._physical_blocks_per_logical_kv_block
|
||||||
|
== remote_physical_per_logical
|
||||||
|
and num_local_blocks < num_remote_blocks
|
||||||
|
):
|
||||||
|
# Partial prefix cache hit for FA group.
|
||||||
|
remote_block_ids[i] = remote_group[-num_local_blocks:]
|
||||||
else:
|
else:
|
||||||
|
# TODO Handle prefix caching with different block_sizes
|
||||||
max_padding = max(
|
max_padding = max(
|
||||||
self._physical_blocks_per_logical_kv_block,
|
self._physical_blocks_per_logical_kv_block,
|
||||||
remote_physical_per_logical,
|
remote_physical_per_logical,
|
||||||
|
|||||||
@@ -1270,9 +1270,6 @@ def get_dcp_group() -> GroupCoordinator:
|
|||||||
return _DCP
|
return _DCP
|
||||||
|
|
||||||
|
|
||||||
# kept for backward compatibility
|
|
||||||
get_context_model_parallel_group = get_dcp_group
|
|
||||||
|
|
||||||
_PP: GroupCoordinator | None = None
|
_PP: GroupCoordinator | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1840,31 +1837,6 @@ def model_parallel_is_initialized():
|
|||||||
_TP_STATE_PATCHED = False
|
_TP_STATE_PATCHED = False
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
||||||
"""Patch the tp group temporarily until this function ends.
|
|
||||||
|
|
||||||
This method is for draft workers of speculative decoding to run draft model
|
|
||||||
with different tp degree from that of target model workers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tp_group (GroupCoordinator): the tp group coordinator
|
|
||||||
"""
|
|
||||||
global _TP_STATE_PATCHED
|
|
||||||
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
|
|
||||||
|
|
||||||
_TP_STATE_PATCHED = True
|
|
||||||
old_tp_group = get_tp_group()
|
|
||||||
global _TP
|
|
||||||
_TP = tp_group
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# restore the original state
|
|
||||||
_TP_STATE_PATCHED = False
|
|
||||||
_TP = old_tp_group
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_world_size() -> int:
|
def get_tensor_model_parallel_world_size() -> int:
|
||||||
"""Return world size for the tensor model parallel group."""
|
"""Return world size for the tensor model parallel group."""
|
||||||
return get_tp_group().world_size
|
return get_tp_group().world_size
|
||||||
@@ -1875,16 +1847,6 @@ def get_tensor_model_parallel_rank() -> int:
|
|||||||
return get_tp_group().rank_in_group
|
return get_tp_group().rank_in_group
|
||||||
|
|
||||||
|
|
||||||
def get_decode_context_model_parallel_world_size() -> int:
|
|
||||||
"""Return world size for the decode context model parallel group."""
|
|
||||||
return get_dcp_group().world_size
|
|
||||||
|
|
||||||
|
|
||||||
def get_decode_context_model_parallel_rank() -> int:
|
|
||||||
"""Return my rank for the decode context model parallel group."""
|
|
||||||
return get_dcp_group().rank_in_group
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_count() -> int:
|
def get_node_count() -> int:
|
||||||
"""Return the total number of nodes in the distributed environment."""
|
"""Return the total number of nodes in the distributed environment."""
|
||||||
assert _NODE_COUNT is not None, "distributed environment is not initialized"
|
assert _NODE_COUNT is not None, "distributed environment is not initialized"
|
||||||
|
|||||||
@@ -64,6 +64,20 @@ def divide(numerator, denominator):
|
|||||||
return numerator // denominator
|
return numerator // denominator
|
||||||
|
|
||||||
|
|
||||||
|
def is_weak_contiguous(inp: torch.Tensor) -> bool:
|
||||||
|
"""Check that *inp* occupies a single contiguous block of memory.
|
||||||
|
|
||||||
|
Unlike ``torch.Tensor.is_contiguous()``, this also accepts tensors
|
||||||
|
whose strides are not strictly C-contiguous (e.g. column-major) as
|
||||||
|
long as the underlying storage from the tensor's offset onward is
|
||||||
|
exactly ``numel * element_size`` bytes.
|
||||||
|
"""
|
||||||
|
return inp.is_contiguous() or (
|
||||||
|
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||||
|
== inp.numel() * inp.element_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_tensor_along_last_dim(
|
def split_tensor_along_last_dim(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
num_partitions: int,
|
num_partitions: int,
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ from vllm.entrypoints.anthropic.protocol import (
|
|||||||
)
|
)
|
||||||
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
||||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||||
from vllm.entrypoints.openai.utils import validate_json_request
|
from vllm.entrypoints.serve.utils.api_utils import (
|
||||||
from vllm.entrypoints.utils import (
|
|
||||||
load_aware_call,
|
load_aware_call,
|
||||||
|
validate_json_request,
|
||||||
with_cancellation,
|
with_cancellation,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from vllm.entrypoints.anthropic.protocol import (
|
|||||||
AnthropicUsage,
|
AnthropicUsage,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
|
||||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -45,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
|||||||
StreamOptions,
|
StreamOptions,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||||
|
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import vllm.envs as envs
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.entrypoints.utils import with_cancellation
|
from vllm.entrypoints.serve.utils.api_utils import with_cancellation
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import typing
|
|||||||
|
|
||||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||||
from vllm.entrypoints.cli.types import CLISubcommand
|
from vllm.entrypoints.cli.types import CLISubcommand
|
||||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user