Merge branch 'main' into wentao-fp8-scaled-mm-oddM

This commit is contained in:
Wentao Ye
2026-06-05 11:25:58 -04:00
committed by GitHub
226 changed files with 4981 additions and 2390 deletions
+13 -12
View File
@@ -28,18 +28,19 @@ steps:
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
- label: CPU-Compatibility Tests
depends_on: []
device: intel_cpu
no_plugin: true
source_file_dependencies:
- cmake/cpu_extension.cmake
- setup.py
- vllm/platforms/cpu.py
commands:
- |
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh"
# Note: SDE can't be downloaded from CI host because of AWS WAF
# - label: CPU-Compatibility Tests
# depends_on: []
# device: intel_cpu
# no_plugin: true
# source_file_dependencies:
# - cmake/cpu_extension.cmake
# - setup.py
# - vllm/platforms/cpu.py
# commands:
# - |
# bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
# bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh"
- label: CPU-Language Generation and Pooling Model Tests
depends_on: []
+2 -1
View File
@@ -40,7 +40,8 @@ steps:
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 &&
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 &&
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel &&
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192 &&
VLLM_XPU_FUSED_MOE_USE_REF=1 python3 examples/basic/offline_inference/generate.py --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --enforce-eager -tp 2
'
- label: "XPU V1 test"
depends_on:
-12
View File
@@ -36,18 +36,6 @@ pull_request_rules:
For future commits, `pre-commit` will run automatically on changed files before each commit.
> [!TIP]
> <details>
> <summary>Is <code>mypy</code> failing?</summary>
> <br/>
> <code>mypy</code> is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
>
> ```bash
> # For mypy (substitute "3.10" with the failing version if needed)
> pre-commit run --hook-stage manual mypy-3.10
> ```
> </details>
- name: comment-dco-failure
description: Comment on PR when DCO check fails
conditions:
+7 -13
View File
@@ -148,33 +148,27 @@ repos:
language: python
entry: python tools/pre_commit/generate_nightly_torch_test.py
files: ^requirements/test/cuda\.(in|txt)$
- id: mypy-local
name: Run mypy locally for lowest supported Python version
entry: python tools/pre_commit/mypy.py 0 "3.10"
stages: [pre-commit] # Don't run in CI
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: python tools/pre_commit/mypy.py "3.10"
<<: &mypy_common
language: python
types_or: [python, pyi]
require_serial: true
additional_dependencies: ["mypy[faster-cache]==1.19.1", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: python tools/pre_commit/mypy.py 1 "3.10"
<<: *mypy_common
stages: [manual] # Only run in CI
additional_dependencies: ["mypy==1.20.2", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11
entry: python tools/pre_commit/mypy.py 1 "3.11"
entry: python tools/pre_commit/mypy.py "3.11"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12
entry: python tools/pre_commit/mypy.py 1 "3.12"
entry: python tools/pre_commit/mypy.py "3.12"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.13
entry: python tools/pre_commit/mypy.py 1 "3.13"
entry: python tools/pre_commit/mypy.py "3.13"
<<: *mypy_common
stages: [manual] # Only run in CI
- id: shellcheck
+3 -1
View File
@@ -98,11 +98,13 @@ pre-commit run --all-files
pre-commit run ruff-check --all-files
# Run mypy as it is in CI:
pre-commit run mypy-3.10 --all-files --hook-stage manual
pre-commit run mypy-3.12 --all-files --hook-stage manual
```
The line length limit for Python code is 88 characters. If you are not sure, use pre-commit to check.
Use [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) (`Args:`/`Returns:`/`Raises:` sections), not reStructuredText/Sphinx fields (`:param:`, `:return:`, `:rtype:`).
### Commit messages
Add attribution using commit trailers such as `Co-authored-by:` (other projects use `Assisted-by:` or `Generated-by:`). For example:
+7 -11
View File
@@ -307,10 +307,7 @@ endif()
#
set(VLLM_EXT_SRC
"csrc/cuda_view.cu"
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
@@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC
"csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
@@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
#
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/libtorch_stable/cuda_view.cu"
"csrc/libtorch_stable/activation_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
@@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
"csrc/libtorch_stable/layernorm_kernels.cu"
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
"csrc/libtorch_stable/attention/merge_attn_states.cu"
"csrc/libtorch_stable/sampler.cu"
"csrc/libtorch_stable/topk.cu"
@@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
"csrc/cuda_utils_kernels.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/cuda_utils_kernels.cu"
"csrc/libtorch_stable/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
@@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
# This ensures we only use C-shim APIs available in PyTorch 2.10.
# This ensures we only use C-shim APIs available in PyTorch 2.11.
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
# which is currently set to 2.10.
# which is currently set to 2.11.
target_compile_definitions(_C_stable_libtorch PRIVATE
TORCH_TARGET_VERSION=0x020A000000000000ULL)
TORCH_TARGET_VERSION=0x020B000000000000ULL)
# Needed to use cuda/hip APIs from C-shim
if(VLLM_GPU_LANG STREQUAL "CUDA")
-59
View File
@@ -1,59 +0,0 @@
#include <torch/all.h>
#include <torch/cuda.h>
#include <cuda_runtime.h>
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
// handle empty tensor
if (cpu_tensor.numel() == 0) {
return torch::empty(cpu_tensor.sizes(),
cpu_tensor.options().device(torch::kCUDA));
}
if (cpu_tensor.is_pinned()) {
// If CPU tensor is pinned, directly get the device pointer.
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
TORCH_CHECK(err == cudaSuccess,
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
return torch::from_blob(
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
cpu_tensor.options().device(torch::kCUDA));
}
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
size_t nbytes = contiguous_cpu.nbytes();
void* host_ptr = nullptr;
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
if (err != cudaSuccess) {
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
}
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
cudaMemcpyDefault);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
}
void* device_ptr = nullptr;
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
}
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
contiguous_cpu.strides(), deleter,
contiguous_cpu.options().device(torch::kCUDA));
}
+76
View File
@@ -0,0 +1,76 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/headeronly/core/ScalarType.h>
#include <torch/csrc/stable/device.h>
#include <torch/csrc/stable/c/shim.h>
#include <torch/headeronly/version.h>
#include <cuda_runtime.h>
#include <array>
#include <optional>
// This function assumes that `cpu_tensor` is a CPU tensor,
// and that UVA (Unified Virtual Addressing) is enabled.
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
torch::stable::Tensor& cpu_tensor) {
STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
const auto dtype = cpu_tensor.scalar_type();
const auto layout = cpu_tensor.layout();
const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA);
// handle empty tensor
if (cpu_tensor.numel() == 0) {
return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev);
}
std::array<StableIValue, 2> is_pinned_stack{
torch::stable::detail::from(cpu_tensor),
torch::stable::detail::from(std::nullopt)};
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
"aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION));
if (torch::stable::detail::to<bool>(is_pinned_stack[0])) {
// If CPU tensor is pinned, directly get the device pointer.
void* host_ptr = const_cast<void*>(cpu_tensor.mutable_data_ptr());
void* device_ptr = nullptr;
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ",
cudaGetErrorString(err));
return torch::stable::from_blob(
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype,
[base = cpu_tensor](void*) {}); // keep cpu tensor alive
}
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor);
size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size();
void* host_ptr = nullptr;
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err));
}
err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes,
cudaMemcpyDefault);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err));
}
void* device_ptr = nullptr;
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
if (err != cudaSuccess) {
cudaFreeHost(host_ptr);
STD_TORCH_CHECK(
false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
}
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(),
contiguous_cpu.strides(), cuda_dev,
contiguous_cpu.scalar_type(), deleter);
}
@@ -1,4 +1,4 @@
#include "cutlass_extensions/common.hpp"
#include "common.hpp"
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
+11
View File
@@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
#endif
// CPU tensor -> CUDA UVA view (shared CUDA/ROCm)
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
torch::stable::Tensor& cpu_tensor);
// Attention kernels (shared CUDA/ROCm)
void merge_attn_states(
torch::stable::Tensor& output,
@@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out,
std::optional<torch::stable::Tensor> residual,
int64_t group_size, bool is_scale_transposed);
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor& scales,
int64_t group_size,
std::optional<torch::stable::Tensor> scale_ub,
bool is_scale_transposed);
// Positional encoding kernels (shared CUDA/ROCm)
void rotary_embedding(torch::stable::Tensor& positions,
torch::stable::Tensor& query,
@@ -18,7 +18,7 @@
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -21,7 +21,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/mixed_dtype_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include <cuda_runtime.h>
@@ -12,7 +12,7 @@
#include <cutlass/arch/arch.h>
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
@@ -20,7 +20,7 @@
#include <cutlass/arch/arch.h>
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
@@ -18,7 +18,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "cutlass/cutlass.h"
@@ -1,11 +1,10 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../torch_utils.h"
#include "../../dispatch_utils.h"
#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh"
#include "quant_conversions.cuh"
namespace vllm {
@@ -105,60 +104,66 @@ __global__ void silu_and_mul_per_block_quant_kernel(
} // namespace vllm
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
torch::stable::Tensor const& input,
torch::stable::Tensor& scales,
int64_t group_size,
std::optional<torch::stable::Tensor> scale_ub,
bool is_scale_transposed) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
static torch::headeronly::ScalarType kFp8Type =
is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn
: torch::headeronly::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
TORCH_CHECK(
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
STD_TORCH_CHECK(out.scalar_type() == kFp8Type ||
out.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Input must be FP16 or BF16");
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
TORCH_CHECK(group_size == 128 || group_size == 64,
STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size);
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
STD_TORCH_CHECK(out.scalar_type() == kFp8Type);
}
int32_t hidden_size = out.size(-1);
auto num_tokens = input.size(0);
int32_t num_groups = hidden_size / group_size;
TORCH_CHECK(input.size(-1) == hidden_size * 2,
STD_TORCH_CHECK(input.size(-1) == hidden_size * 2,
"input last dim must be 2x output hidden_size");
TORCH_CHECK(hidden_size % group_size == 0,
STD_TORCH_CHECK(hidden_size % group_size == 0,
"hidden_size must be divisible by group_size");
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
dim3 grid(num_tokens, num_groups);
dim3 block(group_size);
VLLM_DISPATCH_FLOATING_TYPES(
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_in_t = scalar_t;
VLLM_DISPATCH_QUANT_TYPES(
VLLM_STABLE_DISPATCH_QUANT_TYPES(
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
using scalar_out_t = scalar_t;
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
VLLM_STABLE_DISPATCH_BOOL(
is_scale_transposed, transpose_scale, [&] {
vllm::silu_and_mul_per_block_quant_kernel<
scalar_in_t, scalar_out_t, transpose_scale, gs>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_out_t>(),
scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
out.mutable_data_ptr<scalar_out_t>(),
scales.mutable_data_ptr<float>(),
input.const_data_ptr<scalar_in_t>(),
scale_ub.has_value()
? scale_ub->const_data_ptr<float>()
: nullptr,
hidden_size);
});
@@ -20,7 +20,7 @@
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
namespace vllm::c3x {
@@ -15,7 +15,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
/*
@@ -1,7 +1,7 @@
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::stable::Tensor& c,
@@ -8,7 +8,7 @@
#include <torch/csrc/stable/ops.h>
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
using namespace cute;
@@ -23,7 +23,7 @@
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;
@@ -4,7 +4,7 @@
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "libtorch_stable/cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
+33
View File
@@ -1,4 +1,5 @@
#include "ops.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
@@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
#ifndef USE_ROCM
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
#endif
@@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor? scale_ub, Tensor!? residual, int group_size, "
"bool is_scale_transposed) -> ()");
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
@@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("rms_norm_dynamic_per_token_quant",
TORCH_BOX(&rms_norm_dynamic_per_token_quant));
ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant));
ops.impl("silu_and_mul_per_block_quant",
TORCH_BOX(&silu_and_mul_per_block_quant));
// Positional encoding kernels (shared CUDA/ROCm)
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
@@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
}
STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
ops.impl("get_cuda_view_from_cpu_tensor",
TORCH_BOX(&get_cuda_view_from_cpu_tensor));
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
@@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
}
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
}
STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
cuda_utils) {
cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
}
// Cache ops
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
// Swap in (out) the cache blocks from src to dst.
-8
View File
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void silu_and_mul_per_block_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scales, int64_t group_size,
std::optional<torch::Tensor> scale_ub,
bool is_scale_transposed);
// rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable
// ABI for CUDA). It remains here because the CPU build still uses these
// torch::Tensor declarations.
@@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table, double scale);
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp);
-31
View File
@@ -2,7 +2,6 @@
// cache.h, which is no longer included here after cache ops moved to
// _C_stable_libtorch).
#include <torch/all.h>
#include "cuda_utils.h"
#include "ops.h"
#include "core/registration.h"
#include <torch/library.h>
@@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
&get_cuda_view_from_cpu_tensor);
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Fused SiLU+Mul + per-block quantization
ops.def(
"silu_and_mul_per_block_quant("
"Tensor! out, "
"Tensor input, "
"Tensor! scales, "
"int group_size, "
"Tensor? scale_ub=None, "
"bool is_scale_transposed=False) -> ()");
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
&silu_and_mul_per_block_quant);
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
@@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
}
#endif
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
// Gets the specified device attribute.
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.impl("get_device_attribute", &get_device_attribute);
// Gets the maximum shared memory per block device attribute.
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
+1 -1
View File
@@ -296,7 +296,7 @@ llm = LLM(model="Qwen/Qwen3-8B")
The `fastokens` Python package (>= 0.2.0) must be installed; if it isn't,
vLLM raises a clear `ImportError` at tokenizer load. The override applies to
any `--tokenizer-mode` that ends up loading an HF fast tokenizer (`hf`,
`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Modes that don't use the HF
`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Models that don't use the HF
fast tokenizer (`mistral`, `grok2`, `kimi_audio`) ignore the flag.
Tokenizer-bound workloads — long shared prefixes, bursty short prompts,
+1 -1
View File
@@ -101,7 +101,7 @@ vLLM's `pre-commit` hooks will now run automatically every time you commit.
Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with:
```bash
pre-commit run --hook-stage manual mypy-3.10
pre-commit run --hook-stage manual mypy-3.11
```
### Documentation
+1 -1
View File
@@ -128,7 +128,7 @@ The lease mechanism is controlled through `kv_connector_extra_config` in `--kv-t
vllm serve <MODEL> \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_both",
"kv_role": "kv_producer",
"kv_connector_extra_config": {"kv_lease_duration": 60}
}'
```
+19 -9
View File
@@ -50,7 +50,7 @@ To select a different backend, set `kv_connector_extra_config.backends` in `--kv
vllm serve <MODEL> \
--kv-transfer-config '{
"kv_connector":"NixlConnector",
"kv_role":"kv_both",
"kv_role":"kv_producer",
"kv_connector_extra_config":{"backends":["LIBFABRIC"]}
}'
```
@@ -60,7 +60,7 @@ You can also pass JSON keys individually using dotted arguments, and you can app
```bash
vllm serve <MODEL> \
--kv-transfer-config.kv_connector NixlConnector \
--kv-transfer-config.kv_role kv_both \
--kv-transfer-config.kv_role kv_producer \
--kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC
```
@@ -81,7 +81,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
vllm serve Qwen/Qwen3-0.6B \
--port 8100 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}'
```
### Consumer (Decoder) Configuration
@@ -96,7 +96,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
vllm serve Qwen/Qwen3-0.6B \
--port 8200 \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}'
```
### Proxy Server
@@ -212,10 +212,21 @@ sequenceDiagram
Enable bidirectional KV transfer by setting `bidirectional_kv_xfer` in `kv_connector_extra_config` on **both** P and D instances:
```bash
# Prefill instance
vllm serve <MODEL> \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_both",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"bidirectional_kv_xfer": true
}
}'
# Decode instance
vllm serve <MODEL> \
--kv-transfer-config '{
"kv_connector": "NixlConnector",
"kv_role": "kv_consumer",
"kv_connector_extra_config": {
"bidirectional_kv_xfer": true
}
@@ -359,11 +370,10 @@ For multi-host DP deployment, only need to provide the host/port of the head ins
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
- **kv_both** (deprecated): Previously used as a catch-all when the role was not predetermined. This value is now deprecated for NixlConnector and will be removed in a future release.
!!! tip
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
!!! warning
`kv_role="kv_both"` is deprecated for NixlConnector. Please set `kv_role="kv_producer"` for prefill instances and `kv_role="kv_consumer"` for decode instances. See [#33702](https://github.com/vllm-project/vllm/issues/33702) for details.
### KV Load Failure Policy
+1 -1
View File
@@ -169,7 +169,7 @@ speculative decoding, breaking down the guarantees into three key areas:
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
> provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](/tests/v1/spec_decode).
> provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](../../../tests/v1/spec_decode).
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
3. **vLLM Logprob Stability**
+9 -9
View File
@@ -83,22 +83,22 @@ plugins:
- "re:vllm\\._.*" # Internal modules
- "vllm.third_party"
- "vllm.vllm_flash_attn"
- "re:vllm\\.grpc\\..*_pb2.*" # Auto-generated protobuf files
- "vllm.transformers_utils.configs"
- "vllm.transformers_utils.processors"
- !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default
- mkdocstrings:
handlers:
python:
options:
show_symbol_type_heading: true
show_symbol_type_toc: true
filters:
- "!.*_pb2_grpc" # Exclude auto-generated gRPC stubs
summary:
modules: true
show_signature_annotations: true
separate_signature: true
filters: []
show_overloads: true
signature_crossrefs: true
# Recommendations from api-autonav
docstring_section_style: list
parameter_headings: true
show_symbol_type_heading: true
show_symbol_type_toc: true
summary: true
inventories:
- https://docs.python.org/3/objects.inv
- https://typing-extensions.readthedocs.io/en/latest/objects.inv
+1 -1
View File
@@ -32,7 +32,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.17.0
mistral_common[image] >= 1.11.2
mistral_common[image] >= 1.11.3
opencv-python-headless >= 4.13.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
+1 -1
View File
@@ -31,7 +31,7 @@ torchaudio==2.11.0
torchvision==0.26.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless >= 4.13.0 # required for video test
+1 -1
View File
@@ -409,7 +409,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.11.2
mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/cuda.in
+1 -1
View File
@@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.13.0 # required for video test
datamodel_code_generator # required for minicpm3 test
+1 -1
View File
@@ -30,7 +30,7 @@ tblib # for pickling test exceptions
timm>=1.0.17 # required for internvl and gemma3n-mm test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio]>=1.11.2 # required for voxtral test
mistral_common[image,audio]>=1.11.3 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
opencv-python-headless>=4.13.0 # required for video test
+1 -1
View File
@@ -512,7 +512,7 @@ mcp==1.27.0
# via -r requirements/test/../common.txt
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.11.2
mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/../common.txt
+1 -1
View File
@@ -266,7 +266,7 @@ mbstrdecoder==1.1.4
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.11.2
mistral-common==1.11.3
# via
# -c requirements/common.txt
# -r requirements/test/xpu.in
+55 -2
View File
@@ -38,13 +38,17 @@ impl HfChatBackend {
) -> Result<Self> {
let model_config = load_model_config(files.config_path.as_deref())?;
let model_type = model_config.model_type().unwrap_or_default();
let multimodal_model_info = MultimodalModelInfo::from_paths(
let multimodal_model_info = if options.language_model_only {
None
} else {
MultimodalModelInfo::from_paths(
model_id.clone(),
(!model_type.is_empty()).then_some(model_type.to_string()),
files.config_path.as_deref(),
files.preprocessor_config_path.as_deref(),
tokenizer.clone(),
)?;
)?
};
let multimodal_render_info = resolve_multimodal_render_info(multimodal_model_info.as_ref());
let renderer = options.renderer.resolve(model_type);
@@ -225,6 +229,7 @@ mod tests {
"test-model".to_string(),
LoadModelBackendsOptions {
renderer,
language_model_only: false,
chat_template_content_format: Default::default(),
chat_template: None,
default_chat_template_kwargs: HashMap::new(),
@@ -267,6 +272,54 @@ mod tests {
assert_eq!(prompt, "hello");
}
#[test]
fn language_model_only_skips_multimodal_preprocessor_config() {
let mut files = resolved_files(
r#"{"model_type":"deepseek_v0_vl"}"#,
r#"{"chat_template":"{{ messages[0].content }}"}"#,
);
let preprocessor_config_path = files
.config_path
.as_ref()
.unwrap()
.parent()
.unwrap()
.join("preprocessor_config.json");
write_json(&preprocessor_config_path, r#"{"size":[672,672]}"#);
files.preprocessor_config_path = Some(preprocessor_config_path);
let backend = HfChatBackend::from_resolved_model_files(
files.clone(),
"test-model".to_string(),
LoadModelBackendsOptions {
language_model_only: true,
chat_template_content_format: Default::default(),
chat_template: None,
default_chat_template_kwargs: HashMap::new(),
..Default::default()
},
test_tokenizer(),
)
.unwrap();
assert!(backend.multimodal_model_info().is_none());
let error = HfChatBackend::from_resolved_model_files(
files,
"test-model".to_string(),
LoadModelBackendsOptions {
chat_template_content_format: Default::default(),
chat_template: None,
default_chat_template_kwargs: HashMap::new(),
..Default::default()
},
test_tokenizer(),
)
.err()
.expect("invalid preprocessor config should fail without language_model_only");
assert!(error.to_string().contains("failed to parse preprocessor_config.json"));
}
#[test]
fn explicit_deepseek_renderer_overrides_generic_model_type() {
let prompt = render_prompt(
+3
View File
@@ -60,6 +60,9 @@ pub type DynChatTextBackend = Arc<dyn ChatTextBackend>;
pub struct LoadModelBackendsOptions {
/// Which chat renderer implementation to use.
pub renderer: RendererSelection,
/// Disable frontend-side multimodal preprocessing and render the model as
/// language-only.
pub language_model_only: bool,
/// How to serialize `message.content` when rendering the chat template.
pub chat_template_content_format: ChatTemplateContentFormatOption,
/// Optional server-default chat template override, provided either as an
+7
View File
@@ -116,6 +116,10 @@ pub struct SharedRuntimeArgs {
#[arg(long = "tokenizer-mode", default_value_t)]
#[serde(default, rename = "tokenizer_mode")]
pub renderer: RendererSelection,
/// Disable multimodal inputs and treat the model as language-only.
#[arg(long)]
#[serde(default)]
pub language_model_only: bool,
/// Override the maximum model context length. When set, the frontend uses
/// this value instead of the model's `max_position_embeddings` from
/// `config.json`.
@@ -243,6 +247,7 @@ impl SharedRuntimeArgs {
tool_call_parser: self.tool_call_parser,
reasoning_parser: self.reasoning_parser,
renderer: self.renderer,
language_model_only: self.language_model_only,
chat_template: self.chat_template,
default_chat_template_kwargs: self.default_chat_template_kwargs,
chat_template_content_format: self.chat_template_content_format,
@@ -284,6 +289,7 @@ impl SharedRuntimeArgs {
tool_call_parser: self.tool_call_parser,
reasoning_parser: self.reasoning_parser,
renderer: self.renderer,
language_model_only: self.language_model_only,
chat_template: self.chat_template,
default_chat_template_kwargs: self.default_chat_template_kwargs,
chat_template_content_format: self.chat_template_content_format,
@@ -419,6 +425,7 @@ impl ServeArgs {
self.managed_engine.clone().into_config(
self.runtime.model.clone(),
self.runtime.max_model_len,
self.runtime.language_model_only,
handshake_port,
)
}
+8 -1
View File
@@ -34,6 +34,7 @@ fn serve_args_forward_python_flags_with_separator() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
language_model_only: false,
max_model_len: Some(
512,
),
@@ -263,6 +264,7 @@ fn frontend_args_accept_json() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
language_model_only: false,
max_model_len: None,
grpc_port: None,
shutdown_timeout: 0,
@@ -321,7 +323,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() {
"--output-address",
"ipc:///tmp/output.sock",
"--args-json",
r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","max_model_len":8192,"shutdown_timeout":3}"#,
r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","language_model_only":true,"max_model_len":8192,"shutdown_timeout":3}"#,
])
.unwrap();
@@ -338,6 +340,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() {
ParserSelection::Explicit("qwen3_thinking".to_string())
);
assert_eq!(args.runtime.renderer, RendererSelection::DeepSeekV32);
assert!(args.runtime.language_model_only);
assert_eq!(args.runtime.max_model_len, Some(8192));
assert_eq!(args.runtime.shutdown_timeout, 3);
}
@@ -662,6 +665,7 @@ fn serve_args_accept_handshake_aliases() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
language_model_only: false,
max_model_len: None,
grpc_port: None,
shutdown_timeout: 0,
@@ -783,6 +787,7 @@ fn serve_frontend_config_uses_dp_address_as_advertised_host() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: Auto,
@@ -846,6 +851,7 @@ fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() {
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: Auto,
@@ -924,6 +930,7 @@ fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present
tool_call_parser: Auto,
reasoning_parser: Auto,
renderer: Auto,
language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: Auto,
+20 -7
View File
@@ -1,5 +1,4 @@
use std::collections::BTreeMap;
use std::slice;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
@@ -253,8 +252,18 @@ pub(crate) async fn run_abort_loop(
inner: Arc<ClientInner>,
mut abort_rx: mpsc::UnboundedReceiver<AbortRequest>,
) {
// TODO: receive and abort requests in batch
while let Some(AbortRequest { request_id, cause }) = abort_rx.recv().await {
// Coalesce bursts of auto-aborts into a single Abort message per engine.
// A dropped-stream storm (e.g. many clients disconnecting at once under
// high concurrency) would otherwise issue one engine round-trip per
// request. `recv_many` returns as soon as at least one item is ready, so a
// lone abort is still forwarded promptly.
const MAX_DRAIN: usize = 1024;
let mut batch: Vec<AbortRequest> = Vec::new();
while abort_rx.recv_many(&mut batch, MAX_DRAIN).await > 0 {
let mut by_engine: BTreeMap<EngineId, Vec<String>> = BTreeMap::new();
for AbortRequest { request_id, cause } in batch.drain(..) {
let Some(engine_id) = inner.take_auto_abort_target(&request_id) else {
debug!(request_id, "skip auto-abort for inactive request");
continue;
@@ -272,17 +281,21 @@ pub(crate) async fn run_abort_loop(
}
}
if let Err(error) = inner.do_abort_requests(&engine_id, slice::from_ref(&request_id)).await
{
by_engine.entry(engine_id).or_default().push(request_id);
}
for (engine_id, request_ids) in by_engine {
if let Err(error) = inner.do_abort_requests(&engine_id, &request_ids).await {
warn!(
request_id,
?engine_id,
?request_ids,
error = %error.as_report(),
"failed to auto-abort dropped request stream"
"failed to auto-abort request streams"
);
}
}
}
}
/// Background loop that listens for engine-core outputs and dispatches them to
/// the corresponding request streams based on their `request_id`.
@@ -1225,6 +1225,86 @@ async fn dropping_a_live_stream_triggers_abort() {
client.shutdown().await.unwrap();
}
#[tokio::test]
async fn dropping_multiple_live_streams_aborts_all_in_a_burst() {
init_tracing();
let ipc = IpcNamespace::new().unwrap();
let handshake_address = ipc.handshake_endpoint();
let engine_id = b"engine-burst".to_vec();
let request_ids = ["req-1", "req-2", "req-3"];
let (shutdown_tx, engine_task) = spawn_mock_engine_task(
handshake_address.clone(),
engine_id.clone(),
|dealer, push| {
Box::pin(async move {
for _ in 0..3 {
let add = recv_engine_message(dealer).await;
assert_eq!(add[0].as_ref(), &[0x00]);
}
send_outputs(
push,
EngineCoreOutputs {
outputs: vec![
request_output("req-1", vec![99], None),
request_output("req-2", vec![99], None),
request_output("req-3", vec![99], None),
],
..Default::default()
},
)
.await;
let abort =
timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap();
assert_eq!(abort[0].as_ref(), &[0x01]);
let ids: Vec<String> = rmp_serde::from_slice(&abort[1]).unwrap();
assert_eq!(
ids,
vec![
"req-1".to_string(),
"req-2".to_string(),
"req-3".to_string()
]
);
assert!(
timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err()
);
})
},
);
let client = connect_client_with_ipc(
handshake_test_config(
handshake_address,
1,
"test-model",
Duration::from_secs(2),
0,
None,
),
&ipc,
)
.await;
// Open every request first so all three adds reach the engine before it
// emits outputs, then drain the first token from each stream.
let mut streams = Vec::new();
for id in request_ids {
streams.push(client.call(sample_request_with_id(id)).await.unwrap());
}
for stream in streams.iter_mut() {
let first = timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap();
assert_eq!(first.new_token_ids, vec![99]);
}
// Drop the whole burst back-to-back so the abort worker can batch them.
drop(streams);
let _ = shutdown_tx.send(());
engine_task.await.unwrap();
client.shutdown().await.unwrap();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dispatcher_failure_propagates_to_streams_and_future_calls() {
init_tracing();
+4
View File
@@ -71,6 +71,7 @@ impl ManagedEngineArgs {
self,
model: String,
max_model_len: Option<u32>,
language_model_only: bool,
handshake_port: u16,
) -> ManagedEngineConfig {
let mut python_args = self.python_args;
@@ -79,6 +80,9 @@ impl ManagedEngineArgs {
python_args.push("--max-model-len".to_string());
python_args.push(max_model_len.to_string());
}
if language_model_only {
python_args.push("--language-model-only".to_string());
}
if let Some(data_parallel_size_local) = self.data_parallel_size_local {
python_args.push("--data-parallel-size-local".to_string());
python_args.push(data_parallel_size_local.to_string());
@@ -64,6 +64,7 @@ async fn main() -> Result<()> {
tool_call_parser: ParserSelection::Auto,
reasoning_parser: ParserSelection::Auto,
renderer: RendererSelection::Auto,
language_model_only: false,
chat_template: None,
default_chat_template_kwargs: None,
chat_template_content_format: ChatTemplateContentFormatOption::Auto,
+3
View File
@@ -53,6 +53,9 @@ pub struct Config {
pub reasoning_parser: ParserSelection,
/// Chat renderer selection.
pub renderer: RendererSelection,
/// Disable frontend-side multimodal preprocessing and render the model as
/// language-only.
pub language_model_only: bool,
/// Server-default chat template override, as a file path or inline
/// template.
pub chat_template: Option<String>,
+1
View File
@@ -42,6 +42,7 @@ async fn build_state(config: &Config) -> Result<Arc<AppState>> {
&config.model,
LoadModelBackendsOptions {
renderer: config.renderer,
language_model_only: config.language_model_only,
chat_template: config.chat_template.clone(),
chat_template_content_format: config.chat_template_content_format,
default_chat_template_kwargs: config
@@ -85,6 +85,7 @@ pub async fn chat_completions(
log_request,
prepared.include_usage,
prepared.requested_logprobs,
prepared.include_reasoning,
prepared.echo,
prepared.return_token_ids,
prepared.return_tokens_as_token_ids,
@@ -100,6 +101,7 @@ pub async fn chat_completions(
created,
prepared.requested_logprobs,
prepared.include_prompt_logprobs,
prepared.include_reasoning,
prepared.echo,
prepared.return_token_ids,
prepared.return_tokens_as_token_ids,
@@ -134,6 +136,7 @@ async fn collect_chat_completion(
created: u64,
requested_logprobs: bool,
include_prompt_logprobs: bool,
include_reasoning: bool,
echo: Option<String>,
return_token_ids: bool,
return_tokens_as_token_ids: bool,
@@ -157,6 +160,11 @@ async fn collect_chat_completion(
} = collected;
let stop_reason = finish_reason.as_stop_reason().map(stop_reason_to_json);
let saw_tool_calls = message.tool_calls().next().is_some();
let reasoning = message.reasoning();
// Output logprobs and token IDs cover the complete generated token stream.
// When reasoning is hidden, omit them rather than leaking hidden reasoning
// tokens through per-token metadata.
let include_output_metadata = include_reasoning || reasoning.is_none();
let finish_reason = chat_finish_reason_to_openai(&finish_reason, saw_tool_calls)?.to_string();
let tool_calls = message
.tool_calls()
@@ -169,7 +177,7 @@ async fn collect_chat_completion(
},
})
.collect::<Vec<_>>();
let logprobs = if requested_logprobs {
let logprobs = if requested_logprobs && include_output_metadata {
Some(decoded_logprobs_to_openai_chat(
logprobs.as_ref().ok_or_else(|| {
server_error!("chat response requested logprobs but generation returned none")
@@ -207,12 +215,12 @@ async fn collect_chat_completion(
None => Some(message.text()).filter(|t| !t.is_empty()),
},
tool_calls: Some(tool_calls).filter(|calls| !calls.is_empty()),
reasoning: message.reasoning(),
reasoning: if include_reasoning { reasoning } else { None },
},
logprobs,
finish_reason: Some(finish_reason),
stop_reason,
token_ids: return_token_ids.then_some(token_ids),
token_ids: (return_token_ids && include_output_metadata).then_some(token_ids),
}],
usage: Some(usage),
system_fingerprint: None,
@@ -232,12 +240,18 @@ async fn chat_completion_chunk_stream(
log_request: bool,
include_usage: bool,
requested_logprobs: bool,
include_reasoning: bool,
echo: Option<String>,
return_token_ids: bool,
return_tokens_as_token_ids: bool,
mut y: TryYielder<ChatCompletionStreamResponse, ApiError>,
) -> Result<(), ApiError> {
let mut saw_tool_calls = false;
// `LogprobsDelta` is emitted after all chat events for one decoded update.
// If that update contains hidden reasoning, including delimiter-only block
// starts or ends, omit its token metadata as well as its visible delta.
let mut inside_hidden_reasoning = false;
let mut suppress_current_update_metadata = false;
// If the client requested logprobs or token_ids, we need to buffer chunks until
// we receive the separate `LogprobsDelta` event, so that we can emit one
@@ -268,6 +282,9 @@ async fn chat_completion_chunk_stream(
}
}
Ok(ChatEvent::BlockDelta { kind, delta, .. }) => {
let include_delta =
include_reasoning || !matches!(kind, AssistantBlockKind::Reasoning);
if include_delta {
if let Some(pending_chunk) = pending_chunk.as_mut() {
pending_chunk.push_block_delta(kind, delta);
} else {
@@ -280,17 +297,29 @@ async fn chat_completion_chunk_stream(
))
.await;
}
} else {
suppress_current_update_metadata = true;
}
}
Ok(ChatEvent::LogprobsDelta {
logprobs,
token_ids,
}) => {
let openai_logprobs = logprobs
let include_metadata =
!suppress_current_update_metadata && !inside_hidden_reasoning;
suppress_current_update_metadata = false;
let openai_logprobs = if include_metadata {
logprobs
.as_ref()
.map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids))
.transpose()?;
let openai_token_ids =
return_token_ids.then_some(token_ids).filter(|t| !t.is_empty());
.transpose()?
} else {
None
};
let openai_token_ids = include_metadata
.then_some(token_ids)
.and_then(|token_ids| return_token_ids.then_some(token_ids))
.filter(|t| !t.is_empty());
if let Some(pending_chunk) = pending_chunk.as_mut() {
pending_chunk.logprobs = openai_logprobs;
pending_chunk.token_ids = openai_token_ids;
@@ -311,9 +340,17 @@ async fn chat_completion_chunk_stream(
}
Ok(ChatEvent::BlockStart { kind, .. }) => {
debug!(?kind, "starting new block");
if !include_reasoning && matches!(kind, AssistantBlockKind::Reasoning) {
inside_hidden_reasoning = true;
suppress_current_update_metadata = true;
}
}
Ok(ChatEvent::BlockEnd { .. }) => {
debug!("ending current block");
if inside_hidden_reasoning {
inside_hidden_reasoning = false;
suppress_current_update_metadata = true;
}
}
Ok(ChatEvent::ToolCallStart { index, id, name }) => {
let tool_index = index as u32;
@@ -763,7 +800,9 @@ fn stop_reason_to_json(stop_reason: &StopReason) -> Value {
mod tests {
use futures::{StreamExt as _, stream};
use serde_json::json;
use vllm_chat::{AssistantBlockKind, AssistantToolCall, ChatEvent, FinishReason};
use vllm_chat::{
AssistantBlockKind, AssistantContentBlock, AssistantToolCall, ChatEvent, FinishReason,
};
use vllm_engine_core_client::protocol::StopReason;
use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTokenLogprob};
@@ -895,6 +934,7 @@ mod tests {
false,
false,
true,
true,
None,
false,
false,
@@ -958,6 +998,7 @@ mod tests {
false,
false,
true,
true,
None,
false,
false,
@@ -976,6 +1017,292 @@ mod tests {
assert!(chunks[1].choices[0].logprobs.is_some());
}
#[tokio::test]
async fn chunk_stream_omits_reasoning_delta_when_disabled() {
let stream = stream::iter(vec![
Ok(ChatEvent::Start {
prompt_token_ids: vec![].into(),
prompt_logprobs: None,
}),
Ok(ChatEvent::BlockDelta {
index: 0,
kind: AssistantBlockKind::Reasoning,
delta: "think".to_string(),
}),
Ok(ChatEvent::BlockDelta {
index: 1,
kind: AssistantBlockKind::Text,
delta: "answer".to_string(),
}),
Ok(ChatEvent::Done {
message: Default::default(),
prompt_token_count: 1,
output_token_count: 2,
finish_reason: FinishReason::stop_eos(),
kv_transfer_params: None,
}),
]);
let chunks = chat_completion_chunk_stream(
stream,
"chatcmpl-1".to_string(),
"model".to_string(),
1,
false,
false,
false,
false,
None,
false,
false,
)
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.expect("stream chunks");
assert_eq!(chunks.len(), 3);
assert_eq!(
chunks[1].choices[0].delta.content.as_deref(),
Some("answer")
);
assert!(
chunks
.iter()
.all(|chunk| chunk.choices.iter().all(|choice| choice.delta.reasoning.is_none()))
);
}
#[tokio::test]
async fn chunk_stream_omits_logprobs_for_suppressed_reasoning() {
let stream = stream::iter(vec![
Ok(ChatEvent::Start {
prompt_token_ids: vec![].into(),
prompt_logprobs: None,
}),
Ok(ChatEvent::BlockDelta {
index: 0,
kind: AssistantBlockKind::Reasoning,
delta: "think".to_string(),
}),
Ok(ChatEvent::LogprobsDelta {
logprobs: Some(DecodedLogprobs {
positions: vec![DecodedPositionLogprobs {
entries: vec![DecodedTokenLogprob {
token_id: 11,
token: "think".to_string(),
logprob: -0.1,
rank: 1,
}],
}],
}),
token_ids: vec![11],
}),
Ok(ChatEvent::BlockDelta {
index: 1,
kind: AssistantBlockKind::Text,
delta: "answer".to_string(),
}),
Ok(ChatEvent::LogprobsDelta {
logprobs: Some(DecodedLogprobs {
positions: vec![DecodedPositionLogprobs {
entries: vec![DecodedTokenLogprob {
token_id: 22,
token: "answer".to_string(),
logprob: -0.2,
rank: 1,
}],
}],
}),
token_ids: vec![22],
}),
Ok(ChatEvent::Done {
message: Default::default(),
prompt_token_count: 1,
output_token_count: 2,
finish_reason: FinishReason::stop_eos(),
kv_transfer_params: None,
}),
]);
let chunks = chat_completion_chunk_stream(
stream,
"chatcmpl-1".to_string(),
"model".to_string(),
1,
false,
false,
true,
false,
None,
true,
false,
)
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.expect("stream chunks");
assert_eq!(chunks.len(), 3);
let choice = &chunks[1].choices[0];
assert_eq!(choice.delta.content.as_deref(), Some("answer"));
assert_eq!(choice.token_ids.as_deref(), Some(&[22][..]));
let logprobs = choice.logprobs.as_ref().expect("answer logprobs");
let content = logprobs.content.as_ref().expect("logprobs content");
assert_eq!(content[0].token, "answer");
assert!(chunks.iter().all(|chunk| {
chunk.choices.iter().all(|choice| {
choice.delta.reasoning.is_none()
&& choice.token_ids.as_deref() != Some(&[11][..])
&& choice
.logprobs
.as_ref()
.and_then(|logprobs| logprobs.content.as_ref())
.is_none_or(|content| content.iter().all(|entry| entry.token != "think"))
})
}));
}
#[tokio::test]
async fn chunk_stream_omits_logprobs_for_hidden_reasoning_delimiters() {
let stream = stream::iter(vec![
Ok(ChatEvent::Start {
prompt_token_ids: vec![].into(),
prompt_logprobs: None,
}),
Ok(ChatEvent::BlockStart {
index: 0,
kind: AssistantBlockKind::Reasoning,
}),
Ok(ChatEvent::LogprobsDelta {
logprobs: Some(DecodedLogprobs {
positions: vec![DecodedPositionLogprobs {
entries: vec![DecodedTokenLogprob {
token_id: 11,
token: "<think>".to_string(),
logprob: -0.1,
rank: 1,
}],
}],
}),
token_ids: vec![11],
}),
Ok(ChatEvent::BlockDelta {
index: 0,
kind: AssistantBlockKind::Reasoning,
delta: "think".to_string(),
}),
Ok(ChatEvent::LogprobsDelta {
logprobs: Some(DecodedLogprobs {
positions: vec![DecodedPositionLogprobs {
entries: vec![DecodedTokenLogprob {
token_id: 12,
token: "think".to_string(),
logprob: -0.2,
rank: 1,
}],
}],
}),
token_ids: vec![12],
}),
Ok(ChatEvent::BlockEnd {
index: 0,
block: AssistantContentBlock::Reasoning {
text: "think".to_string(),
},
}),
Ok(ChatEvent::LogprobsDelta {
logprobs: Some(DecodedLogprobs {
positions: vec![DecodedPositionLogprobs {
entries: vec![DecodedTokenLogprob {
token_id: 13,
token: "</think>".to_string(),
logprob: -0.3,
rank: 1,
}],
}],
}),
token_ids: vec![13],
}),
Ok(ChatEvent::BlockStart {
index: 1,
kind: AssistantBlockKind::Text,
}),
Ok(ChatEvent::BlockDelta {
index: 1,
kind: AssistantBlockKind::Text,
delta: "answer".to_string(),
}),
Ok(ChatEvent::LogprobsDelta {
logprobs: Some(DecodedLogprobs {
positions: vec![DecodedPositionLogprobs {
entries: vec![DecodedTokenLogprob {
token_id: 22,
token: "answer".to_string(),
logprob: -0.4,
rank: 1,
}],
}],
}),
token_ids: vec![22],
}),
Ok(ChatEvent::Done {
message: Default::default(),
prompt_token_count: 1,
output_token_count: 4,
finish_reason: FinishReason::stop_eos(),
kv_transfer_params: None,
}),
]);
let chunks = chat_completion_chunk_stream(
stream,
"chatcmpl-1".to_string(),
"model".to_string(),
1,
false,
false,
true,
false,
None,
true,
false,
)
.collect::<Vec<_>>()
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
.expect("stream chunks");
assert_eq!(chunks.len(), 3);
let choice = &chunks[1].choices[0];
assert_eq!(choice.delta.content.as_deref(), Some("answer"));
assert_eq!(choice.token_ids.as_deref(), Some(&[22][..]));
let logprobs = choice.logprobs.as_ref().expect("answer logprobs");
let content = logprobs.content.as_ref().expect("logprobs content");
assert_eq!(content[0].token, "answer");
assert!(chunks.iter().all(|chunk| {
chunk.choices.iter().all(|choice| {
choice.delta.reasoning.is_none()
&& !choice
.token_ids
.as_ref()
.is_some_and(|ids| matches!(ids.as_slice(), [11] | [12] | [13]))
&& choice
.logprobs
.as_ref()
.and_then(|logprobs| logprobs.content.as_ref())
.is_none_or(|content| {
content.iter().all(|entry| {
!matches!(entry.token.as_str(), "<think>" | "think" | "</think>")
})
})
})
}));
}
#[tokio::test]
async fn chunk_stream_preserves_tool_call_index_and_omits_id_from_arguments_delta() {
let stream = stream::iter(vec![
@@ -1017,6 +1344,7 @@ mod tests {
false,
false,
false,
true,
None,
false,
false,
@@ -29,6 +29,8 @@ pub struct PreparedRequest {
pub requested_logprobs: bool,
/// Whether the caller requested top-level prompt logprobs.
pub include_prompt_logprobs: bool,
/// Whether to include reasoning content in OpenAI responses.
pub include_reasoning: bool,
/// Lowered chat request for `vllm-chat`.
pub chat_request: ChatRequest,
/// Last assistant-role message content to echo back when `echo=true`.
@@ -57,6 +59,7 @@ pub(crate) fn prepare_chat_request(
.as_ref()
.map(|request| request.lora_name.clone())
.unwrap_or_else(|| lora_resolution.model_names.first().cloned().unwrap_or_default());
let include_reasoning = request.include_reasoning;
let echo = request
.echo
.then(|| extract_last_assistant_content(&request.messages))
@@ -146,6 +149,7 @@ pub(crate) fn prepare_chat_request(
include_usage,
requested_logprobs,
include_prompt_logprobs,
include_reasoning,
chat_request,
echo,
return_token_ids: request.return_token_ids.unwrap_or(false),
@@ -480,6 +484,23 @@ mod tests {
assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto);
}
#[test]
fn prepare_chat_request_preserves_include_reasoning_false() {
let request = ChatCompletionRequest {
include_reasoning: false,
..base_request()
};
let prepared = prepare_chat_request(
request,
&served(&["Qwen/Qwen1.5-0.5B-Chat"]),
ResolvedRequestContext::default(),
)
.expect("request is valid");
assert!(!prepared.include_reasoning);
}
#[test]
fn prepare_chat_request_preserves_sampling_passthrough_fields() {
let request = ChatCompletionRequest {
@@ -120,12 +120,6 @@ pub(super) fn validate_request_compat(
"thinking_token_budget",
"thinking_token_budget is not supported.",
)?;
if !request.include_reasoning {
bail_invalid_request!(
param = "include_reasoning",
"include_reasoning is not supported."
);
}
reject_non_default(
request.media_io_kwargs.as_ref(),
"media_io_kwargs",
@@ -312,6 +306,17 @@ mod tests {
.expect("reasoning_effort should be accepted");
}
#[test]
fn validate_request_compat_accepts_include_reasoning_false() {
let request = ChatCompletionRequest {
include_reasoning: false,
..base_request()
};
validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"]))
.expect("include_reasoning=false should be accepted");
}
#[test]
fn validate_request_compat_rejects_top_logprobs_without_logprobs() {
let request = ChatCompletionRequest {
+164
View File
@@ -3492,6 +3492,170 @@ async fn reasoning_blocks_are_mapped_to_reasoning_sse_chunks() {
assert!(text.contains("\"content\":\"answer\""), "{text}");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn include_reasoning_false_suppresses_reasoning_in_non_stream_chat() {
let (app, engine_task) = test_app_with_backend_and_stream_output_specs(
Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")),
vec![
(bytes_to_token_ids(b"<think>think</think>"), None),
(
bytes_to_token_ids(b"answer"),
Some(EngineCoreFinishReason::Length),
),
],
)
.await;
let response = app
.clone()
.call(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
json!({
"model": "Qwen/Qwen1.5-0.5B-Chat",
"stream": false,
"include_reasoning": false,
"messages": [{"role": "user", "content": "hello"}]
})
.to_string(),
))
.expect("build request"),
)
.await
.expect("call app");
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body");
engine_task.await.expect("mock engine task");
let text = String::from_utf8(body.to_vec()).expect("utf8 body");
let json: serde_json::Value = serde_json::from_str(&text).expect("decode json");
assert_eq!(json["choices"][0]["message"]["content"], "answer");
assert!(
json["choices"][0]["message"]
.as_object()
.is_some_and(|message| !message.contains_key("reasoning")),
"{text}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn include_reasoning_false_suppresses_non_stream_output_metadata() {
let ipc = IpcNamespace::new().expect("create ipc namespace");
let handshake_address = ipc.handshake_endpoint();
let engine_id = b"engine-openai-hidden-reasoning-logprobs".to_vec();
let engine_task = MockEngineTask::new(spawn_mock_engine_task(
handshake_address.clone(),
engine_id.clone(),
|dealer, push| {
boxed_test_future(async move {
let add = recv_engine_message(dealer).await;
let request: EngineCoreRequest =
rmp_serde::from_slice(&add[1]).expect("decode request");
let reasoning_token_ids = bytes_to_token_ids(b"<think>think</think>");
let answer_token_ids = bytes_to_token_ids(b"answer");
send_outputs(
push,
EngineCoreOutputs {
engine_index: 0,
outputs: vec![
request_output_with_logprobs(
&request.request_id,
reasoning_token_ids.clone(),
None,
None,
Some(sample_logprobs_for_tokens(&reasoning_token_ids)),
None,
),
request_output_with_logprobs(
&request.request_id,
answer_token_ids.clone(),
Some(EngineCoreFinishReason::Length),
None,
Some(sample_logprobs_for_tokens(&answer_token_ids)),
None,
),
],
scheduler_stats: None,
timestamp: 0.0,
utility_output: None,
finished_requests: None,
wave_complete: None,
start_wave: None,
},
)
.await;
})
},
));
let client = EngineCoreClient::connect(
EngineCoreClientConfig::new_single(handshake_address)
.with_model_name("test-model")
.with_local_input_output_addresses(
Some(ipc.input_endpoint()),
Some(ipc.output_endpoint()),
),
)
.await
.expect("connect client");
let chat = ChatLlm::from_shared_backend(
test_llm(client),
Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")),
);
let mut app = build_router(Arc::new(AppState::new(
vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()],
chat,
)));
let response = app
.call(
Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
json!({
"model": "Qwen/Qwen1.5-0.5B-Chat",
"stream": false,
"include_reasoning": false,
"logprobs": true,
"return_token_ids": true,
"messages": [{"role": "user", "content": "hello"}]
})
.to_string(),
))
.expect("build request"),
)
.await
.expect("call app");
assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body");
engine_task.await.expect("mock engine task");
let text = String::from_utf8(body.to_vec()).expect("utf8 body");
let json: serde_json::Value = serde_json::from_str(&text).expect("decode json");
let choice = json["choices"][0].as_object().expect("choice object");
assert_eq!(json["choices"][0]["message"]["content"], "answer");
assert!(
json["choices"][0]["message"]
.as_object()
.is_some_and(|message| !message.contains_key("reasoning")),
"{text}"
);
assert!(!choice.contains_key("logprobs"), "{text}");
assert!(!choice.contains_key("token_ids"), "{text}");
assert!(json["prompt_token_ids"].is_array(), "{text}");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial]
async fn tool_calls_are_mapped_to_tool_call_sse_chunks() {
+25
View File
@@ -115,6 +115,8 @@ impl<T: Tokenizer + ?Sized> IncrementalDecoder for DecodeStream<'_, T> {
fn next_chunk(&mut self) -> Option<String> {
let cutoff = self.cumulative_output.len().saturating_sub(self.min_bytes_to_buffer);
// Ensure we split at a utf-8 char boundary.
let cutoff = self.cumulative_output.floor_char_boundary(cutoff);
(cutoff > self.output_index).then(|| {
let chunk = self.cumulative_output[self.output_index..cutoff].to_string();
self.output_index = cutoff;
@@ -356,4 +358,27 @@ mod tests {
assert_eq!(last_chunk.as_deref(), Some("lo!"));
assert_eq!(full_text, "Hello!");
}
#[test]
fn next_chunk_cutoff_respects_char_boundary() {
// Regression: next_chunk's cutoff (len - min_bytes_to_buffer) must be
// aligned to a UTF-8 char boundary like push_token/flush; otherwise
// streaming multi-byte output (CJK/emoji) with a hold-back buffer (set
// by a stop string) panics slicing cumulative_output mid-character.
let backend = Utf8Backend;
let mut decoder = backend.create_decode_stream(&[], false, 2);
let mut out = String::new();
for byte in "你好A".bytes() {
decoder.push_token(u32::from(byte)).unwrap();
if let Some(chunk) = decoder.next_chunk() {
out.push_str(&chunk);
}
}
let (last_chunk, full_text) = decoder.flush(None).unwrap();
if let Some(chunk) = last_chunk {
out.push_str(&chunk);
}
assert_eq!(full_text, "你好A");
assert_eq!(out, "你好A");
}
}
+2
View File
@@ -565,6 +565,8 @@ def test_size_used_in_multiple_consumer_subgraphs():
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(y, 0)
torch.compile(model_fn, backend=capturing_backend)(x, y)
assert captured_graph is not None, "Graph should be captured by backend"
assert captured_inputs is not None, "Example inputs should be captured by backend"
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
+7
View File
@@ -15,6 +15,7 @@ import pytest
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config.parallel import ParallelConfig
from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables
@@ -379,6 +380,12 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
update_environment_variables(env)
local_rank = int(env["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group(backend="nccl")
use_workspace = env.get("USE_WORKSPACE") == "1"
if use_workspace:
+13
View File
@@ -9,6 +9,7 @@ import pytest
import torch
import torch.distributed
import vllm.envs as envs
from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
@@ -82,6 +83,13 @@ def test_pynccl():
@worker_fn_wrapper
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
# Eager-init path: parent PG has bound_device_id + a CPU backend,
# so split_group is supported.
group = torch.distributed.split_group(
split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
@@ -339,6 +347,11 @@ def test_pynccl_send_recv():
@worker_fn_wrapper
def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
group = torch.distributed.split_group(
split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
@@ -9,6 +9,7 @@ import ray
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.quick_all_reduce import (
@@ -397,12 +398,26 @@ def qr_variable_input(rank, world_size):
ranks = []
for i in range(world_size):
ranks.append(i)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
device_id=device,
)
else:
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
)
else:
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
handle = ops.qr_get_handle(_ptr)
+233
View File
@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for split_group in GroupCoordinator.
These tests verify that:
1. split_group is used for both device and CPU group creation.
2. Multiple subgroups work correctly with split_group.
3. Both GPU and CPU all-reduce work on split groups.
"""
import os
from typing import Any
import multiprocess as mp
import pytest
import torch
import torch.distributed
import vllm.envs as envs
from vllm.distributed.parallel_state import (
GroupCoordinator,
init_distributed_environment,
)
from vllm.utils.system_utils import update_environment_variables
# The whole module exercises the split_group code path, which is opt-in
# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
pytestmark = pytest.mark.skipif(
not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
)
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
env["LOCAL_RANK"] = str(i)
env["WORLD_SIZE"] = str(number_of_processes)
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12346"
# propagate the opt-in flag to the spawned child workers
env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.accelerator.set_device_index(device)
init_distributed_environment()
fn()
return wrapped_fn
def _verify_device_group(coordinator: GroupCoordinator):
"""Verify device group works via all-reduce."""
local_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{local_rank}")
tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
torch.distributed.all_reduce(tensor, group=coordinator.device_group)
torch.accelerator.synchronize()
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"Device group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)
def _verify_cpu_group(coordinator: GroupCoordinator):
"""Verify CPU group works via all-reduce."""
tensor = torch.ones(16, dtype=torch.float32)
torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"CPU group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)
# ---------------------------------------------------------------------------
# Test 1: Basic split_group path with 2 GPUs
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_basic_worker():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
group_ranks = [list(range(world_size))]
coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_basic",
)
_verify_device_group(coordinator)
_verify_cpu_group(coordinator)
@pytest.mark.skipif(
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs to run the test.",
)
def test_split_group_basic():
"""Test basic GroupCoordinator creation with split_group."""
distributed_run(split_group_basic_worker, 2)
# ---------------------------------------------------------------------------
# Test 2: Multiple subgroups with split_group (4 GPUs)
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_multiple_subgroups_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]
coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_multi",
)
assert coordinator.world_size == 2
_verify_device_group(coordinator)
_verify_cpu_group(coordinator)
if rank in [0, 1]:
assert coordinator.ranks == [0, 1]
else:
assert coordinator.ranks == [2, 3]
@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_multiple_subgroups():
"""Test GroupCoordinator with multiple independent subgroups."""
distributed_run(split_group_multiple_subgroups_worker, 4)
# ---------------------------------------------------------------------------
# Test 3: split_group contract — every parent rank must enter with the same
# ``split_ranks``. NCCL happens to produce
# correct subgroups for disjoint partitions because the wrapper hashes
# ``my_group`` to derive the comm-split color, but the contract violation is
# real and would break under non-partition / non-NCCL backends. This test
# captures the actual ``split_ranks`` argument passed on every rank and
# asserts they match.
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_contract_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]
captured: list[list[list[int]]] = []
original_split_group = torch.distributed.split_group
def capturing_split_group(*args, split_ranks=None, **kwargs):
captured.append([list(g) for g in split_ranks])
return original_split_group(*args, split_ranks=split_ranks, **kwargs)
torch.distributed.split_group = capturing_split_group
try:
GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_contract",
)
finally:
torch.distributed.split_group = original_split_group
# GroupCoordinator builds two subgroups (device + cpu) per coordinator,
# so every rank must have made exactly two split_group calls.
if len(captured) != 2:
raise AssertionError(
f"rank {rank} expected 2 split_group calls (device + cpu), "
f"got {len(captured)}: {captured}"
)
world_size = torch.distributed.get_world_size()
for call_idx in range(2):
gathered: list[Any] = [None] * world_size
torch.distributed.all_gather_object(gathered, captured[call_idx])
# Normalize for stable comparison (sort each subgroup and the outer list).
norm = [
sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
]
reference = norm[0]
for r, args in enumerate(norm):
if args != reference:
raise AssertionError(
f"split_group contract violation on call #{call_idx}: "
f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
f"passed split_ranks={gathered[0]}. PyTorch requires every "
"parent rank to enter split_group with the same split_ranks."
)
@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_contract_same_split_ranks_on_all_ranks():
"""All parent ranks must call torch.distributed.split_group with the same
``split_ranks`` argument. This catches the bug where each rank passed
only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
disjoint partitions but is a documented contract violation.
"""
distributed_run(split_group_contract_worker, 4)
+14 -1
View File
@@ -5,12 +5,25 @@
import os
import random
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_world_group
# Let PyTorch choose the WORLD backend for the current device type.
# By default, let PyTorch choose the WORLD backend for the current device
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()
# Create prompts
+14 -1
View File
@@ -5,12 +5,25 @@
import os
import random
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_tp_group, get_world_group
# Let PyTorch choose the WORLD backend for the current device type.
# By default, let PyTorch choose the WORLD backend for the current device
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()
# Create prompts
@@ -808,14 +808,20 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
"logprobs": False,
}
chat_completion = await client.chat.completions.create(**request_args)
# Use raw HTTP for both endpoints so we compare server responses
# directly, without the openai SDK injecting extra fields
# (e.g. `moderation` added in newer SDK versions).
chat_response = requests.post(
server.url_for("v1/chat/completions"), json=request_args
)
chat_response.raise_for_status()
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
chat_output = chat_completion.model_dump()
chat_output = chat_response.json()
invocation_output = invocation_response.json()
assert chat_output.keys() == invocation_output.keys()
+19 -11
View File
@@ -5,8 +5,16 @@ import pytest
import torch
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DEVICE = current_platform.device_type
pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
reason="Lightning attention Triton kernels require CUDA/ROCm or XPU.",
)
NUM_HEADS = [4, 8]
HEAD_SIZES = [64]
BATCH_SIZES = [1, 2]
@@ -121,7 +129,7 @@ def test_linear_decode_forward_triton(
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.set_default_device(DEVICE)
set_random_seed(42)
base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
@@ -129,16 +137,16 @@ def test_linear_decode_forward_triton(
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
slope_rate = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.arange(batch_size, device="cuda")
slot_idx = torch.arange(batch_size, device=DEVICE)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
@@ -162,7 +170,7 @@ def test_linear_decode_forward_triton_with_padding(
head_size: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.set_default_device(DEVICE)
set_random_seed(42)
batch_size = 4
@@ -172,16 +180,16 @@ def test_linear_decode_forward_triton_with_padding(
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
kv_caches = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_caches_copy = kv_caches.clone()
slope_rate = torch.zeros(num_heads, device="cuda")
slope_rate = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
slope_rate[h] = 0.1 * (h + 1)
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
slot_idx = torch.tensor([0, 1, -1, 2], device=DEVICE)
triton_output = linear_decode_forward_triton(
q, k, v, kv_caches, slope_rate, slot_idx
@@ -224,7 +232,7 @@ def test_lightning_attention_reference(
seq_len: int,
dtype: torch.dtype,
):
torch.set_default_device("cuda")
torch.set_default_device(DEVICE)
set_random_seed(42)
base = 0.01
@@ -232,12 +240,12 @@ def test_lightning_attention_reference(
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
ed = torch.zeros(num_heads, device="cuda")
ed = torch.zeros(num_heads, device=DEVICE)
for h in range(num_heads):
ed[h] = 0.1 * (h + 1)
kv_history = base * torch.randn(
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
)
kv_history_clone = kv_history.clone()
@@ -10,6 +10,7 @@ import torch
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
cleanup_dist_env_and_memory,
@@ -60,7 +61,15 @@ def _set_vllm_config(
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
)
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[list(range(world_size))],
group_desc="moe_test_cpu",
)
else:
cpu_group = torch.distributed.new_group(
list(range(world_size)), backend="gloo"
)
return cpu_group
+13 -5
View File
@@ -205,7 +205,10 @@ def run_with_expert_maps(
w2 = kwargs["w2"]
a = kwargs["hidden_states"]
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
max_num_tokens=kwargs.get("hidden_states").shape[0],
experts_per_token=kwargs.get("topk_ids").shape[1],
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
in_dtype=a.dtype,
@@ -258,23 +261,27 @@ def run_8_bit(
a1_scale=None,
)
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
kwargs = {
"hidden_states": moe_tensors.a,
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
"global_num_experts": num_experts,
"activation": MoEActivation.SILU,
"expert_map": None,
"apply_router_weight_on_input": False,
}
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
moe_config = make_dummy_moe_config(
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
max_num_tokens=moe_tensors.a.shape[0],
experts_per_token=topk_ids.shape[1],
num_experts=num_experts,
num_local_experts=num_local_experts,
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
per_out_channel,
False,
topk_weights,
None,
)
workspace13.random_()
@@ -14,6 +14,7 @@ import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
@@ -375,6 +376,12 @@ def _test_deepep_deepgemm_moe(
w1_scale = w1_scale.to(device=device)
w2_scale = w2_scale.to(device=device)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
pg = torch.distributed.split_group(
split_ranks=[list(range(pgi.world_size))],
group_desc="deepep_deepgemm_test",
)
else:
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, pgi.rank)
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
+7
View File
@@ -10,6 +10,7 @@ import pytest
import torch.distributed
from torch.distributed import ProcessGroup
import vllm.envs as envs
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
@@ -375,6 +376,12 @@ def _deep_ep_moe(
w1_scale = w1_scale.to(device=device_idx)
w2_scale = w2_scale.to(device=device_idx)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
pg = torch.distributed.split_group(
split_ranks=[list(range(pgi.world_size))],
group_desc="deepep_test",
)
else:
pg = torch.distributed.new_group(list(range(pgi.world_size)))
test_tensors = TestTensors.make(config, low_latency_mode)
+6 -2
View File
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
def make_dummy_moe_config(
num_experts: int = 1,
num_local_experts: int | None = None,
experts_per_token: int = 1,
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
max_num_tokens: int = 512,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
num_local_experts=num_local_experts
if num_local_experts is not None
else num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation=MoEActivation.SILU,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
max_num_tokens=512,
max_num_tokens=max_num_tokens,
)
@@ -13,9 +13,15 @@ from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton,
awq_gemm_triton,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
device = "cuda"
pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
reason="AWQ Triton kernels require CUDA/ROCm or XPU.",
)
device = current_platform.device_type
def reverse_awq_order(t: torch.Tensor):
@@ -100,32 +100,6 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
assert output.shape == (2, output_size)
def test_kv_cache_scale_name_handling():
# Mock a quant config that supports cache scales
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
# Condition check in load_weights
name = "layers.0.self_attn.k_proj.weight"
scale_name = mock_quant_config.get_cache_scale(name)
# Check if get_cache_scale is called and returns expected value
mock_quant_config.get_cache_scale.assert_called_once_with(name)
assert scale_name == "layers.0.self_attn.kv_scale"
def test_kv_cache_scale_name_no_scale():
# Mock a quant config that returns None for get_cache_scale
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value=None)
name = "layers.0.mlp.gate_proj.weight"
scale_name = mock_quant_config.get_cache_scale(name)
# Should return None for weights that don't have cache scales
assert scale_name is None
def test_maybe_remap_kv_scale_name():
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
@@ -183,33 +157,3 @@ def test_eagle3_lm_head_receives_quant_config():
assert call_kwargs["quant_config"] is mock_quant_config, (
"ParallelLMHead must receive the draft model's quant_config"
)
def test_load_weights_kv_scale_handling():
kv_scale_param = Mock()
kv_scale_param.weight_loader = Mock()
params_dict = {
"layers.0.self_attn.kv_scale": kv_scale_param,
}
mock_quant_config = Mock()
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
# Load_weights logic for KV cache scales
name = "layers.0.self_attn.k_proj.weight"
loaded_weight_tensor = torch.tensor([1.0, 2.0])
if mock_quant_config is not None:
scale_name = mock_quant_config.get_cache_scale(name)
if scale_name:
param = params_dict[scale_name]
assert param is kv_scale_param
weight_to_load = (
loaded_weight_tensor
if loaded_weight_tensor.dim() == 0
else loaded_weight_tensor[0]
)
assert scale_name == "layers.0.self_attn.kv_scale"
assert weight_to_load == loaded_weight_tensor[0]
@@ -14,7 +14,7 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
# -----------------------------------------------------------------------
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
# -----------------------------------------------------------------------
COLBERT_MODELS = {
COLBERT_MODELS: dict[str, dict] = {
"bert": {
"model": "answerdotai/answerai-colbert-small-v1",
"colbert_dim": 96,
@@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import csv
import importlib
import importlib.util
import os
import pytest
import torch
from tests.utils import TestFP8Layer
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
FP8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
aiter_available = importlib.util.find_spec("aiter") is not None
pytestmark = [
pytest.mark.skipif(
not (
current_platform.is_rocm()
and current_platform.supports_fp8()
and aiter_available
),
reason="Requires ROCm + FP8 support + aiter",
),
pytest.mark.usefixtures("default_vllm_config"),
]
@pytest.fixture
def enable_hipb_mm_kernel(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
rocm_aiter_ops.refresh_env_variables()
yield
rocm_aiter_ops.refresh_env_variables()
def _make_config(
*,
weight_quant_key=kFp8StaticChannelSym,
out_dtype: torch.dtype = torch.bfloat16,
weight_shape: tuple[int, int] = (512, 4096),
) -> FP8ScaledMMLinearLayerConfig:
return FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=kFp8DynamicTokenSym,
weight_shape=weight_shape,
input_dtype=torch.bfloat16,
out_dtype=out_dtype,
)
def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None:
if not os.path.exists(path):
return None
with open(path, newline="") as f:
reader = csv.DictReader(f, skipinitialspace=True)
for row in reader:
try:
if (
int(row.get("m", -1)) == m
and int(row.get("n", -1)) == n
and int(row.get("k", -1)) == k
):
return dict(row)
except (TypeError, ValueError):
continue
return None
def _skip_if_no_hipb_mm_solution(exc: RuntimeError) -> None:
if "hipblasLtMatmulAlgoGetHeuristic found 0 valid solutions" in str(exc):
pytest.skip(
"hipb_mm bpreshuffle path has no valid hipBLASLt solution on "
"this ROCm stack."
)
def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens: int):
import aiter
from aiter.ops.shuffle import shuffle_weight
x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda")
w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda")
aiter.hipb_create_extension()
x_q, x_scale = aiter.pertoken_quant(x, quant_dtype=current_platform.fp8_dtype())
w_q, w_scale = aiter.pertoken_quant(w, quant_dtype=current_platform.fp8_dtype())
try:
aiter.hipb_mm(
x_q,
shuffle_weight(w_q, layout=(16, 16)).t(),
solution_index=-1,
out_dtype=torch.bfloat16,
scaleA=x_scale,
scaleB=w_scale.t().contiguous(),
scaleOut=None,
bpreshuffle=True,
)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
def test_hipb_mm_kernel_requires_hipbmm_flag(monkeypatch: pytest.MonkeyPatch):
# The kernel rejects when `is_hip_fp8bmm_enabled()` is False. That helper
# requires AITER + AITER_LINEAR + MI3xx, so dropping AITER_LINEAR exercises
# the rejection branch.
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0")
monkeypatch.delenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", raising=False)
rocm_aiter_ops.refresh_env_variables()
is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported()
assert not is_supported
assert reason == (
"requires setting `VLLM_ROCM_USE_AITER=1`, "
"`VLLM_ROCM_USE_AITER_LINEAR=1`, "
"and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`."
)
def test_hipb_mm_flag_enables_hip_online_tuning(
monkeypatch: pytest.MonkeyPatch,
):
import vllm.envs as envs_mod
import vllm.platforms.rocm as rocm_mod
# The rocm.py gate requires all three AITER flags (and MI3xx) to auto-set
# HIP_ONLINE_TUNING.
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
try:
importlib.reload(envs_mod)
importlib.reload(rocm_mod)
assert envs_mod.VLLM_ROCM_USE_AITER
assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR
assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
assert os.environ.get("HIP_ONLINE_TUNING") == "1"
finally:
monkeypatch.undo()
os.environ.pop("HIP_ONLINE_TUNING", None)
importlib.reload(envs_mod)
importlib.reload(rocm_mod)
rocm_aiter_ops.refresh_env_variables()
def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel):
can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
_make_config()
)
assert can_implement
assert reason is None
@pytest.mark.parametrize(
("config", "expected_reason"),
[
(
_make_config(weight_quant_key=kFp8StaticTensorSym),
"requires per token activation scales and per channel weight scales.",
),
(
_make_config(out_dtype=torch.float16),
"requires bfloat16 output dtype.",
),
(
_make_config(weight_shape=(8, 4090)),
"requires N >= 16 and both N and K divisible by 16, "
"received N=8 and K=4090.",
),
],
)
def test_hipb_mm_kernel_can_implement_rejects_unsupported_configs(
enable_hipb_mm_kernel,
config: FP8ScaledMMLinearLayerConfig,
expected_reason: str,
):
can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
config
)
assert not can_implement
assert reason == expected_reason
def test_hipb_mm_kernel_process_weights_after_loading_shuffles_weights(
enable_hipb_mm_kernel,
):
weight_shape = (512, 4096)
kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
_make_config(weight_shape=weight_shape),
layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
)
layer = torch.nn.Module()
layer.weight = torch.nn.Parameter(
torch.rand(weight_shape, device="cuda").to(current_platform.fp8_dtype()).t(),
requires_grad=False,
)
layer.weight_scale = torch.nn.Parameter(
torch.rand((weight_shape[0], 1), dtype=torch.float32, device="cuda"),
requires_grad=False,
)
layer.input_scale = None
layer.input_scale_ub = None
original_weight = layer.weight.detach().clone()
original_weight_scale = layer.weight_scale.detach().clone()
kernel.process_weights_after_loading(layer)
# process_weights_after_loading now pre-applies the transposes that used
# to live in _rocm_aiter_hipb_mm_fp8_impl, so the stored weight is the
# shuffled tensor with a trailing `.t()` view, and the stored weight scale
# is its transposed-contiguous form.
expected_weight = rocm_aiter_ops.shuffle_weight(
original_weight.t().contiguous()
).t()
torch.testing.assert_close(layer.weight, expected_weight)
expected_weight_scale = original_weight_scale.t().contiguous()
torch.testing.assert_close(layer.weight_scale, expected_weight_scale)
def test_hipb_mm_kernel_forward_matches_raw_aiter_hipb_mm(enable_hipb_mm_kernel):
import aiter
weight_shape = (512, 4096)
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=32)
layer = TestFP8Layer(
weight_shape=weight_shape,
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticChannelSym,
input_dtype=torch.bfloat16,
out_dtype=torch.bfloat16,
device=torch.device("cuda"),
force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
)
# hipb_mm uses a transposed-result GEMM internally, so the flattened token
# count becomes the effective N dimension passed into hipBLASLt. Keep it
# aligned to avoid heuristic failures for tiny N.
x = torch.randn(2, 16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device="cuda")
try:
out = layer(x, bias)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
x_2d = x.view(-1, x.shape[-1])
x_q, x_scale = layer.kernel.quant_fp8(
x_2d,
layer.input_scale,
layer.input_scale_ub,
)
try:
# process_weights_after_loading already applies the trailing `.t()` on
# the shuffled weight and the `.t().contiguous()` on the weight scale,
# so the raw aiter call uses them directly.
expected = aiter.hipb_mm(
x_q,
layer.weight,
solution_index=-1,
bias=bias,
out_dtype=torch.bfloat16,
scaleA=x_scale,
scaleB=layer.weight_scale,
scaleOut=None,
bpreshuffle=True,
).view(*out.shape)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
assert isinstance(layer.kernel, AiterHipbMMPerTokenFp8ScaledMMLinearKernel)
assert out.shape == (2, 16, weight_shape[0])
torch.testing.assert_close(out, expected)
def test_hipb_mm_kernel_forward_accuracy(enable_hipb_mm_kernel):
"""Kernel output should match a dequantized fp32 reference within
fp8 per-token / per-channel quantization noise."""
weight_shape = (512, 4096) # (N, K)
num_tokens = 32
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=num_tokens)
fp8_dtype = current_platform.fp8_dtype()
fp8_max = torch.finfo(fp8_dtype).max
device = torch.device("cuda")
# Build a bf16 weight and quantize per output channel (one scale per row).
w_bf16 = torch.randn(weight_shape, dtype=torch.bfloat16, device=device)
w_amax = w_bf16.abs().amax(dim=1, keepdim=True).to(torch.float32)
w_scale = (w_amax / fp8_max).clamp(min=1e-12)
w_fp8 = (w_bf16.to(torch.float32) / w_scale).clamp(-fp8_max, fp8_max).to(fp8_dtype)
w_dequant = w_fp8.to(torch.float32) * w_scale
bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device=device)
layer = torch.nn.Module()
# Pre-`process_weights_after_loading` convention: weight stored as the
# `[K, N]` view of the fp8 tensor.
layer.weight = torch.nn.Parameter(w_fp8.t(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(w_scale, requires_grad=False)
layer.input_scale = None
layer.input_scale_ub = None
kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
_make_config(weight_shape=weight_shape),
layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
)
kernel.process_weights_after_loading(layer)
x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device=device)
try:
out = kernel.apply_weights(layer, x, bias)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
# Reference: quantize x per-token the same way the kernel does, then run
# the matmul in fp32 against the dequantized weight. This isolates plumbing
# / reduction bugs from inherent fp8 quantization noise.
x_amax = x.abs().amax(dim=1, keepdim=True).to(torch.float32)
x_scale_ref = (x_amax / fp8_max).clamp(min=1e-12)
x_q = (x.to(torch.float32) / x_scale_ref).clamp(-fp8_max, fp8_max).to(fp8_dtype)
x_dequant = x_q.to(torch.float32) * x_scale_ref
expected = (x_dequant @ w_dequant.t() + bias.to(torch.float32)).to(torch.bfloat16)
assert out.shape == (num_tokens, weight_shape[0])
# K=4096 fp8 reduction leaves room for accumulation order drift and
# catastrophic cancellation on near-zero outputs; tolerances are loose
# enough to absorb that but tight enough to catch wrong layouts, missing
# bias, swapped scales, etc.
torch.testing.assert_close(out, expected, atol=5.0, rtol=0.1)
def test_hipb_mm_kernel_online_tuning_writes_csv(
enable_hipb_mm_kernel,
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
weight_shape = (256, 4096)
cache_file = tmp_path / "hip_online_tuning_res.csv"
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=16)
monkeypatch.setenv("HIP_ONLINE_TUNING", "1")
monkeypatch.chdir(tmp_path)
layer = TestFP8Layer(
weight_shape=weight_shape,
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticChannelSym,
input_dtype=torch.bfloat16,
out_dtype=torch.bfloat16,
device=torch.device("cuda"),
force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
)
# The effective heuristic N dimension is the flattened token count.
x = torch.randn(16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
try:
out = layer(x)
except RuntimeError as exc:
_skip_if_no_hipb_mm_solution(exc)
raise
torch.accelerator.synchronize()
assert out.shape == (16, weight_shape[0])
assert cache_file.exists()
# hipb_mm records the internal GEMM dimensions used by hipBLASLt after its
# transposed-result transformation.
row = _find_csv_row(
str(cache_file),
m=weight_shape[0],
n=x.shape[0],
k=weight_shape[1],
)
assert row is not None
+2 -2
View File
@@ -797,11 +797,11 @@ class TestMistralTokenizer:
True,
(
[1, 3, 23325, 2294, 1686, 4, 23325],
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
[1, 3, 22177, 4304, 2662, 4, 22177],
),
(
"<s>[INST]▁Hello▁world▁![/INST]▁Hello",
("<s>[INST]Hello world ![/INST]Hello</s>"),
"<s>[INST]Hello world ![/INST]Hello",
),
),
],
@@ -49,11 +49,13 @@ else
KV_EXTRA_CONFIG=''
fi
# Build the kv-transfer-config once
# Build the kv-transfer-config for P and D
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
KV_CONFIG_P='{"kv_connector":"NixlConnector","kv_role":"kv_producer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
KV_CONFIG_D='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
KV_CONFIG_P="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
KV_CONFIG_D="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
fi
# Models to run
@@ -159,7 +161,7 @@ run_tests_for_model() {
--block-size ${PREFILL_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
--kv-transfer-config '$KV_CONFIG_P'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
@@ -207,7 +209,7 @@ run_tests_for_model() {
--enforce-eager \
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
--kv-transfer-config '$KV_CONFIG_D'"
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
for arg in "${extra_args[@]}"; do
@@ -193,9 +193,11 @@ run_test_for_device() {
local kv_device=$1
if [[ "$kv_device" == "cuda" ]]; then
local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
local kv_config_p='{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
local kv_config_d='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
else
local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}"
local kv_config_p="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"${kv_device}\"}"
local kv_config_d="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"${kv_device}\"}"
fi
echo ""
@@ -248,7 +250,7 @@ run_test_for_device() {
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--kv-transfer-config "$kv_config_p" \
--speculative-config "$PREFILL_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
@@ -287,7 +289,7 @@ run_test_for_device() {
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config "$kv_config" \
--kv-transfer-config "$kv_config_d" \
--speculative-config "$DECODE_SPEC_CONFIG" \
--attention-backend $ATTENTION_BACKEND \
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
@@ -0,0 +1,232 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
from typing import Any
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorHandshakeMetadata,
)
from vllm.v1.engine import core as engine_core_module
pytestmark = pytest.mark.cpu_test
class _Metadata(KVConnectorHandshakeMetadata):
pass
class _FakeExecutor:
handshake_metadata_src: (
list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None
)
last_instance: "_FakeExecutor | None" = None
def __init__(
self,
vllm_config: Any,
) -> None:
del vllm_config
self.handshake_metadata = self.handshake_metadata_src
self.handshake_calls = 0
_FakeExecutor.last_instance = self
def get_kv_connector_handshake_metadata(
self,
) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None:
self.handshake_calls += 1
return self.handshake_metadata
def init_kv_output_aggregator(self, connector: KVConnectorBase_V1) -> None:
pass
def _run_engine_core_handshake(
monkeypatch: pytest.MonkeyPatch,
connector: KVConnectorBase_V1,
*,
handshake_metadata: (
list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None
),
) -> _FakeExecutor:
class _FakeScheduler:
def __init__(self, **kwargs: Any) -> None:
self.connector = connector
def get_kv_connector(self) -> KVConnectorBase_V1:
return connector
_FakeExecutor.handshake_metadata_src = handshake_metadata
_FakeExecutor.last_instance = None
monkeypatch.setattr("vllm.plugins.load_general_plugins", lambda: None)
monkeypatch.setattr(
engine_core_module.EngineCore,
"_initialize_kv_caches",
lambda self, vllm_config: SimpleNamespace(kv_cache_groups=[object()]),
)
monkeypatch.setattr(
engine_core_module,
"StructuredOutputManager",
lambda vllm_config: object(),
)
monkeypatch.setattr(
engine_core_module,
"resolve_kv_cache_block_sizes",
lambda kv_cache_config, vllm_config: (16, 16),
)
monkeypatch.setattr(
engine_core_module,
"MULTIMODAL_REGISTRY",
SimpleNamespace(engine_receiver_cache_from_config=lambda vllm_config: None),
)
monkeypatch.setattr(engine_core_module, "freeze_gc_heap", lambda: None)
monkeypatch.setattr(
engine_core_module, "maybe_attach_gc_debug_callback", lambda: None
)
monkeypatch.setattr(engine_core_module, "enable_envs_cache", lambda: None)
monkeypatch.setattr(engine_core_module, "get_hash_fn_by_name", lambda name: None)
monkeypatch.setattr(engine_core_module, "init_none_hash", lambda hash_fn: None)
monkeypatch.setattr(
engine_core_module, "get_request_block_hasher", lambda *args: None
)
vllm_config = SimpleNamespace(
parallel_config=SimpleNamespace(data_parallel_rank_local=0),
scheduler_config=SimpleNamespace(
get_scheduler_cls=lambda: _FakeScheduler,
enable_chunked_prefill=False,
async_scheduling=False,
),
speculative_config=None,
ec_transfer_config=None,
max_concurrent_batches=1,
model_config=SimpleNamespace(runner_type="generate"),
cache_config=SimpleNamespace(
enable_prefix_caching=False,
prefix_caching_hash_algo="builtin",
),
)
engine_core_module.EngineCore(vllm_config, _FakeExecutor, log_stats=False)
assert _FakeExecutor.last_instance is not None
return _FakeExecutor.last_instance
class _LegacyConnector(KVConnectorBase_V1):
def __init__(self) -> None:
self.legacy_metadata: dict[int, KVConnectorHandshakeMetadata] | None = None
def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: Any,
attn_metadata: Any,
**kwargs: Any,
) -> None:
pass
def wait_for_save(self) -> None:
pass
def get_num_new_matched_tokens(
self, request: Any, num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self, request: Any, blocks: Any, num_external_tokens: int
) -> None:
pass
def build_connector_meta(self, scheduler_output: Any) -> Any:
raise NotImplementedError
def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]
) -> None:
self.legacy_metadata = metadata
class _PPAwareConnector(_LegacyConnector):
def __init__(self) -> None:
super().__init__()
self.pp_aware_metadata: (
dict[tuple[int, int], KVConnectorHandshakeMetadata] | None
) = None
def set_xfer_handshake_metadata_pp_aware(
self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata]
) -> None:
self.pp_aware_metadata = metadata
def test_engine_unwraps_handshake_metadata_for_legacy_connector(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Engine core always asks workers for `(pp_rank, tp_rank)`-keyed metadata,
then unwraps to `{tp_rank: metadata}` for a connector that has not opted
into PP-aware handshake (single-PP producer, all `pp_rank == 0`)."""
metadata_0 = _Metadata()
metadata_1 = _Metadata()
connector = _LegacyConnector()
executor = _run_engine_core_handshake(
monkeypatch,
connector,
handshake_metadata=[
{(0, 0): metadata_0},
None,
{(0, 1): metadata_1},
],
)
assert executor.handshake_calls == 1
assert connector.legacy_metadata == {0: metadata_0, 1: metadata_1}
def test_engine_rejects_pp_producer_for_legacy_connector(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A connector that has not opted into PP-aware handshake must not silently
drop metadata from `pp_rank > 0`; engine core init raises instead."""
connector = _LegacyConnector()
with pytest.raises(ValueError, match="does not support PP-disaggregated"):
_run_engine_core_handshake(
monkeypatch,
connector,
handshake_metadata=[{(0, 0): _Metadata()}, {(1, 0): _Metadata()}],
)
def test_engine_passes_handshake_metadata_through_for_pp_aware_connector(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A PP-aware connector receives the full `(pp_rank, tp_rank)`-keyed dict
unchanged."""
metadata_0 = _Metadata()
metadata_1 = _Metadata()
connector = _PPAwareConnector()
executor = _run_engine_core_handshake(
monkeypatch,
connector,
handshake_metadata=[{(0, 0): metadata_0}, {(1, 0): metadata_1}],
)
assert executor.handshake_calls == 1
assert connector.legacy_metadata is None
assert connector.pp_aware_metadata == {
(0, 0): metadata_0,
(1, 0): metadata_1,
}
@@ -261,11 +261,12 @@ def test_multi_example_connector_consistency():
storage1_scheduler_events = _ignore_event_collection(events["storage1-SCHEDULER"])
storage2_scheduler_events = _ignore_event_collection(events["storage2-SCHEDULER"])
# First event is bind_gpu_block_pool from initialization, then
# set_xfer_handshake_metadata, then on_new_request when the request is enqueued,
# then get_num_new_matched_tokens and update_state_after_alloc from generate().
# set_xfer_handshake_metadata_pp_aware, then on_new_request when the request is
# enqueued, then get_num_new_matched_tokens and update_state_after_alloc from
# generate().
assert storage1_scheduler_events[:6] == [
"bind_gpu_block_pool",
"set_xfer_handshake_metadata",
"set_xfer_handshake_metadata_pp_aware",
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
@@ -285,7 +286,7 @@ def test_multi_example_connector_consistency():
]
assert storage2_scheduler_events[:6] == [
"bind_gpu_block_pool",
"set_xfer_handshake_metadata",
"set_xfer_handshake_metadata_pp_aware",
"on_new_request",
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
@@ -1057,7 +1058,7 @@ def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch):
"connectors": [
{
"kv_connector": "NixlConnector",
"kv_role": "kv_both",
"kv_role": "kv_consumer",
},
{
"kv_connector": "MockConnector",
@@ -1363,7 +1363,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
timeout = 6
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
kv_role="kv_consumer",
kv_connector_extra_config={"kv_lease_duration": timeout},
)
llm_kwargs = {
@@ -2737,3 +2737,50 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
f"got {notif!r} (expected {expected_notif!r}, "
f"buggy form would be {bad_notif!r})"
)
def test_kv_both_deprecation_warning(default_vllm_config, dist_init):
"""kv_role='kv_both' should emit a deprecation log warning."""
from unittest.mock import patch
from vllm.logger import _print_warning_once
_print_warning_once.cache_clear()
vllm_config = create_vllm_config(kv_role="kv_both")
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger"
) as mock_logger:
mock_logger.warning_once = mock_logger.warning_once
NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16),
)
mock_logger.warning_once.assert_called_once()
msg = mock_logger.warning_once.call_args[0][0]
assert "kv_role='kv_both'" in msg
assert "deprecated" in msg
def test_explicit_kv_role_no_deprecation_warning(default_vllm_config, dist_init):
"""kv_role='kv_consumer' or 'kv_producer' should NOT emit a warning."""
from unittest.mock import patch
for role in ("kv_consumer", "kv_producer"):
vllm_config = create_vllm_config(kv_role=role)
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger"
) as mock_logger:
NixlConnector(
vllm_config,
KVConnectorRole.WORKER,
make_kv_cache_config(block_size=16),
)
(
mock_logger.warning_once.assert_not_called(),
(f"kv_role={role!r} should not emit deprecation warning"),
)
@@ -453,7 +453,7 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
"""
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
kv_role="kv_consumer",
)
block_size = 16
llm_kwargs = {
@@ -75,7 +75,7 @@ def test_gpu_memory_rixl_hma(model_name, sw_size):
"gpu_memory_utilization": 0.5,
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
kv_role="kv_consumer",
),
"max_model_len": 2048,
"disable_hybrid_kv_cache_manager": False,
@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineTransferInfo,
TransferTopology,
)
pytestmark = pytest.mark.cpu_test
class _FakeAttentionBackend:
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, int, int, int, int]:
return (2, num_blocks, num_kv_heads, block_size, head_size)
def _make_topology(
*,
tp_rank: int = 1,
tp_size: int = 4,
total_num_kv_heads: int = 8,
) -> TransferTopology:
return TransferTopology(
tp_rank=tp_rank,
tp_size=tp_size,
block_size=16,
engine_id="local-engine",
is_mla=False,
is_mamba=False,
total_num_kv_heads=total_num_kv_heads,
attn_backends=[_FakeAttentionBackend],
)
def test_legacy_register_remote_engine_uses_pp_rank_zero() -> None:
topology = _make_topology()
info = EngineTransferInfo(
remote_tp_size=2,
remote_block_len=1024,
remote_block_size=16,
remote_physical_blocks_per_logical=1,
)
registered = topology.register_remote_engine("remote-engine", info)
assert registered == info
assert registered.remote_pp_rank == 0
assert topology.get_engine_info("remote-engine") == info
assert topology._engines[("remote-engine", 0)] == info
assert topology.target_remote_ranks("remote-engine") == [0]
def test_register_remote_engine_stores_pp_ranks_separately() -> None:
topology = _make_topology(tp_rank=0, tp_size=2)
info_0 = EngineTransferInfo(
remote_tp_size=2,
remote_block_len=1024,
remote_block_size=16,
remote_physical_blocks_per_logical=1,
remote_pp_rank=0,
start_layer=0,
end_layer=16,
)
info_1 = EngineTransferInfo(
remote_tp_size=1,
remote_block_len=512,
remote_block_size=8,
remote_physical_blocks_per_logical=2,
remote_pp_rank=1,
start_layer=16,
end_layer=32,
)
registered_0 = topology.register_remote_engine("remote-engine", info_0)
registered_1 = topology.register_remote_engine("remote-engine", info_1)
assert registered_0 == info_0
assert registered_1 == info_1
assert topology.get_engine_info("remote-engine") == info_0
assert topology.get_engine_info("remote-engine", 0) == info_0
assert topology.get_engine_info("remote-engine", 1) == info_1
assert set(topology._engines) == {
("remote-engine", 0),
("remote-engine", 1),
}
def test_helpers_use_requested_pp_rank() -> None:
topology = _make_topology(tp_rank=1, tp_size=2, total_num_kv_heads=2)
topology.register_remote_engine(
"remote-engine",
EngineTransferInfo(
remote_tp_size=1,
remote_block_len=1024,
remote_block_size=16,
remote_physical_blocks_per_logical=1,
remote_pp_rank=0,
start_layer=0,
end_layer=8,
),
)
topology.register_remote_engine(
"remote-engine",
EngineTransferInfo(
remote_tp_size=4,
remote_block_len=1024,
remote_block_size=16,
remote_physical_blocks_per_logical=1,
remote_pp_rank=1,
start_layer=8,
end_layer=16,
),
)
assert not topology.is_kv_replicated("remote-engine", 0)
assert topology.is_kv_replicated("remote-engine", 1)
assert topology.replicates_kv_cache("remote-engine", 1)
assert topology.target_remote_ranks("remote-engine", 0) == [0]
assert topology.target_remote_ranks("remote-engine", 1) == [2, 3]
assert "remote_pp=1" in topology.describe("remote-engine", 1)
def test_engine_info_fields_have_backward_compatible_defaults() -> None:
topology = _make_topology()
info = EngineTransferInfo(
remote_tp_size=2,
remote_block_len=1024,
remote_block_size=16,
remote_physical_blocks_per_logical=1,
)
registered = topology.register_remote_engine("remote-engine", info)
assert topology.get_engine_info("remote-engine") == registered
assert registered.remote_pp_rank == 0
assert registered.start_layer == 0
assert registered.end_layer == 0
+1 -1
View File
@@ -103,7 +103,7 @@ def create_vllm_config(
kv_load_failure_policy: Literal["recompute", "fail"] = "fail",
kv_connector: str = "NixlConnector",
kv_connector_module_path: str | None = None,
kv_role: str = "kv_both",
kv_role: str = "kv_consumer",
disable_hybrid_kv_cache_manager: bool | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
@@ -0,0 +1,365 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Mock-based unit tests for ObjectStoreSecondaryTierManager.
These tests replace the NIXL backend with an in-memory mock so they run
without S3 credentials or a live object store. They verify the manager's
state machine: job submission, transfer completion polling, and lookup.
"""
import uuid
from collections.abc import Callable
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from vllm.v1.kv_offload.base import OffloadKey, ReqContext, make_offload_key
from vllm.v1.kv_offload.tiering.base import JobMetadata, JobResult
from vllm.v1.kv_offload.tiering.obj.manager import ObjectStoreSecondaryTierManager
# ---------------------------------------------------------------------------
# Shared stubs
# ---------------------------------------------------------------------------
def _make_vllm_config():
return SimpleNamespace(
model_config=SimpleNamespace(model="test/model"),
cache_config=SimpleNamespace(block_size=16, cache_dtype="float16"),
parallel_config=SimpleNamespace(
tensor_parallel_size=1,
pipeline_parallel_size=1,
prefill_context_parallel_size=1,
decode_context_parallel_size=1,
rank=0,
),
)
_OFFLOADING_SPEC = SimpleNamespace(
vllm_config=_make_vllm_config(),
kv_cache_config=SimpleNamespace(kv_cache_groups=[]),
)
_STORE_CONFIG = {
"bucket": "mock-bucket",
"endpoint_override": "mock:9000",
"access_key": "mock-access",
"secret_key": "mock-secret",
}
_BLOCK_ELEMENTS = 256
_DTYPE = torch.float32
_RUN_PREFIX = f"test/{uuid.uuid4().hex[:8]}"
_CTX = ReqContext(req_id="test-req")
def key(n: int) -> OffloadKey:
return make_offload_key(n.to_bytes(8, "big"), 0)
def make_job(
job_id: int,
keys: list[OffloadKey],
block_ids: list[int] | None = None,
) -> JobMetadata:
if block_ids is None:
block_ids = list(range(len(keys)))
return JobMetadata(
job_id=job_id,
keys=keys,
block_ids=np.array(block_ids, dtype=np.int64),
is_promotion=False,
req_context=_CTX,
)
# ---------------------------------------------------------------------------
# Mock NIXL agent
# ---------------------------------------------------------------------------
class MockNixlAgent:
"""In-memory NIXL agent. Tracks stored object keys and simulates async
transfers: transfer() returns PROC, check_xfer_state() returns DONE and
commits the write to the in-memory key set.
The four methods overridden by tests (register_memory, make_prepped_xfer,
check_xfer_state, query_memory) are stored as Callable instance attributes
so mypy allows reassignment in tests.
"""
# Callable attributes — tests may reassign these on instances.
register_memory: Callable
make_prepped_xfer: Callable
check_xfer_state: Callable
query_memory: Callable
def __init__(self):
self._stored_obj_keys: set[str] = set()
# handle_id -> (op, [obj_keys])
self._pending: dict[int, tuple[str, list[str]]] = {}
self._handle_counter = 0
self._last_obj_keys: list[str] = []
# Bind default implementations as instance attributes.
self.register_memory = self._register_memory
self.make_prepped_xfer = self._make_prepped_xfer
self.check_xfer_state = self._check_xfer_state
self.query_memory = self._query_memory
def create_backend(self, backend_type, params):
pass
def _register_memory(self, descs, mem_type=None, backends=None):
mock = MagicMock()
mock.trim.return_value = MagicMock()
# Capture obj_keys from OBJ 4-tuples: (addr, len, dev_id, obj_key)
if mem_type == "OBJ" and descs:
self._last_obj_keys = [d[3] for d in descs if d[3]]
return mock
def deregister_memory(self, desc):
pass
def prep_xfer_dlist(self, agent_name, descs, mem_type=None, backends=None):
return MagicMock()
def _make_prepped_xfer(
self,
op,
local_handle,
local_indices,
remote_handle,
remote_indices,
notif_msg=b"",
backends=None,
skip_desc_merge=False,
):
handle = MagicMock()
handle._id = self._handle_counter
self._pending[self._handle_counter] = (op, list(self._last_obj_keys))
self._handle_counter += 1
return handle
def transfer(self, handle):
return "PROC"
def _check_xfer_state(self, handle):
entry = self._pending.pop(handle._id, None)
if entry:
op, obj_keys = entry
if op == "WRITE":
self._stored_obj_keys.update(obj_keys)
return "DONE"
def release_xfer_handle(self, handle):
pass
def release_dlist_handle(self, handle):
pass
def _query_memory(self, queries, mem_type, agent_name):
return [object() if q[3] in self._stored_obj_keys else None for q in queries]
# ---------------------------------------------------------------------------
# Fixture
# ---------------------------------------------------------------------------
def _make_tier(
num_blocks: int = 4,
) -> tuple[ObjectStoreSecondaryTierManager, MockNixlAgent]:
"""Create a tier backed by a fresh MockNixlAgent."""
mock_agent = MockNixlAgent()
tensor = torch.zeros((num_blocks, _BLOCK_ELEMENTS), dtype=_DTYPE)
view = memoryview(tensor.numpy())
with (
patch("vllm.v1.kv_offload.tiering.obj.manager.nixl_agent_config"),
patch(
"vllm.v1.kv_offload.tiering.obj.manager.nixl_agent",
return_value=mock_agent,
),
):
tier = ObjectStoreSecondaryTierManager(
offloading_spec=_OFFLOADING_SPEC,
primary_kv_view=view,
tier_type="obj",
store_config=_STORE_CONFIG,
prefix=_RUN_PREFIX,
)
return tier, mock_agent
def drain(
tier: ObjectStoreSecondaryTierManager, max_rounds: int = 20
) -> list[JobResult]:
"""Poll get_finished_jobs() until all in-flight jobs resolve."""
results: list[JobResult] = []
for _ in range(max_rounds):
results.extend(tier.get_finished_jobs())
if not tier._transfers:
break
return results
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestMockObjTierBasic:
def setup_method(self):
self.tier, self.agent = _make_tier(num_blocks=4)
def test_lookup_empty_tier(self):
assert self.tier.lookup(key(1), _CTX) is False
def test_store_and_lookup(self):
self.tier.submit_store(make_job(1, [key(1)], [0]))
results = drain(self.tier)
assert len(results) == 1
assert results[0].success
assert self.tier.lookup(key(1), _CTX) is True
def test_lookup_unrelated_key_returns_false(self):
self.tier.submit_store(make_job(1, [key(1)], [0]))
drain(self.tier)
assert self.tier.lookup(key(999), _CTX) is False
def test_store_then_load_roundtrip(self):
self.tier.submit_store(make_job(1, [key(1), key(2)], [0, 1]))
results = drain(self.tier)
assert results[0].success
self.tier.submit_load(make_job(2, [key(1), key(2)], [0, 1]))
results = drain(self.tier)
assert len(results) == 1
assert results[0].success
def test_multiple_jobs_tracked_independently(self):
self.tier.submit_store(make_job(1, [key(1)], [0]))
self.tier.submit_store(make_job(2, [key(2)], [1]))
results = drain(self.tier)
assert len(results) == 2
assert all(r.success for r in results)
def test_failed_transfer_reported(self):
self.agent.check_xfer_state = lambda h: "ERR"
self.tier.submit_store(make_job(1, [key(1)], [0]))
results = drain(self.tier)
assert len(results) == 1
assert not results[0].success
def test_pending_transfer_not_returned_until_done(self):
# First poll returns PROC; second poll returns DONE.
call_count = [0]
original = self.agent.check_xfer_state
def delayed(h):
call_count[0] += 1
return "PROC" if call_count[0] == 1 else original(h)
self.agent.check_xfer_state = delayed
self.tier.submit_store(make_job(1, [key(1)], [0]))
assert list(self.tier.get_finished_jobs()) == []
results = list(self.tier.get_finished_jobs())
assert len(results) == 1
assert results[0].success
class TestMockObjTierMultiBlock:
def test_store_multiple_blocks(self):
tier, _ = _make_tier(num_blocks=8)
keys = [key(i) for i in range(8)]
tier.submit_store(make_job(1, keys, list(range(8))))
results = drain(tier)
assert len(results) == 1
assert results[0].success
assert all(tier.lookup(k, _CTX) for k in keys)
def test_partial_block_lookup(self):
tier, _ = _make_tier(num_blocks=4)
tier.submit_store(make_job(1, [key(0), key(1)], [0, 1]))
drain(tier)
assert tier.lookup(key(0), _CTX) is True
assert tier.lookup(key(1), _CTX) is True
assert tier.lookup(key(2), _CTX) is False
class TestMockObjTierFailures:
def test_lookup_exception_returns_false(self):
tier, agent = _make_tier(num_blocks=4)
agent.query_memory = lambda *a, **k: (_ for _ in ()).throw(
RuntimeError("backend error")
)
assert tier.lookup(key(1), _CTX) is False
def test_submit_store_register_memory_failure_reported_in_get_finished(self):
tier, agent = _make_tier(num_blocks=4)
agent.register_memory = lambda *a, **k: None
tier.submit_store(make_job(1, [key(1)], [0]))
results = list(tier.get_finished_jobs())
assert len(results) == 1
assert results[0].job_id == 1
assert not results[0].success
def test_submit_load_register_memory_failure_reported_in_get_finished(self):
tier, agent = _make_tier(num_blocks=4)
agent.register_memory = lambda *a, **k: None
tier.submit_load(make_job(2, [key(1)], [0]))
results = list(tier.get_finished_jobs())
assert len(results) == 1
assert results[0].job_id == 2
assert not results[0].success
def test_submit_store_make_prepped_xfer_failure_reported_in_get_finished(self):
tier, agent = _make_tier(num_blocks=4)
agent.make_prepped_xfer = lambda *a, **k: None
tier.submit_store(make_job(3, [key(1)], [0]))
results = list(tier.get_finished_jobs())
assert len(results) == 1
assert results[0].job_id == 3
assert not results[0].success
def test_failure_and_success_both_returned_by_get_finished(self):
# One job fails at submission, another succeeds in flight.
tier, agent = _make_tier(num_blocks=4)
original_register = agent.register_memory
call_count = [0]
def register_once_fail(*a, **k):
call_count[0] += 1
return None if call_count[0] == 1 else original_register(*a, **k)
agent.register_memory = register_once_fail
tier.submit_store(make_job(1, [key(1)], [0])) # fails immediately
tier.submit_store(make_job(2, [key(2)], [1])) # succeeds
results = drain(tier)
assert len(results) == 2
by_id = {r.job_id: r for r in results}
assert not by_id[1].success
assert by_id[2].success
class TestMockObjTierShutdown:
def test_shutdown_clears_in_flight_transfers(self):
tier, agent = _make_tier(num_blocks=4)
# Keep transfer in flight by never completing it
agent.check_xfer_state = lambda h: "PROC"
tier.submit_store(make_job(1, [key(1)], [0]))
assert len(tier._transfers) == 1
tier.shutdown()
assert len(tier._transfers) == 0
assert tier._dram_prepped_handle is None
assert tier._primary_reg is None
def test_shutdown_idempotent(self):
tier, _ = _make_tier(num_blocks=4)
tier.shutdown()
tier.shutdown() # must not raise
+64 -1
View File
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
target_probs: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
device: torch.device,
use_fp64_gumbel: bool = False,
) -> torch.Tensor:
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
dtype=q_dtype,
device=device,
)
q.exponential_()
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
batch_size = 2
vocab_size = 64
max_spec_len = 2
num_tokens = batch_size * max_spec_len
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
target_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
target_probs = F.softmax(target_probs, dim=-1)
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
generators = {
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
generators=generators,
)
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
target_probs.log(),
)
expected = native_sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device=torch.device(DEVICE_TYPE),
use_fp64_gumbel=True,
)
actual = sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device=torch.device(DEVICE_TYPE),
use_fp64_gumbel=True,
)
assert torch.equal(actual, expected)
########################### Tests for Synthetic Rejection Sampling #########
+39 -1
View File
@@ -6,7 +6,12 @@ from torch import Generator
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.sample.ops.topk_topp_sampler import (
apply_top_k_top_p_pytorch,
random_sample,
)
from vllm.v1.sample.sampler import Sampler
DEVICE_TYPE = current_platform.device_type
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
def _seed_default_generator(seed: int) -> None:
set_random_seed(seed)
@pytest.fixture(autouse=True)
def reset_default_device():
"""
@@ -49,6 +58,35 @@ def reset_default_device():
torch.set_default_device(original_device)
def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
sampler = Sampler(use_fp64_gumbel=True)
assert sampler.topk_topp_sampler.use_fp64_gumbel
def test_random_sample_uses_fp64_exponential_race_when_requested():
torch.set_default_device(DEVICE_TYPE)
probs = torch.tensor(
[
[0.70, 0.20, 0.10],
[0.05, 0.15, 0.80],
[0.25, 0.25, 0.50],
],
dtype=torch.float32,
device=DEVICE_TYPE,
)
_seed_default_generator(12345)
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
q.exponential_()
expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
_seed_default_generator(12345)
actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
assert torch.equal(actual, expected)
def test_topk_impl_equivalence():
torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
+2 -1
View File
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
proposer.model = model_mock
proposer._draft_attn_layer_names = {"layer.0"}
def fake_compute_probs(logits, sampling_metadata):
def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
assert use_fp64_gumbel == proposer.use_fp64_gumbel
probs = torch.softmax(logits, dim=-1)
return probs.argmax(dim=-1), probs
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.llm_base_proposer import (
compute_probs_and_sample_next_token,
)
DEVICE_TYPE = current_platform.device_type
def _seed_default_generator(seed: int) -> None:
set_random_seed(seed)
def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
return SamplingMetadata(
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
all_greedy=False,
all_random=True,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
presence_penalties=torch.empty(0, device=DEVICE_TYPE),
repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
output_token_ids=[[] for _ in range(batch_size)],
spec_token_ids=[[] for _ in range(batch_size)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessors(),
)
def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
batch_size = 4
vocab_size = 32
generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
logits = torch.randn(
batch_size,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
generator=generator,
)
metadata = _make_sampling_metadata(batch_size)
_seed_default_generator(12345)
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
q.exponential_()
expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
_seed_default_generator(12345)
actual_ids, actual_probs = compute_probs_and_sample_next_token(
logits.clone(),
metadata,
use_fp64_gumbel=True,
)
assert torch.equal(actual_ids, expected_ids)
assert torch.allclose(actual_probs, probs)
@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUDA proof for fp32 exponential-race tail truncation.
This script is intentionally not a unit test. It is a reproducible, GPU-only
statistical proof for the hidden Gumbel-max idiom:
q.exponential_()
sample = (probs / q).argmax()
For q ~ Exp(1), this is equivalent to argmax(log(probs) + Gumbel). On CUDA,
fp32 exponential samples inherit a 24-bit uniform lower-tail cutoff, so very
small q values are impossible. The many-tail experiment below chooses a case
where a correct sampler should select a low-probability tail token dozens of
times, while fp32 q cannot select one.
"""
from __future__ import annotations
import argparse
import math
import time
import torch
def _seed(seed: int) -> None:
torch.manual_seed(seed)
def measure_exponential_lower_tail(
*,
device: torch.device,
samples: int,
chunk_size: int,
seed: int,
) -> None:
threshold = 2.0**-24
print(f"lower-tail threshold: {threshold:.18e}")
for dtype in (torch.float32, torch.float64):
_seed(seed)
count_below = 0
min_q = float("inf")
max_q = 0.0
start = time.perf_counter()
remaining = samples
while remaining > 0:
n = min(chunk_size, remaining)
q = torch.empty((n,), dtype=dtype, device=device)
q.exponential_()
count_below += int((q < threshold).sum().item())
min_q = min(min_q, float(q.min().item()))
max_q = max(max_q, float(q.max().item()))
remaining -= n
torch.accelerator.synchronize()
elapsed = time.perf_counter() - start
print(
f"{dtype}: samples={samples} count(q < 2^-24)={count_below} "
f"min={min_q:.18e} max={max_q:.6f} elapsed={elapsed:.2f}s"
)
def run_many_tail_race(
*,
device: torch.device,
trials: int,
num_tail_tokens: int,
gap: float,
chunk_trials: int,
seed: int,
) -> None:
p_tail = math.exp(-gap)
expected_tail_hits = (
trials * (num_tail_tokens * p_tail) / (1.0 + num_tail_tokens * p_tail)
)
print(
"many-tail race: "
f"trials={trials} num_tail_tokens={num_tail_tokens} gap={gap} "
f"expected_tail_hits={expected_tail_hits:.4f}"
)
for dtype in (torch.float32, torch.float64):
_seed(seed)
hits = 0
p0 = torch.tensor(1.0, dtype=dtype, device=device)
pt = torch.tensor(p_tail, dtype=dtype, device=device)
start = time.perf_counter()
remaining = trials
while remaining > 0:
batch = min(chunk_trials, remaining)
q0 = torch.empty((batch,), dtype=dtype, device=device)
q0.exponential_()
qt = torch.empty((batch, num_tail_tokens), dtype=dtype, device=device)
qt.exponential_()
head_score = p0 / q0
tail_score = (pt / qt).amax(dim=-1)
hits += int((tail_score > head_score).sum().item())
remaining -= batch
torch.accelerator.synchronize()
elapsed = time.perf_counter() - start
print(f"{dtype}: tail_hits={hits} elapsed={elapsed:.2f}s")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--lower-tail-samples", type=int, default=200_000_000)
parser.add_argument("--lower-tail-chunk-size", type=int, default=10_000_000)
parser.add_argument("--race-trials", type=int, default=100_000)
parser.add_argument("--race-tail-tokens", type=int, default=262_144)
parser.add_argument("--race-gap", type=float, default=20.5)
parser.add_argument("--race-chunk-trials", type=int, default=64)
parser.add_argument("--seed", type=int, default=2026)
args = parser.parse_args()
if not torch.accelerator.is_available():
raise RuntimeError("CUDA is required for this proof.")
device = torch.accelerator.current_accelerator()
if device.type != "cuda":
raise RuntimeError("CUDA is required for this proof.")
print(f"torch={torch.__version__} cuda={torch.version.cuda}")
print(f"device={device}")
measure_exponential_lower_tail(
device=device,
samples=args.lower_tail_samples,
chunk_size=args.lower_tail_chunk_size,
seed=args.seed,
)
run_many_tail_race(
device=device,
trials=args.race_trials,
num_tail_tokens=args.race_tail_tokens,
gap=args.race_gap,
chunk_trials=args.race_chunk_trials,
seed=args.seed,
)
if __name__ == "__main__":
main()
+3 -5
View File
@@ -8,11 +8,9 @@ on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.
Usage:
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
python tools/pre_commit/mypy.py <python_version> <changed_files...>
Args:
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
"silent" for the main group of files.
python_version: Python version to use (e.g., "3.10") or "local" to use
the local Python version.
changed_files: List of changed files to check.
@@ -98,8 +96,8 @@ def mypy(
def main():
python_version = sys.argv[2]
file_groups = group_files(sys.argv[3:])
python_version = sys.argv[1]
file_groups = group_files(sys.argv[2:])
if python_version == "local":
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
+73
View File
@@ -27,6 +27,16 @@ except ImportError:
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
FP8_DTYPE = current_platform.fp8_dtype()
_HIPB_MM_INITIALIZED_DEVICES: set[int] = set()
def _ensure_hipb_mm_extension_initialized() -> None:
import aiter
device = torch.accelerator.current_device_index()
if device not in _HIPB_MM_INITIALIZED_DEVICES:
aiter.hipb_create_extension()
_HIPB_MM_INITIALIZED_DEVICES.add(device)
def is_aiter_found() -> bool:
@@ -625,6 +635,43 @@ def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
return torch.empty(m, n, dtype=output_dtype, device=A.device)
def _rocm_aiter_hipb_mm_fp8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
from aiter import hipb_mm
_ensure_hipb_mm_extension_initialized()
return hipb_mm(
A,
B,
solution_index=-1,
bias=bias,
out_dtype=output_dtype,
scaleA=As,
scaleB=Bs,
scaleOut=None,
bpreshuffle=True,
)
def _rocm_aiter_hipb_mm_fp8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[1]
return torch.empty(m, n, dtype=output_dtype, device=A.device)
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
@@ -1308,6 +1355,7 @@ class rocm_aiter_ops:
# TODO: Consolidate under _LINEAR_ENABLED
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
_LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
# TODO: Consolidate under _LINEAR_ENABLED
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
@@ -1340,6 +1388,7 @@ class rocm_aiter_ops:
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
@@ -1512,6 +1561,13 @@ class rocm_aiter_ops:
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()
@classmethod
@if_aiter_supported
def is_linear_hipbmm_enabled(cls) -> bool:
from vllm.platforms.rocm import on_mi3xx
return cls.is_linear_enabled() and on_mi3xx() and cls._LINEAR_HIPBMM_ENABLED
@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
@@ -1668,6 +1724,12 @@ class rocm_aiter_ops:
fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_hipb_mm_fp8",
op_func=_rocm_aiter_hipb_mm_fp8_impl,
fake_impl=_rocm_aiter_hipb_mm_fp8_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
@@ -1858,6 +1920,17 @@ class rocm_aiter_ops:
A, B, As, Bs, bias, output_dtype
)
@staticmethod
def hipb_mm_fp8(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_hipb_mm_fp8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
+9 -5
View File
@@ -868,9 +868,10 @@ def cutlass_scaled_mm_azp(
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
:param azp_adj: In the per-tensor case, this should include the azp.
Args:
azp_adj: In the per-tensor case, this should include the azp.
Always per-channel.
:param azp: Only set in the per-token case. Per-token if set.
azp: Only set in the per-token case. Per-token if set.
"""
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
@@ -3886,9 +3887,12 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
Note that sylvester hadamard transforms are also symmetric, which means that
this function is also applies the (transpose <=> inverse) transform.
:param x: value to be transformed inplace
:param inplace: modify value in place
:return: value after transformation
Args:
x: value to be transformed inplace
inplace: modify value in place
Returns:
value after transformation
"""
return torch.ops._C.hadacore_transform(x, inplace)
+6 -3
View File
@@ -82,11 +82,12 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
def hash_source(*srcs: str | Any) -> str:
"""
Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash.
Args:
srcs: strings or objects to add to the hash.
Objects and functions have their source inspected.
Results are cached by resolved types to avoid repeated
inspect.getsource() calls.
:return:
"""
# Resolve instances to their class for a hashable cache key.
cache_key = tuple(
@@ -99,7 +100,9 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
def hash_dict(dict_: dict[Any, Any]) -> str:
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
Returns:
A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
@@ -276,8 +276,10 @@ class FixFunctionalizationPass(VllmInductorPass):
"""
Replace mutated getitem users of the auto-functionalized node with the
mutated arguments.
:param node: The auto-functionalized node
:param mutated_args: The mutated arguments, indexed by getitem index.
Args:
node: The auto-functionalized node
mutated_args: The mutated arguments, indexed by getitem index.
If the value of an arg is a string, `node.kwargs[arg]` is used.
"""
for idx, user in self.getitem_users(node).items():
@@ -317,9 +319,10 @@ class FixFunctionalizationPass(VllmInductorPass):
as node.kwargs cannot be used.
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
:param graph: Graph to insert the defunctionalized node into
:param node: The auto-functionalized node to defunctionalize
:param args: If we cannot use kwargs, specify args directly.
Args:
graph: Graph to insert the defunctionalized node into
node: The auto-functionalized node to defunctionalize
args: If we cannot use kwargs, specify args directly.
If an arg is a string, `node.kwargs[arg]` is used.
""" # noqa: E501
assert is_func(node, auto_functionalized), (
@@ -108,9 +108,13 @@ class NoOpEliminationPass(VllmInductorPass):
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
Args:
dim: The dimension arg to reshape/slice
i_dim: The corresponding dimension in the input tensor
Returns:
Are the dimensions equivalent?
There are two cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
+4 -3
View File
@@ -235,9 +235,10 @@ class ModelConfig:
temperature and top_k/top_p.
"""
use_fp64_gumbel: bool = False
"""Whether to use FP64 (instead of FP32) for the Gumbel noise used by the
sampler. FP64 reduces the chance of ties in Gumbel-max sampling at the cost
of significantly lower kernel throughput on most GPUs."""
"""Whether to use FP64 (instead of FP32) random noise for Gumbel-max and
equivalent exponential-race sampling. FP64 preserves lower-tail sampling
events that fp32 uniform/exponential draws can truncate, at the cost of
significantly lower throughput on most GPUs."""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
+6 -3
View File
@@ -180,7 +180,8 @@ class CuMemAllocator:
All data in the memory allocation with the specified tag will be
offloaded to CPU memory, and others will be discarded.
:param offload_tags: The tags of the memory allocation that will be
Args:
offload_tags: The tags of the memory allocation that will be
offloaded. The rest of the memory allocation will be discarded.
"""
if offload_tags is None:
@@ -230,7 +231,8 @@ class CuMemAllocator:
All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
Args:
tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
@@ -255,7 +257,8 @@ class CuMemAllocator:
All memory allocation created inside the context will be allocated
in the memory pool, and has the specified tag.
:param tag: The tag of the memory allocation. If None, the default tag
Args:
tag: The tag of the memory allocation. If None, the default tag
will be used.
"""
if tag is None:

Some files were not shown because too many files have changed in this diff Show More