mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Merge branch 'main' into wentao-enable-all-dense-for-mrv2
This commit is contained in:
@@ -1299,12 +1299,11 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_generate.py
|
||||
- pytest -v -s entrypoints/offline_mode
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Integration (Pooling) # TBD
|
||||
timeout_in_minutes: 180
|
||||
@@ -1346,7 +1345,7 @@ steps:
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
||||
|
||||
- label: OpenAI API correctness # TBD
|
||||
timeout_in_minutes: 180
|
||||
|
||||
@@ -11,7 +11,7 @@ steps:
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/offline_mode --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/serve --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling --ignore=entrypoints/speech_to_text --ignore=tests/entrypoints/generate
|
||||
|
||||
- label: Entrypoints Integration (LLM)
|
||||
key: entrypoints-integration-llm
|
||||
@@ -20,12 +20,11 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py --ignore=entrypoints/llm/offline_mode
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
- pytest -v -s entrypoints/llm/offline_mode # Needs to avoid interference with other tests
|
||||
mirror:
|
||||
amd:
|
||||
device: mi325_1
|
||||
|
||||
@@ -33,10 +33,3 @@ share/python-wheels/
|
||||
*.egg
|
||||
MANIFEST
|
||||
rust/target/
|
||||
# Not needed in Docker builds
|
||||
docs/
|
||||
.github/
|
||||
.pre-commit-config.yaml
|
||||
.clang-format
|
||||
.gitattributes
|
||||
format.sh
|
||||
|
||||
+2
-1
@@ -34,10 +34,11 @@
|
||||
/vllm/entrypoints/speech_to_text/realtime @njhill
|
||||
/vllm/entrypoints/speech_to_text @NickLucche
|
||||
/vllm/entrypoints/pooling @noooop
|
||||
/vllm/entrypoints/sagemaker @DarkLight1337
|
||||
/vllm/entrypoints/serve/sagemaker @DarkLight1337
|
||||
/vllm/entrypoints/serve @njhill
|
||||
/vllm/entrypoints/*.py @njhill
|
||||
/vllm/entrypoints/chat_utils.py @DarkLight1337
|
||||
/vllm/entrypoints/offline_utils.py @DarkLight1337
|
||||
/vllm/entrypoints/llm.py @DarkLight1337
|
||||
|
||||
# Rust Frontend
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
actions: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
|
||||
- uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v10.3.0
|
||||
with:
|
||||
# Increasing this value ensures that changes to this workflow
|
||||
# propagate to all issues and PRs in days rather than months
|
||||
|
||||
@@ -102,6 +102,35 @@ constexpr float NUM_TOKEN_CUTOFF = 1024;
|
||||
constexpr int kNumLanes = 32;
|
||||
constexpr int kElemsPerLane = kHeadDim / kNumLanes; // 16
|
||||
|
||||
// Pack this lane's 16 fp32 elements into per-tensor E4M3 FP8 (one uint4 = 16
|
||||
// B), scaling by `scale` (a reciprocal scale) and saturating to ±448. Used by
|
||||
// the FlashInfer full-cache path for both the Q and KV stores.
|
||||
__device__ __forceinline__ uint4 packFp8E4M3x16(float const* values,
|
||||
float const scale) {
|
||||
#ifndef USE_ROCM
|
||||
uint4 out;
|
||||
auto* out2 = reinterpret_cast<__nv_fp8x2_storage_t*>(&out);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane / 2; i++) {
|
||||
float2 scaled =
|
||||
make_float2(values[2 * i] * scale, values[2 * i + 1] * scale);
|
||||
scaled.x = fminf(fmaxf(scaled.x, -kFp8Max), kFp8Max);
|
||||
scaled.y = fminf(fmaxf(scaled.y, -kFp8Max), kFp8Max);
|
||||
out2[i] = __nv_cvt_float2_to_fp8x2(scaled, __NV_SATFINITE, __NV_E4M3);
|
||||
}
|
||||
return out;
|
||||
#else
|
||||
uint8_t out_bytes[kElemsPerLane];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
float scaled = values[i] * scale;
|
||||
scaled = fminf(fmaxf(scaled, -kFp8Max), kFp8Max);
|
||||
out_bytes[i] = rocm_cvt_float_to_fp8_e4m3(scaled);
|
||||
}
|
||||
return *reinterpret_cast<uint4 const*>(out_bytes);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// Small inline helpers
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
@@ -649,6 +678,257 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert(
|
||||
#undef DISPATCH
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// FlashInfer full-cache kernel
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
//
|
||||
// Sibling to the FlashMLA kernel above, used by the FlashInfer V4 sparse-MLA
|
||||
// backend. Differences from the legacy path:
|
||||
// * No Q head padding — output Q layout matches the input num_heads_q.
|
||||
// * KV is written as a *contiguous* 512-wide row per token (token-strided),
|
||||
// not the legacy UE8M0 paged layout with a separate scale tail.
|
||||
// * Q/KV are stored either as bf16 or as per-tensor E4M3 FP8 (one global
|
||||
// scale), selected by the STORE_Q_FP8 / STORE_KV_FP8 template flags.
|
||||
//
|
||||
// Grid: 1D, gridDim.x = ceil(num_tokens_full * (num_heads_q + 1) / warps).
|
||||
// Each warp handles one (token, slot): slot < num_heads_q → Q, slot ==
|
||||
// num_heads_q → KV.
|
||||
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
|
||||
__global__ void fusedDeepseekV4FullCacheKernel(
|
||||
scalar_t_in* __restrict__ q_inout, // [N, H, 512], in place (bf16)
|
||||
uint8_t* __restrict__ q_fp8_out, // [N, H, 512] fp8, optional
|
||||
int64_t const q_fp8_stride0, // elements (fp8 == bytes)
|
||||
int64_t const q_fp8_stride1, // elements (fp8 == bytes)
|
||||
scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16
|
||||
uint8_t* __restrict__ k_cache, // contiguous bf16 or fp8 cache
|
||||
int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64
|
||||
int64_t const* __restrict__ position_ids, // [N] i64
|
||||
float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32
|
||||
float const* __restrict__ fp8_scale_ptr, // scalar, KV fp8 only
|
||||
float const* __restrict__ q_fp8_scale_inv, // scalar, Q fp8 only
|
||||
float const eps,
|
||||
int const num_tokens_full, // = q.size(0) = kv.size(0)
|
||||
int const num_tokens_insert, // = slot_mapping.size(0)
|
||||
int const num_heads_q, // H (no padding)
|
||||
int const cache_block_size, // tokens per cache block
|
||||
int64_t const kv_block_stride, // bytes per cache block
|
||||
int64_t const kv_token_stride) { // bytes per cache token
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
if constexpr (std::is_same_v<scalar_t_in, c10::BFloat16>) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
using Converter = vllm::_typeConvert<scalar_t_in>;
|
||||
int const warpsPerBlock = blockDim.x / 32;
|
||||
int const warpId = threadIdx.x / 32;
|
||||
int const laneId = threadIdx.x % 32;
|
||||
int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId;
|
||||
|
||||
int const slotsPerToken = num_heads_q + 1;
|
||||
int const tokenIdx = globalWarpIdx / slotsPerToken;
|
||||
int const slotIdx = globalWarpIdx % slotsPerToken;
|
||||
if (tokenIdx >= num_tokens_full) return;
|
||||
bool const isKV = (slotIdx == num_heads_q);
|
||||
// KV branch: skip DP-padded tokens (no slot reserved for them).
|
||||
if (isKV && tokenIdx >= num_tokens_insert) return;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
int const dim_base = laneId * kElemsPerLane; // in [0, 512) step 16
|
||||
scalar_t_in const* src_ptr;
|
||||
if (isKV) {
|
||||
src_ptr = kv_in + static_cast<int64_t>(tokenIdx) * kHeadDim + dim_base;
|
||||
} else {
|
||||
src_ptr = q_inout +
|
||||
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) *
|
||||
kHeadDim +
|
||||
dim_base;
|
||||
}
|
||||
uint4 const v0 = *reinterpret_cast<uint4 const*>(src_ptr);
|
||||
uint4 const v1 = *reinterpret_cast<uint4 const*>(src_ptr + 8);
|
||||
|
||||
// ── Decode bf16 → 16 fp32 registers ───────────────────────────────────
|
||||
float elements[kElemsPerLane];
|
||||
{
|
||||
auto const* p0 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type const*>(&v0);
|
||||
auto const* p1 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type const*>(&v1);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float2 f2 = Converter::convert(p0[i]);
|
||||
elements[2 * i] = f2.x;
|
||||
elements[2 * i + 1] = f2.y;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
float2 f2 = Converter::convert(p1[i]);
|
||||
elements[8 + 2 * i] = f2.x;
|
||||
elements[8 + 2 * i + 1] = f2.y;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Q branch: RMSNorm (no weight) ─────────────────────────────────────
|
||||
if (!isKV) {
|
||||
float sumOfSquares = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
sumOfSquares += elements[i] * elements[i];
|
||||
}
|
||||
sumOfSquares = warpSum<float>(sumOfSquares);
|
||||
float const rms_rcp =
|
||||
rsqrtf(sumOfSquares / static_cast<float>(kHeadDim) + eps);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kElemsPerLane; i++) {
|
||||
elements[i] = elements[i] * rms_rcp;
|
||||
}
|
||||
}
|
||||
|
||||
// ── GPT-J RoPE on dims [NOPE_DIM, HEAD_DIM) ───────────────────────────
|
||||
bool const is_rope_lane = dim_base >= kNopeDim;
|
||||
if (is_rope_lane) {
|
||||
int64_t const pos = position_ids[tokenIdx];
|
||||
constexpr int kHalfRope = kRopeDim / 2;
|
||||
float const* cos_ptr = cos_sin_cache + pos * kRopeDim;
|
||||
float const* sin_ptr = cos_ptr + kHalfRope;
|
||||
int const rope_local_base = dim_base - kNopeDim;
|
||||
int const half_base = rope_local_base >> 1;
|
||||
float4 const c0 = *reinterpret_cast<float4 const*>(cos_ptr + half_base);
|
||||
float4 const c1 = *reinterpret_cast<float4 const*>(cos_ptr + half_base + 4);
|
||||
float4 const s0 = *reinterpret_cast<float4 const*>(sin_ptr + half_base);
|
||||
float4 const s1 = *reinterpret_cast<float4 const*>(sin_ptr + half_base + 4);
|
||||
float const cos_arr[8] = {c0.x, c0.y, c0.z, c0.w, c1.x, c1.y, c1.z, c1.w};
|
||||
float const sin_arr[8] = {s0.x, s0.y, s0.z, s0.w, s1.x, s1.y, s1.z, s1.w};
|
||||
#pragma unroll
|
||||
for (int p = 0; p < kElemsPerLane / 2; p++) {
|
||||
float const x_even = elements[2 * p];
|
||||
float const x_odd = elements[2 * p + 1];
|
||||
elements[2 * p] = x_even * cos_arr[p] - x_odd * sin_arr[p];
|
||||
elements[2 * p + 1] = x_even * sin_arr[p] + x_odd * cos_arr[p];
|
||||
}
|
||||
}
|
||||
|
||||
// ── Store ─────────────────────────────────────────────────────────────
|
||||
if (!isKV) {
|
||||
if constexpr (STORE_Q_FP8) {
|
||||
float const scale_inv = VLLM_LDG(q_fp8_scale_inv);
|
||||
uint4 const out = packFp8E4M3x16(elements, scale_inv);
|
||||
uint8_t* dst = q_fp8_out +
|
||||
static_cast<int64_t>(tokenIdx) * q_fp8_stride0 +
|
||||
static_cast<int64_t>(slotIdx) * q_fp8_stride1 + dim_base;
|
||||
*reinterpret_cast<uint4*>(dst) = out;
|
||||
} else {
|
||||
uint4 out0, out1;
|
||||
auto* po0 = reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
|
||||
auto* po1 = reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po0[i] = Converter::convert(
|
||||
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po1[i] = Converter::convert(
|
||||
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
|
||||
}
|
||||
scalar_t_in* dst =
|
||||
q_inout +
|
||||
(static_cast<int64_t>(tokenIdx) * num_heads_q + slotIdx) * kHeadDim +
|
||||
dim_base;
|
||||
*reinterpret_cast<uint4*>(dst) = out0;
|
||||
*reinterpret_cast<uint4*>(dst + 8) = out1;
|
||||
}
|
||||
} else {
|
||||
int64_t const slot_id = slot_mapping[tokenIdx];
|
||||
if (slot_id >= 0) {
|
||||
int64_t const block_idx = slot_id / cache_block_size;
|
||||
int64_t const pos_in_block = slot_id % cache_block_size;
|
||||
uint8_t* cache_row =
|
||||
k_cache + block_idx * kv_block_stride + pos_in_block * kv_token_stride;
|
||||
if constexpr (STORE_KV_FP8) {
|
||||
float const inv_scale = 1.0f / VLLM_LDG(fp8_scale_ptr);
|
||||
uint4 const out = packFp8E4M3x16(elements, inv_scale);
|
||||
*reinterpret_cast<uint4*>(cache_row + dim_base) = out;
|
||||
} else {
|
||||
uint4 out0, out1;
|
||||
auto* po0 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type*>(&out0);
|
||||
auto* po1 =
|
||||
reinterpret_cast<typename Converter::packed_hip_type*>(&out1);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po0[i] = Converter::convert(
|
||||
make_float2(elements[2 * i], elements[2 * i + 1]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
po1[i] = Converter::convert(
|
||||
make_float2(elements[8 + 2 * i], elements[8 + 2 * i + 1]));
|
||||
}
|
||||
scalar_t_in* dst = reinterpret_cast<scalar_t_in*>(cache_row) + dim_base;
|
||||
*reinterpret_cast<uint4*>(dst) = out0;
|
||||
*reinterpret_cast<uint4*>(dst + 8) = out1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Configure + launch helper shared by the bf16 and fp8 full-cache launchers.
|
||||
template <typename scalar_t_in, bool STORE_Q_FP8, bool STORE_KV_FP8>
|
||||
static void launchFullCacheKernel(
|
||||
scalar_t_in* q_inout, uint8_t* q_fp8_out, int64_t q_fp8_stride0,
|
||||
int64_t q_fp8_stride1, scalar_t_in const* kv_in, uint8_t* k_cache,
|
||||
int64_t const* slot_mapping, int64_t const* position_ids,
|
||||
float const* cos_sin_cache, float const* fp8_scale,
|
||||
float const* q_fp8_scale_inv, float const eps, int const num_tokens_full,
|
||||
int const num_tokens_insert, int const num_heads_q,
|
||||
int const cache_block_size, int64_t const kv_block_stride,
|
||||
int64_t const kv_token_stride, char const* op_name, cudaStream_t stream) {
|
||||
constexpr int kBlockSize = 256;
|
||||
constexpr int kWarpsPerBlock = kBlockSize / 32;
|
||||
int64_t const total_warps =
|
||||
static_cast<int64_t>(num_tokens_full) * (num_heads_q + 1);
|
||||
int const grid =
|
||||
static_cast<int>((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock);
|
||||
auto* kernel =
|
||||
fusedDeepseekV4FullCacheKernel<scalar_t_in, STORE_Q_FP8, STORE_KV_FP8>;
|
||||
#ifndef USE_ROCM
|
||||
static int const sm_version = getSMVersion();
|
||||
STD_TORCH_CHECK(sm_version >= 80, op_name,
|
||||
" requires sm_80+ (Ampere or newer); got sm_", sm_version);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = dim3(grid);
|
||||
config.blockDim = dim3(kBlockSize);
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
config.attrs = attrs;
|
||||
config.numAttrs = (sm_version >= 90) ? 1 : 0;
|
||||
cudaLaunchKernelEx(&config, kernel, q_inout, q_fp8_out, q_fp8_stride0,
|
||||
q_fp8_stride1, kv_in, k_cache, slot_mapping, position_ids,
|
||||
cos_sin_cache, fp8_scale, q_fp8_scale_inv, eps,
|
||||
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||
cache_block_size, kv_block_stride, kv_token_stride);
|
||||
#else
|
||||
kernel<<<grid, kBlockSize, 0, stream>>>(
|
||||
q_inout, q_fp8_out, q_fp8_stride0, q_fp8_stride1, kv_in, k_cache,
|
||||
slot_mapping, position_ids, cos_sin_cache, fp8_scale, q_fp8_scale_inv,
|
||||
eps, num_tokens_full, num_tokens_insert, num_heads_q, cache_block_size,
|
||||
kv_block_stride, kv_token_stride);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace deepseek_v4_fused_ops
|
||||
} // namespace vllm
|
||||
|
||||
@@ -735,3 +1015,167 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
});
|
||||
return q_out;
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
// FlashInfer full-cache torch ops
|
||||
// ────────────────────────────────────────────────────────────────────────────
|
||||
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
||||
torch::stable::Tensor& q, // [N, H, 512] bf16, in place
|
||||
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
|
||||
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] bf16
|
||||
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
|
||||
torch::stable::Tensor const& position_ids, // [N] int64
|
||||
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
|
||||
double eps, int64_t cache_block_size) {
|
||||
using torch::headeronly::ScalarType;
|
||||
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
|
||||
"q must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
|
||||
"kv must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
|
||||
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
|
||||
slot_mapping.scalar_type() == ScalarType::Long,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
|
||||
position_ids.scalar_type() == ScalarType::Long,
|
||||
"position_ids must be int64 CUDA");
|
||||
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
|
||||
cos_sin_cache.scalar_type() == ScalarType::Float &&
|
||||
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64] float32");
|
||||
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
|
||||
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
STD_TORCH_CHECK(q.scalar_type() == ScalarType::BFloat16 &&
|
||||
kv.scalar_type() == ScalarType::BFloat16,
|
||||
"q and kv must be bfloat16");
|
||||
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
|
||||
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
|
||||
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
|
||||
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::BFloat16,
|
||||
"k_cache must be bfloat16");
|
||||
|
||||
int const num_tokens_full = static_cast<int>(q.size(0));
|
||||
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
int const num_heads_q = static_cast<int>(q.size(1));
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
q.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
|
||||
|
||||
// bf16 cache: 2 bytes/element -> byte strides for the uint8-addressed kernel.
|
||||
int64_t const kv_block_stride = k_cache.stride(0) * 2;
|
||||
int64_t const kv_token_stride = k_cache.stride(1) * 2;
|
||||
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
q.scalar_type(),
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert", [&] {
|
||||
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, false,
|
||||
false>(
|
||||
reinterpret_cast<scalar_t*>(q.mutable_data_ptr()), nullptr, 0, 0,
|
||||
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
|
||||
slot_mapping.const_data_ptr<int64_t>(),
|
||||
position_ids.const_data_ptr<int64_t>(),
|
||||
cos_sin_cache.const_data_ptr<float>(), nullptr, nullptr,
|
||||
static_cast<float>(eps), num_tokens_full, num_tokens_insert,
|
||||
num_heads_q, static_cast<int>(cache_block_size), kv_block_stride,
|
||||
kv_token_stride,
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
||||
torch::stable::Tensor const& q, // [N, H, 512] bf16, read-only
|
||||
torch::stable::Tensor const& kv, // [N, 512] bf16, read-only
|
||||
torch::stable::Tensor& q_fp8, // [N, H, 512] fp8 e4m3
|
||||
torch::stable::Tensor& k_cache, // [num_blocks, bs, 512] fp8
|
||||
torch::stable::Tensor const& slot_mapping, // [num_tokens_insert] int64
|
||||
torch::stable::Tensor const& position_ids, // [N] int64
|
||||
torch::stable::Tensor const& cos_sin_cache, // [max_pos, 64] float32
|
||||
torch::stable::Tensor const& fp8_scale, // scalar float32 (KV scale)
|
||||
torch::stable::Tensor const& q_fp8_scale_inv, // scalar float32 (1 / Q scale)
|
||||
double eps, int64_t cache_block_size) {
|
||||
using torch::headeronly::ScalarType;
|
||||
STD_TORCH_CHECK(q.device().is_cuda() && q.is_contiguous(),
|
||||
"q must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(kv.device().is_cuda() && kv.is_contiguous(),
|
||||
"kv must be contiguous CUDA");
|
||||
STD_TORCH_CHECK(q_fp8.device().is_cuda() && q_fp8.is_contiguous() &&
|
||||
q_fp8.scalar_type() == ScalarType::Float8_e4m3fn &&
|
||||
q_fp8.dim() == 3 && q_fp8.size(0) == q.size(0) &&
|
||||
q_fp8.size(1) == q.size(1) && q_fp8.size(2) == q.size(2),
|
||||
"q_fp8 must be a contiguous float8_e4m3fn tensor matching q");
|
||||
STD_TORCH_CHECK(k_cache.device().is_cuda(), "k_cache must be CUDA");
|
||||
STD_TORCH_CHECK(slot_mapping.device().is_cuda() &&
|
||||
slot_mapping.scalar_type() == ScalarType::Long,
|
||||
"slot_mapping must be int64 CUDA");
|
||||
STD_TORCH_CHECK(position_ids.device().is_cuda() &&
|
||||
position_ids.scalar_type() == ScalarType::Long,
|
||||
"position_ids must be int64 CUDA");
|
||||
STD_TORCH_CHECK(cos_sin_cache.device().is_cuda() &&
|
||||
cos_sin_cache.scalar_type() == ScalarType::Float &&
|
||||
cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64,
|
||||
"cos_sin_cache shape [max_pos, 64] float32");
|
||||
STD_TORCH_CHECK(fp8_scale.device().is_cuda() &&
|
||||
fp8_scale.scalar_type() == ScalarType::Float &&
|
||||
fp8_scale.size(0) == 1,
|
||||
"fp8_scale must be a scalar float32 CUDA tensor");
|
||||
STD_TORCH_CHECK(q_fp8_scale_inv.device().is_cuda() &&
|
||||
q_fp8_scale_inv.scalar_type() == ScalarType::Float &&
|
||||
q_fp8_scale_inv.size(0) == 1,
|
||||
"q_fp8_scale_inv must be a scalar float32 CUDA tensor");
|
||||
STD_TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]");
|
||||
STD_TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]");
|
||||
STD_TORCH_CHECK(q.scalar_type() == kv.scalar_type(),
|
||||
"q and kv dtype must match");
|
||||
STD_TORCH_CHECK(k_cache.dim() == 3 && k_cache.size(1) == cache_block_size &&
|
||||
k_cache.size(2) == 512 && k_cache.stride(2) == 1,
|
||||
"k_cache shape [num_blocks, cache_block_size, 512] contiguous");
|
||||
STD_TORCH_CHECK(k_cache.scalar_type() == ScalarType::Float8_e4m3fn,
|
||||
"k_cache must be float8_e4m3fn");
|
||||
|
||||
int const num_tokens_full = static_cast<int>(q.size(0));
|
||||
int const num_tokens_insert = static_cast<int>(slot_mapping.size(0));
|
||||
STD_TORCH_CHECK(static_cast<int>(kv.size(0)) == num_tokens_full &&
|
||||
static_cast<int>(position_ids.size(0)) == num_tokens_full,
|
||||
"q/kv/position_ids row counts must match");
|
||||
STD_TORCH_CHECK(num_tokens_insert <= num_tokens_full,
|
||||
"slot_mapping must not exceed q row count");
|
||||
int const num_heads_q = static_cast<int>(q.size(1));
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
q.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(q.get_device_index());
|
||||
|
||||
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||
q.scalar_type(),
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert", [&] {
|
||||
vllm::deepseek_v4_fused_ops::launchFullCacheKernel<scalar_t, true,
|
||||
true>(
|
||||
// q is read-only in the fp8 path (the kernel writes q_fp8); the
|
||||
// launcher signature is non-const, so cast away const on the ptr.
|
||||
reinterpret_cast<scalar_t*>(
|
||||
const_cast<void*>(q.const_data_ptr())),
|
||||
reinterpret_cast<uint8_t*>(q_fp8.mutable_data_ptr()),
|
||||
q_fp8.stride(0), q_fp8.stride(1),
|
||||
reinterpret_cast<scalar_t const*>(kv.const_data_ptr()),
|
||||
reinterpret_cast<uint8_t*>(k_cache.mutable_data_ptr()),
|
||||
slot_mapping.const_data_ptr<int64_t>(),
|
||||
position_ids.const_data_ptr<int64_t>(),
|
||||
cos_sin_cache.const_data_ptr<float>(),
|
||||
fp8_scale.const_data_ptr<float>(),
|
||||
q_fp8_scale_inv.const_data_ptr<float>(), static_cast<float>(eps),
|
||||
num_tokens_full, num_tokens_insert, num_heads_q,
|
||||
static_cast<int>(cache_block_size),
|
||||
// fp8 cache: 1 byte/element -> stride already in bytes.
|
||||
k_cache.stride(0), k_cache.stride(1),
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -238,6 +238,23 @@ torch::stable::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
|
||||
torch::stable::Tensor const& cos_sin_cache, int64_t q_head_padded,
|
||||
double eps, int64_t cache_block_size);
|
||||
|
||||
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
||||
torch::stable::Tensor& q, torch::stable::Tensor const& kv,
|
||||
torch::stable::Tensor& k_cache, torch::stable::Tensor const& slot_mapping,
|
||||
torch::stable::Tensor const& position_ids,
|
||||
torch::stable::Tensor const& cos_sin_cache, double eps,
|
||||
int64_t cache_block_size);
|
||||
|
||||
void fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
||||
torch::stable::Tensor const& q, torch::stable::Tensor const& kv,
|
||||
torch::stable::Tensor& q_fp8, torch::stable::Tensor& k_cache,
|
||||
torch::stable::Tensor const& slot_mapping,
|
||||
torch::stable::Tensor const& position_ids,
|
||||
torch::stable::Tensor const& cos_sin_cache,
|
||||
torch::stable::Tensor const& fp8_scale,
|
||||
torch::stable::Tensor const& q_fp8_scale_inv, double eps,
|
||||
int64_t cache_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::stable::Tensor minimax_allreduce_rms(
|
||||
torch::stable::Tensor const& input,
|
||||
|
||||
@@ -343,6 +343,20 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"int q_head_padded, float eps, int cache_block_size) -> Tensor");
|
||||
|
||||
// FlashInfer V4 full-cache variants: write Q in place (bf16) or to a separate
|
||||
// FP8 tensor, and KV into a contiguous 512-wide token-strided cache.
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert("
|
||||
"Tensor! q, Tensor kv, Tensor! k_cache, Tensor slot_mapping, "
|
||||
"Tensor position_ids, Tensor cos_sin_cache, float eps, "
|
||||
"int cache_block_size) -> ()");
|
||||
ops.def(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert("
|
||||
"Tensor q, Tensor kv, Tensor! q_fp8, Tensor! k_cache, "
|
||||
"Tensor slot_mapping, Tensor position_ids, Tensor cos_sin_cache, "
|
||||
"Tensor fp8_scale, Tensor q_fp8_scale_inv, float eps, "
|
||||
"int cache_block_size) -> ()");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"minimax_allreduce_rms("
|
||||
@@ -591,6 +605,12 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("fused_qk_norm_rope", TORCH_BOX(&fused_qk_norm_rope));
|
||||
ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
|
||||
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert));
|
||||
ops.impl(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert",
|
||||
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert));
|
||||
ops.impl(
|
||||
"fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert",
|
||||
TORCH_BOX(&fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert));
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("minimax_allreduce_rms", TORCH_BOX(&minimax_allreduce_rms));
|
||||
ops.impl("minimax_allreduce_rms_qk", TORCH_BOX(&minimax_allreduce_rms_qk));
|
||||
|
||||
@@ -55,7 +55,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||
// kernel launch. Registered in _C_stable_libtorch.
|
||||
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
|
||||
// full-cache bf16/fp8 variants).
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
|
||||
@@ -98,7 +98,6 @@ RUN if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
ARG USE_SCCACHE
|
||||
ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET_NAME}}
|
||||
ENV SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION_NAME}}
|
||||
ENV SCCACHE_ENDPOINT=${USE_SCCACHE:+${SCCACHE_ENDPOINT}}
|
||||
ENV SCCACHE_S3_NO_CREDENTIALS=${USE_SCCACHE:+${SCCACHE_S3_NO_CREDENTIALS}}
|
||||
ENV SCCACHE_IDLE_TIMEOUT=${USE_SCCACHE:+0}
|
||||
|
||||
|
||||
@@ -228,3 +228,17 @@ MLA decode backends are selected using the standard
|
||||
| `TOKENSPEED_MLA` | fp16, bf16 | `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
|
||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `float16`, `bfloat16` | Any | 576 | ❌ | ❌ | ✅ | ❌ | ❌ | Decoder | Any |
|
||||
|
||||
### DeepSeek V4 Decode Backends
|
||||
|
||||
DeepSeek V4 sparse MLA uses its own decode backends, selected via
|
||||
`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,
|
||||
`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index
|
||||
pipeline (compressor + SWA + indexer, 256-token blocks, head 512);
|
||||
default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.
|
||||
|
||||
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Non-Causal | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
|
||||
| ------- | ------ | --------- | ----------- | ---------- | ---- | ---------- | ------ | --------- | --- | --------------- | ------------ |
|
||||
| `FLASHINFER_MLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
| `FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | 256 | 512 | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
| `ROCM_FLASHMLA_SPARSE_DSV4` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
|
||||
@@ -82,6 +82,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
|
||||
|
||||
| Architecture | Models | CG for Image | CG for Video |
|
||||
| ------------ | ------ | ------------ | ------------ |
|
||||
| `InternVLChatModel` | `InternVL3.5`, `InternVL3`, `InternVL2.5`, `InternVL2` | ✅︎ | ✅︎ |
|
||||
| `Qwen2VLForConditionalGeneration` | `Qwen2-VL` | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
|
||||
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices.
|
||||
|
||||
!!! tip
|
||||
To get started with quantization, see [LLM Compressor](llm_compressor.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
|
||||
To get started with quantization, see [LLM Compressor](llm_compressor/README.md), a library for optimizing models for deployment with vLLM that supports FP8, INT8, INT4, and other quantization formats.
|
||||
|
||||
The following are the supported quantization formats for vLLM:
|
||||
|
||||
@@ -12,9 +12,11 @@ The following are the supported quantization formats for vLLM:
|
||||
- [GGUF](gguf.md)
|
||||
- [GPTQModel](gptqmodel.md)
|
||||
- [Intel Neural Compressor](inc.md)
|
||||
- [INT4 W4A16](int4.md)
|
||||
- [INT8 W8A8](int8.md)
|
||||
- [FP8 W8A8](fp8.md)
|
||||
- [LLM Compressor](llm_compressor/README.md)
|
||||
- [FP8 W8A8](llm_compressor/fp8.md)
|
||||
- [INT4 W4A16](llm_compressor/int4.md)
|
||||
- [INT8 W4A8](llm_compressor/int8_w4a8.md)
|
||||
- [INT8 W8A8](llm_compressor/int8_w8a8.md)
|
||||
- [NVIDIA Model Optimizer](modelopt.md)
|
||||
- [Online Quantization](online.md)
|
||||
- [AMD Quark](quark.md)
|
||||
@@ -46,16 +48,17 @@ th:not(:first-child) {
|
||||
}
|
||||
</style>
|
||||
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU |
|
||||
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- |
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ |
|
||||
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | Arm CPU |
|
||||
| ------------------------- | ----- | ------ | ------ | --- | ------ | ------- | --------- | ------- | ------- |
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ |
|
||||
| Marlin (GPTQ/AWQ/FP8/FP4) | ❌ | ✅︎* | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||
| llm-compressor INT8 (W8A8)| ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ |
|
||||
| llm-compressor INT8 (W4A8)| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ |
|
||||
| llm-compressor FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
|
||||
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
||||
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
||||
|
||||
+9
-9
@@ -21,9 +21,17 @@ The FP8 types typically supported in hardware have two distinct representations,
|
||||
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||
|
||||
```bash
|
||||
pip install llmcompressor
|
||||
(venv-llm-compressor) pip install llmcompressor
|
||||
```
|
||||
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||
```
|
||||
|
||||
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||
|
||||
## Quantization Process
|
||||
|
||||
The quantization process involves three main steps:
|
||||
@@ -57,8 +65,6 @@ For FP8 quantization, we can recover accuracy with simple RTN quantization. We r
|
||||
|
||||
Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
@@ -81,12 +87,6 @@ Since simple RTN does not require data for weight quantization and the activatio
|
||||
|
||||
### 3. Evaluating Accuracy
|
||||
|
||||
Install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm "lm-eval[api]>=0.4.12"
|
||||
```
|
||||
|
||||
Load and run the model in `vllm`:
|
||||
|
||||
```python
|
||||
+4
-8
@@ -12,15 +12,17 @@ Please visit the HF collection of [quantized INT4 checkpoints of popular LLMs re
|
||||
To use INT4 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||
|
||||
```bash
|
||||
pip install llmcompressor
|
||||
(venv-llm-compressor) pip install llmcompressor
|
||||
```
|
||||
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm "lm-eval[api]>=0.4.12"
|
||||
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||
```
|
||||
|
||||
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||
|
||||
## Quantization Process
|
||||
|
||||
The quantization process involves four main steps:
|
||||
@@ -52,8 +54,6 @@ When quantizing weights to INT4, you need sample data to estimate the weight upd
|
||||
It's best to use calibration data that closely matches your deployment data.
|
||||
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -77,8 +77,6 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra
|
||||
|
||||
Now, apply the quantization algorithms:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
@@ -141,8 +139,6 @@ lm_eval --model vllm \
|
||||
|
||||
The following is an example of an expanded quantization recipe you can tune to your own use case:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
@@ -0,0 +1,217 @@
|
||||
# INT8 W4A8
|
||||
|
||||
vLLM supports quantizing weights to INT4 and activations to INT8 for memory savings and inference acceleration.
|
||||
This quantization method is particularly useful for reducing model size while maintaining good performance.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
To use INT8 W4A8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library.
|
||||
|
||||
```bash
|
||||
(venv-llm-compressor) pip install llmcompressor
|
||||
```
|
||||
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||
```
|
||||
|
||||
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||
|
||||
## Quantization Process
|
||||
|
||||
The quantization process involves four main steps:
|
||||
|
||||
1. Loading the model
|
||||
2. Preparing calibration data
|
||||
3. Applying quantization
|
||||
4. Evaluating accuracy in vLLM
|
||||
|
||||
### 1. Loading the Model
|
||||
|
||||
Load your model and tokenizer using the standard `transformers` AutoModel classes:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
```
|
||||
|
||||
### 2. Preparing Calibration Data
|
||||
|
||||
When quantizing activations to INT8 and weights to INT4, you need sample data to estimate the activation scales.
|
||||
It's best to use calibration data that closely matches your deployment data.
|
||||
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
NUM_CALIBRATION_SAMPLES = 512
|
||||
MAX_SEQUENCE_LENGTH = 2048
|
||||
|
||||
# Load and preprocess the dataset
|
||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
|
||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||
|
||||
def preprocess(example):
|
||||
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||
ds = ds.map(preprocess)
|
||||
|
||||
def tokenize(sample):
|
||||
return tokenizer(sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False)
|
||||
ds = ds.map(tokenize, remove_columns=ds.column_names)
|
||||
```
|
||||
|
||||
### 3. Applying Quantization
|
||||
|
||||
Now, apply the quantization algorithms.
|
||||
|
||||
The following recipes create W4A8 models (int4 weights, int8 activations). On Arm® CPUs, this is accelerated through [KleidiAI](https://github.com/ARM-software/kleidiai).
|
||||
|
||||
Use groupwise for best accuracy, and channelwise for best inference performance.
|
||||
|
||||
=== "Groupwise"
|
||||
|
||||
```python
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
|
||||
# Configure the quantization algorithms
|
||||
recipe = [
|
||||
GPTQModifier(
|
||||
targets="Linear",
|
||||
scheme="W4A8",
|
||||
ignore=["lm_head"],
|
||||
dampening_frac=0.01
|
||||
),
|
||||
]
|
||||
|
||||
# Apply quantization
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-G128-Dynamic-Per-Token"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
```
|
||||
|
||||
=== "Channelwise"
|
||||
|
||||
```python
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
|
||||
|
||||
scheme = {
|
||||
"targets": ["Linear"],
|
||||
"weights": {
|
||||
"num_bits": 4,
|
||||
"type": QuantizationType.INT,
|
||||
"strategy": QuantizationStrategy.CHANNEL,
|
||||
"symmetric": True,
|
||||
"dynamic": False,
|
||||
"group_size": None,
|
||||
},
|
||||
"input_activations": {
|
||||
"num_bits": 8,
|
||||
"type": QuantizationType.INT,
|
||||
"strategy": QuantizationStrategy.TOKEN,
|
||||
"dynamic": True,
|
||||
"symmetric": False,
|
||||
"observer": None,
|
||||
},
|
||||
"output_activations": None,
|
||||
}
|
||||
|
||||
recipe = [
|
||||
GPTQModifier(
|
||||
targets="Linear",
|
||||
config_groups={"group_0": scheme},
|
||||
ignore=["lm_head"],
|
||||
dampening_frac=0.01,
|
||||
),
|
||||
]
|
||||
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A8-Channelwise-Dynamic-Per-Token"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
```
|
||||
|
||||
### 4. Evaluating Accuracy
|
||||
|
||||
=== "Groupwise"
|
||||
|
||||
After quantization, you can load and run the model in vLLM:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token")
|
||||
```
|
||||
|
||||
To evaluate accuracy, you can use `lm_eval`:
|
||||
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-G128-Dynamic-Per-Token",add_bos_token=true \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5 \
|
||||
--limit 250 \
|
||||
--batch_size 'auto'
|
||||
```
|
||||
|
||||
=== "Channelwise"
|
||||
|
||||
After quantization, you can load and run the model in vLLM:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM("./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token")
|
||||
```
|
||||
|
||||
To evaluate accuracy, you can use `lm_eval`:
|
||||
|
||||
```bash
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A8-Channelwise-Dynamic-Per-Token",add_bos_token=true \
|
||||
--tasks gsm8k \
|
||||
--num_fewshot 5 \
|
||||
--limit 250 \
|
||||
--batch_size 'auto'
|
||||
```
|
||||
|
||||
!!! note
|
||||
Quantized models can be sensitive to the presence of the `bos` token. Make sure to include the `add_bos_token=True` argument when running evaluations.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Start with 512 samples for calibration data (increase if accuracy drops)
|
||||
- Use a sequence length of 2048 as a starting point
|
||||
- Employ the chat template or instruction template that the model was trained with
|
||||
- If you've fine-tuned a model, consider using a sample of your training data for calibration
|
||||
|
||||
## Troubleshooting and Support
|
||||
|
||||
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository.
|
||||
+4
-6
@@ -17,15 +17,17 @@ Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs re
|
||||
To use INT8 quantization with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||
|
||||
```bash
|
||||
pip install llmcompressor
|
||||
(venv-llm-compressor) pip install llmcompressor
|
||||
```
|
||||
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm "lm-eval[api]>=0.4.12"
|
||||
(venv-vllm) pip install vllm "lm-eval[api]>=0.4.12"
|
||||
```
|
||||
|
||||
Please use separate environments for vLLM and llm-compressor as they might not work together.
|
||||
|
||||
## Quantization Process
|
||||
|
||||
The quantization process involves four main steps:
|
||||
@@ -57,8 +59,6 @@ When quantizing activations to INT8, you need sample data to estimate the activa
|
||||
It's best to use calibration data that closely matches your deployment data.
|
||||
For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -84,8 +84,6 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra
|
||||
|
||||
Now, apply the quantization algorithms:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||
@@ -569,6 +569,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `GlmOcrForConditionalGeneration` | GLM-OCR | T + I<sup>E+</sup> | `zai-org/GLM-OCR`, etc. | ✅︎ | ✅︎ |
|
||||
| `Granite4VisionForConditionalGeneration` | Granite 4 Vision | T + I<sup>E+</sup> | `ibm-granite/granite-4.1-3b-vision`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | T + A | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
|
||||
| `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | |
|
||||
| `HCXVisionV2ForCausalLM` | HyperCLOVAX-SEED-Think-32B | T + I<sup>+</sup> + V<sup>+</sup> | `naver-hyperclovax/HyperCLOVAX-SEED-Think-32B` | | |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | ✅︎ | ✅︎ |
|
||||
@@ -709,6 +710,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
|
||||
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-4.0-1b-speech`, `ibm-granite/granite-speech-3.3-2b`, etc. | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechPlusForConditionalGeneration` | Granite Speech Plus | `ibm-granite/granite-speech-4.1-2b-plus` | ✅︎ | ✅︎ |
|
||||
| `Qwen3ASRForConditionalGeneration` | Qwen3-ASR | `Qwen/Qwen3-ASR-1.7B`, etc. | ✅︎ | ✅︎ |
|
||||
| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, etc. | | ✅︎ |
|
||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@@ -24,8 +24,14 @@ echo "Checking pre-commit/pre-run-check status..."
|
||||
MAX_WAIT=300
|
||||
INTERVAL=60
|
||||
ELAPSED=0
|
||||
# Use a GitHub token if provided to raise the API rate limit (60 -> 5000
|
||||
# requests/hour). Set GITHUB_TOKEN in the Read the Docs environment variables.
|
||||
CURL_AUTH=()
|
||||
if [ -n "$GITHUB_TOKEN" ]; then
|
||||
CURL_AUTH=(-H "Authorization: Bearer $GITHUB_TOKEN")
|
||||
fi
|
||||
while :; do
|
||||
RAW=$(curl -sS -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
|
||||
RAW=$(curl -sS "${CURL_AUTH[@]}" -w "\n%{http_code}" "https://api.github.com/repos/vllm-project/vllm/commits/${READTHEDOCS_GIT_COMMIT_HASH}/check-runs?check_name=pre-run-check&filter=latest")
|
||||
HTTP_CODE=$(printf %s "$RAW" | tail -n1)
|
||||
BODY=$(printf %s "$RAW" | sed '$d')
|
||||
if [ "$HTTP_CODE" != "200" ]; then
|
||||
|
||||
@@ -2554,6 +2554,7 @@ MODELS_NEED_VIDEO_METADATA = [
|
||||
|
||||
|
||||
MODELS_SUPPORT_VIT_CUDA_GRAPH = [
|
||||
"internvl_chat",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
|
||||
@@ -110,6 +110,9 @@ plugins:
|
||||
redirect_maps:
|
||||
features/spec_decode/README.md: features/speculative_decoding/README.md
|
||||
features/spec_decode/speculators.md: features/speculative_decoding/speculators.md
|
||||
features/quantization/fp8.md: features/quantization/llm_compressor/fp8.md
|
||||
features/quantization/int4.md: features/quantization/llm_compressor/int4.md
|
||||
features/quantization/int8.md: features/quantization/llm_compressor/int8_w8a8.md
|
||||
serving/openai_compatible_server.md: serving/online_serving/README.md
|
||||
|
||||
markdown_extensions:
|
||||
|
||||
@@ -38,7 +38,7 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=77.0.3,<81.0.0; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.15.0.1 # required for compressed-tensors
|
||||
compressed-tensors == 0.17.0 # required for compressed-tensors
|
||||
depyf==0.20.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
|
||||
@@ -18,7 +18,7 @@ tilelang==0.1.9
|
||||
nvidia-cudnn-frontend>=1.13.0,<1.19.0
|
||||
|
||||
# Required for faster safetensors model loading
|
||||
fastsafetensors >= 0.2.2
|
||||
fastsafetensors >= 0.3.2
|
||||
|
||||
# QuACK and Cutlass DSL for FA4 (cute-DSL implementation)
|
||||
nvidia-cutlass-dsl[cu13]==4.5.2
|
||||
@@ -28,4 +28,4 @@ quack-kernels>=0.3.3
|
||||
tokenspeed-mla==0.1.2
|
||||
|
||||
# Humming kernels for quantization gemm
|
||||
humming-kernels[cu13]==0.1.2
|
||||
humming-kernels[cu13]==0.1.4
|
||||
|
||||
@@ -23,3 +23,6 @@ timm>=1.0.17
|
||||
# To be consistent with test_quark.py
|
||||
amd-quark>=0.8.99
|
||||
tilelang==0.1.10
|
||||
|
||||
# Required for faster safetensors model loading
|
||||
fastsafetensors >= 0.3.2
|
||||
|
||||
@@ -57,7 +57,7 @@ arctic-inference == 0.1.1; platform_machine == "x86_64" # Required for suffix de
|
||||
numba == 0.65.0 # Required for N-gram speculative decoding
|
||||
numpy
|
||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||
fastsafetensors>=0.2.2; platform_machine == "x86_64" # 0.2.2 contains important fixes for multi-GPU mem usage
|
||||
fastsafetensors>=0.3.2
|
||||
instanttensor>=0.1.5; platform_machine == "x86_64"
|
||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||
decord==0.6.0; platform_machine == "x86_64"
|
||||
|
||||
@@ -191,7 +191,7 @@ fastparquet==2024.11.0
|
||||
# via genai-perf
|
||||
fastrlock==0.8.2
|
||||
# via cupy-cuda12x
|
||||
fastsafetensors==0.2.2
|
||||
fastsafetensors==0.3.2
|
||||
# via
|
||||
# -c requirements/cuda.txt
|
||||
# -r requirements/test/cuda.in
|
||||
|
||||
@@ -43,6 +43,6 @@ tritonclient>=2.51.0
|
||||
numba == 0.65.0 # Required for N-gram speculative decoding
|
||||
numpy
|
||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||
fastsafetensors>=0.2.2
|
||||
fastsafetensors>=0.3.2
|
||||
instanttensor>=0.1.5
|
||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||
|
||||
@@ -56,7 +56,7 @@ arctic-inference==0.1.1 # Required for suffix decoding test
|
||||
numba==0.65.0 # Required for N-gram speculative decoding
|
||||
numpy
|
||||
runai-model-streamer[s3,gcs,azure]==0.15.7
|
||||
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@0.2.2 # PyPI only ships CUDA wheels
|
||||
fastsafetensors>=0.3.2
|
||||
instanttensor>=0.1.5
|
||||
pydantic>=2.12 # 2.11 leads to error on python 3.13
|
||||
decord==0.6.0
|
||||
|
||||
@@ -143,7 +143,7 @@ colorful==0.5.8
|
||||
# via ray
|
||||
colorlog==6.10.1
|
||||
# via optuna
|
||||
compressed-tensors==0.15.0.1
|
||||
compressed-tensors==0.17.0
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/../common.txt
|
||||
@@ -240,8 +240,10 @@ fastar==0.10.0
|
||||
# via fastapi-cloud-cli
|
||||
fastparquet==2026.3.0
|
||||
# via genai-perf
|
||||
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@65d80088fca7a8f567fba30415fbcc80f7d2259c
|
||||
# via -r requirements/test/rocm.in
|
||||
fastsafetensors==0.3.2
|
||||
# via
|
||||
# -c requirements/rocm.txt
|
||||
# -r requirements/test/rocm.in
|
||||
filelock==3.25.2
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
|
||||
@@ -1168,7 +1168,7 @@ setup(
|
||||
"zen": ["zentorch==2.11.0.0"],
|
||||
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.2.2"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.3.2"],
|
||||
"instanttensor": ["instanttensor >= 0.1.5"],
|
||||
"runai": ["runai-model-streamer[s3,gcs,azure] >= 0.15.7"],
|
||||
"audio": [
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.compilation.passes.fusion.allreduce_rms_fusion import (
|
||||
AllReduceFusionPass,
|
||||
RocmAiterAllReduceFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.utility.fix_functionalization import (
|
||||
FixFunctionalizationPass,
|
||||
)
|
||||
@@ -33,7 +34,7 @@ from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
@@ -91,6 +92,49 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceGemmaRMSNormModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=16,
|
||||
token_num=16,
|
||||
eps=1e-6,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = [GemmaRMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
# Non-trivial weight (~Gemma range) so (1 + w) exercises the scale path.
|
||||
for n in self.norm:
|
||||
n.weight.data.normal_(mean=0.0, std=0.1)
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
@@ -209,6 +253,15 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
"test_model, enable_quant_fp8_custom_op, use_aiter",
|
||||
[
|
||||
(TestAllReduceRMSNormModel, False, IS_AITER_FOUND),
|
||||
pytest.param(
|
||||
TestAllReduceGemmaRMSNormModel,
|
||||
False,
|
||||
False,
|
||||
marks=pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Not supported on ROCm platform",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
TestAllReduceRMSNormStaticQuantFP8Model,
|
||||
True,
|
||||
@@ -404,4 +457,9 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
if test_model_cls is TestAllReduceGemmaRMSNormModel:
|
||||
fused_op = torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
fused_nodes = list(find_op_nodes(fused_op, backend.graph_post_pass))
|
||||
assert fused_nodes
|
||||
assert all(n.kwargs.get("weight_bias") == 1.0 for n in fused_nodes)
|
||||
del all_reduce_fusion_pass
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Inductor FALLBACK_ALLOW_LIST patch in env_override.py.
|
||||
|
||||
The patch wraps ``torch._inductor.lowering.FALLBACK_ALLOW_LIST`` in a thin
|
||||
proxy that auto-allows any custom op in the ``vllm::`` or ``vllm_aiter::``
|
||||
namespaces. This routes those ops through Inductor's fast-path
|
||||
``make_fallback(target, warn=False, override_decomp=True)`` and avoids the
|
||||
expensive ``error.operator_str(target, args, kwargs)`` formatting that
|
||||
recursively stringifies every input ``TensorBox``.
|
||||
|
||||
The slow path is what made ``torch.compile`` effectively hang on Kimi-K2.6
|
||||
TP=8 (deep MoE/TP IR provenance trees). These tests cover both the proxy's
|
||||
semantics in isolation and the membership-check fast-path that Inductor's
|
||||
``GraphLowering.call_function`` actually performs, so we can validate the
|
||||
optimization without needing a full GPU compile.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.env_override import (
|
||||
_patch_inductor_fallback_allow_list,
|
||||
_VllmFallbackAllowList,
|
||||
)
|
||||
|
||||
|
||||
class TestVllmFallbackAllowListProxy:
|
||||
"""Unit tests for the membership-proxy semantics."""
|
||||
|
||||
def test_vllm_namespace_auto_allowed(self):
|
||||
proxy = _VllmFallbackAllowList(set())
|
||||
assert "vllm::all_reduce" in proxy
|
||||
assert "vllm::fused_add_rms_norm" in proxy
|
||||
assert "vllm::all_reduce.default" in proxy
|
||||
|
||||
def test_vllm_aiter_namespace_auto_allowed(self):
|
||||
proxy = _VllmFallbackAllowList(set())
|
||||
assert "vllm_aiter::fused_add_rms_norm" in proxy
|
||||
assert "vllm_aiter::rocm_aiter_fused_moe" in proxy
|
||||
|
||||
def test_unknown_namespace_falls_through(self):
|
||||
proxy = _VllmFallbackAllowList({"torchvision::roi_align"})
|
||||
assert "torchvision::roi_align" in proxy
|
||||
assert "made_up_ns::nonexistent_op" not in proxy
|
||||
|
||||
def test_non_string_falls_through_to_inner(self):
|
||||
sentinel = object()
|
||||
inner = {sentinel}
|
||||
proxy = _VllmFallbackAllowList(inner)
|
||||
assert sentinel in proxy
|
||||
assert object() not in proxy
|
||||
|
||||
def test_prefix_only_match_not_substring(self):
|
||||
proxy = _VllmFallbackAllowList(set())
|
||||
assert "not_vllm::something" not in proxy
|
||||
assert " vllm::space_prefixed" not in proxy
|
||||
|
||||
def test_standard_entries_preserved(self):
|
||||
base = {"torchvision::roi_align", "aten::index_add"}
|
||||
proxy = _VllmFallbackAllowList(base)
|
||||
assert "torchvision::roi_align" in proxy
|
||||
assert "aten::index_add" in proxy
|
||||
assert "aten::__not_present__" not in proxy
|
||||
|
||||
def test_add_and_discard_delegate_to_inner(self):
|
||||
inner: set[str] = set()
|
||||
proxy = _VllmFallbackAllowList(inner)
|
||||
proxy.add("custom::op")
|
||||
assert "custom::op" in inner
|
||||
proxy.discard("custom::op")
|
||||
assert "custom::op" not in inner
|
||||
|
||||
def test_iter_len_repr(self):
|
||||
base = {"torchvision::roi_align", "aten::index_add"}
|
||||
proxy = _VllmFallbackAllowList(base)
|
||||
assert set(iter(proxy)) == base
|
||||
assert len(proxy) == len(base)
|
||||
assert "torchvision::roi_align" in repr(proxy)
|
||||
|
||||
def test_getattr_delegates_to_inner(self):
|
||||
class _Inner:
|
||||
sentinel = "i_am_inner"
|
||||
|
||||
def some_method(self):
|
||||
return 42
|
||||
|
||||
inner = _Inner()
|
||||
proxy = _VllmFallbackAllowList(inner)
|
||||
assert proxy.sentinel == "i_am_inner"
|
||||
assert proxy.some_method() == 42
|
||||
|
||||
def test_sentinel_attribute(self):
|
||||
proxy = _VllmFallbackAllowList(set())
|
||||
assert proxy._vllm_patched is True
|
||||
|
||||
|
||||
class TestPatchApplication:
|
||||
"""Integration tests verifying the patch reaches ``torch._inductor``."""
|
||||
|
||||
def test_patch_applied_to_lowering(self):
|
||||
import torch._inductor.lowering as _lowering
|
||||
|
||||
assert getattr(_lowering.FALLBACK_ALLOW_LIST, "_vllm_patched", False), (
|
||||
"env_override._patch_inductor_fallback_allow_list did not run"
|
||||
)
|
||||
|
||||
def test_graph_module_local_binding_rebound(self):
|
||||
# ``torch/_inductor/graph.py`` does:
|
||||
# from torch._inductor.lowering import FALLBACK_ALLOW_LIST
|
||||
# so the patch has to overwrite the graph module's local binding too,
|
||||
# otherwise the fast-path check in GraphLowering.call_function still
|
||||
# sees the original (unwrapped) OrderedSet.
|
||||
import torch._inductor.graph as _graph
|
||||
import torch._inductor.lowering as _lowering
|
||||
|
||||
if not hasattr(_graph, "FALLBACK_ALLOW_LIST"):
|
||||
pytest.skip(
|
||||
"torch._inductor.graph no longer imports FALLBACK_ALLOW_LIST "
|
||||
"as a module-level symbol; nothing to rebind."
|
||||
)
|
||||
|
||||
assert _graph.FALLBACK_ALLOW_LIST is _lowering.FALLBACK_ALLOW_LIST
|
||||
|
||||
def test_patch_is_idempotent(self):
|
||||
import torch._inductor.lowering as _lowering
|
||||
|
||||
first = _lowering.FALLBACK_ALLOW_LIST
|
||||
_patch_inductor_fallback_allow_list()
|
||||
_patch_inductor_fallback_allow_list()
|
||||
assert _lowering.FALLBACK_ALLOW_LIST is first
|
||||
|
||||
def test_real_vllm_ops_in_real_allow_list(self):
|
||||
# End-to-end membership check using the live (already-patched) object.
|
||||
import torch._inductor.lowering as _lowering
|
||||
|
||||
allow_list = _lowering.FALLBACK_ALLOW_LIST
|
||||
assert "vllm::all_reduce" in allow_list
|
||||
assert "vllm::fused_add_rms_norm" in allow_list
|
||||
assert "vllm_aiter::fused_add_rms_norm" in allow_list
|
||||
|
||||
|
||||
class TestInductorFallbackFastPath:
|
||||
"""Emulates ``GraphLowering.call_function``'s FALLBACK_ALLOW_LIST check.
|
||||
|
||||
The relevant snippet in ``torch/_inductor/graph.py`` is roughly::
|
||||
|
||||
base_name = target.name()
|
||||
if base_name not in FALLBACK_ALLOW_LIST:
|
||||
log.info(
|
||||
"Creating implicit fallback for:\\n%s",
|
||||
error.operator_str(target, args, kwargs),
|
||||
)
|
||||
out = make_fallback(target, ...)
|
||||
|
||||
On a deep MoE/TP graph (Kimi-K2.6 at TP=4/8) ``operator_str`` recurses
|
||||
through every input ``TensorBox.__str__`` and ends up taking many minutes
|
||||
of CPU per encountered op. The patch ensures the membership test
|
||||
short-circuits for ``vllm::*``/``vllm_aiter::*`` ops so the slow path is
|
||||
never entered. These tests pin that behaviour without needing a real
|
||||
GPU compile.
|
||||
"""
|
||||
|
||||
def _simulate_graph_lowering(self, target_names: list[str]):
|
||||
"""Returns the set of target names that would have hit the slow
|
||||
operator_str() path under the patched FALLBACK_ALLOW_LIST.
|
||||
"""
|
||||
import torch._inductor.lowering as _lowering
|
||||
|
||||
allow_list = _lowering.FALLBACK_ALLOW_LIST
|
||||
slow_path_hits: list[str] = []
|
||||
for name in target_names:
|
||||
if name not in allow_list:
|
||||
slow_path_hits.append(name)
|
||||
return slow_path_hits
|
||||
|
||||
def test_vllm_ops_skip_slow_path(self):
|
||||
slow = self._simulate_graph_lowering(
|
||||
[
|
||||
"vllm::all_reduce",
|
||||
"vllm::fused_add_rms_norm",
|
||||
"vllm_aiter::rocm_aiter_fused_moe",
|
||||
"vllm_aiter::asm_moe",
|
||||
]
|
||||
)
|
||||
assert slow == [], (
|
||||
"Patched FALLBACK_ALLOW_LIST must short-circuit for all "
|
||||
f"vllm::*/vllm_aiter::* ops; got slow-path hits: {slow}"
|
||||
)
|
||||
|
||||
def test_non_vllm_ops_still_hit_slow_path(self):
|
||||
# Without the patch this is also what would happen; with the patch
|
||||
# the behaviour for non-vllm namespaces must be unchanged.
|
||||
slow = self._simulate_graph_lowering(
|
||||
["my_user_ns::custom_op", "fancy_ns::something_else"]
|
||||
)
|
||||
assert "my_user_ns::custom_op" in slow
|
||||
assert "fancy_ns::something_else" in slow
|
||||
|
||||
def test_kimi_k2_6_style_op_stream(self):
|
||||
"""Emulates one decoder layer's worth of fallback hits.
|
||||
|
||||
Kimi-K2.6 at TP=4 lowers a stream of ``vllm::all_reduce`` +
|
||||
``vllm_aiter::fused_add_rms_norm`` calls (one per residual block)
|
||||
plus a handful of fused-MoE ops. Pre-patch every one of these would
|
||||
invoke ``operator_str`` and stringify a hundreds-deep IR provenance
|
||||
tree; post-patch they must all short-circuit.
|
||||
"""
|
||||
n_layers = 64 # Kimi-K2.6 has ~64 decoder layers per replica
|
||||
op_stream: list[str] = []
|
||||
for _ in range(n_layers):
|
||||
op_stream.extend(
|
||||
[
|
||||
"vllm::all_reduce",
|
||||
"vllm_aiter::fused_add_rms_norm",
|
||||
"vllm_aiter::rocm_aiter_fused_moe",
|
||||
]
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
slow = self._simulate_graph_lowering(op_stream)
|
||||
elapsed_s = time.perf_counter() - start
|
||||
|
||||
assert slow == [], (
|
||||
f"Expected all {len(op_stream)} vllm/vllm_aiter ops to take "
|
||||
f"the fast path; got {len(slow)} slow-path hits."
|
||||
)
|
||||
# ``__contains__`` is O(1) per call, so a Kimi-sized stream should
|
||||
# complete in well under a second even on a slow runner. The
|
||||
# pre-patch slow path took many minutes per op on Kimi-K2.6 TP=8.
|
||||
assert elapsed_s < 1.0, (
|
||||
f"FALLBACK_ALLOW_LIST membership check is unexpectedly slow: "
|
||||
f"{elapsed_s:.3f}s for {len(op_stream)} ops"
|
||||
)
|
||||
|
||||
def test_inner_set_membership_still_works_for_standard_ops(self):
|
||||
"""The patch must not break Inductor's existing fallback decisions
|
||||
for non-vllm ops such as ``torchvision::roi_align``."""
|
||||
import torch._inductor.lowering as _lowering
|
||||
|
||||
allow_list = _lowering.FALLBACK_ALLOW_LIST
|
||||
# ``torchvision::roi_align`` has been a member of the upstream
|
||||
# FALLBACK_ALLOW_LIST since the original Inductor implementation.
|
||||
# If the proxy ever broke pass-through, this would regress.
|
||||
if "torchvision::roi_align" not in allow_list:
|
||||
pytest.skip(
|
||||
"Upstream FALLBACK_ALLOW_LIST no longer ships "
|
||||
"torchvision::roi_align; nothing to verify."
|
||||
)
|
||||
@@ -277,12 +277,15 @@ def assert_verification_synced(local_ok: bool, msg: str) -> None:
|
||||
assert bool(ok_tensor.item()), msg
|
||||
|
||||
|
||||
def create_eplb_communicator_or_raise(*, group_coordinator, backend, expert_weights):
|
||||
def create_eplb_communicator_or_raise(
|
||||
*, group_coordinator, backend, expert_weights, expert_buffer
|
||||
):
|
||||
try:
|
||||
return create_eplb_communicator(
|
||||
group_coordinator=group_coordinator,
|
||||
backend=backend,
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
@@ -355,7 +358,8 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
communicator.set_stream(cuda_stream)
|
||||
|
||||
@@ -368,6 +372,7 @@ def _test_async_transfer_layer_without_mtp_worker(
|
||||
ep_group=ep_group,
|
||||
communicator=communicator,
|
||||
cuda_stream=cuda_stream,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
cuda_stream.synchronize()
|
||||
move_from_buffer(
|
||||
@@ -460,10 +465,12 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices
|
||||
)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend=eplb_communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Execute weight rearrangement
|
||||
@@ -471,9 +478,9 @@ def _test_rearrange_expert_weights_with_redundancy(
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator=communicator,
|
||||
communicator,
|
||||
)
|
||||
|
||||
# Verify the rearrangement result
|
||||
@@ -593,10 +600,12 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Execute rearrangement (should be no change)
|
||||
@@ -604,9 +613,9 @@ def _test_rearrange_expert_weights_no_change(env, world_size) -> None:
|
||||
indices,
|
||||
indices, # Same indices
|
||||
expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=False,
|
||||
)
|
||||
|
||||
# Verify that the weights have not changed
|
||||
@@ -726,10 +735,12 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
layer_copy.append(weight.clone())
|
||||
original_weights.append(layer_copy)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator_or_raise(
|
||||
group_coordinator=ep_group_coordinator,
|
||||
backend="torch_nccl",
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Execute profile mode rearrangement
|
||||
@@ -737,9 +748,10 @@ def _test_rearrange_expert_weights_profile_mode(env, world_size) -> None:
|
||||
old_indices,
|
||||
new_indices,
|
||||
expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
communicator,
|
||||
is_profile=True, # Profile mode
|
||||
is_profile=True,
|
||||
)
|
||||
|
||||
# In profile mode, the weights should remain unchanged
|
||||
|
||||
@@ -9,9 +9,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_eplb_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -213,12 +215,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend="torch_nccl",
|
||||
expert_weights=rank_expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator,
|
||||
)
|
||||
|
||||
num_local_experts = test_config.num_local_experts
|
||||
|
||||
@@ -10,11 +10,13 @@ import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
|
||||
from vllm.distributed.eplb.eplb_state import EplbLayerState
|
||||
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
get_dp_group,
|
||||
get_eplb_group,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
@@ -171,12 +173,20 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
|
||||
for lidx in range(test_config.num_layers):
|
||||
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in rank_expert_weights[0]]
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend="torch_nccl",
|
||||
expert_weights=rank_expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
rearrange_expert_weights_inplace(
|
||||
indices,
|
||||
shuffled_indices,
|
||||
rank_expert_weights,
|
||||
expert_buffer,
|
||||
ep_group,
|
||||
is_profile=False,
|
||||
communicator,
|
||||
)
|
||||
|
||||
num_global_experts = test_config.num_experts
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
@@ -13,10 +14,13 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.entrypoints.serve.lora.protocol import (
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.exceptions import VLLMNotFoundError
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
@@ -130,3 +134,60 @@ async def test_unload_lora_adapter_not_found():
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "NotFoundError"
|
||||
assert response.error.code == HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
class _ConcretePoolingServing(PoolingServingBase):
|
||||
"""Minimal concrete subclass used only in these unit tests."""
|
||||
|
||||
request_id_prefix = "test"
|
||||
|
||||
def get_io_processor(self, request):
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_response(self, ctx):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _make_pooling_serving(lora_name: str) -> _ConcretePoolingServing:
|
||||
lora_request = LoRARequest(
|
||||
lora_name=lora_name, lora_int_id=1, lora_path="/path/to/lora"
|
||||
)
|
||||
mock_models = MagicMock()
|
||||
mock_models.lora_requests = {lora_name: lora_request}
|
||||
mock_models.is_base_model.side_effect = lambda name: name == MODEL_NAME
|
||||
|
||||
serving = object.__new__(_ConcretePoolingServing)
|
||||
serving.models = mock_models
|
||||
return serving
|
||||
|
||||
|
||||
def _make_pooling_ctx(model_name: str) -> PoolingServeContext:
|
||||
mock_request = MagicMock()
|
||||
mock_request.model = model_name
|
||||
return PoolingServeContext(
|
||||
request=mock_request,
|
||||
model_name=MODEL_NAME,
|
||||
request_id="test-id",
|
||||
pooling_params=PoolingParams(),
|
||||
)
|
||||
|
||||
|
||||
def test_pooling_maybe_get_adapters_lora_name_sets_lora_request():
|
||||
"""LoRA adapter name must populate ctx.lora_request without raising."""
|
||||
lora_name = "bot-embed-lora"
|
||||
serving = _make_pooling_serving(lora_name)
|
||||
ctx = _make_pooling_ctx(lora_name)
|
||||
|
||||
serving._maybe_get_adapters(ctx)
|
||||
|
||||
assert ctx.lora_request is not None
|
||||
assert ctx.lora_request.lora_name == lora_name
|
||||
|
||||
|
||||
def test_pooling_maybe_get_adapters_unknown_model_raises():
|
||||
"""An unrecognised model name must still raise VLLMNotFoundError."""
|
||||
serving = _make_pooling_serving("some-lora")
|
||||
ctx = _make_pooling_ctx("unknown-model")
|
||||
|
||||
with pytest.raises(VLLMNotFoundError):
|
||||
serving._maybe_get_adapters(ctx)
|
||||
|
||||
+1
-1
@@ -6,7 +6,7 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# Model name constants used across tests
|
||||
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
||||
+2
-1
@@ -22,7 +22,8 @@ import tempfile
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .conftest import (
|
||||
MODEL_NAME_SMOLLM,
|
||||
)
|
||||
+2
-1
@@ -4,7 +4,8 @@ import openai # use the official async_client for correctness check
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .conftest import MODEL_NAME_SMOLLM
|
||||
|
||||
|
||||
+2
-1
@@ -12,7 +12,8 @@ import tempfile
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .conftest import (
|
||||
MODEL_NAME_SMOLLM,
|
||||
)
|
||||
+2
-1
@@ -6,7 +6,8 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .conftest import (
|
||||
HEADER_SAGEMAKER_CLOSED_SESSION_ID,
|
||||
HEADER_SAGEMAKER_NEW_SESSION_ID,
|
||||
@@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||
from vllm.entrypoints.utils import (
|
||||
from vllm.entrypoints.serve.utils.api_utils import (
|
||||
get_max_tokens,
|
||||
sanitize_message,
|
||||
should_include_usage,
|
||||
+1
-1
@@ -6,7 +6,7 @@ from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai import fingerprint as fp
|
||||
from vllm.entrypoints.serve.utils import fingerprint as fp
|
||||
|
||||
|
||||
def _cfg(tp=1, pp=1, dp=1, ep=False, digest="a3b21f94deadbeef"):
|
||||
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
|
||||
|
||||
|
||||
def test_request_logger_log_outputs():
|
||||
"""Test the new log_outputs functionality."""
|
||||
# Create a mock logger to capture log calls
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test basic output logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-123",
|
||||
outputs="Hello, world!",
|
||||
output_token_ids=[1, 2, 3, 4],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-123"
|
||||
assert call_args[3] == "Hello, world!"
|
||||
assert call_args[4] == [1, 2, 3, 4]
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_streaming_delta():
|
||||
"""Test log_outputs with streaming delta mode."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test streaming delta logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-456",
|
||||
outputs="Hello",
|
||||
output_token_ids=[1],
|
||||
finish_reason=None,
|
||||
is_streaming=True,
|
||||
delta=True,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-456"
|
||||
assert call_args[2] == " (streaming delta)"
|
||||
assert call_args[3] == "Hello"
|
||||
assert call_args[4] == [1]
|
||||
assert call_args[5] is None
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_streaming_complete():
|
||||
"""Test log_outputs with streaming complete mode."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test streaming complete logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-789",
|
||||
outputs="Complete response",
|
||||
output_token_ids=[1, 2, 3],
|
||||
finish_reason="length",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-789"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[3] == "Complete response"
|
||||
assert call_args[4] == [1, 2, 3]
|
||||
assert call_args[5] == "length"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_with_truncation():
|
||||
"""Test log_outputs respects max_log_len setting."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
# Set max_log_len to 10
|
||||
request_logger = RequestLogger(max_log_len=10)
|
||||
|
||||
# Test output truncation
|
||||
long_output = "This is a very long output that should be truncated"
|
||||
long_token_ids = list(range(20)) # 20 tokens
|
||||
|
||||
request_logger.log_outputs(
|
||||
request_id="test-truncate",
|
||||
outputs=long_output,
|
||||
output_token_ids=long_token_ids,
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args
|
||||
|
||||
# Check that output was truncated to first 10 characters
|
||||
logged_output = call_args[0][3]
|
||||
assert logged_output == "This is a "
|
||||
assert len(logged_output) == 10
|
||||
|
||||
# Check that token IDs were truncated to first 10 tokens
|
||||
logged_token_ids = call_args[0][4]
|
||||
assert logged_token_ids == list(range(10))
|
||||
assert len(logged_token_ids) == 10
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_none_values():
|
||||
"""Test log_outputs handles None values correctly."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test with None output_token_ids
|
||||
request_logger.log_outputs(
|
||||
request_id="test-none",
|
||||
outputs="Test output",
|
||||
output_token_ids=None,
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-none"
|
||||
assert call_args[3] == "Test output"
|
||||
assert call_args[4] is None
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_empty_output():
|
||||
"""Test log_outputs handles empty output correctly."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=5)
|
||||
|
||||
# Test with empty output
|
||||
request_logger.log_outputs(
|
||||
request_id="test-empty",
|
||||
outputs="",
|
||||
output_token_ids=[],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-empty"
|
||||
assert call_args[3] == ""
|
||||
assert call_args[4] == []
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_integration():
|
||||
"""Test that log_outputs can be called alongside log_inputs."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test that both methods can be called without interference
|
||||
request_logger.log_inputs(
|
||||
request_id="test-integration",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_embeds=None,
|
||||
params=None,
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
request_logger.log_outputs(
|
||||
request_id="test-integration",
|
||||
outputs="Test output",
|
||||
output_token_ids=[4, 5, 6],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
# Should have been called twice - once for inputs, once for outputs
|
||||
assert mock_logger.info.call_count == 2
|
||||
|
||||
# Check that the calls were made with correct patterns
|
||||
input_call = mock_logger.info.call_args_list[0][0]
|
||||
output_call = mock_logger.info.call_args_list[1][0]
|
||||
|
||||
assert "Received request %s" in input_call[0]
|
||||
assert input_call[1] == "test-integration"
|
||||
|
||||
assert "Generated response %s%s" in output_call[0]
|
||||
assert output_call[1] == "test-integration"
|
||||
|
||||
|
||||
def test_streaming_complete_logs_full_text_content():
|
||||
"""Test that streaming complete logging includes
|
||||
full accumulated text, not just token count."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.serve.utils.request_logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test with actual content instead of token count format
|
||||
full_response = "This is a complete response from streaming"
|
||||
request_logger.log_outputs(
|
||||
request_id="test-streaming-full-text",
|
||||
outputs=full_response,
|
||||
output_token_ids=None,
|
||||
finish_reason="streaming_complete",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
|
||||
# Verify the logged output is the full text, not a token count format
|
||||
logged_output = call_args[3]
|
||||
assert logged_output == full_response
|
||||
assert "tokens>" not in logged_output
|
||||
assert "streaming_complete" not in logged_output
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1] == "test-streaming-full-text"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[5] == "streaming_complete"
|
||||
+1
-1
@@ -7,7 +7,7 @@ from ssl import SSLContext
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.entrypoints.serve.utils.ssl import SSLCertRefresher
|
||||
|
||||
|
||||
class MockSSLContext(SSLContext):
|
||||
@@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
accuracy_threshold: 0.72
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
rocm_request_timeout_seconds: 1800
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
rocm_request_timeout_seconds: 1800
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@@ -106,7 +106,7 @@ async def call_vllm_api(
|
||||
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||
return text, completion_tokens
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
print(f"Error calling vLLM API ({type(e).__name__}): {e}")
|
||||
return "", 0
|
||||
|
||||
|
||||
@@ -177,6 +177,7 @@ def evaluate_gsm8k(
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: int | None = 42,
|
||||
request_timeout_seconds: float = 600,
|
||||
) -> dict[str, float | int]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
@@ -205,9 +206,8 @@ def evaluate_gsm8k(
|
||||
output_tokens[i] = tokens
|
||||
return answer, tokens
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
) as session:
|
||||
timeout = aiohttp.ClientTimeout(total=request_timeout_seconds)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
|
||||
@@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
|
||||
host = f"http://{host}"
|
||||
|
||||
# Run GSM8K evaluation
|
||||
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
|
||||
if current_platform.is_rocm():
|
||||
request_timeout_seconds = eval_config.get(
|
||||
"rocm_request_timeout_seconds", request_timeout_seconds
|
||||
)
|
||||
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=eval_config["num_questions"],
|
||||
num_shots=eval_config["num_fewshot"],
|
||||
host=host,
|
||||
port=port,
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename):
|
||||
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
||||
print(f"Number of questions: {eval_config['num_questions']}")
|
||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
|
||||
if current_platform.is_rocm():
|
||||
request_timeout_seconds = eval_config.get(
|
||||
"rocm_request_timeout_seconds", request_timeout_seconds
|
||||
)
|
||||
print(f"Request timeout: {request_timeout_seconds}s")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
print(f"Environment variables: {env_dict}")
|
||||
|
||||
|
||||
@@ -0,0 +1,339 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""ROCm kernel correctness tests for AITER unified attention.
|
||||
|
||||
Compares ``aiter.ops.triton.unified_attention`` against ``ref_paged_attn`` under
|
||||
decode, prefill, and mixed batches with varied shapes.
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.attention.test_triton_unified_attention import ref_paged_attn
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
_SKIP_NON_MI3XX = True
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
_SKIP_NON_MI3XX = not on_mi3xx()
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm-specific tests"),
|
||||
pytest.mark.skipif(_SKIP_NON_MI3XX, reason="MI300/MI350 ROCm only"),
|
||||
]
|
||||
|
||||
NUM_Q_HEADS = 8
|
||||
NUM_KV_HEADS = 8
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 64]
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
# (query_len, kv_len) per sequence
|
||||
MIXED_SEQ_LENS = [
|
||||
[(1, 128), (5, 18), (129, 463)],
|
||||
[(10, 256), (5, 64), (32, 128)],
|
||||
[(1, 1024), (5, 18), (129, 1328)],
|
||||
]
|
||||
DECODE_SEQ_LENS = [
|
||||
[(1, 128), (1, 256), (1, 384), (1, 512)],
|
||||
[(1, 1024), (1, 1536), (1, 2048)],
|
||||
]
|
||||
PREFILL_SEQ_LENS = [
|
||||
[(256, 256), (128, 512)],
|
||||
[(64, 128), (32, 256), (16, 512)],
|
||||
[(256, 1024), (128, 2048)],
|
||||
]
|
||||
|
||||
DEFAULT_ATOL, DEFAULT_RTOL = 1.5e-2, 1e-2
|
||||
FP8_ATOL, FP8_RTOL = 1.5e-1, 1.5e-1
|
||||
# Non-unity scale so q_descale handling is exercised explicitly.
|
||||
Q_SCALE = 0.75
|
||||
K_SCALE, V_SCALE = 0.5, 0.25
|
||||
|
||||
Fp8Variant = Literal["fp8_kv", "fp8_query", "fp8_query_kv"]
|
||||
|
||||
FP8_VARIANTS = [
|
||||
pytest.param("fp8_kv", id="fp8_kv"),
|
||||
pytest.param("fp8_query", id="fp8_query"),
|
||||
pytest.param("fp8_query_kv", id="fp8_query_kv"),
|
||||
]
|
||||
|
||||
FP8_SEQ_LENS = [
|
||||
MIXED_SEQ_LENS[0],
|
||||
DECODE_SEQ_LENS[0],
|
||||
DECODE_SEQ_LENS[1],
|
||||
PREFILL_SEQ_LENS[0],
|
||||
PREFILL_SEQ_LENS[2],
|
||||
]
|
||||
|
||||
|
||||
def _require_aiter() -> None:
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported
|
||||
|
||||
if not is_aiter_found_and_supported():
|
||||
pytest.skip("aiter is required on supported ROCm hardware for this test")
|
||||
|
||||
|
||||
def _make_case(
|
||||
*,
|
||||
seq_lens: list[tuple[int, int]],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
num_blocks: int = 2048,
|
||||
kv_cache_dtype: torch.dtype | None = None,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
q_dtype: torch.dtype | None = None,
|
||||
q_scale: float = Q_SCALE,
|
||||
) -> dict[str, Any]:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
query_lens = [q for q, _ in seq_lens]
|
||||
kv_lens = [k for _, k in seq_lens]
|
||||
num_seqs = len(seq_lens)
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens), NUM_Q_HEADS, head_size, dtype=dtype)
|
||||
if kv_cache_dtype is None:
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, NUM_KV_HEADS, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
else:
|
||||
key_cache = torch.clamp(
|
||||
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
|
||||
-1.0,
|
||||
1.0,
|
||||
).to(kv_cache_dtype)
|
||||
value_cache = torch.clamp(
|
||||
torch.randn(num_blocks, block_size, NUM_KV_HEADS, head_size),
|
||||
-1.0,
|
||||
1.0,
|
||||
).to(kv_cache_dtype)
|
||||
|
||||
cu_seqlens_q = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
seq_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
descale_shape = (num_seqs, NUM_KV_HEADS)
|
||||
k_descale = torch.full(descale_shape, k_scale, dtype=torch.float32, device="cuda")
|
||||
v_descale = torch.full(descale_shape, v_scale, dtype=torch.float32, device="cuda")
|
||||
|
||||
kernel_query = query
|
||||
q_descale = None
|
||||
if q_dtype is not None:
|
||||
q_descale = torch.tensor(q_scale, dtype=torch.float32, device="cuda")
|
||||
kernel_query = (query / q_scale).to(q_dtype)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"kernel_query": kernel_query,
|
||||
"key_cache": key_cache,
|
||||
"value_cache": value_cache,
|
||||
"block_tables": block_tables,
|
||||
"query_lens": query_lens,
|
||||
"kv_lens": kv_lens,
|
||||
"seq_lens_tensor": seq_lens_tensor,
|
||||
"cu_seqlens_q": cu_seqlens_q,
|
||||
"q_descale": q_descale,
|
||||
"k_descale": k_descale,
|
||||
"v_descale": v_descale,
|
||||
"scale": scale,
|
||||
"max_query_len": max_query_len,
|
||||
"max_kv_len": max_kv_len,
|
||||
"query_dtype": dtype,
|
||||
"k_scale": k_scale,
|
||||
"v_scale": v_scale,
|
||||
}
|
||||
|
||||
|
||||
def _make_fp8_case(
|
||||
*,
|
||||
seq_lens: list[tuple[int, int]],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
variant: Fp8Variant,
|
||||
) -> dict[str, Any]:
|
||||
use_fp8_kv = variant in ("fp8_kv", "fp8_query_kv")
|
||||
use_fp8_query = variant in ("fp8_query", "fp8_query_kv")
|
||||
return _make_case(
|
||||
seq_lens=seq_lens,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype=FP8_DTYPE if use_fp8_kv else None,
|
||||
k_scale=K_SCALE if use_fp8_kv else 1.0,
|
||||
v_scale=V_SCALE if use_fp8_kv else 1.0,
|
||||
q_dtype=FP8_DTYPE if use_fp8_query else None,
|
||||
)
|
||||
|
||||
|
||||
def _run_aiter_unified_attention(case: dict[str, Any]) -> torch.Tensor:
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
kernel_query = case["kernel_query"]
|
||||
# Kernel writes high-precision output even when Q is FP8 (matches vLLM usage).
|
||||
output = torch.empty_like(case["query"])
|
||||
unified_attention(
|
||||
q=kernel_query,
|
||||
k=case["key_cache"],
|
||||
v=case["value_cache"],
|
||||
out=output,
|
||||
cu_seqlens_q=case["cu_seqlens_q"],
|
||||
max_seqlen_q=case["max_query_len"],
|
||||
seqused_k=case["seq_lens_tensor"],
|
||||
max_seqlen_k=case["max_kv_len"],
|
||||
softmax_scale=case["scale"],
|
||||
causal=True,
|
||||
alibi_slopes=None,
|
||||
window_size=(-1, -1),
|
||||
block_table=case["block_tables"],
|
||||
softcap=0,
|
||||
q_descale=case["q_descale"],
|
||||
k_descale=case["k_descale"],
|
||||
v_descale=case["v_descale"],
|
||||
sinks=None,
|
||||
output_scale=None,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def _ref_output(case: dict[str, Any]) -> torch.Tensor:
|
||||
key_cache = case["key_cache"]
|
||||
value_cache = case["value_cache"]
|
||||
if key_cache.dtype != case["query_dtype"]:
|
||||
key_cache = key_cache.to(case["query_dtype"]) * case["k_scale"]
|
||||
value_cache = value_cache.to(case["query_dtype"]) * case["v_scale"]
|
||||
|
||||
return ref_paged_attn(
|
||||
query=case["query"],
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=case["query_lens"],
|
||||
kv_lens=case["kv_lens"],
|
||||
block_tables=case["block_tables"],
|
||||
scale=case["scale"],
|
||||
)
|
||||
|
||||
|
||||
def _assert_matches_reference(
|
||||
case: dict[str, Any],
|
||||
*,
|
||||
atol: float = DEFAULT_ATOL,
|
||||
rtol: float = DEFAULT_RTOL,
|
||||
) -> None:
|
||||
output = _run_aiter_unified_attention(case)
|
||||
output_ref = _ref_output(case)
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", MIXED_SEQ_LENS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_aiter_unified_attn_mixed_batch(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Decode + prefill sequences in one batch (native dtypes)."""
|
||||
_require_aiter()
|
||||
set_random_seed(0)
|
||||
|
||||
case = _make_case(
|
||||
seq_lens=seq_lens,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
_assert_matches_reference(case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", DECODE_SEQ_LENS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_aiter_unified_attn_decode(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Single-token decode (native dtypes)."""
|
||||
_require_aiter()
|
||||
set_random_seed(0)
|
||||
|
||||
case = _make_case(
|
||||
seq_lens=seq_lens,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
_assert_matches_reference(case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", PREFILL_SEQ_LENS)
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@torch.inference_mode()
|
||||
def test_aiter_unified_attn_prefill(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
"""Prefill-only batches with query_len > 1 (native dtypes)."""
|
||||
_require_aiter()
|
||||
set_random_seed(0)
|
||||
|
||||
case = _make_case(
|
||||
seq_lens=seq_lens,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
_assert_matches_reference(case)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.supports_fp8(),
|
||||
reason="FP8 not supported on this hardware",
|
||||
)
|
||||
@pytest.mark.parametrize("variant", FP8_VARIANTS)
|
||||
@pytest.mark.parametrize("seq_lens", FP8_SEQ_LENS)
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("block_size", [16, 64])
|
||||
@torch.inference_mode()
|
||||
def test_aiter_unified_attn_fp8(
|
||||
variant: Fp8Variant,
|
||||
seq_lens: list[tuple[int, int]],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
"""FP8 KV cache, FP8 query, or both; compared at bf16 reference precision."""
|
||||
_require_aiter()
|
||||
set_random_seed(0)
|
||||
|
||||
case = _make_fp8_case(
|
||||
seq_lens=seq_lens,
|
||||
head_size=head_size,
|
||||
block_size=block_size,
|
||||
variant=variant,
|
||||
)
|
||||
_assert_matches_reference(case, atol=FP8_ATOL, rtol=FP8_RTOL)
|
||||
@@ -205,7 +205,10 @@ def run_with_expert_maps(
|
||||
w2 = kwargs["w2"]
|
||||
a = kwargs["hidden_states"]
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
max_num_tokens=kwargs.get("hidden_states").shape[0],
|
||||
experts_per_token=kwargs.get("topk_ids").shape[1],
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
@@ -258,23 +261,27 @@ def run_8_bit(
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
|
||||
kwargs = {
|
||||
"hidden_states": moe_tensors.a,
|
||||
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
|
||||
"global_num_experts": num_experts,
|
||||
"activation": MoEActivation.SILU,
|
||||
"expert_map": None,
|
||||
"apply_router_weight_on_input": False,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
|
||||
max_num_tokens=moe_tensors.a.shape[0],
|
||||
experts_per_token=topk_ids.shape[1],
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
||||
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
||||
in_dtype=moe_tensors.a.dtype,
|
||||
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
|
||||
per_out_channel,
|
||||
False,
|
||||
topk_weights,
|
||||
None,
|
||||
)
|
||||
|
||||
workspace13.random_()
|
||||
|
||||
@@ -1287,10 +1287,12 @@ def _test_body_eplb(
|
||||
|
||||
expert_weights = [list(eplb_moe_layer.get_expert_weights())]
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=vllm_config.parallel_config.eplb_config.communicator,
|
||||
expert_weights=expert_weights[0],
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
# Rearrange expert weights across EP ranks
|
||||
@@ -1298,6 +1300,7 @@ def _test_body_eplb(
|
||||
old_global_expert_indices=initial_indices.unsqueeze(0),
|
||||
new_global_expert_indices=shuffled_indices.unsqueeze(0),
|
||||
expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
ep_group=cpu_group,
|
||||
communicator=communicator,
|
||||
)
|
||||
|
||||
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def make_dummy_moe_config(
|
||||
num_experts: int = 1,
|
||||
num_local_experts: int | None = None,
|
||||
experts_per_token: int = 1,
|
||||
hidden_dim: int = 1,
|
||||
intermediate_size_per_partition: int = 1,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
max_num_tokens: int = 512,
|
||||
) -> FusedMoEConfig:
|
||||
"""
|
||||
This is a dummy config for the mk constructor interface
|
||||
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
num_local_experts=num_experts,
|
||||
num_local_experts=num_local_experts
|
||||
if num_local_experts is not None
|
||||
else num_experts,
|
||||
num_logical_experts=num_experts,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation=MoEActivation.SILU,
|
||||
in_dtype=in_dtype,
|
||||
device="cuda",
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
max_num_tokens=512,
|
||||
max_num_tokens=max_num_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Triton dequant-gather kernel used by
|
||||
``CompressedTensorsEmbeddingWNA16Int`` (quantized embedding lookup)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_embedding import ( # noqa: E501
|
||||
_dequant_gather_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _dequant_gather_torch(
|
||||
ids: torch.Tensor,
|
||||
weight_packed: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
hidden: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
"""Reference: gather packed rows by id, unpack int32-packed INT, dequant."""
|
||||
n = ids.shape[0]
|
||||
int8 = unpack_from_int32(weight_packed[ids], num_bits, torch.Size([n, hidden]))
|
||||
scale_rows = weight_scale[ids]
|
||||
w = int8.to(scale_rows.dtype)
|
||||
if scale_rows.shape[1] == 1:
|
||||
return w * scale_rows
|
||||
ng = scale_rows.shape[1]
|
||||
return (w.view(n, ng, hidden // ng) * scale_rows.unsqueeze(-1)).view(n, hidden)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="Triton dequant kernel requires CUDA"
|
||||
)
|
||||
@pytest.mark.parametrize("num_bits", [2, 4, 8])
|
||||
@pytest.mark.parametrize("group_size", [0, 256]) # 0 -> channel
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_ids", [1, 17, 4096])
|
||||
def test_dequant_gather(num_bits, group_size, dtype, num_ids):
|
||||
torch.manual_seed(0)
|
||||
device = "cuda"
|
||||
vocab, hidden = 1000, 2048
|
||||
pack_factor = 32 // num_bits
|
||||
|
||||
# Random full-range int32 packed weights (covers the sign bit -> exercises the
|
||||
# arithmetic-shift + mask unpack path).
|
||||
weight_packed = torch.randint(
|
||||
-(2**31),
|
||||
2**31,
|
||||
(vocab, hidden // pack_factor),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_groups = 1 if group_size == 0 else hidden // group_size
|
||||
weight_scale = torch.rand(vocab, num_groups, dtype=dtype, device=device) + 0.01
|
||||
|
||||
ids = torch.randint(0, vocab, (num_ids,), dtype=torch.long, device=device)
|
||||
|
||||
out = _dequant_gather_triton(ids, weight_packed, weight_scale, hidden, num_bits)
|
||||
ref = _dequant_gather_torch(ids, weight_packed, weight_scale, hidden, num_bits)
|
||||
|
||||
assert out.shape == (num_ids, hidden)
|
||||
assert out.dtype == dtype
|
||||
torch.testing.assert_close(out, ref)
|
||||
@@ -468,6 +468,7 @@ def _reference_kv_compress_norm_rope(
|
||||
use_fp4: bool = False,
|
||||
rms_eps: float = 1e-6,
|
||||
fp8_max: float = 448.0,
|
||||
return_full_cache: bool = False,
|
||||
):
|
||||
"""Compress → RMSNorm → GPT-J RoPE → quantize.
|
||||
|
||||
@@ -521,6 +522,12 @@ def _reference_kv_compress_norm_rope(
|
||||
results.append(torch.cat([nope, rope]).to(state_cache.dtype))
|
||||
result = torch.stack(results)
|
||||
|
||||
if return_full_cache:
|
||||
# Contiguous 512-wide bf16 row (nope unrotated + rope rotated), matching
|
||||
# the FlashInfer full-cache layout before any per-tensor fp8 quant. The
|
||||
# kernel rounds the fp32 result to bf16 once at the store.
|
||||
return result.to(torch.bfloat16)
|
||||
|
||||
if use_fp4:
|
||||
return quantize_to_mxfp4(result)
|
||||
else:
|
||||
@@ -667,3 +674,145 @@ def test_fused_kv_insert_indexer(num_tokens: int, kv_block_size: int, use_fp4: b
|
||||
assert torch.equal(actual_scale, scale[i : i + 1]), (
|
||||
f"token {i}: scale {actual_scale.item()} != {scale[i].item()}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("compress_ratio", [4, 128])
|
||||
@pytest.mark.parametrize("store_fp8", [False, True])
|
||||
def test_cutedsl_full_cache_store(compress_ratio: int, store_fp8: bool):
|
||||
"""CuTeDSL compressor full-cache (FlashInfer) store parity for head=512.
|
||||
|
||||
Exercises the contiguous bf16 / per-tensor fp8 store branch of both the C4
|
||||
fused kernel and the C128 split kernel against the PyTorch reference.
|
||||
"""
|
||||
cutedsl = pytest.importorskip("cutlass") # noqa: F841
|
||||
from vllm.models.deepseek_v4.nvidia.ops.sparse_attn_compress_cutedsl import (
|
||||
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl,
|
||||
)
|
||||
|
||||
HEAD_DIM = 512
|
||||
ROPE_DIM = 64
|
||||
RMS_EPS = 1e-6
|
||||
FP8_MAX = 448.0
|
||||
# C128 compress (Block8 kernel) requires state-cache block_size=8; C4 uses 16.
|
||||
BLOCK_SIZE = 8 if compress_ratio == 128 else 16
|
||||
KV_BLOCK_SIZE = 64
|
||||
device = "cuda"
|
||||
torch.manual_seed(7)
|
||||
|
||||
overlap = 1 if compress_ratio == 4 else 0
|
||||
coff = 1 + overlap
|
||||
num_tokens = 8
|
||||
|
||||
num_pages = (compress_ratio * num_tokens - 1) // BLOCK_SIZE + 2
|
||||
# The production CompressorStateCache is fp32.
|
||||
state_cache = torch.randn(
|
||||
num_pages, BLOCK_SIZE, 2 * coff * HEAD_DIM, dtype=torch.float32, device=device
|
||||
)
|
||||
block_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0)
|
||||
token_to_req = torch.zeros(num_tokens, dtype=torch.int32, device=device)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
positions = torch.arange(
|
||||
compress_ratio - 1,
|
||||
compress_ratio * num_tokens,
|
||||
compress_ratio,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
rms_weight = torch.randn(HEAD_DIM, dtype=torch.bfloat16, device=device)
|
||||
cos_sin_cache = torch.randn(
|
||||
compress_ratio * num_tokens, ROPE_DIM, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
dtype = torch.float8_e4m3fn if store_fp8 else torch.bfloat16
|
||||
kv_n_blocks = (num_tokens + KV_BLOCK_SIZE - 1) // KV_BLOCK_SIZE + 1
|
||||
k_cache = torch.zeros(
|
||||
kv_n_blocks, KV_BLOCK_SIZE, HEAD_DIM, dtype=dtype, device=device
|
||||
)
|
||||
fp8_scale = torch.tensor(
|
||||
[0.5 if store_fp8 else 1.0], dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
if compress_ratio == 4:
|
||||
fused_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
|
||||
state_cache,
|
||||
token_to_req,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
BLOCK_SIZE,
|
||||
rms_weight,
|
||||
RMS_EPS,
|
||||
cos_sin_cache,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
KV_BLOCK_SIZE,
|
||||
k_cache.stride(0),
|
||||
head_size=HEAD_DIM,
|
||||
state_width=coff * HEAD_DIM,
|
||||
rope_head_dim=ROPE_DIM,
|
||||
fp8_max=FP8_MAX,
|
||||
quant_block=64,
|
||||
token_stride=576,
|
||||
scale_dim=8,
|
||||
compress_ratio=compress_ratio,
|
||||
overlap=True,
|
||||
store_full_kv=True,
|
||||
store_full_fp8=store_fp8,
|
||||
fp8_scale=fp8_scale,
|
||||
)
|
||||
else:
|
||||
compressed_kv = torch.empty(
|
||||
(num_tokens, HEAD_DIM), dtype=torch.float32, device=device
|
||||
)
|
||||
split_kv_compress_norm_rope_insert_sparse_attn_cutedsl(
|
||||
state_cache,
|
||||
token_to_req,
|
||||
positions,
|
||||
slot_mapping,
|
||||
block_table,
|
||||
BLOCK_SIZE,
|
||||
compressed_kv,
|
||||
rms_weight,
|
||||
RMS_EPS,
|
||||
cos_sin_cache,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
KV_BLOCK_SIZE,
|
||||
k_cache.stride(0),
|
||||
head_size=HEAD_DIM,
|
||||
state_width=coff * HEAD_DIM,
|
||||
rope_head_dim=ROPE_DIM,
|
||||
fp8_max=FP8_MAX,
|
||||
quant_block=64,
|
||||
token_stride=576,
|
||||
scale_dim=8,
|
||||
compress_ratio=compress_ratio,
|
||||
overlap=bool(overlap),
|
||||
store_full_kv=True,
|
||||
store_full_fp8=store_fp8,
|
||||
fp8_scale=fp8_scale,
|
||||
)
|
||||
|
||||
ref = _reference_kv_compress_norm_rope(
|
||||
state_cache,
|
||||
block_table,
|
||||
positions,
|
||||
rms_weight,
|
||||
cos_sin_cache,
|
||||
compress_ratio,
|
||||
overlap,
|
||||
rms_eps=RMS_EPS,
|
||||
return_full_cache=True,
|
||||
) # [num_tokens, HEAD_DIM] bf16
|
||||
|
||||
actual = torch.stack(
|
||||
[k_cache[i // KV_BLOCK_SIZE, i % KV_BLOCK_SIZE] for i in range(num_tokens)]
|
||||
)
|
||||
if store_fp8:
|
||||
ref_fp8 = torch.clamp(ref.float() / fp8_scale, -FP8_MAX, FP8_MAX).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
torch.testing.assert_close(actual.float(), ref_fp8.float(), rtol=0.0, atol=0.3)
|
||||
else:
|
||||
torch.testing.assert_close(actual.float(), ref.float(), rtol=3e-2, atol=3e-2)
|
||||
|
||||
@@ -67,7 +67,7 @@ def apply_rope_gptj_last_k(
|
||||
head_dim = x.shape[-1]
|
||||
nope_dim = head_dim - rope_dim
|
||||
|
||||
cs = cos_sin_cache[positions].to(torch.float32)
|
||||
cs = cos_sin_cache[positions.long()].to(torch.float32)
|
||||
cos = cs[..., :half]
|
||||
sin = cs[..., half:]
|
||||
|
||||
@@ -114,6 +114,18 @@ def _op_available() -> bool:
|
||||
return hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert")
|
||||
|
||||
|
||||
def _full_cache_fp8_op_available() -> bool:
|
||||
return hasattr(
|
||||
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert"
|
||||
)
|
||||
|
||||
|
||||
def _full_cache_bf16_op_available() -> bool:
|
||||
return hasattr(
|
||||
torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert"
|
||||
)
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or not _op_available(),
|
||||
reason="CUDA not available or fused DeepseekV4 op not built in",
|
||||
@@ -415,3 +427,238 @@ def test_combined_q_and_kv(
|
||||
"padded head slots must be exact zero"
|
||||
)
|
||||
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||
|
||||
|
||||
# ── Full-cache (FlashInfer) path parity ──────────────────────────────────────
|
||||
|
||||
|
||||
def _call_full_cache_fp8_fused(
|
||||
q,
|
||||
kv,
|
||||
q_fp8,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
fp8_scale,
|
||||
q_fp8_scale_inv,
|
||||
eps,
|
||||
bs,
|
||||
):
|
||||
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
|
||||
q,
|
||||
kv,
|
||||
q_fp8,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
positions.long(),
|
||||
cos_sin_cache,
|
||||
fp8_scale,
|
||||
q_fp8_scale_inv,
|
||||
eps,
|
||||
bs,
|
||||
)
|
||||
|
||||
|
||||
def _call_full_cache_bf16_fused(
|
||||
q,
|
||||
kv,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
bs,
|
||||
):
|
||||
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
|
||||
q,
|
||||
kv,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
positions.long(),
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
bs,
|
||||
)
|
||||
|
||||
|
||||
def _fp8_full_cache_reference(
|
||||
q,
|
||||
kv,
|
||||
k_cache,
|
||||
q_fp8,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
fp8_scale,
|
||||
q_fp8_scale_inv,
|
||||
):
|
||||
q_ref = rmsnorm_no_weight(q, eps)
|
||||
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache)
|
||||
q_fp8.copy_(
|
||||
torch.clamp(q_ref.float() * q_fp8_scale_inv, -FP8_MAX, FP8_MAX).to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
)
|
||||
|
||||
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
|
||||
valid = slot_mapping >= 0
|
||||
slots = slot_mapping[valid]
|
||||
block_idx = slots // block_size
|
||||
pos_in_block = slots % block_size
|
||||
k_cache[block_idx, pos_in_block] = torch.clamp(
|
||||
kv_ref[valid].float() / fp8_scale, -FP8_MAX, FP8_MAX
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def _bf16_full_cache_reference(
|
||||
q,
|
||||
kv,
|
||||
k_cache,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
):
|
||||
q_ref = rmsnorm_no_weight(q, eps)
|
||||
# Kernel keeps RMSNorm+RoPE in fp32 and rounds to bf16 once at the store.
|
||||
q_ref = apply_rope_gptj_last_k(q_ref, positions, cos_sin_cache).to(q.dtype)
|
||||
|
||||
kv_ref = apply_rope_gptj_last_k(kv, positions, cos_sin_cache)
|
||||
valid = slot_mapping >= 0
|
||||
slots = slot_mapping[valid]
|
||||
block_idx = slots // block_size
|
||||
pos_in_block = slots % block_size
|
||||
k_cache[block_idx, pos_in_block] = kv_ref[valid]
|
||||
return q_ref
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _full_cache_fp8_op_available(),
|
||||
reason="full-cache per-tensor FP8 DeepseekV4 op not built in",
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [4, 17])
|
||||
@pytest.mark.parametrize("n_heads", [8, 17])
|
||||
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
|
||||
def test_full_cache_per_tensor_fp8_matches_reference(
|
||||
num_tokens: int,
|
||||
n_heads: int,
|
||||
positions_dtype: torch.dtype,
|
||||
):
|
||||
torch.manual_seed(4)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
block_size = 16
|
||||
max_pos = 4096
|
||||
|
||||
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
|
||||
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
fp8_scale = torch.tensor([1.0], dtype=torch.float32, device=device)
|
||||
q_fp8_scale_inv = torch.tensor([1.0], dtype=torch.float32, device=device)
|
||||
|
||||
q_fp8_ref = torch.empty_like(q, dtype=torch.float8_e4m3fn)
|
||||
q_fp8_fused = torch.empty_like(q, dtype=torch.float8_e4m3fn)
|
||||
k_cache_ref = torch.zeros(
|
||||
num_blocks, block_size, HEAD_DIM, dtype=torch.float8_e4m3fn, device=device
|
||||
)
|
||||
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||
|
||||
_fp8_full_cache_reference(
|
||||
q,
|
||||
kv,
|
||||
k_cache_ref,
|
||||
q_fp8_ref,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
fp8_scale,
|
||||
q_fp8_scale_inv,
|
||||
)
|
||||
_call_full_cache_fp8_fused(
|
||||
q.clone(),
|
||||
kv,
|
||||
q_fp8_fused,
|
||||
k_cache_fused,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
fp8_scale,
|
||||
q_fp8_scale_inv,
|
||||
eps,
|
||||
block_size,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
q_fp8_fused.float(), q_fp8_ref.float(), rtol=0, atol=0.25
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
k_cache_fused.float(), k_cache_ref.float(), rtol=0, atol=0.25
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _full_cache_bf16_op_available(),
|
||||
reason="full-cache BF16 DeepseekV4 op not built in",
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [4, 17])
|
||||
@pytest.mark.parametrize("n_heads", [8, 17])
|
||||
@pytest.mark.parametrize("positions_dtype", [torch.int32, torch.int64])
|
||||
def test_full_cache_bf16_matches_reference(
|
||||
num_tokens: int,
|
||||
n_heads: int,
|
||||
positions_dtype: torch.dtype,
|
||||
):
|
||||
torch.manual_seed(5)
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
block_size = 16
|
||||
max_pos = 4096
|
||||
|
||||
q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device)
|
||||
kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device)
|
||||
positions = torch.arange(num_tokens, dtype=positions_dtype, device=device)
|
||||
cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device)
|
||||
|
||||
num_blocks = (num_tokens + block_size - 1) // block_size + 1
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device)
|
||||
|
||||
q_fused = q.clone()
|
||||
k_cache_ref = torch.zeros(
|
||||
num_blocks, block_size, HEAD_DIM, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
k_cache_fused = torch.zeros_like(k_cache_ref)
|
||||
q_ref = _bf16_full_cache_reference(
|
||||
q,
|
||||
kv,
|
||||
k_cache_ref,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
)
|
||||
_call_full_cache_bf16_fused(
|
||||
q_fused,
|
||||
kv,
|
||||
k_cache_fused,
|
||||
slot_mapping,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
block_size,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(q_fused, q_ref, rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0)
|
||||
|
||||
@@ -0,0 +1,481 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for sequence and token pooler head classes."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
|
||||
from vllm.model_executor.layers.pooler.seqwise.heads import (
|
||||
ClassifierPoolerHead,
|
||||
EmbeddingPoolerHead,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise.heads import (
|
||||
TokenClassifierPoolerHead,
|
||||
TokenEmbeddingPoolerHead,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||
|
||||
_HIDDEN = 16
|
||||
_BATCH = 3
|
||||
|
||||
|
||||
def _make_params(
|
||||
n: int,
|
||||
*,
|
||||
task: str = "embed",
|
||||
dimensions: int | None = None,
|
||||
use_activation: bool | None = None,
|
||||
) -> list[PoolingParams]:
|
||||
return [
|
||||
PoolingParams(task=task, dimensions=dimensions, use_activation=use_activation)
|
||||
for _ in range(n)
|
||||
]
|
||||
|
||||
|
||||
def _make_metadata(pooling_params: list[PoolingParams]) -> PoolingMetadata:
|
||||
n = len(pooling_params)
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.ones(n, dtype=torch.long),
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids_cpu=None,
|
||||
pooling_params=pooling_params,
|
||||
pooling_states=[PoolingStates() for _ in range(n)],
|
||||
)
|
||||
|
||||
|
||||
def _linear(in_f: int, out_f: int) -> nn.Linear:
|
||||
torch.manual_seed(42)
|
||||
return nn.Linear(in_f, out_f, bias=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EmbeddingPoolerHead
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestEmbeddingPoolerHead:
|
||||
def test_supported_tasks(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
assert head.get_supported_tasks() == {"embed"}
|
||||
|
||||
def test_passthrough(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH))
|
||||
out = head(x, meta)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_head_dtype(self):
|
||||
head = EmbeddingPoolerHead(head_dtype=torch.float16)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH))
|
||||
out = head(x, meta)
|
||||
assert out.dtype == torch.float16
|
||||
|
||||
def test_projector(self):
|
||||
proj = _linear(_HIDDEN, 8)
|
||||
head = EmbeddingPoolerHead(projector=proj)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH))
|
||||
out = head(x, meta)
|
||||
assert out.shape == (_BATCH, 8)
|
||||
assert torch.allclose(out, proj(x))
|
||||
|
||||
def test_matryoshka_uniform(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, dimensions=4)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert out.shape == (_BATCH, 4)
|
||||
assert torch.equal(out, x[..., :4])
|
||||
|
||||
def test_matryoshka_mixed(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
x = torch.randn(2, _HIDDEN)
|
||||
params = [
|
||||
PoolingParams(task="embed", dimensions=4),
|
||||
PoolingParams(task="embed", dimensions=8),
|
||||
]
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 2
|
||||
assert out[0].shape[-1] == 4
|
||||
assert out[1].shape[-1] == 8
|
||||
|
||||
def test_matryoshka_mixed_with_none(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
x = torch.randn(2, _HIDDEN)
|
||||
params = [
|
||||
PoolingParams(task="embed", dimensions=4),
|
||||
PoolingParams(task="embed", dimensions=None),
|
||||
]
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert isinstance(out, list)
|
||||
assert out[0].shape[-1] == 4
|
||||
assert torch.equal(out[1], x[1])
|
||||
|
||||
def test_activation_uniform_true(self):
|
||||
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, use_activation=True)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
norms = torch.linalg.norm(out, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
|
||||
|
||||
def test_activation_uniform_false(self):
|
||||
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, use_activation=False)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_activation_mixed_flags(self):
|
||||
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(2, _HIDDEN)
|
||||
params = [
|
||||
PoolingParams(task="embed", use_activation=True),
|
||||
PoolingParams(task="embed", use_activation=False),
|
||||
]
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert isinstance(out, list)
|
||||
norm_0 = torch.linalg.norm(out[0], dim=-1)
|
||||
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
|
||||
assert torch.equal(out[1], x[1])
|
||||
|
||||
def test_list_input_gets_stacked(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
|
||||
meta = _make_metadata(_make_params(_BATCH))
|
||||
out = head(tensors, meta)
|
||||
assert out.shape == (_BATCH, _HIDDEN)
|
||||
expected = torch.stack(tensors)
|
||||
assert torch.equal(out, expected)
|
||||
|
||||
def test_projector_then_matryoshka(self):
|
||||
proj = _linear(_HIDDEN, 8)
|
||||
head = EmbeddingPoolerHead(projector=proj)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, dimensions=4)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert out.shape == (_BATCH, 4)
|
||||
assert torch.equal(out, proj(x)[..., :4])
|
||||
|
||||
def test_matryoshka_then_activation(self):
|
||||
head = EmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, dimensions=4, use_activation=True)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert out.shape == (_BATCH, 4)
|
||||
norms = torch.linalg.norm(out, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
|
||||
|
||||
def test_empty_batch(self):
|
||||
head = EmbeddingPoolerHead()
|
||||
x = torch.randn(0, _HIDDEN)
|
||||
meta = _make_metadata([])
|
||||
out = head(x, meta)
|
||||
assert out.shape == (0, _HIDDEN)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClassifierPoolerHead
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestClassifierPoolerHead:
|
||||
def test_supported_tasks(self):
|
||||
head = ClassifierPoolerHead()
|
||||
assert head.get_supported_tasks() == {"classify"}
|
||||
|
||||
def test_passthrough(self):
|
||||
head = ClassifierPoolerHead()
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_head_dtype(self):
|
||||
head = ClassifierPoolerHead(head_dtype=torch.float16)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
assert out.dtype == torch.float16
|
||||
|
||||
def test_classifier(self):
|
||||
clf = _linear(_HIDDEN, 3)
|
||||
head = ClassifierPoolerHead(classifier=clf)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
assert out.shape == (_BATCH, 3)
|
||||
assert torch.allclose(out, clf(x))
|
||||
|
||||
def test_logit_mean(self):
|
||||
head = ClassifierPoolerHead(logit_mean=2.0)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
assert torch.allclose(out, x - 2.0)
|
||||
|
||||
def test_logit_sigma(self):
|
||||
head = ClassifierPoolerHead(logit_sigma=0.5)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
assert torch.allclose(out, x / 0.5)
|
||||
|
||||
def test_platt_scaling_combined(self):
|
||||
head = ClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
assert torch.allclose(out, (x - 1.0) / 2.0)
|
||||
|
||||
def test_activation_uniform_true(self):
|
||||
head = ClassifierPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, task="classify", use_activation=True)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
norms = torch.linalg.norm(out, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones(_BATCH), atol=1e-5)
|
||||
|
||||
def test_activation_uniform_false(self):
|
||||
head = ClassifierPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
params = _make_params(_BATCH, task="classify", use_activation=False)
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_activation_mixed_flags(self):
|
||||
head = ClassifierPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(2, _HIDDEN)
|
||||
params = [
|
||||
PoolingParams(task="classify", use_activation=True),
|
||||
PoolingParams(task="classify", use_activation=False),
|
||||
]
|
||||
meta = _make_metadata(params)
|
||||
out = head(x, meta)
|
||||
assert isinstance(out, list)
|
||||
norm_0 = torch.linalg.norm(out[0], dim=-1)
|
||||
assert torch.allclose(norm_0, torch.ones(1), atol=1e-5)
|
||||
assert torch.equal(out[1], x[1])
|
||||
|
||||
def test_list_input_gets_stacked(self):
|
||||
head = ClassifierPoolerHead()
|
||||
tensors = [torch.randn(_HIDDEN) for _ in range(_BATCH)]
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(tensors, meta)
|
||||
assert out.shape == (_BATCH, _HIDDEN)
|
||||
expected = torch.stack(tensors)
|
||||
assert torch.equal(out, expected)
|
||||
|
||||
def test_classifier_then_platt_scaling(self):
|
||||
clf = _linear(_HIDDEN, 3)
|
||||
head = ClassifierPoolerHead(classifier=clf, logit_mean=1.0, logit_sigma=2.0)
|
||||
x = torch.randn(_BATCH, _HIDDEN)
|
||||
meta = _make_metadata(_make_params(_BATCH, task="classify"))
|
||||
out = head(x, meta)
|
||||
expected = (clf(x) - 1.0) / 2.0
|
||||
assert torch.allclose(out, expected)
|
||||
|
||||
def test_empty_batch(self):
|
||||
head = ClassifierPoolerHead()
|
||||
x = torch.randn(0, _HIDDEN)
|
||||
meta = _make_metadata([])
|
||||
out = head(x, meta)
|
||||
assert out.shape == (0, _HIDDEN)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TokenEmbeddingPoolerHead
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTokenEmbeddingPoolerHead:
|
||||
def test_supported_tasks(self):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
assert head.get_supported_tasks() == {"token_embed"}
|
||||
|
||||
def test_passthrough(self):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_none_chunked_prefill(self):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
param = PoolingParams(task="token_embed")
|
||||
out = head.forward_chunk(None, param)
|
||||
assert out is None
|
||||
|
||||
def test_head_dtype(self):
|
||||
head = TokenEmbeddingPoolerHead(head_dtype=torch.float16)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.dtype == torch.float16
|
||||
|
||||
def test_projector(self):
|
||||
proj = _linear(_HIDDEN, 8)
|
||||
head = TokenEmbeddingPoolerHead(projector=proj)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.shape == (5, 8)
|
||||
assert torch.allclose(out, proj(x))
|
||||
|
||||
def test_matryoshka_truncation(self):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed", dimensions=4)
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.shape == (5, 4)
|
||||
assert torch.equal(out, x[..., :4])
|
||||
|
||||
def test_activation_true(self):
|
||||
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed", use_activation=True)
|
||||
out = head.forward_chunk(x, param)
|
||||
norms = torch.linalg.norm(out, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
|
||||
|
||||
def test_activation_false(self):
|
||||
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed", use_activation=False)
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_projector_then_matryoshka(self):
|
||||
proj = _linear(_HIDDEN, 8)
|
||||
head = TokenEmbeddingPoolerHead(projector=proj)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed", dimensions=4)
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.shape == (5, 4)
|
||||
assert torch.equal(out, proj(x)[..., :4])
|
||||
|
||||
def test_matryoshka_then_activation(self):
|
||||
head = TokenEmbeddingPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_embed", dimensions=4, use_activation=True)
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.shape == (5, 4)
|
||||
norms = torch.linalg.norm(out, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
|
||||
|
||||
def test_forward_mixed_batch_chunked_prefill(self):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
|
||||
params = _make_params(3, task="token_embed")
|
||||
meta = _make_metadata(params)
|
||||
out = head(pooled_data, meta)
|
||||
assert len(out) == 3
|
||||
assert torch.equal(out[0], pooled_data[0])
|
||||
assert out[1] is None
|
||||
assert torch.equal(out[2], pooled_data[2])
|
||||
|
||||
def test_forward_empty_batch(self):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
meta = _make_metadata([])
|
||||
out = head([], meta)
|
||||
assert out == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TokenClassifierPoolerHead
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTokenClassifierPoolerHead:
|
||||
def test_supported_tasks(self):
|
||||
head = TokenClassifierPoolerHead()
|
||||
assert head.get_supported_tasks() == {"token_classify"}
|
||||
|
||||
def test_passthrough(self):
|
||||
head = TokenClassifierPoolerHead()
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_none_chunked_prefill(self):
|
||||
head = TokenClassifierPoolerHead()
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(None, param)
|
||||
assert out is None
|
||||
|
||||
def test_head_dtype(self):
|
||||
head = TokenClassifierPoolerHead(head_dtype=torch.float16)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.dtype == torch.float16
|
||||
|
||||
def test_classifier(self):
|
||||
clf = _linear(_HIDDEN, 3)
|
||||
head = TokenClassifierPoolerHead(classifier=clf)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert out.shape == (5, 3)
|
||||
assert torch.allclose(out, clf(x))
|
||||
|
||||
def test_logit_mean(self):
|
||||
head = TokenClassifierPoolerHead(logit_mean=2.0)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.allclose(out, x - 2.0)
|
||||
|
||||
def test_logit_sigma(self):
|
||||
head = TokenClassifierPoolerHead(logit_sigma=0.5)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.allclose(out, x / 0.5)
|
||||
|
||||
def test_platt_scaling_combined(self):
|
||||
head = TokenClassifierPoolerHead(logit_mean=1.0, logit_sigma=2.0)
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify")
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.allclose(out, (x - 1.0) / 2.0)
|
||||
|
||||
def test_activation_true(self):
|
||||
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify", use_activation=True)
|
||||
out = head.forward_chunk(x, param)
|
||||
norms = torch.linalg.norm(out, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones(5), atol=1e-5)
|
||||
|
||||
def test_activation_false(self):
|
||||
head = TokenClassifierPoolerHead(activation=PoolerNormalize())
|
||||
x = torch.randn(5, _HIDDEN)
|
||||
param = PoolingParams(task="token_classify", use_activation=False)
|
||||
out = head.forward_chunk(x, param)
|
||||
assert torch.equal(out, x)
|
||||
|
||||
def test_forward_mixed_batch_chunked_prefill(self):
|
||||
head = TokenClassifierPoolerHead()
|
||||
pooled_data = [torch.randn(5, _HIDDEN), None, torch.randn(3, _HIDDEN)]
|
||||
params = _make_params(3, task="token_classify")
|
||||
meta = _make_metadata(params)
|
||||
out = head(pooled_data, meta)
|
||||
assert len(out) == 3
|
||||
assert torch.equal(out[0], pooled_data[0])
|
||||
assert out[1] is None
|
||||
assert torch.equal(out[2], pooled_data[2])
|
||||
|
||||
def test_forward_empty_batch(self):
|
||||
head = TokenClassifierPoolerHead()
|
||||
meta = _make_metadata([])
|
||||
out = head([], meta)
|
||||
assert out == []
|
||||
@@ -2,11 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from tests.utils import multi_gpu_test
|
||||
from tests.utils import multi_gpu_test, wait_for_gpu_memory_to_clear
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
@@ -404,6 +405,30 @@ def _get_vllm_runner_params(
|
||||
}
|
||||
|
||||
|
||||
def _wait_for_rocm_memory_to_settle() -> None:
|
||||
if not current_platform.is_rocm():
|
||||
return
|
||||
|
||||
num_gpus = current_platform.device_count()
|
||||
if num_gpus == 0:
|
||||
return
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(num_gpus)),
|
||||
threshold_ratio=0.01,
|
||||
timeout_s=120,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _owned_vLLM_runner(vllm_runner, kwargs):
|
||||
try:
|
||||
with vllm_runner(**kwargs) as runner:
|
||||
yield runner
|
||||
finally:
|
||||
_wait_for_rocm_memory_to_settle()
|
||||
|
||||
|
||||
def _get_vLLM_output(
|
||||
vllm_runner,
|
||||
kwargs,
|
||||
@@ -413,14 +438,18 @@ def _get_vLLM_output(
|
||||
num_repetitions=1,
|
||||
vllm_model=None,
|
||||
):
|
||||
runner_context = (
|
||||
_owned_vLLM_runner(vllm_runner, kwargs)
|
||||
if vllm_model is None
|
||||
else nullcontext(vllm_model)
|
||||
)
|
||||
with runner_context as runner:
|
||||
outs = []
|
||||
if vllm_model is None:
|
||||
vllm_model = vllm_runner(**kwargs)
|
||||
for _ in range(num_repetitions):
|
||||
if num_logprobs < 0:
|
||||
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
vllm_output = runner.generate_greedy(prompts, max_tokens)
|
||||
else:
|
||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
||||
vllm_output = runner.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
outs.append(vllm_output)
|
||||
@@ -772,8 +801,14 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
|
||||
# Cache only part of all the prompts
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
|
||||
with _owned_vLLM_runner(vllm_runner, vllm_runner_kwargs) as vllm_model:
|
||||
vllm_outputs_partial_cache, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts[:3],
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
vllm_model=vllm_model,
|
||||
)
|
||||
|
||||
compare_operator(
|
||||
@@ -826,7 +861,7 @@ def test_same_mamba_output_apc_on_vs_off(
|
||||
|
||||
# No prefix caching
|
||||
kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
|
||||
with vllm_runner(**kwargs_no_apc) as vllm_model:
|
||||
with _owned_vLLM_runner(vllm_runner, kwargs_no_apc) as vllm_model:
|
||||
outputs_no_apc, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
kwargs_no_apc,
|
||||
@@ -841,7 +876,7 @@ def test_same_mamba_output_apc_on_vs_off(
|
||||
"enable_prefix_caching": True,
|
||||
"mamba_block_size": 16,
|
||||
}
|
||||
with vllm_runner(**kwargs_with_apc) as vllm_model:
|
||||
with _owned_vLLM_runner(vllm_runner, kwargs_with_apc) as vllm_model:
|
||||
outputs_with_apc, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
kwargs_with_apc,
|
||||
|
||||
@@ -30,11 +30,14 @@ def vllm_to_hf_output(
|
||||
|
||||
MODEL_NAME = "ibm-granite/granite-speech-3.3-2b"
|
||||
MODEL_NAME_4_0 = "ibm-granite/granite-4.0-1b-speech"
|
||||
# "plus" variant of granite speech (uses GraniteSpeechPlusForConditionalGeneration).
|
||||
MODEL_NAME_4_1_PLUS = "ibm-granite/granite-speech-4.1-2b-plus"
|
||||
# Audio lora co-exists directly in the 3.3 model directory,
|
||||
# the 4.0 model has adapters merged into the weights.
|
||||
# the 4.0 and 4.1-plus models have adapters merged into the weights.
|
||||
models: dict[str, str | None] = {
|
||||
MODEL_NAME: MODEL_NAME,
|
||||
MODEL_NAME_4_0: None,
|
||||
MODEL_NAME_4_1_PLUS: None,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,10 @@ def qwen_vl_chat_template(content: str) -> str:
|
||||
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
def internvl_chat_template(content: str) -> str:
|
||||
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
def step3_vl_chat_template(content: str) -> str:
|
||||
return (
|
||||
"<|begin▁of▁sentence|> You are a helpful assistant.<|BOT|>user\n "
|
||||
@@ -51,6 +55,17 @@ def step3_vl_chat_template(content: str) -> str:
|
||||
|
||||
|
||||
MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
|
||||
"internvl": VitCudagraphTestConfig(
|
||||
model="OpenGVLab/InternVL3-1B",
|
||||
num_video_frames=8,
|
||||
image_prompt=internvl_chat_template("<image>\nWhat is in this image?"),
|
||||
video_prompt=internvl_chat_template(
|
||||
"<video>\nDescribe this video in one sentence."
|
||||
),
|
||||
needs_video_metadata=False,
|
||||
vllm_runner_kwargs={"trust_remote_code": True},
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
"qwen2_5_vl": VitCudagraphTestConfig(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
image_prompt=qwen_vl_chat_template(
|
||||
|
||||
@@ -938,6 +938,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"ibm-granite/granite-speech-3.3-2b",
|
||||
extras={"4.0-1b": "ibm-granite/granite-4.0-1b-speech"},
|
||||
),
|
||||
"GraniteSpeechPlusForConditionalGeneration": _HfExamplesInfo(
|
||||
"ibm-granite/granite-speech-4.1-2b-plus",
|
||||
min_transformers_version="5.8.0",
|
||||
),
|
||||
"GLM4VForCausalLM": _HfExamplesInfo(
|
||||
"zai-org/glm-4v-9b",
|
||||
trust_remote_code=True,
|
||||
|
||||
@@ -36,11 +36,24 @@ def tokenizer():
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_obj():
|
||||
return ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
|
||||
@@ -328,3 +341,27 @@ def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
|
||||
tc.function.arguments for tc in tool_calls if tc.function.arguments
|
||||
)
|
||||
assert tool_args.endswith(remainder)
|
||||
|
||||
|
||||
def test_parse_delta_tool_choice_none(tokenizer, request_obj):
|
||||
parser = make_parser(tokenizer, reasoning=False, tool=True)
|
||||
request = request_obj.model_copy(update={"tool_choice": "none"})
|
||||
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
|
||||
reasoning, content, tool_calls = collect_fields(results)
|
||||
|
||||
assert reasoning == ""
|
||||
assert len(tool_calls) == 0
|
||||
assert "<tool_call>" in content
|
||||
assert "get_weather" in content
|
||||
|
||||
|
||||
def test_parse_delta_tool_choice_none_with_reasoning(tokenizer, request_obj):
|
||||
parser = make_parser(tokenizer, reasoning=True, tool=True)
|
||||
request = request_obj.model_copy(update={"tool_choice": "none"})
|
||||
results = stream_text(parser, tokenizer, MODEL_OUTPUT, request, prompt_token_ids=[])
|
||||
reasoning, content, tool_calls = collect_fields(results)
|
||||
|
||||
assert "let me think about this" in reasoning
|
||||
assert len(tool_calls) == 0
|
||||
assert "<tool_call>" in content
|
||||
assert "get_weather" in content
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A4Mxfp4,
|
||||
CompressedTensorsW4A8Fp8,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A8Mxfp8,
|
||||
@@ -37,9 +36,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
@@ -376,13 +372,12 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
|
||||
@pytest.mark.parametrize(
|
||||
"args",
|
||||
[
|
||||
# TODO: Enable once model is available again
|
||||
# ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", True),
|
||||
("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", False),
|
||||
],
|
||||
)
|
||||
def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
model, scheme = args
|
||||
model, use_a16 = args
|
||||
with vllm_runner(model, enforce_eager=True) as llm:
|
||||
|
||||
def check_model(model):
|
||||
@@ -390,15 +385,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
if (
|
||||
isinstance(qkv_proj.scheme, scheme)
|
||||
or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
|
||||
and not cutlass_fp4_supported()
|
||||
):
|
||||
assert True
|
||||
else:
|
||||
raise AssertionError("FP4 Scheme Mismatch")
|
||||
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A4Fp4)
|
||||
assert qkv_proj.scheme.use_a16 == use_a16
|
||||
assert qkv_proj.scheme.group_size == 16
|
||||
|
||||
llm.apply_model(check_model)
|
||||
|
||||
+1
-244
@@ -10,12 +10,11 @@ from dataclasses import dataclass
|
||||
from json.decoder import JSONDecodeError
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.logger import (
|
||||
_DATE_FORMAT,
|
||||
_FORMAT,
|
||||
@@ -269,248 +268,6 @@ def test_prepare_object_to_dump():
|
||||
assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs():
|
||||
"""Test the new log_outputs functionality."""
|
||||
# Create a mock logger to capture log calls
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test basic output logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-123",
|
||||
outputs="Hello, world!",
|
||||
output_token_ids=[1, 2, 3, 4],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-123"
|
||||
assert call_args[3] == "Hello, world!"
|
||||
assert call_args[4] == [1, 2, 3, 4]
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_streaming_delta():
|
||||
"""Test log_outputs with streaming delta mode."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test streaming delta logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-456",
|
||||
outputs="Hello",
|
||||
output_token_ids=[1],
|
||||
finish_reason=None,
|
||||
is_streaming=True,
|
||||
delta=True,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-456"
|
||||
assert call_args[2] == " (streaming delta)"
|
||||
assert call_args[3] == "Hello"
|
||||
assert call_args[4] == [1]
|
||||
assert call_args[5] is None
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_streaming_complete():
|
||||
"""Test log_outputs with streaming complete mode."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test streaming complete logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-789",
|
||||
outputs="Complete response",
|
||||
output_token_ids=[1, 2, 3],
|
||||
finish_reason="length",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-789"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[3] == "Complete response"
|
||||
assert call_args[4] == [1, 2, 3]
|
||||
assert call_args[5] == "length"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_with_truncation():
|
||||
"""Test log_outputs respects max_log_len setting."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
# Set max_log_len to 10
|
||||
request_logger = RequestLogger(max_log_len=10)
|
||||
|
||||
# Test output truncation
|
||||
long_output = "This is a very long output that should be truncated"
|
||||
long_token_ids = list(range(20)) # 20 tokens
|
||||
|
||||
request_logger.log_outputs(
|
||||
request_id="test-truncate",
|
||||
outputs=long_output,
|
||||
output_token_ids=long_token_ids,
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args
|
||||
|
||||
# Check that output was truncated to first 10 characters
|
||||
logged_output = call_args[0][3]
|
||||
assert logged_output == "This is a "
|
||||
assert len(logged_output) == 10
|
||||
|
||||
# Check that token IDs were truncated to first 10 tokens
|
||||
logged_token_ids = call_args[0][4]
|
||||
assert logged_token_ids == list(range(10))
|
||||
assert len(logged_token_ids) == 10
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_none_values():
|
||||
"""Test log_outputs handles None values correctly."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test with None output_token_ids
|
||||
request_logger.log_outputs(
|
||||
request_id="test-none",
|
||||
outputs="Test output",
|
||||
output_token_ids=None,
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-none"
|
||||
assert call_args[3] == "Test output"
|
||||
assert call_args[4] is None
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_empty_output():
|
||||
"""Test log_outputs handles empty output correctly."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=5)
|
||||
|
||||
# Test with empty output
|
||||
request_logger.log_outputs(
|
||||
request_id="test-empty",
|
||||
outputs="",
|
||||
output_token_ids=[],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-empty"
|
||||
assert call_args[3] == ""
|
||||
assert call_args[4] == []
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_integration():
|
||||
"""Test that log_outputs can be called alongside log_inputs."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test that both methods can be called without interference
|
||||
request_logger.log_inputs(
|
||||
request_id="test-integration",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_embeds=None,
|
||||
params=None,
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
request_logger.log_outputs(
|
||||
request_id="test-integration",
|
||||
outputs="Test output",
|
||||
output_token_ids=[4, 5, 6],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
# Should have been called twice - once for inputs, once for outputs
|
||||
assert mock_logger.info.call_count == 2
|
||||
|
||||
# Check that the calls were made with correct patterns
|
||||
input_call = mock_logger.info.call_args_list[0][0]
|
||||
output_call = mock_logger.info.call_args_list[1][0]
|
||||
|
||||
assert "Received request %s" in input_call[0]
|
||||
assert input_call[1] == "test-integration"
|
||||
|
||||
assert "Generated response %s%s" in output_call[0]
|
||||
assert output_call[1] == "test-integration"
|
||||
|
||||
|
||||
def test_streaming_complete_logs_full_text_content():
|
||||
"""Test that streaming complete logging includes
|
||||
full accumulated text, not just token count."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test with actual content instead of token count format
|
||||
full_response = "This is a complete response from streaming"
|
||||
request_logger.log_outputs(
|
||||
request_id="test-streaming-full-text",
|
||||
outputs=full_response,
|
||||
output_token_ids=None,
|
||||
finish_reason="streaming_complete",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
|
||||
# Verify the logged output is the full text, not a token count format
|
||||
logged_output = call_args[3]
|
||||
assert logged_output == full_response
|
||||
assert "tokens>" not in logged_output
|
||||
assert "streaming_complete" not in logged_output
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1] == "test-streaming-full-text"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[5] == "streaming_complete"
|
||||
|
||||
|
||||
# Add vllm prefix to make sure logs go through the vllm logger
|
||||
test_logger = init_logger("vllm.test_logger")
|
||||
|
||||
|
||||
@@ -54,15 +54,14 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
(
|
||||
MiniMaxText01LinearAttention,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
hidden_inner_size=256,
|
||||
num_heads=8,
|
||||
config=SimpleNamespace(
|
||||
hidden_size=256,
|
||||
num_attention_heads=8,
|
||||
head_dim=32,
|
||||
max_position=2048,
|
||||
block_size=64,
|
||||
num_hidden_layer=12,
|
||||
layer_idx=0,
|
||||
linear_layer_idx=0,
|
||||
num_hidden_layers=12,
|
||||
block=64,
|
||||
),
|
||||
prefix="layers.0.self_attn",
|
||||
),
|
||||
LinearAttentionBackend,
|
||||
MambaAttentionBackendEnum.LINEAR,
|
||||
@@ -88,6 +87,8 @@ def test_mamba_layers_get_attn_backend(
|
||||
expected_mamba_type,
|
||||
):
|
||||
"""Test that Mamba-like layers return the correct attention backend."""
|
||||
if layer_class is MiniMaxText01LinearAttention:
|
||||
init_kwargs["vllm_config"] = default_vllm_config
|
||||
layer = layer_class(**init_kwargs)
|
||||
|
||||
backend_class = layer.get_attn_backend()
|
||||
|
||||
@@ -358,6 +358,43 @@ def test_free_kv_cache_block_queue_append_n():
|
||||
)
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_prepend_n():
|
||||
# Seed the queue with one block so prepend has an existing head to splice
|
||||
# in front of (fake_head->b0->fake_tail).
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||
queue = FreeKVCacheBlockQueue(blocks[0:1])
|
||||
|
||||
# Prepend 0 blocks is a no-op.
|
||||
queue.prepend_n([])
|
||||
assert queue.num_free_blocks == 1
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[0]
|
||||
|
||||
# Prepend 2 blocks; they land in front of the existing head, in order.
|
||||
# fake_head->b4->b5->b0->fake_tail
|
||||
queue.prepend_n(blocks[4:6])
|
||||
assert queue.num_free_blocks == 3
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[4]
|
||||
assert blocks[4].prev_free_block is queue.fake_free_list_head
|
||||
assert blocks[4].next_free_block is blocks[5]
|
||||
assert blocks[5].prev_free_block is blocks[4]
|
||||
assert blocks[5].next_free_block is blocks[0]
|
||||
assert blocks[0].prev_free_block is blocks[5]
|
||||
assert blocks[0].next_free_block is queue.fake_free_list_tail
|
||||
assert queue.fake_free_list_tail.prev_free_block is blocks[0]
|
||||
|
||||
# A second prepend goes ahead of everything previously prepended.
|
||||
# fake_head->b1->b2->b4->b5->b0->fake_tail
|
||||
queue.prepend_n(blocks[1:3])
|
||||
assert queue.num_free_blocks == 5
|
||||
assert queue.fake_free_list_head.next_free_block is blocks[1]
|
||||
assert blocks[1].next_free_block is blocks[2]
|
||||
assert blocks[2].next_free_block is blocks[4]
|
||||
|
||||
# The popleft order reflects the front-to-back queue order.
|
||||
assert [queue.popleft().block_id for _ in range(5)] == [1, 2, 4, 5, 0]
|
||||
assert queue.num_free_blocks == 0
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_popleft_n():
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
|
||||
# Create an empty FreeKVCacheBlockQueue with these blocks
|
||||
|
||||
@@ -39,6 +39,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpecKind,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
@@ -2875,6 +2876,350 @@ def test_hybrid_cache_blocks_clamped_to_lcm():
|
||||
)
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch):
|
||||
"""Verify fixed intervals retain sparse tails plus the latest replay tail."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# The SWA manager uses the configured 64-token interval (a multiple of the
|
||||
# 32-token lcm_block_size) as its retention segment. For this 128-token
|
||||
# prompt, the retained SWA tails are the 64-token interval boundary, the
|
||||
# 96-token replay boundary, and the 128-token interval boundary.
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
expected_swa_cached = {7, 11, 15}
|
||||
for i in range(16):
|
||||
cached = pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1])
|
||||
if i in expected_swa_cached:
|
||||
assert cached is not None, f"SWA hash {i} should be cached"
|
||||
else:
|
||||
assert cached is None, f"SWA hash {i} should not be cached"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interval, expected_match",
|
||||
[
|
||||
# scheduler_block_size is 32 (= lcm(4*8, 8)); 33 is not a multiple of it.
|
||||
("33", "multiple of scheduler_block_size"),
|
||||
# A negative multiple (-32 % 32 == 0) must still be rejected explicitly,
|
||||
# otherwise it would pass the modulo check and silently degrade to dense.
|
||||
("-32", "non-negative"),
|
||||
],
|
||||
)
|
||||
def test_hybrid_local_kv_retention_interval_rejects_invalid(
|
||||
monkeypatch, interval, expected_match
|
||||
):
|
||||
"""A retention interval that is negative or not a multiple of
|
||||
scheduler_block_size errors out at construction time."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", interval)
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected_match):
|
||||
make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_interval_survives_recycling(monkeypatch):
|
||||
"""Verify retained local checkpoints are reused after block recycling."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "1024")
|
||||
hash_block_size = 4
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=800,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
MLAAttentionSpec(
|
||||
block_size=64 * hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.uint8,
|
||||
compress_ratio=4,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["swa"],
|
||||
SlidingWindowSpec(
|
||||
block_size=16 * hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=512,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["c128"],
|
||||
SlidingWindowSpec(
|
||||
block_size=2 * hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=128,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["c4"],
|
||||
SlidingWindowSpec(
|
||||
block_size=hash_block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=8,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=4096,
|
||||
enable_caching=True,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
|
||||
def fill_request(request_id: str, token_offset: int) -> list[int]:
|
||||
token_ids = [
|
||||
token_offset + i for i in range(1024) for _ in range(hash_block_size)
|
||||
]
|
||||
fill_req = make_request(request_id, token_ids, hash_block_size, sha256)
|
||||
while fill_req.num_computed_tokens < len(token_ids):
|
||||
num_new_tokens = min(512, len(token_ids) - fill_req.num_computed_tokens)
|
||||
blocks = manager.allocate_slots(fill_req, num_new_tokens)
|
||||
assert blocks is not None
|
||||
fill_req.num_computed_tokens += num_new_tokens
|
||||
manager.free(fill_req)
|
||||
return token_ids
|
||||
|
||||
token_ids = fill_request("fill_0", 0)
|
||||
replay_req = make_request("replay", token_ids[:1800], hash_block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
||||
assert num_computed_tokens == 1024
|
||||
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
||||
|
||||
fill_request("fill_1", 100_000)
|
||||
replay_req = make_request("replay_again", token_ids[:1800], hash_block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(replay_req)
|
||||
assert num_computed_tokens == 1024
|
||||
assert [len(blocks) for blocks in computed_blocks.blocks] == [4, 16, 128, 256]
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_latest_only_reuses_replay_boundary(monkeypatch):
|
||||
"""Verify latest-only retention reuses only the replayable prompt boundary."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req0)
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
expected_swa_cached = {11}
|
||||
for i in range(16):
|
||||
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
||||
if i in expected_swa_cached:
|
||||
assert cached is not None, f"SWA hash {i} should be cached"
|
||||
else:
|
||||
assert cached is None, f"SWA hash {i} should not be cached"
|
||||
|
||||
manager.free(req0)
|
||||
retained_swa_block = pool.get_cached_block(req0.block_hashes[11], [1])
|
||||
assert retained_swa_block is not None
|
||||
assert retained_swa_block[0].ref_cnt == 0
|
||||
|
||||
req1 = make_request("1", token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Full prompt hits intentionally recompute the final block for logits, so
|
||||
# the longest usable hit is the previous LCM boundary: 96 tokens.
|
||||
assert num_computed_tokens == 12 * block_size
|
||||
assert len(computed_blocks.blocks[1]) == 12
|
||||
|
||||
shorter_req = make_request("2", token_ids[: 12 * block_size], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(shorter_req)
|
||||
assert num_computed_tokens == 0
|
||||
assert len(computed_blocks.blocks[1]) == 0
|
||||
|
||||
|
||||
def test_hybrid_local_kv_retention_mtp_reuses_latest_boundary(monkeypatch):
|
||||
"""Verify MTP/EAGLE SWA retention keeps the extra proof block.
|
||||
|
||||
EAGLE/MTP lookup matches one additional local block after the returned
|
||||
prefix and then drops it. Sparse retention must therefore cache the normal
|
||||
local tail at the latest replay boundary plus one extra SWA block.
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["full"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["swa_mtp"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
is_eagle_group=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
# 127 tokens: latest replay boundary is floor((127 - 1) / 32) * 32 = 96.
|
||||
# The EAGLE/MTP SWA lookup group must cache the local tail ending at
|
||||
# 104 tokens, and that tail is two 8-token blocks wide: hashes 11 and 12.
|
||||
token_ids = [i for i in range(15) for _ in range(block_size)] + [15] * 7
|
||||
req0 = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(
|
||||
req0,
|
||||
len(token_ids),
|
||||
num_computed_tokens,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
expected_swa_cached = {11, 12}
|
||||
for i in range(15):
|
||||
cached = pool.get_cached_block(req0.block_hashes[i], kv_cache_group_ids=[1])
|
||||
if i in expected_swa_cached:
|
||||
assert cached is not None, f"SWA hash {i} should be cached"
|
||||
else:
|
||||
assert cached is None, f"SWA hash {i} should not be cached"
|
||||
|
||||
manager.free(req0)
|
||||
|
||||
req1 = make_request("1", token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert num_computed_tokens == 12 * block_size
|
||||
assert [len(blocks) for blocks in computed_blocks.blocks] == [3, 12]
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
@@ -3058,3 +3403,215 @@ def test_can_fit_full_sequence_full_attention_still_gates_oversized():
|
||||
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
|
||||
|
||||
assert manager.allocate_slots(req, block_size, full_sequence_must_fit=True) is None
|
||||
|
||||
|
||||
def test_swa_free_split_keeps_cached_tail_ahead_of_scratch(monkeypatch):
|
||||
"""Default path (no retention): freeing an SWA request must place its
|
||||
uncached scratch blocks at the front of the free queue (recycled first)
|
||||
and keep its cached checkpoint blocks at the back (retained for prefix
|
||||
hits). This split is always-on, independent of the retention interval."""
|
||||
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
||||
block_size = 8
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=4 * block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
swa_manager = manager.coordinator.single_type_managers[1]
|
||||
null_block = manager.block_pool.null_block
|
||||
cached_ids: set[int] = set()
|
||||
uncached_ids: set[int] = set()
|
||||
cached_hash_indices: list[int] = []
|
||||
for i, block in enumerate(swa_manager.req_to_blocks[req.request_id]):
|
||||
if block is null_block:
|
||||
continue
|
||||
if block.block_hash is None:
|
||||
uncached_ids.add(block.block_id)
|
||||
else:
|
||||
cached_ids.add(block.block_id)
|
||||
cached_hash_indices.append(i)
|
||||
# The dense default mask caches only the per-segment tails, so a 16-block
|
||||
# SWA prompt must produce a mix of retained and scratch blocks.
|
||||
assert cached_ids, "expected some retained (cached) SWA tail blocks"
|
||||
assert uncached_ids, "expected some scratch (uncached) SWA blocks"
|
||||
|
||||
manager.free(req)
|
||||
|
||||
order = [
|
||||
b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
]
|
||||
pos = {bid: i for i, bid in enumerate(order)}
|
||||
# Every scratch block is recycled before every retained block.
|
||||
assert max(pos[bid] for bid in uncached_ids) < min(pos[bid] for bid in cached_ids)
|
||||
# The retained tails survive the free and still serve a prefix-cache hit.
|
||||
for i in cached_hash_indices:
|
||||
assert (
|
||||
manager.block_pool.get_cached_block(
|
||||
req.block_hashes[i], kv_cache_group_ids=[1]
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _make_pure_swa_manager(block_size, sliding_window, num_blocks=100, **kwargs):
|
||||
"""Single sliding-window group (UnitaryKVCacheCoordinator)."""
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=sliding_window,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return make_kv_cache_manager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_pure_swa_retention_interval_caches_sparse_tails(monkeypatch):
|
||||
"""Sparse retention must work for a pure-SWA single-group model, not just
|
||||
hybrid models: only the per-interval tails plus the latest replay tail are
|
||||
cached, and a replay still hits the latest replayable boundary."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "64")
|
||||
block_size = 16
|
||||
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||
assert type(manager.coordinator).__name__ == "UnitaryKVCacheCoordinator"
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
cached = {
|
||||
i
|
||||
for i in range(16)
|
||||
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||
is not None
|
||||
}
|
||||
# per_segment = 64 / 16 = 4, need = cdiv(16-1, 16) = 1 -> segment tails at
|
||||
# i%4==3 -> {3,7,11,15}; latest replay boundary (255//16*16 = 240) -> tail
|
||||
# block 14. Crucially this is a strict subset of all 16 blocks: retention
|
||||
# is actually sparse for pure SWA (not silently dense).
|
||||
assert cached == {3, 7, 11, 14, 15}
|
||||
|
||||
# A replay of the same prompt hits the latest replayable boundary (240).
|
||||
replay = make_request("1", token_ids, block_size, sha256)
|
||||
_, num_computed = manager.get_computed_blocks(replay)
|
||||
assert num_computed == 240
|
||||
|
||||
|
||||
def test_pure_swa_retention_latest_only(monkeypatch):
|
||||
"""`=0` on a pure-SWA model keeps only the latest replay tail."""
|
||||
monkeypatch.setenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", "0")
|
||||
block_size = 16
|
||||
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
cached = {
|
||||
i
|
||||
for i in range(16)
|
||||
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||
is not None
|
||||
}
|
||||
# No segment tails (interval 0); only the latest replay tail (block 14).
|
||||
assert cached == {14}
|
||||
|
||||
replay = make_request("1", token_ids, block_size, sha256)
|
||||
_, num_computed = manager.get_computed_blocks(replay)
|
||||
assert num_computed == 240
|
||||
|
||||
|
||||
def test_pure_swa_retention_dense_default_caches_all(monkeypatch):
|
||||
"""With retention unset, a pure-SWA model must keep the dense behavior:
|
||||
every block boundary is a potential hit, so all blocks are cached."""
|
||||
monkeypatch.delenv("VLLM_PREFIX_CACHE_RETENTION_INTERVAL", raising=False)
|
||||
block_size = 16
|
||||
manager = _make_pure_swa_manager(block_size, sliding_window=block_size)
|
||||
|
||||
token_ids = [i for i in range(16) for _ in range(block_size)]
|
||||
req = make_request("0", token_ids, block_size, sha256)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
blocks = manager.allocate_slots(
|
||||
req,
|
||||
len(token_ids),
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
|
||||
pool = manager.block_pool
|
||||
cached = {
|
||||
i
|
||||
for i in range(16)
|
||||
if pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[0])
|
||||
is not None
|
||||
}
|
||||
assert cached == set(range(16))
|
||||
|
||||
@@ -11,7 +11,9 @@ import pytest
|
||||
import torch
|
||||
from utils import skip_unsupported
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -51,7 +53,7 @@ def test_rms_norm_batch_invariant_vs_standard(
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation (Triton)
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare outputs
|
||||
# Use looser tolerance for bfloat16 due to its lower precision
|
||||
@@ -125,7 +127,7 @@ def test_fused_add_rms_norm_batch_invariant_residual_path(
|
||||
)
|
||||
|
||||
merged_single = x_single + residual_single
|
||||
ref_out = triton_rms_norm(merged_single, weight, eps=eps)
|
||||
ref_out = rms_norm_batch_invariant(merged_single, weight, eps=eps)
|
||||
|
||||
torch.testing.assert_close(
|
||||
residual_out_single,
|
||||
@@ -193,7 +195,7 @@ def test_rms_norm_3d_input(
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||
|
||||
# Use looser tolerance for bfloat16
|
||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||
@@ -242,7 +244,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||
|
||||
# Check for NaN or Inf
|
||||
assert not torch.isnan(standard_output).any(), (
|
||||
@@ -289,7 +291,7 @@ def test_rms_norm_formula(default_vllm_config):
|
||||
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare against formula
|
||||
torch.testing.assert_close(
|
||||
@@ -325,7 +327,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||
|
||||
# Use looser tolerance for bfloat16
|
||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||
@@ -360,7 +362,7 @@ def test_rms_norm_determinism(default_vllm_config):
|
||||
# Run multiple times
|
||||
outputs = []
|
||||
for _ in range(5):
|
||||
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
|
||||
output = rms_norm_batch_invariant(input_tensor.clone(), weight, eps=eps)
|
||||
outputs.append(output)
|
||||
|
||||
# All outputs should be identical
|
||||
@@ -395,7 +397,7 @@ if __name__ == "__main__":
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
triton_output = rms_norm_batch_invariant(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare
|
||||
max_diff = (triton_output - standard_output).abs().max().item()
|
||||
|
||||
@@ -98,7 +98,6 @@ def _make_connector_with_fake_worker(
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
assert isinstance(worker.nixl_wrapper, FakeNixlWrapper)
|
||||
worker.nixl_wrapper.set_cycles_before_xfer_done(cycles_before_done)
|
||||
worker.kv_cache_layout = "HND"
|
||||
if do_handshake:
|
||||
remote_agents = worker._nixl_handshake(
|
||||
|
||||
@@ -6,25 +6,56 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class _StubLMCacheMPConnector:
|
||||
"""Stand-in for LMCacheMPConnector used in config-translation tests.
|
||||
|
||||
The real connector module hard-imports the optional ``lmcache`` package
|
||||
at module load time, which is not installed in the cpu_test image. This
|
||||
test only asserts on the connector *name* and the ``extra_config`` dict
|
||||
produced by ``VllmConfig``, never instantiates the connector, so a bare
|
||||
placeholder class is sufficient. Not subclassing ``SupportsHMA`` mirrors
|
||||
the real connector's HMA support (it does not support HMA either)."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stub_lmcache_mp_connector(monkeypatch):
|
||||
"""Replace the lazy loader so VllmConfig.__post_init__ does not import
|
||||
``lmcache_mp_connector`` (and thus ``lmcache``) during config tests."""
|
||||
monkeypatch.setitem(
|
||||
KVConnectorFactory._registry,
|
||||
"LMCacheMPConnector",
|
||||
lambda: _StubLMCacheMPConnector,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||
[
|
||||
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30)),
|
||||
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
||||
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
||||
# ``lmcache`` backend now defaults to LMCacheMPConnector. The KV
|
||||
# storage capacity is owned by the standalone LMCache server, so
|
||||
# ``kv_offloading_size`` is intentionally not propagated.
|
||||
("lmcache", 4.0, 1, 1, "LMCacheMPConnector", None),
|
||||
("lmcache", 8.0, 2, 2, "LMCacheMPConnector", None),
|
||||
# When kv_offloading_size is None, offloading is disabled (backend is ignored)
|
||||
("native", None, 1, 1, None, None),
|
||||
],
|
||||
)
|
||||
def test_kv_connector(
|
||||
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
||||
stub_lmcache_mp_connector,
|
||||
kv_offloading_backend,
|
||||
kv_offloading_size,
|
||||
tp,
|
||||
pp,
|
||||
expected_backend,
|
||||
expected_bytes,
|
||||
):
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||
@@ -59,10 +90,12 @@ def test_kv_connector(
|
||||
# Existing config should be preserved
|
||||
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||
elif kv_offloading_backend == "lmcache":
|
||||
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
||||
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
||||
# Existing config should be replaced
|
||||
assert "existing_key" not in kv_connector_extra_config
|
||||
# MP mode does not push lmcache.local_cpu / max_local_cpu_size into
|
||||
# extra config (the LMCache server owns capacity). Pre-existing
|
||||
# extra config entries are preserved as-is.
|
||||
assert "lmcache.local_cpu" not in kv_connector_extra_config
|
||||
assert "lmcache.max_local_cpu_size" not in kv_connector_extra_config
|
||||
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||
|
||||
|
||||
def _build_config(
|
||||
|
||||
@@ -197,13 +197,6 @@ class FakeNixlWrapper:
|
||||
def get_xfer_telemetry(self, handle: int) -> dict:
|
||||
return get_default_xfer_telemetry()
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_fake_nixl_pkg():
|
||||
@@ -578,10 +571,7 @@ class TestNixlHandshake:
|
||||
"""Test case where multiple xfers are initiated to the same engine.
|
||||
|
||||
This test triggers the connector to load remote KV for the same
|
||||
`request_id`. The transfer is not done immediately due to
|
||||
`set_cycles_before_xfer_done`, so there is a state where there are
|
||||
multiple transfer states for the same `request_id`, and `get_finished`
|
||||
should handle it correctly (wait for all transfers to be done).
|
||||
`request_id`.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
@@ -598,7 +588,6 @@ class TestNixlHandshake:
|
||||
)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||
worker = connector.connector_worker
|
||||
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
# simulate handshake
|
||||
worker.dst_xfer_side_handles = {
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
|
||||
@@ -1304,7 +1293,6 @@ def test_scheduler_kv_connector_stats_aggregation():
|
||||
# Worker stats with transfer metrics
|
||||
worker_stats = NixlKVConnectorStats()
|
||||
worker_stats.record_transfer(get_default_xfer_telemetry())
|
||||
worker_stats.data["remote_tokens"] = []
|
||||
|
||||
# Scheduler stats with custom metric (needs dummy transfer to avoid being skipped)
|
||||
scheduler_stats = NixlKVConnectorStats()
|
||||
@@ -1314,7 +1302,6 @@ def test_scheduler_kv_connector_stats_aggregation():
|
||||
"post_duration": [0],
|
||||
"bytes_transferred": [0],
|
||||
"num_descriptors": [0],
|
||||
"remote_tokens": [128],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1355,7 +1342,6 @@ def test_scheduler_kv_connector_stats_aggregation():
|
||||
).scheduler_stats.kv_connector_stats
|
||||
nixl_stats = final_stats["NixlConnector"]
|
||||
assert nixl_stats.num_successful_transfers == 2
|
||||
assert nixl_stats.data["remote_tokens"] == [128]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||
|
||||
@@ -275,6 +275,83 @@ def test_apply_prefix_caching_mamba_hybrid(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"local_physical_per_logical,remote_physical_per_logical,"
|
||||
"local_block_ids,remote_block_ids,"
|
||||
"expected_local,expected_remote",
|
||||
[
|
||||
# SSM prefix caching: remote has 3 placeholder + 1 real block,
|
||||
# local has only the 1 real block. FA blocks are equal (no trim).
|
||||
pytest.param(
|
||||
10,
|
||||
10,
|
||||
[list(range(10)), [42]],
|
||||
[list(range(10)), [40, 41, 42, 43]],
|
||||
[list(range(10)), [42]],
|
||||
[list(range(10)), [43]],
|
||||
id="ssm_prefix_trim_only",
|
||||
),
|
||||
# FA partial prefix cache hit with homogeneous TP: local has 4 FA
|
||||
# blocks (prefix cached), remote has full 10. SSM equal (no trim).
|
||||
pytest.param(
|
||||
10,
|
||||
10,
|
||||
[list(range(6, 10)), [42]],
|
||||
[list(range(10)), [42]],
|
||||
[list(range(6, 10)), [42]],
|
||||
[list(range(6, 10)), [42]],
|
||||
id="fa_prefix_hit_homo_tp",
|
||||
),
|
||||
# Both: FA partial prefix hit + SSM placeholder trim.
|
||||
# local FA=[6..9] (4 blocks, prefix cached), remote FA=[0..9]
|
||||
# local SSM=[99], remote SSM=[10, 20, 99] (2 placeholders + real)
|
||||
pytest.param(
|
||||
10,
|
||||
10,
|
||||
[[6, 7, 8, 9], [99]],
|
||||
[list(range(10)), [10, 20, 99]],
|
||||
[[6, 7, 8, 9], [99]],
|
||||
[[6, 7, 8, 9], [99]],
|
||||
id="fa_prefix_hit_and_ssm_trim",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_apply_prefix_caching_ssm_prefix_cache_hit(
|
||||
local_physical_per_logical,
|
||||
remote_physical_per_logical,
|
||||
local_block_ids,
|
||||
remote_block_ids,
|
||||
expected_local,
|
||||
expected_remote,
|
||||
):
|
||||
"""_apply_prefix_caching end-trims SSM remote blocks to match the single
|
||||
local block (placeholders dropped) and end-trims FA remote blocks on
|
||||
partial prefix cache hits when physical_per_logical matches.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker import (
|
||||
NixlConnectorWorker,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
worker = object.__new__(NixlConnectorWorker)
|
||||
worker._has_mamba = True
|
||||
worker._physical_blocks_per_logical_kv_block = local_physical_per_logical
|
||||
worker._group_spec_types = (FullAttentionSpec, MambaSpec)
|
||||
worker.kv_cache_config = make_kv_cache_config(block_size=16, mamba_enabled=True)
|
||||
|
||||
aligned_local, aligned_remote = worker._apply_prefix_caching(
|
||||
local_block_ids, remote_block_ids, remote_physical_per_logical
|
||||
)
|
||||
|
||||
assert aligned_local == expected_local, (
|
||||
f"Expected local {expected_local}, got {aligned_local}"
|
||||
)
|
||||
assert aligned_remote == expected_remote, (
|
||||
f"Expected remote {expected_remote}, got {aligned_remote}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
@pytest.mark.parametrize(
|
||||
"local_physical_per_logical,remote_physical_per_logical,"
|
||||
|
||||
@@ -1072,8 +1072,8 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
not current_platform.is_cuda(),
|
||||
reason="Attention backend FLASHINFER is only supported on CUDA.",
|
||||
)
|
||||
def test_hybrid_attention_mamba_tensor_shapes():
|
||||
"""
|
||||
@@ -1508,8 +1508,8 @@ def test_is_uniform_decode() -> None:
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASHINFER is not supported on ROCm.",
|
||||
not current_platform.is_cuda(),
|
||||
reason="Attention backend FLASHINFER is only supported on CUDA.",
|
||||
)
|
||||
def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
||||
"""Test that a ValueError is raised when max_num_seqs exceeds the
|
||||
|
||||
@@ -1562,7 +1562,9 @@ def generate_legend() -> str:
|
||||
|
||||
|
||||
def generate_mla_section(
|
||||
prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]]
|
||||
prefill_backends: list[dict[str, Any]],
|
||||
decode_backends: list[dict[str, Any]],
|
||||
v4_decode_backends: list[dict[str, Any]] | None = None,
|
||||
) -> str:
|
||||
"""Generate the complete MLA section with prefill and decode tables."""
|
||||
lines = [
|
||||
@@ -1611,6 +1613,22 @@ def generate_mla_section(
|
||||
columns = _build_columns(is_mla=True, has_versions=False)
|
||||
lines.extend(_render_table(columns, decode_backends))
|
||||
|
||||
if v4_decode_backends:
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"### DeepSeek V4 Decode Backends",
|
||||
"",
|
||||
"DeepSeek V4 sparse MLA uses its own decode backends, selected via",
|
||||
"`--attention-backend=<BACKEND>` (e.g., `FLASHMLA_SPARSE_DSV4`,",
|
||||
"`FLASHINFER_MLA_SPARSE_DSV4`). They share the V4 sparse-index",
|
||||
"pipeline (compressor + SWA + indexer, 256-token blocks, head 512);",
|
||||
"default on NVIDIA is `FLASHMLA_SPARSE_DSV4`.",
|
||||
"",
|
||||
]
|
||||
)
|
||||
lines.extend(_render_table(columns, v4_decode_backends))
|
||||
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -1651,9 +1669,15 @@ def generate_docs() -> str:
|
||||
if fi_features:
|
||||
all_backends = _expand_flashinfer_variants(all_backends, fi_features)
|
||||
|
||||
# Split into MLA and non-MLA
|
||||
mla_backends = [b for b in all_backends if b["is_mla"]]
|
||||
non_mla_backends = [b for b in all_backends if not b["is_mla"]]
|
||||
# DeepSeek V4 (*_DSV4) decode backends get their own subsection rather than
|
||||
# mixing into the main MLA / standard tables (the ROCm V4 backend isn't
|
||||
# flagged is_mla by the AST heuristic, so filter purely on the name).
|
||||
def _is_v4(b: dict[str, Any]) -> bool:
|
||||
return b["name"].endswith("_DSV4")
|
||||
|
||||
v4_decode_backends = [b for b in all_backends if _is_v4(b)]
|
||||
mla_backends = [b for b in all_backends if b["is_mla"] and not _is_v4(b)]
|
||||
non_mla_backends = [b for b in all_backends if not b["is_mla"] and not _is_v4(b)]
|
||||
|
||||
# Generate documentation
|
||||
script_path = "tools/pre_commit/generate_attention_backend_docs.py"
|
||||
@@ -1703,7 +1727,9 @@ def generate_docs() -> str:
|
||||
doc_lines.append("\n>\n".join(footnotes) + "\n")
|
||||
|
||||
# Add MLA section with prefill and decode backends
|
||||
doc_lines.append(generate_mla_section(mla_prefill_backends, mla_backends))
|
||||
doc_lines.append(
|
||||
generate_mla_section(mla_prefill_backends, mla_backends, v4_decode_backends)
|
||||
)
|
||||
|
||||
return "\n".join(doc_lines)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
from .plot import SweepPlotArgs
|
||||
from .plot import main as plot_main
|
||||
|
||||
@@ -44,6 +44,24 @@ from .matcher_utils import MatcherQuantFP8
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
_IR_RMS_NORM_OP = torch.ops.vllm_ir.rms_norm.default
|
||||
_IR_FUSED_ADD_RMS_NORM_OP = torch.ops.vllm_ir.fused_add_rms_norm.default
|
||||
|
||||
|
||||
def _norm_input_weight_dtype_match(match: pm.Match) -> bool:
|
||||
"""Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma
|
||||
fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm."""
|
||||
for node in match.nodes:
|
||||
if node.target == _IR_RMS_NORM_OP:
|
||||
x, weight = node.args[0], node.args[1]
|
||||
elif node.target == _IR_FUSED_ADD_RMS_NORM_OP:
|
||||
x, weight = node.args[0], node.args[2]
|
||||
else:
|
||||
continue
|
||||
if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
|
||||
return x.meta["val"].dtype == weight.meta["val"].dtype
|
||||
return True
|
||||
|
||||
|
||||
# The empirical value for small batch
|
||||
PDL_ADVANCE_LAUNCH_TOKENS = 16
|
||||
@@ -132,6 +150,7 @@ if flashinfer_comm is not None:
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
weight_bias: float = 0.0,
|
||||
) -> None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
@@ -208,6 +227,7 @@ if flashinfer_comm is not None:
|
||||
layout_code=layout_code,
|
||||
use_oneshot=use_oneshot,
|
||||
fp32_acc=fp32_acc,
|
||||
weight_bias=weight_bias,
|
||||
trigger_completion_at_end=num_tokens > PDL_ADVANCE_LAUNCH_TOKENS,
|
||||
)
|
||||
|
||||
@@ -225,6 +245,7 @@ if flashinfer_comm is not None:
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
weight_bias: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -399,14 +420,142 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
# allreduce_in, residual
|
||||
return allreduce[1], allreduce[2]
|
||||
|
||||
# extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern.
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=_norm_input_weight_dtype_match,
|
||||
)
|
||||
|
||||
# Same pattern, but only return the output and not residual
|
||||
# (helpful for end of graph where residual is not used again)
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
extra_check=_norm_input_weight_dtype_match,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceGemmaRMSNormPattern(BasePattern):
|
||||
"""Gemma-style variant of AllReduceRMSNormPattern (no residual)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [self.empty(5, 16), self.empty(16)]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = vllm.ir.ops.rms_norm(
|
||||
allreduce_output, weight.float() + 1.0, self.epsilon
|
||||
)
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=rms_result,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
weight_bias=1.0,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
return allreduce[3], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddGemmaRMSNormPattern(BasePattern):
|
||||
"""Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
residual = self.empty(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = vllm.ir.ops.fused_add_rms_norm(
|
||||
allreduce_output, residual, weight.float() + 1.0, self.epsilon
|
||||
)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
weight_bias=1.0,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
return allreduce[1], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||
@@ -881,6 +1030,18 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceGemmaRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddGemmaRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
|
||||
@@ -36,10 +36,12 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501
|
||||
|
||||
|
||||
@@ -84,9 +84,10 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic128Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
kFp8Dynamic64Sym: torch.ops._C.per_token_group_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if hasattr(torch.ops._C, "per_token_group_fp8_quant"):
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out
|
||||
|
||||
|
||||
+19
-17
@@ -109,22 +109,24 @@ class EPLBConfig:
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
pipeline_parallel_size: int = 1
|
||||
pipeline_parallel_size: int = Field(default=1, ge=1)
|
||||
"""Number of pipeline parallel groups."""
|
||||
tensor_parallel_size: int = 1
|
||||
tensor_parallel_size: int = Field(default=1, ge=1)
|
||||
"""Number of tensor parallel groups."""
|
||||
prefill_context_parallel_size: int = 1
|
||||
prefill_context_parallel_size: int = Field(default=1, ge=1)
|
||||
"""Number of prefill context parallel groups."""
|
||||
data_parallel_size: int = 1
|
||||
data_parallel_size: int = Field(default=1, ge=1)
|
||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||
the product of the tensor parallel size and data parallel size."""
|
||||
data_parallel_size_local: int = 1
|
||||
"""Number of local data parallel groups."""
|
||||
data_parallel_rank: int = 0
|
||||
"""Rank of the data parallel group."""
|
||||
data_parallel_size_local: int = Field(default=1, ge=0)
|
||||
"""Number of local data parallel groups. A value of 0 is a sentinel used by
|
||||
the engine-args layer to signal that data parallelism was specified
|
||||
externally (see `ParallelConfig.__post_init__`)."""
|
||||
data_parallel_rank: int = Field(default=0, ge=0)
|
||||
"""Rank of the data parallel group. The runtime check at
|
||||
``__post_init__`` further bounds this by ``data_parallel_size``."""
|
||||
data_parallel_rank_local: int | None = None
|
||||
"""Local rank of the data parallel group,
|
||||
set only in SPMD mode."""
|
||||
"""Local rank of the data parallel group, set only in SPMD mode."""
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
"""IP of the data parallel master."""
|
||||
data_parallel_rpc_port: int = 29550
|
||||
@@ -184,7 +186,7 @@ class ParallelConfig:
|
||||
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
|
||||
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
max_parallel_loading_workers: int | None = Field(default=None, ge=1)
|
||||
"""Maximum number of parallel loading workers when loading model
|
||||
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||
parallel and large models."""
|
||||
@@ -197,15 +199,15 @@ class ParallelConfig:
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
ubatch_size: int = Field(default=0, ge=0)
|
||||
"""Number of ubatch size."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
dbo_decode_token_threshold: int = Field(default=32, ge=0)
|
||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||
If the number of tokens in the request is greater than this threshold,
|
||||
microbatching will be used. Otherwise, the request will be processed in a
|
||||
single batch."""
|
||||
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
||||
dbo_prefill_token_threshold: int = Field(default=512, ge=0) # TODO(lucas): tune
|
||||
"""The threshold for dual batch overlap for batches that contain one or more
|
||||
prefills. If the number of tokens in the request is greater than this
|
||||
threshold, microbatching will be used. Otherwise, the request will be
|
||||
@@ -260,10 +262,10 @@ class ParallelConfig:
|
||||
master_port: int = 29501
|
||||
"""distributed master port for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
node_rank: int = 0
|
||||
node_rank: int = Field(default=0, ge=0)
|
||||
"""distributed node rank for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
nnodes: int = 1
|
||||
nnodes: int = Field(default=1, ge=1)
|
||||
"""num of nodes for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
numa_bind: bool = False
|
||||
@@ -318,7 +320,7 @@ class ParallelConfig:
|
||||
"""Port of the coordination TCPStore. Can be set by the API server; workers
|
||||
connect as clients to exchange self-picked group ports at runtime."""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
decode_context_parallel_size: int = Field(default=1, ge=1)
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
+6
-10
@@ -771,10 +771,6 @@ class VllmConfig:
|
||||
# If no KVTransferConfig is provided, create a default one.
|
||||
if self.kv_transfer_config is None:
|
||||
self.kv_transfer_config = KVTransferConfig()
|
||||
num_kv_ranks = (
|
||||
self.parallel_config.tensor_parallel_size
|
||||
* self.parallel_config.pipeline_parallel_size
|
||||
)
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
|
||||
@@ -786,12 +782,12 @@ class VllmConfig:
|
||||
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
|
||||
)
|
||||
elif kv_offloading_backend == "lmcache":
|
||||
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
|
||||
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
|
||||
self.kv_transfer_config.kv_connector_extra_config = {
|
||||
"lmcache.local_cpu": True,
|
||||
"lmcache.max_local_cpu_size": kv_gb_per_rank,
|
||||
}
|
||||
# Default to LMCache multi-process (MP) mode. The actual KV
|
||||
# storage capacity is managed by the standalone LMCache server
|
||||
# process, so ``kv_offloading_size`` is not propagated here.
|
||||
# ``LMCacheMPConnector`` falls back to ``tcp://localhost:5555``
|
||||
# when host/port are not provided via extra_config.
|
||||
self.kv_transfer_config.kv_connector = "LMCacheMPConnector"
|
||||
|
||||
# This is the same for all backends
|
||||
self.kv_transfer_config.kv_role = "kv_both"
|
||||
|
||||
@@ -40,11 +40,7 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
from vllm.distributed.utils import is_weak_contiguous # noqa: E402
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
@@ -24,11 +24,7 @@ except Exception:
|
||||
quick_ar = False
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor):
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
from vllm.distributed.utils import is_weak_contiguous # noqa: E402, F401
|
||||
|
||||
|
||||
class QuickReduceRegime(Enum):
|
||||
|
||||
@@ -470,10 +470,14 @@ class ElasticEPScalingExecutor:
|
||||
module._replace_quant_method(module.quant_method.old_quant_method)
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
|
||||
eplb_model_state.expert_buffer = [
|
||||
torch.empty_like(w) for w in model.expert_weights[0]
|
||||
]
|
||||
eplb_model_state.communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=parallel_config.eplb_config.communicator,
|
||||
expert_weights=model.expert_weights[0],
|
||||
expert_weights=model.expert_weights,
|
||||
expert_buffer=eplb_model_state.expert_buffer,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -120,6 +120,7 @@ def transfer_run_periodically(
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Wait until all writes to expert_buffer have finished before making the
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (
|
||||
is_local_first_rank,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.distributed.utils import is_weak_contiguous
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -63,8 +64,22 @@ class EplbCommunicator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
pass
|
||||
def execute(self) -> None:
|
||||
"""Complete all enqueued transfers.
|
||||
|
||||
Some backends perform communication here; others (e.g. NIXL)
|
||||
issue transfers eagerly in add_recv and only wait here.
|
||||
On return, all data is available in the destination buffers.
|
||||
"""
|
||||
|
||||
def set_transfer_context( # noqa: B027
|
||||
self, old_indices: np.ndarray, layer_idx: int
|
||||
) -> None:
|
||||
"""Pre-set layer context before add_recv calls.
|
||||
|
||||
Default is a no-op; overridden by backends (e.g. NIXL) that need
|
||||
layer-level context to issue transfers inside add_recv.
|
||||
"""
|
||||
|
||||
@property
|
||||
def needs_profile_buffer_reservation(self) -> bool:
|
||||
@@ -125,7 +140,7 @@ class TorchDistNcclEplbCommunicator(EplbCommunicator):
|
||||
)
|
||||
)
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
def execute(self) -> None:
|
||||
if not self._p2p_ops:
|
||||
return
|
||||
try:
|
||||
@@ -168,7 +183,7 @@ class TorchDistGlooStagedEplbCommunicator(EplbCommunicator):
|
||||
for tensor in tensors:
|
||||
self._ops.append(("recv", tensor, src_rank))
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
def execute(self) -> None:
|
||||
if not self._ops:
|
||||
return
|
||||
|
||||
@@ -229,28 +244,46 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
all_expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
expert_buffer: Sequence[torch.Tensor],
|
||||
) -> None:
|
||||
assert expert_weights, "NixlEplbCommunicator requires non-empty expert_weights."
|
||||
assert all_expert_weights, (
|
||||
"NixlEplbCommunicator requires non-empty all_expert_weights."
|
||||
)
|
||||
assert expert_buffer, "NixlEplbCommunicator requires non-empty expert_buffer."
|
||||
nixl_wrapper_cls = nixl_utils.NixlWrapper
|
||||
if nixl_wrapper_cls is None:
|
||||
raise RuntimeError("NIXL/ RIXL is unavailable.")
|
||||
|
||||
self._cpu_group = cpu_group
|
||||
self._cuda_stream = cuda_stream
|
||||
self._world_size = cpu_group.size()
|
||||
self._rank = cpu_group.rank()
|
||||
# expert_id -> weight tensors to pack into the send buffer.
|
||||
self._expert_send_map: dict[int, list[torch.Tensor]] = {}
|
||||
# src_rank -> expert_id -> weight tensors to unpack after transfer.
|
||||
self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {}
|
||||
self._num_local_experts: int = expert_weights[0].shape[0]
|
||||
self._device = expert_weights[0].device
|
||||
for tensor in expert_weights:
|
||||
assert tensor.device == self._device, (
|
||||
"All local EPLB tensors are expected to be on the same device: "
|
||||
f"expected={self._device}, got={tensor.device}"
|
||||
|
||||
self._all_expert_weights = all_expert_weights
|
||||
self._expert_buffer = expert_buffer
|
||||
self._num_local_experts: int = all_expert_weights[0][0].shape[0]
|
||||
self._device = all_expert_weights[0][0].device
|
||||
|
||||
for layer_tensors in all_expert_weights:
|
||||
for tensor in layer_tensors:
|
||||
assert is_weak_contiguous(tensor), (
|
||||
"Expert weight tensors must be contiguous in memory"
|
||||
)
|
||||
assert tensor.device == self._device, (
|
||||
"All local EPLB tensors are expected to be on the same "
|
||||
f"device: expected={self._device}, got={tensor.device}"
|
||||
)
|
||||
for tensor in expert_buffer:
|
||||
assert is_weak_contiguous(tensor), (
|
||||
"expert_buffer tensors must be contiguous in memory"
|
||||
)
|
||||
|
||||
# (local_dlist, remote_dlist, xfer_handle) for in-flight READs;
|
||||
# accumulated by add_recv, drained by execute.
|
||||
self._xfer_entries: list[tuple[int, int, int]] = []
|
||||
# Per-rank expert_id -> physical row; set by set_transfer_context.
|
||||
self._expert_to_src_row: list[dict[int, int]] | None = None
|
||||
self._layer_idx: int | None = None
|
||||
|
||||
nixl_agent_config = nixl_utils.nixl_agent_config
|
||||
config = (
|
||||
@@ -260,15 +293,16 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
)
|
||||
self._nixl_wrapper = nixl_wrapper_cls(self._make_agent_name(), config)
|
||||
self._nixl_memory_type = "VRAM"
|
||||
self._registered_desc: object | None = None
|
||||
# NIXL registration handles; deregistered in __del__.
|
||||
self._registered_descs: list[object] = []
|
||||
self._remote_agents: dict[int, str] = {}
|
||||
self._remote_send_meta: dict[int, tuple[int, int]] = {}
|
||||
self._send_buffer: torch.Tensor = torch.empty(0)
|
||||
self._recv_buffer: torch.Tensor = torch.empty(0)
|
||||
self._expert_bytes: int = 0
|
||||
# peer -> (layer, tensor) -> (base_ptr, bytes_per_expert, dev_id).
|
||||
self._remote_send_meta: dict[
|
||||
int, dict[tuple[int, int], tuple[int, int, int]]
|
||||
] = {}
|
||||
|
||||
self._cuda_device_id = int(self._device.index or 0)
|
||||
self._init_step("buffers", self._init_registered_buffers, expert_weights)
|
||||
self._init_step("buffers", self._init_registered_buffers)
|
||||
self._init_step("agents", self._init_remote_agents)
|
||||
self._init_step("send meta", self._exchange_remote_send_meta)
|
||||
self._log_initialized()
|
||||
@@ -291,19 +325,34 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
uid = uuid.uuid4().hex[:8]
|
||||
return f"eplb-{self._rank}{pp_suffix}-{uid}"
|
||||
|
||||
def set_stream(self, cuda_stream: torch.cuda.Stream | None) -> None:
|
||||
pass
|
||||
|
||||
def add_send(
|
||||
self,
|
||||
tensors: list[torch.Tensor],
|
||||
dst_rank: int,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
assert dst_rank != self._rank, (
|
||||
"EPLB communicator should not enqueue same-rank sends: "
|
||||
f"rank={self._rank}, dst_rank={dst_rank}"
|
||||
# No-op: NIXL READ is receiver-initiated. The sender's expert
|
||||
# weights are pre-registered and always readable in-place.
|
||||
pass
|
||||
|
||||
def set_transfer_context(self, old_indices: np.ndarray, layer_idx: int) -> None:
|
||||
# Pre-compute expert_id -> src_row mapping for every rank so that
|
||||
# add_recv can immediately issue NIXL READs.
|
||||
assert not self._xfer_entries, (
|
||||
f"set_transfer_context() called with {len(self._xfer_entries)} "
|
||||
f"pending transfers from layer {self._layer_idx}; "
|
||||
f"execute() was not called after previous add_recv() calls"
|
||||
)
|
||||
# An expert sent to multiple peers is packed only once; skip duplicates.
|
||||
if expert_id not in self._expert_send_map:
|
||||
self._expert_send_map[expert_id] = tensors
|
||||
self._layer_idx = layer_idx
|
||||
n = self._num_local_experts
|
||||
rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n)
|
||||
self._expert_to_src_row = [
|
||||
{int(eid): i for i, eid in enumerate(row) if eid != -1}
|
||||
for row in rank_experts
|
||||
]
|
||||
|
||||
def add_recv(
|
||||
self,
|
||||
@@ -311,13 +360,44 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
src_rank: int,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
assert src_rank != self._rank, (
|
||||
"EPLB communicator should not enqueue same-rank recvs: "
|
||||
f"rank={self._rank}, src_rank={src_rank}"
|
||||
# Build NIXL descriptors and issue the RDMA READ immediately,
|
||||
# overlapping the transfer with the remaining Python loop in
|
||||
# move_to_buffer.
|
||||
assert self._expert_to_src_row is not None and self._layer_idx is not None, (
|
||||
"set_transfer_context() must be called before add_recv()"
|
||||
)
|
||||
recv_experts = self._recv_map.setdefault(src_rank, {})
|
||||
if expert_id not in recv_experts:
|
||||
recv_experts[expert_id] = tensors
|
||||
src_row = self._expert_to_src_row[src_rank][expert_id]
|
||||
layer_idx = self._layer_idx
|
||||
|
||||
local_descs: list[tuple[int, int, int]] = []
|
||||
remote_descs: list[tuple[int, int, int]] = []
|
||||
for t_idx, t in enumerate(tensors):
|
||||
send_base, send_stride, remote_dev = self._remote_send_meta[src_rank][
|
||||
(layer_idx, t_idx)
|
||||
]
|
||||
assert t.nbytes == send_stride, (
|
||||
f"tensor {t_idx} size {t.nbytes} != remote stride {send_stride}"
|
||||
)
|
||||
local_descs.append(
|
||||
(
|
||||
t.data_ptr(),
|
||||
t.nbytes,
|
||||
self._cuda_device_id,
|
||||
)
|
||||
)
|
||||
remote_descs.append(
|
||||
(
|
||||
send_base + src_row * send_stride,
|
||||
send_stride,
|
||||
remote_dev,
|
||||
)
|
||||
)
|
||||
|
||||
local_h, remote_h, xfer_h = self._create_peer_xfer(
|
||||
src_rank, local_descs, remote_descs
|
||||
)
|
||||
self._nixl_wrapper.transfer(xfer_h)
|
||||
self._xfer_entries.append((local_h, remote_h, xfer_h))
|
||||
|
||||
def _init_remote_agents(self) -> None:
|
||||
local_metadata = self._nixl_wrapper.get_agent_metadata()
|
||||
@@ -334,73 +414,60 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
peer_metadata
|
||||
)
|
||||
|
||||
def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None:
|
||||
total_bytes = max(sum(t.nbytes for t in expert_weights), 1)
|
||||
assert total_bytes % self._num_local_experts == 0, (
|
||||
f"Number of bytes in moe layer {total_bytes} is not divisible "
|
||||
f"by number of local experts {self._num_local_experts}"
|
||||
)
|
||||
self._expert_bytes = total_bytes // self._num_local_experts
|
||||
def _init_registered_buffers(self) -> None:
|
||||
all_tensors: list[torch.Tensor] = []
|
||||
for layer_tensors in self._all_expert_weights:
|
||||
all_tensors.extend(layer_tensors)
|
||||
all_tensors.extend(self._expert_buffer)
|
||||
|
||||
self._send_buffer = torch.empty(
|
||||
total_bytes, device=self._device, dtype=torch.uint8
|
||||
)
|
||||
self._recv_buffer = torch.empty(
|
||||
total_bytes, device=self._device, dtype=torch.uint8
|
||||
)
|
||||
|
||||
descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer])
|
||||
descs = self._nixl_wrapper.get_reg_descs(all_tensors)
|
||||
self._nixl_wrapper.register_memory(descs)
|
||||
self._registered_desc = descs
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
def _exchange_remote_send_meta(self) -> None:
|
||||
"""Exchange send-buffer metadata so each rank can build dynamic
|
||||
descriptors at execute time."""
|
||||
local_meta: tuple[int, int] = (
|
||||
self._send_buffer.data_ptr(),
|
||||
"""Exchange per-layer per-tensor metadata so receivers can compute
|
||||
remote RDMA addresses at transfer time."""
|
||||
local_meta: dict[tuple[int, int], tuple[int, int, int]] = {}
|
||||
for layer_idx, layer_tensors in enumerate(self._all_expert_weights):
|
||||
for t_idx, t in enumerate(layer_tensors):
|
||||
nbytes_per_expert = t.nbytes // self._num_local_experts
|
||||
local_meta[(layer_idx, t_idx)] = (
|
||||
t.data_ptr(),
|
||||
nbytes_per_expert,
|
||||
self._cuda_device_id,
|
||||
)
|
||||
gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size
|
||||
|
||||
# Per-rank map: (layer_idx, tensor_idx) -> (base_ptr, bytes_per_expert, dev_id).
|
||||
# add_recv uses base_ptr + src_row * bytes_per_expert to compute
|
||||
# the remote RDMA address for each expert.
|
||||
gathered_meta: list[dict[tuple[int, int], tuple[int, int, int]] | None] = [
|
||||
None
|
||||
] * self._world_size
|
||||
torch.distributed.all_gather_object(
|
||||
gathered_meta, local_meta, group=self._cpu_group
|
||||
)
|
||||
|
||||
local_keys = set(local_meta.keys())
|
||||
for peer in self._remote_agents:
|
||||
peer_meta = gathered_meta[peer]
|
||||
assert peer_meta is not None
|
||||
peer_keys = set(peer_meta.keys())
|
||||
if peer_keys != local_keys:
|
||||
raise RuntimeError(
|
||||
f"NIXL EPLB metadata key mismatch with rank {peer}: "
|
||||
f"local={sorted(local_keys)}, peer={sorted(peer_keys)}"
|
||||
)
|
||||
for key in local_keys:
|
||||
_, local_stride, _ = local_meta[key]
|
||||
_, peer_stride, _ = peer_meta[key]
|
||||
if local_stride != peer_stride:
|
||||
raise RuntimeError(
|
||||
f"NIXL EPLB nbytes_per_expert mismatch for {key} "
|
||||
f"with rank {peer}: "
|
||||
f"local={local_stride}, peer={peer_stride}"
|
||||
)
|
||||
self._remote_send_meta[peer] = peer_meta
|
||||
|
||||
@staticmethod
|
||||
def _pack_send_buffer(
|
||||
in_tensors: list[torch.Tensor],
|
||||
send_buffer: torch.Tensor,
|
||||
byte_offset: int,
|
||||
) -> None:
|
||||
for tensor in in_tensors:
|
||||
raw = tensor.reshape(-1).view(torch.uint8)
|
||||
if raw.numel() == 0:
|
||||
continue
|
||||
send_buffer[byte_offset : byte_offset + raw.numel()].copy_(
|
||||
raw, non_blocking=True
|
||||
)
|
||||
byte_offset += raw.numel()
|
||||
|
||||
@staticmethod
|
||||
def _unpack_recv_buffer(
|
||||
recv_buffer: torch.Tensor,
|
||||
out_tensors: list[torch.Tensor],
|
||||
byte_offset: int,
|
||||
) -> None:
|
||||
for tensor in out_tensors:
|
||||
num_bytes = tensor.numel() * tensor.element_size()
|
||||
if num_bytes == 0:
|
||||
continue
|
||||
tensor.reshape(-1).view(torch.uint8).copy_(
|
||||
recv_buffer[byte_offset : byte_offset + num_bytes],
|
||||
non_blocking=True,
|
||||
)
|
||||
byte_offset += num_bytes
|
||||
|
||||
def _wait_for_all_transfers(self, handles: list[int]) -> None:
|
||||
pending = set(handles)
|
||||
while pending:
|
||||
@@ -456,110 +523,52 @@ class NixlEplbCommunicator(EplbCommunicator):
|
||||
)
|
||||
return (local_handle, remote_handle, xfer_handle)
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
assert old_indices is not None, (
|
||||
"NixlEplbCommunicator.execute requires old_indices"
|
||||
def execute(self) -> None:
|
||||
assert self._layer_idx is not None or not self._xfer_entries, (
|
||||
"set_transfer_context() must be called before execute() "
|
||||
"if any add_recv() calls were made"
|
||||
)
|
||||
|
||||
xfer_entries: list[tuple[int, int, int]] = []
|
||||
try:
|
||||
n = self._num_local_experts
|
||||
rank_experts = old_indices[: self._world_size * n].reshape(
|
||||
self._world_size, n
|
||||
)
|
||||
# Build expert_id -> send slot mapping per rank.
|
||||
expert_to_send_slot: list[dict[int, int]] = [
|
||||
{int(eid): i for i, eid in enumerate(row) if eid != -1}
|
||||
for row in rank_experts
|
||||
]
|
||||
self._wait_for_all_transfers([x[2] for x in self._xfer_entries])
|
||||
|
||||
# Phase 1: pack each expert at its slot offset in the send buffer.
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
for expert_id, tensors in self._expert_send_map.items():
|
||||
slot = expert_to_send_slot[self._rank][expert_id]
|
||||
byte_offset = slot * self._expert_bytes
|
||||
self._pack_send_buffer(tensors, self._send_buffer, byte_offset)
|
||||
|
||||
# Ensure all packed data is visible in device memory before pulls.
|
||||
if self._cuda_stream is not None:
|
||||
self._cuda_stream.synchronize()
|
||||
else:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
# READ is receiver-initiated; synchronize all ranks before transfer.
|
||||
# We use monitored_barrier so a rank that crashes or exits early
|
||||
# produces a diagnostic timeout instead of a silent hang.
|
||||
# Post-READ barrier.
|
||||
# Correctness fence for zero-copy: prevents overwrite-while-
|
||||
# remote-read race.
|
||||
torch.distributed.monitored_barrier(
|
||||
group=self._cpu_group,
|
||||
timeout=timedelta(minutes=5),
|
||||
)
|
||||
|
||||
# Phase 2: issue one batched READ per peer.
|
||||
recv_offsets: dict[tuple[int, int], int] = {}
|
||||
recv_offset = 0
|
||||
recv_base = self._recv_buffer.data_ptr()
|
||||
for src in range(self._world_size):
|
||||
if src == self._rank:
|
||||
continue
|
||||
recv_experts = self._recv_map.get(src)
|
||||
if not recv_experts:
|
||||
continue
|
||||
expert_ids = list(recv_experts.keys())
|
||||
remote_base, remote_dev = self._remote_send_meta[src]
|
||||
local_descs: list[tuple[int, int, int]] = []
|
||||
remote_descs: list[tuple[int, int, int]] = []
|
||||
for expert_id in expert_ids:
|
||||
slot = expert_to_send_slot[src][expert_id]
|
||||
remote_off = slot * self._expert_bytes
|
||||
recv_offsets[(src, expert_id)] = recv_offset
|
||||
local_descs.append(
|
||||
(
|
||||
recv_base + recv_offset,
|
||||
self._expert_bytes,
|
||||
self._cuda_device_id,
|
||||
)
|
||||
)
|
||||
remote_descs.append(
|
||||
(remote_base + remote_off, self._expert_bytes, remote_dev)
|
||||
)
|
||||
recv_offset += self._expert_bytes
|
||||
assert recv_offset <= self._recv_buffer.nbytes
|
||||
local_h, remote_h, xfer_h = self._create_peer_xfer(
|
||||
src, local_descs, remote_descs
|
||||
)
|
||||
self._nixl_wrapper.transfer(xfer_h)
|
||||
xfer_entries.append((local_h, remote_h, xfer_h))
|
||||
|
||||
# Phase 3: wait for all in-flight transfers, then unpack.
|
||||
self._wait_for_all_transfers([x[2] for x in xfer_entries])
|
||||
|
||||
with torch.cuda.stream(self._cuda_stream):
|
||||
for (src, expert_id), offset in recv_offsets.items():
|
||||
self._unpack_recv_buffer(
|
||||
self._recv_buffer,
|
||||
self._recv_map[src][expert_id],
|
||||
offset,
|
||||
)
|
||||
finally:
|
||||
for local_h, remote_h, xfer_h in xfer_entries:
|
||||
for local_h, remote_h, xfer_h in self._xfer_entries:
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(local_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(remote_h)
|
||||
self._expert_send_map.clear()
|
||||
self._recv_map.clear()
|
||||
self._xfer_entries.clear()
|
||||
self._expert_to_src_row = None
|
||||
self._layer_idx = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._registered_desc is not None:
|
||||
self._nixl_wrapper.deregister_memory(self._registered_desc)
|
||||
self._registered_desc = None
|
||||
with contextlib.suppress(Exception):
|
||||
for local_h, remote_h, xfer_h in self._xfer_entries:
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_xfer_handle(xfer_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(local_h)
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.release_dlist_handle(remote_h)
|
||||
with contextlib.suppress(Exception):
|
||||
for descs in self._registered_descs:
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.deregister_memory(descs)
|
||||
self._registered_descs.clear()
|
||||
with contextlib.suppress(Exception):
|
||||
for agent_name in self._remote_agents.values():
|
||||
with contextlib.suppress(Exception):
|
||||
self._nixl_wrapper.remove_remote_agent(agent_name)
|
||||
self._remote_agents.clear()
|
||||
except Exception as e:
|
||||
logger.warning("Error during NixlEplbCommunicator cleanup: %s", e)
|
||||
|
||||
|
||||
class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
@@ -600,7 +609,7 @@ class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
for tensor in tensors:
|
||||
self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream)
|
||||
|
||||
def execute(self, old_indices: np.ndarray | None = None) -> None:
|
||||
def execute(self) -> None:
|
||||
if self._group_started:
|
||||
self._pynccl_comm.group_end()
|
||||
self._group_started = False
|
||||
@@ -609,7 +618,8 @@ class PyNcclEplbCommunicator(EplbCommunicator):
|
||||
def create_eplb_communicator(
|
||||
group_coordinator: GroupCoordinator,
|
||||
backend: str | None,
|
||||
expert_weights: Sequence[torch.Tensor],
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
expert_buffer: Sequence[torch.Tensor],
|
||||
) -> EplbCommunicator:
|
||||
"""Create an EPLB communicator for the given backend.
|
||||
|
||||
@@ -624,16 +634,18 @@ def create_eplb_communicator(
|
||||
``"pynccl"`` in that case. When tensors reside on CPU,
|
||||
``"torch_gloo"`` or ``"torch_nccl"`` are used via the CPU
|
||||
process group.
|
||||
expert_weights: Expert weight tensors from *one* MoE layer.
|
||||
NixlEplbCommunicator pre-allocates send/recv buffers sized
|
||||
to this layer, so all other MoE layers must have the same
|
||||
tensor count, shapes, and dtypes.
|
||||
expert_weights: Expert weight tensors for *all* MoE layers.
|
||||
Shape ``(num_layers)(num_tensors_per_layer)``.
|
||||
NixlEplbCommunicator registers all layers with NIXL for
|
||||
zero-copy RDMA reads.
|
||||
expert_buffer: Pre-allocated receive buffer tensors (one per
|
||||
weight tensor in a single layer).
|
||||
"""
|
||||
# Keep a safe default for callers that have not resolved communicator yet.
|
||||
if backend is None:
|
||||
backend = "torch_nccl"
|
||||
|
||||
tensor_device_type = expert_weights[0].device.type if expert_weights else "cpu"
|
||||
first_layer = expert_weights[0] if expert_weights else []
|
||||
tensor_device_type = first_layer[0].device.type if first_layer else "cpu"
|
||||
torch_group = (
|
||||
group_coordinator.cpu_group
|
||||
if tensor_device_type == "cpu"
|
||||
@@ -649,7 +661,7 @@ def create_eplb_communicator(
|
||||
unsupported_dtypes = sorted(
|
||||
{
|
||||
tensor.dtype
|
||||
for tensor in expert_weights
|
||||
for tensor in first_layer
|
||||
if not ncclDataTypeEnum.supports_torch_dtype(tensor.dtype)
|
||||
},
|
||||
key=str,
|
||||
@@ -704,7 +716,8 @@ def create_eplb_communicator(
|
||||
try:
|
||||
return NixlEplbCommunicator(
|
||||
cpu_group=group_coordinator.cpu_group,
|
||||
expert_weights=expert_weights,
|
||||
all_expert_weights=expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -450,7 +450,8 @@ class EplbState:
|
||||
communicator = create_eplb_communicator(
|
||||
group_coordinator=get_eplb_group(),
|
||||
backend=self.parallel_config.eplb_config.communicator,
|
||||
expert_weights=model.expert_weights[0],
|
||||
expert_weights=model.expert_weights,
|
||||
expert_buffer=expert_buffer,
|
||||
)
|
||||
|
||||
model_state = EplbModelState(
|
||||
@@ -766,6 +767,7 @@ class EplbState:
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
new_physical_to_logical_map,
|
||||
eplb_model_state.model.expert_weights,
|
||||
eplb_model_state.expert_buffer,
|
||||
ep_group,
|
||||
eplb_model_state.communicator,
|
||||
is_profile,
|
||||
|
||||
@@ -178,6 +178,7 @@ def move_to_buffer(
|
||||
cuda_stream: torch.cuda.Stream | None,
|
||||
ep_rank: int,
|
||||
communicator: EplbCommunicator,
|
||||
layer_idx: int = 0,
|
||||
) -> TransferMetadata:
|
||||
"""
|
||||
Rearranges expert weights during EPLB rebalancing.
|
||||
@@ -193,6 +194,7 @@ def move_to_buffer(
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
ep_rank: Rank of this process in expert parallel group.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
layer_idx: Index of the MoE layer being transferred.
|
||||
|
||||
Returns:
|
||||
TransferMetadata: Metadata needed for completing remote weight transfers.
|
||||
@@ -265,6 +267,8 @@ def move_to_buffer(
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
communicator.set_transfer_context(old_indices, layer_idx)
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
experts = send_expert_ids[:send_count]
|
||||
@@ -331,9 +335,8 @@ def move_to_buffer(
|
||||
expert_id=int(expert),
|
||||
)
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
communicator.execute(old_indices=old_indices)
|
||||
# wait for the communication to finish
|
||||
# 4. Execute transfers and wait for completion.
|
||||
communicator.execute()
|
||||
return TransferMetadata(
|
||||
is_unchanged=is_unchanged,
|
||||
is_received_locally=is_received_locally,
|
||||
@@ -431,6 +434,7 @@ def transfer_layer(
|
||||
is_profile: bool = False,
|
||||
cuda_stream: torch.cuda.Stream | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
layer_idx: int = 0,
|
||||
) -> TransferMetadata:
|
||||
"""
|
||||
Rearranges the expert weights in place according to the new expert indices.
|
||||
@@ -452,6 +456,7 @@ def transfer_layer(
|
||||
communications to reserve enough memory for the buffers.
|
||||
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
||||
layer_idx: Index of the MoE layer being transferred.
|
||||
|
||||
Returns:
|
||||
TransferMetadata: Metadata needed for completing remote weight transfers,
|
||||
@@ -499,6 +504,7 @@ def transfer_layer(
|
||||
cuda_stream=cuda_stream,
|
||||
ep_rank=ep_group.rank(),
|
||||
communicator=communicator,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
@@ -506,6 +512,7 @@ def rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_global_expert_indices: torch.Tensor,
|
||||
expert_weights: Sequence[Sequence[torch.Tensor]],
|
||||
expert_buffer: Sequence[torch.Tensor],
|
||||
ep_group: ProcessGroup,
|
||||
communicator: EplbCommunicator,
|
||||
is_profile: bool = False,
|
||||
@@ -524,6 +531,8 @@ def rearrange_expert_weights_inplace(
|
||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
||||
For example, a linear layer may have up and down projection,
|
||||
so weight_count = 2. Each weight's hidden size can be different.
|
||||
expert_buffer: Pre-allocated receive buffer tensors (one per
|
||||
weight tensor in a single layer).
|
||||
ep_group: The device process group for expert parallelism.
|
||||
communicator: EplbCommunicator instance for P2P communication.
|
||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||
@@ -566,10 +575,10 @@ def rearrange_expert_weights_inplace(
|
||||
# Reserve NCCL communication buffers via a dummy all_gather.
|
||||
# Backends that pre-allocate their own transfer buffers
|
||||
# skip this to avoid the extra memory spike during profiling.
|
||||
weights_buffer: list[torch.Tensor] = [
|
||||
profile_buffer: list[torch.Tensor] = [
|
||||
torch.empty_like(w) for w in first_layer_weights
|
||||
]
|
||||
for weight, buffer in zip(expert_weights[0], weights_buffer):
|
||||
for weight, buffer in zip(expert_weights[0], profile_buffer):
|
||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||
torch.distributed.barrier()
|
||||
all_gather(
|
||||
@@ -579,10 +588,7 @@ def rearrange_expert_weights_inplace(
|
||||
)
|
||||
return
|
||||
|
||||
# Buffers to hold the expert weights during the exchange.
|
||||
# NOTE: Currently we assume the same weights across different layers
|
||||
# have the same shape.
|
||||
weights_buffer = [torch.empty_like(w) for w in first_layer_weights]
|
||||
weights_buffer = list(expert_buffer)
|
||||
|
||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
|
||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
|
||||
@@ -597,6 +603,7 @@ def rearrange_expert_weights_inplace(
|
||||
cuda_stream=None,
|
||||
ep_rank=ep_rank,
|
||||
communicator=communicator,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
move_from_buffer(
|
||||
|
||||
@@ -120,14 +120,11 @@ class KVOutputAggregator:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = kv_output.kv_connector_stats
|
||||
elif kv_connector_stats := kv_output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(
|
||||
aggregated_kv_connector_stats, type(kv_connector_stats)
|
||||
)
|
||||
aggregated_kv_connector_stats = (
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
aggregated_kv_connector_stats = aggregated_kv_connector_stats.aggregate(
|
||||
kv_connector_stats
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_worker_meta from all workers.
|
||||
|
||||
@@ -2333,9 +2333,25 @@ class NixlConnectorWorker:
|
||||
for i, remote_group in enumerate(remote_block_ids):
|
||||
num_local_blocks = len(local_block_ids[i])
|
||||
num_remote_blocks = len(remote_group)
|
||||
if _is_ssm_spec(self._group_spec_types[i]):
|
||||
assert num_local_blocks == num_remote_blocks
|
||||
if (
|
||||
_is_ssm_spec(self._group_spec_types[i])
|
||||
and num_local_blocks < num_remote_blocks
|
||||
):
|
||||
# NOTE (NickLucche): With prefix caching on SSM, (remote) blocks
|
||||
# prior to the last one are placeholders (null blocks). Mind that
|
||||
# this doesn't really impact transfer, as we only still care about
|
||||
# the last "block", the full in-place state.
|
||||
assert num_local_blocks == 1, "SSM can only have one local block"
|
||||
remote_block_ids[i] = remote_group[-num_local_blocks:]
|
||||
elif (
|
||||
self._physical_blocks_per_logical_kv_block
|
||||
== remote_physical_per_logical
|
||||
and num_local_blocks < num_remote_blocks
|
||||
):
|
||||
# Partial prefix cache hit for FA group.
|
||||
remote_block_ids[i] = remote_group[-num_local_blocks:]
|
||||
else:
|
||||
# TODO Handle prefix caching with different block_sizes
|
||||
max_padding = max(
|
||||
self._physical_blocks_per_logical_kv_block,
|
||||
remote_physical_per_logical,
|
||||
|
||||
@@ -1270,9 +1270,6 @@ def get_dcp_group() -> GroupCoordinator:
|
||||
return _DCP
|
||||
|
||||
|
||||
# kept for backward compatibility
|
||||
get_context_model_parallel_group = get_dcp_group
|
||||
|
||||
_PP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@@ -1840,31 +1837,6 @@ def model_parallel_is_initialized():
|
||||
_TP_STATE_PATCHED = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
||||
"""Patch the tp group temporarily until this function ends.
|
||||
|
||||
This method is for draft workers of speculative decoding to run draft model
|
||||
with different tp degree from that of target model workers.
|
||||
|
||||
Args:
|
||||
tp_group (GroupCoordinator): the tp group coordinator
|
||||
"""
|
||||
global _TP_STATE_PATCHED
|
||||
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
|
||||
|
||||
_TP_STATE_PATCHED = True
|
||||
old_tp_group = get_tp_group()
|
||||
global _TP
|
||||
_TP = tp_group
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# restore the original state
|
||||
_TP_STATE_PATCHED = False
|
||||
_TP = old_tp_group
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size() -> int:
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
return get_tp_group().world_size
|
||||
@@ -1875,16 +1847,6 @@ def get_tensor_model_parallel_rank() -> int:
|
||||
return get_tp_group().rank_in_group
|
||||
|
||||
|
||||
def get_decode_context_model_parallel_world_size() -> int:
|
||||
"""Return world size for the decode context model parallel group."""
|
||||
return get_dcp_group().world_size
|
||||
|
||||
|
||||
def get_decode_context_model_parallel_rank() -> int:
|
||||
"""Return my rank for the decode context model parallel group."""
|
||||
return get_dcp_group().rank_in_group
|
||||
|
||||
|
||||
def get_node_count() -> int:
|
||||
"""Return the total number of nodes in the distributed environment."""
|
||||
assert _NODE_COUNT is not None, "distributed environment is not initialized"
|
||||
|
||||
@@ -64,6 +64,20 @@ def divide(numerator, denominator):
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def is_weak_contiguous(inp: torch.Tensor) -> bool:
|
||||
"""Check that *inp* occupies a single contiguous block of memory.
|
||||
|
||||
Unlike ``torch.Tensor.is_contiguous()``, this also accepts tensors
|
||||
whose strides are not strictly C-contiguous (e.g. column-major) as
|
||||
long as the underlying storage from the tensor's offset onward is
|
||||
exactly ``numel * element_size`` bytes.
|
||||
"""
|
||||
return inp.is_contiguous() or (
|
||||
inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
|
||||
== inp.numel() * inp.element_size()
|
||||
)
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
|
||||
@@ -17,9 +17,9 @@ from vllm.entrypoints.anthropic.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.utils import (
|
||||
from vllm.entrypoints.serve.utils.api_utils import (
|
||||
load_aware_call,
|
||||
validate_json_request,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicUsage,
|
||||
)
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest,
|
||||
@@ -45,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.utils.request_logger import RequestLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
|
||||
@@ -22,7 +22,7 @@ import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.entrypoints.serve.utils.api_utils import with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
@@ -7,7 +7,7 @@ import typing
|
||||
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
from vllm.entrypoints.serve.utils.api_utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user