mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
Merge branch 'main' into wentao-fp8-scaled-mm-oddM
This commit is contained in:
@@ -28,18 +28,19 @@ steps:
|
||||
pytest -x -v -s tests/kernels/quantization/test_cpu_fp8_scaled_mm.py
|
||||
pytest -x -v -s tests/kernels/mamba/cpu/test_cpu_gdn_ops.py"
|
||||
|
||||
- label: CPU-Compatibility Tests
|
||||
depends_on: []
|
||||
device: intel_cpu
|
||||
no_plugin: true
|
||||
source_file_dependencies:
|
||||
- cmake/cpu_extension.cmake
|
||||
- setup.py
|
||||
- vllm/platforms/cpu.py
|
||||
commands:
|
||||
- |
|
||||
bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
|
||||
bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh"
|
||||
# Note: SDE can't be downloaded from CI host because of AWS WAF
|
||||
# - label: CPU-Compatibility Tests
|
||||
# depends_on: []
|
||||
# device: intel_cpu
|
||||
# no_plugin: true
|
||||
# source_file_dependencies:
|
||||
# - cmake/cpu_extension.cmake
|
||||
# - setup.py
|
||||
# - vllm/platforms/cpu.py
|
||||
# commands:
|
||||
# - |
|
||||
# bash .buildkite/scripts/hardware_ci/run-cpu-test.sh 20m "
|
||||
# bash .buildkite/scripts/hardware_ci/run-cpu-compatibility-test.sh"
|
||||
|
||||
- label: CPU-Language Generation and Pooling Model Tests
|
||||
depends_on: []
|
||||
|
||||
@@ -40,7 +40,8 @@ steps:
|
||||
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --block-size 64 --enforce-eager --max-model-len 8192 &&
|
||||
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 &&
|
||||
python3 examples/basic/offline_inference/generate.py --model ibm-research/PowerMoE-3b --block-size 64 --enforce-eager -tp 2 --enable-expert-parallel &&
|
||||
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192
|
||||
python3 examples/basic/offline_inference/generate.py --model superjob/Qwen3-4B-Instruct-2507-GPTQ-Int4 --max-model-len 8192 &&
|
||||
VLLM_XPU_FUSED_MOE_USE_REF=1 python3 examples/basic/offline_inference/generate.py --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --enforce-eager -tp 2
|
||||
'
|
||||
- label: "XPU V1 test"
|
||||
depends_on:
|
||||
|
||||
@@ -36,18 +36,6 @@ pull_request_rules:
|
||||
|
||||
For future commits, `pre-commit` will run automatically on changed files before each commit.
|
||||
|
||||
> [!TIP]
|
||||
> <details>
|
||||
> <summary>Is <code>mypy</code> failing?</summary>
|
||||
> <br/>
|
||||
> <code>mypy</code> is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
|
||||
>
|
||||
> ```bash
|
||||
> # For mypy (substitute "3.10" with the failing version if needed)
|
||||
> pre-commit run --hook-stage manual mypy-3.10
|
||||
> ```
|
||||
> </details>
|
||||
|
||||
- name: comment-dco-failure
|
||||
description: Comment on PR when DCO check fails
|
||||
conditions:
|
||||
|
||||
+7
-13
@@ -148,33 +148,27 @@ repos:
|
||||
language: python
|
||||
entry: python tools/pre_commit/generate_nightly_torch_test.py
|
||||
files: ^requirements/test/cuda\.(in|txt)$
|
||||
- id: mypy-local
|
||||
name: Run mypy locally for lowest supported Python version
|
||||
entry: python tools/pre_commit/mypy.py 0 "3.10"
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: python tools/pre_commit/mypy.py "3.10"
|
||||
<<: &mypy_common
|
||||
language: python
|
||||
types_or: [python, pyi]
|
||||
require_serial: true
|
||||
additional_dependencies: ["mypy[faster-cache]==1.19.1", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.10"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
additional_dependencies: ["mypy==1.20.2", regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.11"
|
||||
entry: python tools/pre_commit/mypy.py "3.11"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.12"
|
||||
entry: python tools/pre_commit/mypy.py "3.12"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.13
|
||||
entry: python tools/pre_commit/mypy.py 1 "3.13"
|
||||
entry: python tools/pre_commit/mypy.py "3.13"
|
||||
<<: *mypy_common
|
||||
stages: [manual] # Only run in CI
|
||||
- id: shellcheck
|
||||
|
||||
@@ -98,11 +98,13 @@ pre-commit run --all-files
|
||||
pre-commit run ruff-check --all-files
|
||||
|
||||
# Run mypy as it is in CI:
|
||||
pre-commit run mypy-3.10 --all-files --hook-stage manual
|
||||
pre-commit run mypy-3.12 --all-files --hook-stage manual
|
||||
```
|
||||
|
||||
The line length limit for Python code is 88 characters. If you are not sure, use pre-commit to check.
|
||||
|
||||
Use [Google-style docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) (`Args:`/`Returns:`/`Raises:` sections), not reStructuredText/Sphinx fields (`:param:`, `:return:`, `:rtype:`).
|
||||
|
||||
### Commit messages
|
||||
|
||||
Add attribution using commit trailers such as `Co-authored-by:` (other projects use `Assisted-by:` or `Generated-by:`). For example:
|
||||
|
||||
+7
-11
@@ -307,10 +307,7 @@ endif()
|
||||
#
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
@@ -346,9 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
@@ -627,6 +621,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
set(VLLM_STABLE_EXT_SRC
|
||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||
"csrc/libtorch_stable/cuda_view.cu"
|
||||
"csrc/libtorch_stable/activation_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/int8/scaled_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/w8a8/fp8/common.cu"
|
||||
@@ -639,6 +634,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
"csrc/libtorch_stable/layernorm_kernels.cu"
|
||||
"csrc/libtorch_stable/layernorm_quant_kernels.cu"
|
||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/libtorch_stable/quantization/fused_kernels/fused_silu_mul_block_quant.cu"
|
||||
"csrc/libtorch_stable/attention/merge_attn_states.cu"
|
||||
"csrc/libtorch_stable/sampler.cu"
|
||||
"csrc/libtorch_stable/topk.cu"
|
||||
@@ -653,8 +649,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_STABLE_EXT_SRC
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/cuda_utils_kernels.cu"
|
||||
"csrc/libtorch_stable/cutlass_extensions/common.cpp"
|
||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
@@ -1076,11 +1072,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
WITH_SOABI)
|
||||
|
||||
# Set TORCH_TARGET_VERSION for stable ABI compatibility.
|
||||
# This ensures we only use C-shim APIs available in PyTorch 2.10.
|
||||
# This ensures we only use C-shim APIs available in PyTorch 2.11.
|
||||
# _C_stable_libtorch is abi compatible with PyTorch >= TORCH_TARGET_VERSION
|
||||
# which is currently set to 2.10.
|
||||
# which is currently set to 2.11.
|
||||
target_compile_definitions(_C_stable_libtorch PRIVATE
|
||||
TORCH_TARGET_VERSION=0x020A000000000000ULL)
|
||||
TORCH_TARGET_VERSION=0x020B000000000000ULL)
|
||||
|
||||
# Needed to use cuda/hip APIs from C-shim
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor,
|
||||
// and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
|
||||
// handle empty tensor
|
||||
if (cpu_tensor.numel() == 0) {
|
||||
return torch::empty(cpu_tensor.sizes(),
|
||||
cpu_tensor.options().device(torch::kCUDA));
|
||||
}
|
||||
|
||||
if (cpu_tensor.is_pinned()) {
|
||||
// If CPU tensor is pinned, directly get the device pointer.
|
||||
void* host_ptr = const_cast<void*>(cpu_tensor.data_ptr());
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
|
||||
return torch::from_blob(
|
||||
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(),
|
||||
[base = cpu_tensor](void*) {}, // keep cpu tensor alive
|
||||
cpu_tensor.options().device(torch::kCUDA));
|
||||
}
|
||||
|
||||
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
|
||||
torch::Tensor contiguous_cpu = cpu_tensor.contiguous();
|
||||
size_t nbytes = contiguous_cpu.nbytes();
|
||||
|
||||
void* host_ptr = nullptr;
|
||||
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
|
||||
if (err != cudaSuccess) {
|
||||
AT_ERROR("cudaHostAlloc failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
err = cudaMemcpy(host_ptr, contiguous_cpu.data_ptr(), nbytes,
|
||||
cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
AT_ERROR("cudaMemcpy failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
void* device_ptr = nullptr;
|
||||
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
AT_ERROR("cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
|
||||
|
||||
return torch::from_blob(device_ptr, contiguous_cpu.sizes(),
|
||||
contiguous_cpu.strides(), deleter,
|
||||
contiguous_cpu.options().device(torch::kCUDA));
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/csrc/stable/device.h>
|
||||
#include <torch/csrc/stable/c/shim.h>
|
||||
#include <torch/headeronly/version.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <array>
|
||||
#include <optional>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor,
|
||||
// and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
|
||||
torch::stable::Tensor& cpu_tensor) {
|
||||
STD_TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
|
||||
const auto dtype = cpu_tensor.scalar_type();
|
||||
const auto layout = cpu_tensor.layout();
|
||||
const torch::stable::Device cuda_dev(torch::headeronly::DeviceType::CUDA);
|
||||
|
||||
// handle empty tensor
|
||||
if (cpu_tensor.numel() == 0) {
|
||||
return torch::stable::empty(cpu_tensor.sizes(), dtype, layout, cuda_dev);
|
||||
}
|
||||
|
||||
std::array<StableIValue, 2> is_pinned_stack{
|
||||
torch::stable::detail::from(cpu_tensor),
|
||||
torch::stable::detail::from(std::nullopt)};
|
||||
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
|
||||
"aten::is_pinned", "", is_pinned_stack.data(), TORCH_ABI_VERSION));
|
||||
if (torch::stable::detail::to<bool>(is_pinned_stack[0])) {
|
||||
// If CPU tensor is pinned, directly get the device pointer.
|
||||
void* host_ptr = const_cast<void*>(cpu_tensor.mutable_data_ptr());
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
STD_TORCH_CHECK(err == cudaSuccess, "cudaHostGetDevicePointer failed: ",
|
||||
cudaGetErrorString(err));
|
||||
|
||||
return torch::stable::from_blob(
|
||||
device_ptr, cpu_tensor.sizes(), cpu_tensor.strides(), cuda_dev, dtype,
|
||||
[base = cpu_tensor](void*) {}); // keep cpu tensor alive
|
||||
}
|
||||
|
||||
// If CPU tensor is not pinned, allocate a new pinned memory buffer.
|
||||
torch::stable::Tensor contiguous_cpu = torch::stable::contiguous(cpu_tensor);
|
||||
size_t nbytes = contiguous_cpu.numel() * contiguous_cpu.element_size();
|
||||
|
||||
void* host_ptr = nullptr;
|
||||
cudaError_t err = cudaHostAlloc(&host_ptr, nbytes, cudaHostAllocMapped);
|
||||
if (err != cudaSuccess) {
|
||||
STD_TORCH_CHECK(false, "cudaHostAlloc failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
err = cudaMemcpy(host_ptr, contiguous_cpu.const_data_ptr(), nbytes,
|
||||
cudaMemcpyDefault);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
STD_TORCH_CHECK(false, "cudaMemcpy failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
void* device_ptr = nullptr;
|
||||
err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
if (err != cudaSuccess) {
|
||||
cudaFreeHost(host_ptr);
|
||||
STD_TORCH_CHECK(
|
||||
false, "cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
auto deleter = [host_ptr](void*) { cudaFreeHost(host_ptr); };
|
||||
|
||||
return torch::stable::from_blob(device_ptr, contiguous_cpu.sizes(),
|
||||
contiguous_cpu.strides(), cuda_dev,
|
||||
contiguous_cpu.scalar_type(), deleter);
|
||||
}
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
@@ -164,6 +164,10 @@ torch::stable::Tensor awq_dequantize(torch::stable::Tensor _kernel,
|
||||
|
||||
#endif
|
||||
|
||||
// CPU tensor -> CUDA UVA view (shared CUDA/ROCm)
|
||||
torch::stable::Tensor get_cuda_view_from_cpu_tensor(
|
||||
torch::stable::Tensor& cpu_tensor);
|
||||
|
||||
// Attention kernels (shared CUDA/ROCm)
|
||||
void merge_attn_states(
|
||||
torch::stable::Tensor& output,
|
||||
@@ -215,6 +219,13 @@ void rms_norm_per_block_quant(torch::stable::Tensor& out,
|
||||
std::optional<torch::stable::Tensor> residual,
|
||||
int64_t group_size, bool is_scale_transposed);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scales,
|
||||
int64_t group_size,
|
||||
std::optional<torch::stable::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
|
||||
// Positional encoding kernels (shared CUDA/ROCm)
|
||||
void rotary_embedding(torch::stable::Tensor& positions,
|
||||
torch::stable::Tensor& query,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "get_group_starts.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
|
||||
+34
-29
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../../torch_utils.h"
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "libtorch_stable/quantization/fused_kernels/quant_conversions.cuh"
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@@ -105,60 +104,66 @@ __global__ void silu_and_mul_per_block_quant_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
void silu_and_mul_per_block_quant(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& input,
|
||||
torch::stable::Tensor& scales,
|
||||
int64_t group_size,
|
||||
std::optional<torch::stable::Tensor> scale_ub,
|
||||
bool is_scale_transposed) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
static torch::headeronly::ScalarType kFp8Type =
|
||||
is_fp8_ocp() ? torch::headeronly::ScalarType::Float8_e4m3fn
|
||||
: torch::headeronly::ScalarType::Float8_e4m3fnuz;
|
||||
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
TORCH_CHECK(
|
||||
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16,
|
||||
STD_TORCH_CHECK(out.scalar_type() == kFp8Type ||
|
||||
out.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
STD_TORCH_CHECK(
|
||||
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||
"Input must be FP16 or BF16");
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32, "Scales must be FP32");
|
||||
TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
STD_TORCH_CHECK(scales.scalar_type() == torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(group_size == 128 || group_size == 64,
|
||||
"Unsupported group size: ", group_size);
|
||||
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
STD_TORCH_CHECK(out.scalar_type() == kFp8Type);
|
||||
}
|
||||
|
||||
int32_t hidden_size = out.size(-1);
|
||||
auto num_tokens = input.size(0);
|
||||
int32_t num_groups = hidden_size / group_size;
|
||||
|
||||
TORCH_CHECK(input.size(-1) == hidden_size * 2,
|
||||
STD_TORCH_CHECK(input.size(-1) == hidden_size * 2,
|
||||
"input last dim must be 2x output hidden_size");
|
||||
TORCH_CHECK(hidden_size % group_size == 0,
|
||||
STD_TORCH_CHECK(hidden_size % group_size == 0,
|
||||
"hidden_size must be divisible by group_size");
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
input.get_device_index());
|
||||
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||
|
||||
dim3 grid(num_tokens, num_groups);
|
||||
dim3 block(group_size);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_STABLE_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_in_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_QUANT_TYPES(
|
||||
VLLM_STABLE_DISPATCH_QUANT_TYPES(
|
||||
out.scalar_type(), "silu_and_mul_per_block_quant", [&] {
|
||||
using scalar_out_t = scalar_t;
|
||||
|
||||
VLLM_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_DISPATCH_BOOL(is_scale_transposed, transpose_scale, [&] {
|
||||
VLLM_STABLE_DISPATCH_GROUP_SIZE(group_size, gs, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(
|
||||
is_scale_transposed, transpose_scale, [&] {
|
||||
vllm::silu_and_mul_per_block_quant_kernel<
|
||||
scalar_in_t, scalar_out_t, transpose_scale, gs>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_out_t>(),
|
||||
scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
out.mutable_data_ptr<scalar_out_t>(),
|
||||
scales.mutable_data_ptr<float>(),
|
||||
input.const_data_ptr<scalar_in_t>(),
|
||||
scale_ub.has_value()
|
||||
? scale_ub->const_data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::stable::Tensor& c,
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "libtorch_stable/cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "ops.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/csrc/stable/library.h>
|
||||
@@ -27,6 +28,8 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
|
||||
#endif
|
||||
@@ -321,6 +324,16 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"Tensor? scale_ub, Tensor!? residual, int group_size, "
|
||||
"bool is_scale_transposed) -> ()");
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||
ops.def(
|
||||
@@ -599,6 +612,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("rms_norm_dynamic_per_token_quant",
|
||||
TORCH_BOX(&rms_norm_dynamic_per_token_quant));
|
||||
ops.impl("rms_norm_per_block_quant", TORCH_BOX(&rms_norm_per_block_quant));
|
||||
ops.impl("silu_and_mul_per_block_quant",
|
||||
TORCH_BOX(&silu_and_mul_per_block_quant));
|
||||
|
||||
// Positional encoding kernels (shared CUDA/ROCm)
|
||||
ops.impl("rotary_embedding", TORCH_BOX(&rotary_embedding));
|
||||
@@ -661,6 +676,11 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
ops.impl("paged_attention_v2", TORCH_BOX(&paged_attention_v2));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CPU, ops) {
|
||||
ops.impl("get_cuda_view_from_cpu_tensor",
|
||||
TORCH_BOX(&get_cuda_view_from_cpu_tensor));
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||
// available for all backends. This is the stable ABI equivalent of calling
|
||||
@@ -681,6 +701,19 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
ops.impl("ggml_moe_get_block_size", TORCH_BOX(&ggml_moe_get_block_size));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cuda_utils, cuda_utils) {
|
||||
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C_cuda_utils, CompositeExplicitAutograd,
|
||||
cuda_utils) {
|
||||
cuda_utils.impl("get_device_attribute", TORCH_BOX(&get_device_attribute));
|
||||
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||
TORCH_BOX(&get_max_shared_memory_per_block_device_attribute));
|
||||
}
|
||||
|
||||
// Cache ops
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(_C_cache_ops, ops) {
|
||||
// Swap in (out) the cache blocks from src to dst.
|
||||
|
||||
@@ -40,12 +40,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void silu_and_mul_per_block_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scales, int64_t group_size,
|
||||
std::optional<torch::Tensor> scale_ub,
|
||||
bool is_scale_transposed);
|
||||
|
||||
// rotary_embedding also exist in csrc/libtorch_stable/ops.h (torch::stable
|
||||
// ABI for CUDA). It remains here because the CPU build still uses these
|
||||
// torch::Tensor declarations.
|
||||
@@ -84,8 +78,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// cache.h, which is no longer included here after cache ops moved to
|
||||
// _C_stable_libtorch).
|
||||
#include <torch/all.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include "core/registration.h"
|
||||
#include <torch/library.h>
|
||||
@@ -32,27 +31,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||
&get_cuda_view_from_cpu_tensor);
|
||||
|
||||
// Activation ops (quantized only — basic ops moved to _C_stable_libtorch)
|
||||
ops.def(
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
// Fused SiLU+Mul + per-block quantization
|
||||
ops.def(
|
||||
"silu_and_mul_per_block_quant("
|
||||
"Tensor! out, "
|
||||
"Tensor input, "
|
||||
"Tensor! scales, "
|
||||
"int group_size, "
|
||||
"Tensor? scale_ub=None, "
|
||||
"bool is_scale_transposed=False) -> ()");
|
||||
ops.impl("silu_and_mul_per_block_quant", torch::kCUDA,
|
||||
&silu_and_mul_per_block_quant);
|
||||
|
||||
// Horizontally-fused DeepseekV4-MLA: per-head RMSNorm + GPT-J RoPE for Q, and
|
||||
// GPT-J RoPE + UE8M0 FP8 quant + paged cache insert for KV, all in one
|
||||
// kernel launch. Registered in _C_stable_libtorch (incl. the FlashInfer V4
|
||||
@@ -178,18 +161,4 @@ TORCH_LIBRARY_FRAGMENT(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
// Cuda utils
|
||||
|
||||
// Gets the specified device attribute.
|
||||
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
|
||||
cuda_utils.impl("get_device_attribute", &get_device_attribute);
|
||||
|
||||
// Gets the maximum shared memory per block device attribute.
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
||||
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@@ -296,7 +296,7 @@ llm = LLM(model="Qwen/Qwen3-8B")
|
||||
The `fastokens` Python package (>= 0.2.0) must be installed; if it isn't,
|
||||
vLLM raises a clear `ImportError` at tokenizer load. The override applies to
|
||||
any `--tokenizer-mode` that ends up loading an HF fast tokenizer (`hf`,
|
||||
`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Modes that don't use the HF
|
||||
`deepseek_v32`, `deepseek_v4`, `qwen_vl`, …). Models that don't use the HF
|
||||
fast tokenizer (`mistral`, `grok2`, `kimi_audio`) ignore the flag.
|
||||
|
||||
Tokenizer-bound workloads — long shared prefixes, bursty short prompts,
|
||||
|
||||
@@ -101,7 +101,7 @@ vLLM's `pre-commit` hooks will now run automatically every time you commit.
|
||||
Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with:
|
||||
|
||||
```bash
|
||||
pre-commit run --hook-stage manual mypy-3.10
|
||||
pre-commit run --hook-stage manual mypy-3.11
|
||||
```
|
||||
|
||||
### Documentation
|
||||
|
||||
@@ -128,7 +128,7 @@ The lease mechanism is controlled through `kv_connector_extra_config` in `--kv-t
|
||||
vllm serve <MODEL> \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_connector_extra_config": {"kv_lease_duration": 60}
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -50,7 +50,7 @@ To select a different backend, set `kv_connector_extra_config.backends` in `--kv
|
||||
vllm serve <MODEL> \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector":"NixlConnector",
|
||||
"kv_role":"kv_both",
|
||||
"kv_role":"kv_producer",
|
||||
"kv_connector_extra_config":{"backends":["LIBFABRIC"]}
|
||||
}'
|
||||
```
|
||||
@@ -60,7 +60,7 @@ You can also pass JSON keys individually using dotted arguments, and you can app
|
||||
```bash
|
||||
vllm serve <MODEL> \
|
||||
--kv-transfer-config.kv_connector NixlConnector \
|
||||
--kv-transfer-config.kv_role kv_both \
|
||||
--kv-transfer-config.kv_role kv_producer \
|
||||
--kv-transfer-config.kv_connector_extra_config.backends+ LIBFABRIC
|
||||
```
|
||||
|
||||
@@ -81,7 +81,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5600 \
|
||||
vllm serve Qwen/Qwen3-0.6B \
|
||||
--port 8100 \
|
||||
--enforce-eager \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_producer","kv_load_failure_policy":"fail"}'
|
||||
```
|
||||
|
||||
### Consumer (Decoder) Configuration
|
||||
@@ -96,7 +96,7 @@ VLLM_NIXL_SIDE_CHANNEL_PORT=5601 \
|
||||
vllm serve Qwen/Qwen3-0.6B \
|
||||
--port 8200 \
|
||||
--enforce-eager \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail"}'
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_consumer","kv_load_failure_policy":"fail"}'
|
||||
```
|
||||
|
||||
### Proxy Server
|
||||
@@ -212,10 +212,21 @@ sequenceDiagram
|
||||
Enable bidirectional KV transfer by setting `bidirectional_kv_xfer` in `kv_connector_extra_config` on **both** P and D instances:
|
||||
|
||||
```bash
|
||||
# Prefill instance
|
||||
vllm serve <MODEL> \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_connector_extra_config": {
|
||||
"bidirectional_kv_xfer": true
|
||||
}
|
||||
}'
|
||||
|
||||
# Decode instance
|
||||
vllm serve <MODEL> \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_connector_extra_config": {
|
||||
"bidirectional_kv_xfer": true
|
||||
}
|
||||
@@ -359,11 +370,10 @@ For multi-host DP deployment, only need to provide the host/port of the head ins
|
||||
|
||||
- **kv_producer**: For prefiller instances that generate KV caches
|
||||
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
|
||||
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.
|
||||
- **kv_both** (deprecated): Previously used as a catch-all when the role was not predetermined. This value is now deprecated for NixlConnector and will be removed in a future release.
|
||||
|
||||
!!! tip
|
||||
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
|
||||
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
|
||||
!!! warning
|
||||
`kv_role="kv_both"` is deprecated for NixlConnector. Please set `kv_role="kv_producer"` for prefill instances and `kv_role="kv_consumer"` for decode instances. See [#33702](https://github.com/vllm-project/vllm/issues/33702) for details.
|
||||
|
||||
### KV Load Failure Policy
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ speculative decoding, breaking down the guarantees into three key areas:
|
||||
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
||||
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||||
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||||
> provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](/tests/v1/spec_decode).
|
||||
> provides a lossless guarantee. Almost all of the tests in [tests/spec_decode/e2e](../../../tests/v1/spec_decode).
|
||||
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
||||
|
||||
3. **vLLM Logprob Stability**
|
||||
|
||||
+9
-9
@@ -83,22 +83,22 @@ plugins:
|
||||
- "re:vllm\\._.*" # Internal modules
|
||||
- "vllm.third_party"
|
||||
- "vllm.vllm_flash_attn"
|
||||
- "re:vllm\\.grpc\\..*_pb2.*" # Auto-generated protobuf files
|
||||
- "vllm.transformers_utils.configs"
|
||||
- "vllm.transformers_utils.processors"
|
||||
- !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
options:
|
||||
show_symbol_type_heading: true
|
||||
show_symbol_type_toc: true
|
||||
filters:
|
||||
- "!.*_pb2_grpc" # Exclude auto-generated gRPC stubs
|
||||
summary:
|
||||
modules: true
|
||||
show_signature_annotations: true
|
||||
separate_signature: true
|
||||
filters: []
|
||||
show_overloads: true
|
||||
signature_crossrefs: true
|
||||
# Recommendations from api-autonav
|
||||
docstring_section_style: list
|
||||
parameter_headings: true
|
||||
show_symbol_type_heading: true
|
||||
show_symbol_type_toc: true
|
||||
summary: true
|
||||
inventories:
|
||||
- https://docs.python.org/3/objects.inv
|
||||
- https://typing-extensions.readthedocs.io/en/latest/objects.inv
|
||||
|
||||
@@ -32,7 +32,7 @@ partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.17.0
|
||||
mistral_common[image] >= 1.11.2
|
||||
mistral_common[image] >= 1.11.3
|
||||
opencv-python-headless >= 4.13.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
||||
@@ -31,7 +31,7 @@ torchaudio==2.11.0
|
||||
torchvision==0.26.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
|
||||
opencv-python-headless >= 4.13.0 # required for video test
|
||||
|
||||
@@ -409,7 +409,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.11.2
|
||||
mistral-common==1.11.3
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/cuda.in
|
||||
|
||||
@@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.11.2 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.11.3 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.13.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
||||
@@ -30,7 +30,7 @@ tblib # for pickling test exceptions
|
||||
timm>=1.0.17 # required for internvl and gemma3n-mm test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio]>=1.11.2 # required for voxtral test
|
||||
mistral_common[image,audio]>=1.11.3 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test, Nemotron Parse in test_common.py
|
||||
opencv-python-headless>=4.13.0 # required for video test
|
||||
|
||||
@@ -512,7 +512,7 @@ mcp==1.27.0
|
||||
# via -r requirements/test/../common.txt
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.11.2
|
||||
mistral-common==1.11.3
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/../common.txt
|
||||
|
||||
@@ -266,7 +266,7 @@ mbstrdecoder==1.1.4
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.11.2
|
||||
mistral-common==1.11.3
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/xpu.in
|
||||
|
||||
@@ -38,13 +38,17 @@ impl HfChatBackend {
|
||||
) -> Result<Self> {
|
||||
let model_config = load_model_config(files.config_path.as_deref())?;
|
||||
let model_type = model_config.model_type().unwrap_or_default();
|
||||
let multimodal_model_info = MultimodalModelInfo::from_paths(
|
||||
let multimodal_model_info = if options.language_model_only {
|
||||
None
|
||||
} else {
|
||||
MultimodalModelInfo::from_paths(
|
||||
model_id.clone(),
|
||||
(!model_type.is_empty()).then_some(model_type.to_string()),
|
||||
files.config_path.as_deref(),
|
||||
files.preprocessor_config_path.as_deref(),
|
||||
tokenizer.clone(),
|
||||
)?;
|
||||
)?
|
||||
};
|
||||
let multimodal_render_info = resolve_multimodal_render_info(multimodal_model_info.as_ref());
|
||||
|
||||
let renderer = options.renderer.resolve(model_type);
|
||||
@@ -225,6 +229,7 @@ mod tests {
|
||||
"test-model".to_string(),
|
||||
LoadModelBackendsOptions {
|
||||
renderer,
|
||||
language_model_only: false,
|
||||
chat_template_content_format: Default::default(),
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: HashMap::new(),
|
||||
@@ -267,6 +272,54 @@ mod tests {
|
||||
assert_eq!(prompt, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn language_model_only_skips_multimodal_preprocessor_config() {
|
||||
let mut files = resolved_files(
|
||||
r#"{"model_type":"deepseek_v0_vl"}"#,
|
||||
r#"{"chat_template":"{{ messages[0].content }}"}"#,
|
||||
);
|
||||
let preprocessor_config_path = files
|
||||
.config_path
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("preprocessor_config.json");
|
||||
write_json(&preprocessor_config_path, r#"{"size":[672,672]}"#);
|
||||
files.preprocessor_config_path = Some(preprocessor_config_path);
|
||||
|
||||
let backend = HfChatBackend::from_resolved_model_files(
|
||||
files.clone(),
|
||||
"test-model".to_string(),
|
||||
LoadModelBackendsOptions {
|
||||
language_model_only: true,
|
||||
chat_template_content_format: Default::default(),
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: HashMap::new(),
|
||||
..Default::default()
|
||||
},
|
||||
test_tokenizer(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(backend.multimodal_model_info().is_none());
|
||||
|
||||
let error = HfChatBackend::from_resolved_model_files(
|
||||
files,
|
||||
"test-model".to_string(),
|
||||
LoadModelBackendsOptions {
|
||||
chat_template_content_format: Default::default(),
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: HashMap::new(),
|
||||
..Default::default()
|
||||
},
|
||||
test_tokenizer(),
|
||||
)
|
||||
.err()
|
||||
.expect("invalid preprocessor config should fail without language_model_only");
|
||||
assert!(error.to_string().contains("failed to parse preprocessor_config.json"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn explicit_deepseek_renderer_overrides_generic_model_type() {
|
||||
let prompt = render_prompt(
|
||||
|
||||
@@ -60,6 +60,9 @@ pub type DynChatTextBackend = Arc<dyn ChatTextBackend>;
|
||||
pub struct LoadModelBackendsOptions {
|
||||
/// Which chat renderer implementation to use.
|
||||
pub renderer: RendererSelection,
|
||||
/// Disable frontend-side multimodal preprocessing and render the model as
|
||||
/// language-only.
|
||||
pub language_model_only: bool,
|
||||
/// How to serialize `message.content` when rendering the chat template.
|
||||
pub chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
/// Optional server-default chat template override, provided either as an
|
||||
|
||||
@@ -116,6 +116,10 @@ pub struct SharedRuntimeArgs {
|
||||
#[arg(long = "tokenizer-mode", default_value_t)]
|
||||
#[serde(default, rename = "tokenizer_mode")]
|
||||
pub renderer: RendererSelection,
|
||||
/// Disable multimodal inputs and treat the model as language-only.
|
||||
#[arg(long)]
|
||||
#[serde(default)]
|
||||
pub language_model_only: bool,
|
||||
/// Override the maximum model context length. When set, the frontend uses
|
||||
/// this value instead of the model's `max_position_embeddings` from
|
||||
/// `config.json`.
|
||||
@@ -243,6 +247,7 @@ impl SharedRuntimeArgs {
|
||||
tool_call_parser: self.tool_call_parser,
|
||||
reasoning_parser: self.reasoning_parser,
|
||||
renderer: self.renderer,
|
||||
language_model_only: self.language_model_only,
|
||||
chat_template: self.chat_template,
|
||||
default_chat_template_kwargs: self.default_chat_template_kwargs,
|
||||
chat_template_content_format: self.chat_template_content_format,
|
||||
@@ -284,6 +289,7 @@ impl SharedRuntimeArgs {
|
||||
tool_call_parser: self.tool_call_parser,
|
||||
reasoning_parser: self.reasoning_parser,
|
||||
renderer: self.renderer,
|
||||
language_model_only: self.language_model_only,
|
||||
chat_template: self.chat_template,
|
||||
default_chat_template_kwargs: self.default_chat_template_kwargs,
|
||||
chat_template_content_format: self.chat_template_content_format,
|
||||
@@ -419,6 +425,7 @@ impl ServeArgs {
|
||||
self.managed_engine.clone().into_config(
|
||||
self.runtime.model.clone(),
|
||||
self.runtime.max_model_len,
|
||||
self.runtime.language_model_only,
|
||||
handshake_port,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ fn serve_args_forward_python_flags_with_separator() {
|
||||
tool_call_parser: Auto,
|
||||
reasoning_parser: Auto,
|
||||
renderer: Auto,
|
||||
language_model_only: false,
|
||||
max_model_len: Some(
|
||||
512,
|
||||
),
|
||||
@@ -263,6 +264,7 @@ fn frontend_args_accept_json() {
|
||||
tool_call_parser: Auto,
|
||||
reasoning_parser: Auto,
|
||||
renderer: Auto,
|
||||
language_model_only: false,
|
||||
max_model_len: None,
|
||||
grpc_port: None,
|
||||
shutdown_timeout: 0,
|
||||
@@ -321,7 +323,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() {
|
||||
"--output-address",
|
||||
"ipc:///tmp/output.sock",
|
||||
"--args-json",
|
||||
r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","max_model_len":8192,"shutdown_timeout":3}"#,
|
||||
r#"{"model_tag":"Qwen/Qwen3-0.6B","engine_ready_timeout_secs":42,"tool_call_parser":"hermes","reasoning_parser":"qwen3_thinking","tokenizer_mode":"deepseek_v32","language_model_only":true,"max_model_len":8192,"shutdown_timeout":3}"#,
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
@@ -338,6 +340,7 @@ fn frontend_args_json_accepts_supported_non_default_fields() {
|
||||
ParserSelection::Explicit("qwen3_thinking".to_string())
|
||||
);
|
||||
assert_eq!(args.runtime.renderer, RendererSelection::DeepSeekV32);
|
||||
assert!(args.runtime.language_model_only);
|
||||
assert_eq!(args.runtime.max_model_len, Some(8192));
|
||||
assert_eq!(args.runtime.shutdown_timeout, 3);
|
||||
}
|
||||
@@ -662,6 +665,7 @@ fn serve_args_accept_handshake_aliases() {
|
||||
tool_call_parser: Auto,
|
||||
reasoning_parser: Auto,
|
||||
renderer: Auto,
|
||||
language_model_only: false,
|
||||
max_model_len: None,
|
||||
grpc_port: None,
|
||||
shutdown_timeout: 0,
|
||||
@@ -783,6 +787,7 @@ fn serve_frontend_config_uses_dp_address_as_advertised_host() {
|
||||
tool_call_parser: Auto,
|
||||
reasoning_parser: Auto,
|
||||
renderer: Auto,
|
||||
language_model_only: false,
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: None,
|
||||
chat_template_content_format: Auto,
|
||||
@@ -846,6 +851,7 @@ fn serve_frontend_config_keeps_tcp_transport_for_non_local_only_topology() {
|
||||
tool_call_parser: Auto,
|
||||
reasoning_parser: Auto,
|
||||
renderer: Auto,
|
||||
language_model_only: false,
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: None,
|
||||
chat_template_content_format: Auto,
|
||||
@@ -924,6 +930,7 @@ fn frontend_config_uses_external_coordinator_when_coordinator_address_is_present
|
||||
tool_call_parser: Auto,
|
||||
reasoning_parser: Auto,
|
||||
renderer: Auto,
|
||||
language_model_only: false,
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: None,
|
||||
chat_template_content_format: Auto,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::slice;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
@@ -253,8 +252,18 @@ pub(crate) async fn run_abort_loop(
|
||||
inner: Arc<ClientInner>,
|
||||
mut abort_rx: mpsc::UnboundedReceiver<AbortRequest>,
|
||||
) {
|
||||
// TODO: receive and abort requests in batch
|
||||
while let Some(AbortRequest { request_id, cause }) = abort_rx.recv().await {
|
||||
// Coalesce bursts of auto-aborts into a single Abort message per engine.
|
||||
// A dropped-stream storm (e.g. many clients disconnecting at once under
|
||||
// high concurrency) would otherwise issue one engine round-trip per
|
||||
// request. `recv_many` returns as soon as at least one item is ready, so a
|
||||
// lone abort is still forwarded promptly.
|
||||
const MAX_DRAIN: usize = 1024;
|
||||
let mut batch: Vec<AbortRequest> = Vec::new();
|
||||
|
||||
while abort_rx.recv_many(&mut batch, MAX_DRAIN).await > 0 {
|
||||
let mut by_engine: BTreeMap<EngineId, Vec<String>> = BTreeMap::new();
|
||||
|
||||
for AbortRequest { request_id, cause } in batch.drain(..) {
|
||||
let Some(engine_id) = inner.take_auto_abort_target(&request_id) else {
|
||||
debug!(request_id, "skip auto-abort for inactive request");
|
||||
continue;
|
||||
@@ -272,16 +281,20 @@ pub(crate) async fn run_abort_loop(
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(error) = inner.do_abort_requests(&engine_id, slice::from_ref(&request_id)).await
|
||||
{
|
||||
by_engine.entry(engine_id).or_default().push(request_id);
|
||||
}
|
||||
|
||||
for (engine_id, request_ids) in by_engine {
|
||||
if let Err(error) = inner.do_abort_requests(&engine_id, &request_ids).await {
|
||||
warn!(
|
||||
request_id,
|
||||
?engine_id,
|
||||
?request_ids,
|
||||
error = %error.as_report(),
|
||||
"failed to auto-abort dropped request stream"
|
||||
"failed to auto-abort request streams"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Background loop that listens for engine-core outputs and dispatches them to
|
||||
|
||||
@@ -1225,6 +1225,86 @@ async fn dropping_a_live_stream_triggers_abort() {
|
||||
client.shutdown().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dropping_multiple_live_streams_aborts_all_in_a_burst() {
|
||||
init_tracing();
|
||||
let ipc = IpcNamespace::new().unwrap();
|
||||
let handshake_address = ipc.handshake_endpoint();
|
||||
let engine_id = b"engine-burst".to_vec();
|
||||
let request_ids = ["req-1", "req-2", "req-3"];
|
||||
|
||||
let (shutdown_tx, engine_task) = spawn_mock_engine_task(
|
||||
handshake_address.clone(),
|
||||
engine_id.clone(),
|
||||
|dealer, push| {
|
||||
Box::pin(async move {
|
||||
for _ in 0..3 {
|
||||
let add = recv_engine_message(dealer).await;
|
||||
assert_eq!(add[0].as_ref(), &[0x00]);
|
||||
}
|
||||
send_outputs(
|
||||
push,
|
||||
EngineCoreOutputs {
|
||||
outputs: vec![
|
||||
request_output("req-1", vec![99], None),
|
||||
request_output("req-2", vec![99], None),
|
||||
request_output("req-3", vec![99], None),
|
||||
],
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
let abort =
|
||||
timeout(Duration::from_secs(1), recv_engine_message(dealer)).await.unwrap();
|
||||
assert_eq!(abort[0].as_ref(), &[0x01]);
|
||||
let ids: Vec<String> = rmp_serde::from_slice(&abort[1]).unwrap();
|
||||
assert_eq!(
|
||||
ids,
|
||||
vec![
|
||||
"req-1".to_string(),
|
||||
"req-2".to_string(),
|
||||
"req-3".to_string()
|
||||
]
|
||||
);
|
||||
assert!(
|
||||
timeout(Duration::from_millis(100), recv_engine_message(dealer)).await.is_err()
|
||||
);
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
let client = connect_client_with_ipc(
|
||||
handshake_test_config(
|
||||
handshake_address,
|
||||
1,
|
||||
"test-model",
|
||||
Duration::from_secs(2),
|
||||
0,
|
||||
None,
|
||||
),
|
||||
&ipc,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Open every request first so all three adds reach the engine before it
|
||||
// emits outputs, then drain the first token from each stream.
|
||||
let mut streams = Vec::new();
|
||||
for id in request_ids {
|
||||
streams.push(client.call(sample_request_with_id(id)).await.unwrap());
|
||||
}
|
||||
for stream in streams.iter_mut() {
|
||||
let first = timeout(Duration::from_secs(1), stream.next()).await.unwrap().unwrap().unwrap();
|
||||
assert_eq!(first.new_token_ids, vec![99]);
|
||||
}
|
||||
// Drop the whole burst back-to-back so the abort worker can batch them.
|
||||
drop(streams);
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
engine_task.await.unwrap();
|
||||
client.shutdown().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn dispatcher_failure_propagates_to_streams_and_future_calls() {
|
||||
init_tracing();
|
||||
|
||||
@@ -71,6 +71,7 @@ impl ManagedEngineArgs {
|
||||
self,
|
||||
model: String,
|
||||
max_model_len: Option<u32>,
|
||||
language_model_only: bool,
|
||||
handshake_port: u16,
|
||||
) -> ManagedEngineConfig {
|
||||
let mut python_args = self.python_args;
|
||||
@@ -79,6 +80,9 @@ impl ManagedEngineArgs {
|
||||
python_args.push("--max-model-len".to_string());
|
||||
python_args.push(max_model_len.to_string());
|
||||
}
|
||||
if language_model_only {
|
||||
python_args.push("--language-model-only".to_string());
|
||||
}
|
||||
if let Some(data_parallel_size_local) = self.data_parallel_size_local {
|
||||
python_args.push("--data-parallel-size-local".to_string());
|
||||
python_args.push(data_parallel_size_local.to_string());
|
||||
|
||||
@@ -64,6 +64,7 @@ async fn main() -> Result<()> {
|
||||
tool_call_parser: ParserSelection::Auto,
|
||||
reasoning_parser: ParserSelection::Auto,
|
||||
renderer: RendererSelection::Auto,
|
||||
language_model_only: false,
|
||||
chat_template: None,
|
||||
default_chat_template_kwargs: None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption::Auto,
|
||||
|
||||
@@ -53,6 +53,9 @@ pub struct Config {
|
||||
pub reasoning_parser: ParserSelection,
|
||||
/// Chat renderer selection.
|
||||
pub renderer: RendererSelection,
|
||||
/// Disable frontend-side multimodal preprocessing and render the model as
|
||||
/// language-only.
|
||||
pub language_model_only: bool,
|
||||
/// Server-default chat template override, as a file path or inline
|
||||
/// template.
|
||||
pub chat_template: Option<String>,
|
||||
|
||||
@@ -42,6 +42,7 @@ async fn build_state(config: &Config) -> Result<Arc<AppState>> {
|
||||
&config.model,
|
||||
LoadModelBackendsOptions {
|
||||
renderer: config.renderer,
|
||||
language_model_only: config.language_model_only,
|
||||
chat_template: config.chat_template.clone(),
|
||||
chat_template_content_format: config.chat_template_content_format,
|
||||
default_chat_template_kwargs: config
|
||||
|
||||
@@ -85,6 +85,7 @@ pub async fn chat_completions(
|
||||
log_request,
|
||||
prepared.include_usage,
|
||||
prepared.requested_logprobs,
|
||||
prepared.include_reasoning,
|
||||
prepared.echo,
|
||||
prepared.return_token_ids,
|
||||
prepared.return_tokens_as_token_ids,
|
||||
@@ -100,6 +101,7 @@ pub async fn chat_completions(
|
||||
created,
|
||||
prepared.requested_logprobs,
|
||||
prepared.include_prompt_logprobs,
|
||||
prepared.include_reasoning,
|
||||
prepared.echo,
|
||||
prepared.return_token_ids,
|
||||
prepared.return_tokens_as_token_ids,
|
||||
@@ -134,6 +136,7 @@ async fn collect_chat_completion(
|
||||
created: u64,
|
||||
requested_logprobs: bool,
|
||||
include_prompt_logprobs: bool,
|
||||
include_reasoning: bool,
|
||||
echo: Option<String>,
|
||||
return_token_ids: bool,
|
||||
return_tokens_as_token_ids: bool,
|
||||
@@ -157,6 +160,11 @@ async fn collect_chat_completion(
|
||||
} = collected;
|
||||
let stop_reason = finish_reason.as_stop_reason().map(stop_reason_to_json);
|
||||
let saw_tool_calls = message.tool_calls().next().is_some();
|
||||
let reasoning = message.reasoning();
|
||||
// Output logprobs and token IDs cover the complete generated token stream.
|
||||
// When reasoning is hidden, omit them rather than leaking hidden reasoning
|
||||
// tokens through per-token metadata.
|
||||
let include_output_metadata = include_reasoning || reasoning.is_none();
|
||||
let finish_reason = chat_finish_reason_to_openai(&finish_reason, saw_tool_calls)?.to_string();
|
||||
let tool_calls = message
|
||||
.tool_calls()
|
||||
@@ -169,7 +177,7 @@ async fn collect_chat_completion(
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let logprobs = if requested_logprobs {
|
||||
let logprobs = if requested_logprobs && include_output_metadata {
|
||||
Some(decoded_logprobs_to_openai_chat(
|
||||
logprobs.as_ref().ok_or_else(|| {
|
||||
server_error!("chat response requested logprobs but generation returned none")
|
||||
@@ -207,12 +215,12 @@ async fn collect_chat_completion(
|
||||
None => Some(message.text()).filter(|t| !t.is_empty()),
|
||||
},
|
||||
tool_calls: Some(tool_calls).filter(|calls| !calls.is_empty()),
|
||||
reasoning: message.reasoning(),
|
||||
reasoning: if include_reasoning { reasoning } else { None },
|
||||
},
|
||||
logprobs,
|
||||
finish_reason: Some(finish_reason),
|
||||
stop_reason,
|
||||
token_ids: return_token_ids.then_some(token_ids),
|
||||
token_ids: (return_token_ids && include_output_metadata).then_some(token_ids),
|
||||
}],
|
||||
usage: Some(usage),
|
||||
system_fingerprint: None,
|
||||
@@ -232,12 +240,18 @@ async fn chat_completion_chunk_stream(
|
||||
log_request: bool,
|
||||
include_usage: bool,
|
||||
requested_logprobs: bool,
|
||||
include_reasoning: bool,
|
||||
echo: Option<String>,
|
||||
return_token_ids: bool,
|
||||
return_tokens_as_token_ids: bool,
|
||||
mut y: TryYielder<ChatCompletionStreamResponse, ApiError>,
|
||||
) -> Result<(), ApiError> {
|
||||
let mut saw_tool_calls = false;
|
||||
// `LogprobsDelta` is emitted after all chat events for one decoded update.
|
||||
// If that update contains hidden reasoning, including delimiter-only block
|
||||
// starts or ends, omit its token metadata as well as its visible delta.
|
||||
let mut inside_hidden_reasoning = false;
|
||||
let mut suppress_current_update_metadata = false;
|
||||
|
||||
// If the client requested logprobs or token_ids, we need to buffer chunks until
|
||||
// we receive the separate `LogprobsDelta` event, so that we can emit one
|
||||
@@ -268,6 +282,9 @@ async fn chat_completion_chunk_stream(
|
||||
}
|
||||
}
|
||||
Ok(ChatEvent::BlockDelta { kind, delta, .. }) => {
|
||||
let include_delta =
|
||||
include_reasoning || !matches!(kind, AssistantBlockKind::Reasoning);
|
||||
if include_delta {
|
||||
if let Some(pending_chunk) = pending_chunk.as_mut() {
|
||||
pending_chunk.push_block_delta(kind, delta);
|
||||
} else {
|
||||
@@ -280,17 +297,29 @@ async fn chat_completion_chunk_stream(
|
||||
))
|
||||
.await;
|
||||
}
|
||||
} else {
|
||||
suppress_current_update_metadata = true;
|
||||
}
|
||||
}
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs,
|
||||
token_ids,
|
||||
}) => {
|
||||
let openai_logprobs = logprobs
|
||||
let include_metadata =
|
||||
!suppress_current_update_metadata && !inside_hidden_reasoning;
|
||||
suppress_current_update_metadata = false;
|
||||
let openai_logprobs = if include_metadata {
|
||||
logprobs
|
||||
.as_ref()
|
||||
.map(|lp| decoded_logprobs_to_openai_chat(lp, return_tokens_as_token_ids))
|
||||
.transpose()?;
|
||||
let openai_token_ids =
|
||||
return_token_ids.then_some(token_ids).filter(|t| !t.is_empty());
|
||||
.transpose()?
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let openai_token_ids = include_metadata
|
||||
.then_some(token_ids)
|
||||
.and_then(|token_ids| return_token_ids.then_some(token_ids))
|
||||
.filter(|t| !t.is_empty());
|
||||
if let Some(pending_chunk) = pending_chunk.as_mut() {
|
||||
pending_chunk.logprobs = openai_logprobs;
|
||||
pending_chunk.token_ids = openai_token_ids;
|
||||
@@ -311,9 +340,17 @@ async fn chat_completion_chunk_stream(
|
||||
}
|
||||
Ok(ChatEvent::BlockStart { kind, .. }) => {
|
||||
debug!(?kind, "starting new block");
|
||||
if !include_reasoning && matches!(kind, AssistantBlockKind::Reasoning) {
|
||||
inside_hidden_reasoning = true;
|
||||
suppress_current_update_metadata = true;
|
||||
}
|
||||
}
|
||||
Ok(ChatEvent::BlockEnd { .. }) => {
|
||||
debug!("ending current block");
|
||||
if inside_hidden_reasoning {
|
||||
inside_hidden_reasoning = false;
|
||||
suppress_current_update_metadata = true;
|
||||
}
|
||||
}
|
||||
Ok(ChatEvent::ToolCallStart { index, id, name }) => {
|
||||
let tool_index = index as u32;
|
||||
@@ -763,7 +800,9 @@ fn stop_reason_to_json(stop_reason: &StopReason) -> Value {
|
||||
mod tests {
|
||||
use futures::{StreamExt as _, stream};
|
||||
use serde_json::json;
|
||||
use vllm_chat::{AssistantBlockKind, AssistantToolCall, ChatEvent, FinishReason};
|
||||
use vllm_chat::{
|
||||
AssistantBlockKind, AssistantContentBlock, AssistantToolCall, ChatEvent, FinishReason,
|
||||
};
|
||||
use vllm_engine_core_client::protocol::StopReason;
|
||||
use vllm_text::{DecodedLogprobs, DecodedPositionLogprobs, DecodedTokenLogprob};
|
||||
|
||||
@@ -895,6 +934,7 @@ mod tests {
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
None,
|
||||
false,
|
||||
false,
|
||||
@@ -958,6 +998,7 @@ mod tests {
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
None,
|
||||
false,
|
||||
false,
|
||||
@@ -976,6 +1017,292 @@ mod tests {
|
||||
assert!(chunks[1].choices[0].logprobs.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chunk_stream_omits_reasoning_delta_when_disabled() {
|
||||
let stream = stream::iter(vec![
|
||||
Ok(ChatEvent::Start {
|
||||
prompt_token_ids: vec![].into(),
|
||||
prompt_logprobs: None,
|
||||
}),
|
||||
Ok(ChatEvent::BlockDelta {
|
||||
index: 0,
|
||||
kind: AssistantBlockKind::Reasoning,
|
||||
delta: "think".to_string(),
|
||||
}),
|
||||
Ok(ChatEvent::BlockDelta {
|
||||
index: 1,
|
||||
kind: AssistantBlockKind::Text,
|
||||
delta: "answer".to_string(),
|
||||
}),
|
||||
Ok(ChatEvent::Done {
|
||||
message: Default::default(),
|
||||
prompt_token_count: 1,
|
||||
output_token_count: 2,
|
||||
finish_reason: FinishReason::stop_eos(),
|
||||
kv_transfer_params: None,
|
||||
}),
|
||||
]);
|
||||
|
||||
let chunks = chat_completion_chunk_stream(
|
||||
stream,
|
||||
"chatcmpl-1".to_string(),
|
||||
"model".to_string(),
|
||||
1,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
None,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.expect("stream chunks");
|
||||
|
||||
assert_eq!(chunks.len(), 3);
|
||||
assert_eq!(
|
||||
chunks[1].choices[0].delta.content.as_deref(),
|
||||
Some("answer")
|
||||
);
|
||||
assert!(
|
||||
chunks
|
||||
.iter()
|
||||
.all(|chunk| chunk.choices.iter().all(|choice| choice.delta.reasoning.is_none()))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chunk_stream_omits_logprobs_for_suppressed_reasoning() {
|
||||
let stream = stream::iter(vec![
|
||||
Ok(ChatEvent::Start {
|
||||
prompt_token_ids: vec![].into(),
|
||||
prompt_logprobs: None,
|
||||
}),
|
||||
Ok(ChatEvent::BlockDelta {
|
||||
index: 0,
|
||||
kind: AssistantBlockKind::Reasoning,
|
||||
delta: "think".to_string(),
|
||||
}),
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs: Some(DecodedLogprobs {
|
||||
positions: vec![DecodedPositionLogprobs {
|
||||
entries: vec![DecodedTokenLogprob {
|
||||
token_id: 11,
|
||||
token: "think".to_string(),
|
||||
logprob: -0.1,
|
||||
rank: 1,
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
token_ids: vec![11],
|
||||
}),
|
||||
Ok(ChatEvent::BlockDelta {
|
||||
index: 1,
|
||||
kind: AssistantBlockKind::Text,
|
||||
delta: "answer".to_string(),
|
||||
}),
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs: Some(DecodedLogprobs {
|
||||
positions: vec![DecodedPositionLogprobs {
|
||||
entries: vec![DecodedTokenLogprob {
|
||||
token_id: 22,
|
||||
token: "answer".to_string(),
|
||||
logprob: -0.2,
|
||||
rank: 1,
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
token_ids: vec![22],
|
||||
}),
|
||||
Ok(ChatEvent::Done {
|
||||
message: Default::default(),
|
||||
prompt_token_count: 1,
|
||||
output_token_count: 2,
|
||||
finish_reason: FinishReason::stop_eos(),
|
||||
kv_transfer_params: None,
|
||||
}),
|
||||
]);
|
||||
|
||||
let chunks = chat_completion_chunk_stream(
|
||||
stream,
|
||||
"chatcmpl-1".to_string(),
|
||||
"model".to_string(),
|
||||
1,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.expect("stream chunks");
|
||||
|
||||
assert_eq!(chunks.len(), 3);
|
||||
let choice = &chunks[1].choices[0];
|
||||
assert_eq!(choice.delta.content.as_deref(), Some("answer"));
|
||||
assert_eq!(choice.token_ids.as_deref(), Some(&[22][..]));
|
||||
let logprobs = choice.logprobs.as_ref().expect("answer logprobs");
|
||||
let content = logprobs.content.as_ref().expect("logprobs content");
|
||||
assert_eq!(content[0].token, "answer");
|
||||
assert!(chunks.iter().all(|chunk| {
|
||||
chunk.choices.iter().all(|choice| {
|
||||
choice.delta.reasoning.is_none()
|
||||
&& choice.token_ids.as_deref() != Some(&[11][..])
|
||||
&& choice
|
||||
.logprobs
|
||||
.as_ref()
|
||||
.and_then(|logprobs| logprobs.content.as_ref())
|
||||
.is_none_or(|content| content.iter().all(|entry| entry.token != "think"))
|
||||
})
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chunk_stream_omits_logprobs_for_hidden_reasoning_delimiters() {
|
||||
let stream = stream::iter(vec![
|
||||
Ok(ChatEvent::Start {
|
||||
prompt_token_ids: vec![].into(),
|
||||
prompt_logprobs: None,
|
||||
}),
|
||||
Ok(ChatEvent::BlockStart {
|
||||
index: 0,
|
||||
kind: AssistantBlockKind::Reasoning,
|
||||
}),
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs: Some(DecodedLogprobs {
|
||||
positions: vec![DecodedPositionLogprobs {
|
||||
entries: vec![DecodedTokenLogprob {
|
||||
token_id: 11,
|
||||
token: "<think>".to_string(),
|
||||
logprob: -0.1,
|
||||
rank: 1,
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
token_ids: vec![11],
|
||||
}),
|
||||
Ok(ChatEvent::BlockDelta {
|
||||
index: 0,
|
||||
kind: AssistantBlockKind::Reasoning,
|
||||
delta: "think".to_string(),
|
||||
}),
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs: Some(DecodedLogprobs {
|
||||
positions: vec![DecodedPositionLogprobs {
|
||||
entries: vec![DecodedTokenLogprob {
|
||||
token_id: 12,
|
||||
token: "think".to_string(),
|
||||
logprob: -0.2,
|
||||
rank: 1,
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
token_ids: vec![12],
|
||||
}),
|
||||
Ok(ChatEvent::BlockEnd {
|
||||
index: 0,
|
||||
block: AssistantContentBlock::Reasoning {
|
||||
text: "think".to_string(),
|
||||
},
|
||||
}),
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs: Some(DecodedLogprobs {
|
||||
positions: vec![DecodedPositionLogprobs {
|
||||
entries: vec![DecodedTokenLogprob {
|
||||
token_id: 13,
|
||||
token: "</think>".to_string(),
|
||||
logprob: -0.3,
|
||||
rank: 1,
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
token_ids: vec![13],
|
||||
}),
|
||||
Ok(ChatEvent::BlockStart {
|
||||
index: 1,
|
||||
kind: AssistantBlockKind::Text,
|
||||
}),
|
||||
Ok(ChatEvent::BlockDelta {
|
||||
index: 1,
|
||||
kind: AssistantBlockKind::Text,
|
||||
delta: "answer".to_string(),
|
||||
}),
|
||||
Ok(ChatEvent::LogprobsDelta {
|
||||
logprobs: Some(DecodedLogprobs {
|
||||
positions: vec![DecodedPositionLogprobs {
|
||||
entries: vec![DecodedTokenLogprob {
|
||||
token_id: 22,
|
||||
token: "answer".to_string(),
|
||||
logprob: -0.4,
|
||||
rank: 1,
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
token_ids: vec![22],
|
||||
}),
|
||||
Ok(ChatEvent::Done {
|
||||
message: Default::default(),
|
||||
prompt_token_count: 1,
|
||||
output_token_count: 4,
|
||||
finish_reason: FinishReason::stop_eos(),
|
||||
kv_transfer_params: None,
|
||||
}),
|
||||
]);
|
||||
|
||||
let chunks = chat_completion_chunk_stream(
|
||||
stream,
|
||||
"chatcmpl-1".to_string(),
|
||||
"model".to_string(),
|
||||
1,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.expect("stream chunks");
|
||||
|
||||
assert_eq!(chunks.len(), 3);
|
||||
let choice = &chunks[1].choices[0];
|
||||
assert_eq!(choice.delta.content.as_deref(), Some("answer"));
|
||||
assert_eq!(choice.token_ids.as_deref(), Some(&[22][..]));
|
||||
let logprobs = choice.logprobs.as_ref().expect("answer logprobs");
|
||||
let content = logprobs.content.as_ref().expect("logprobs content");
|
||||
assert_eq!(content[0].token, "answer");
|
||||
assert!(chunks.iter().all(|chunk| {
|
||||
chunk.choices.iter().all(|choice| {
|
||||
choice.delta.reasoning.is_none()
|
||||
&& !choice
|
||||
.token_ids
|
||||
.as_ref()
|
||||
.is_some_and(|ids| matches!(ids.as_slice(), [11] | [12] | [13]))
|
||||
&& choice
|
||||
.logprobs
|
||||
.as_ref()
|
||||
.and_then(|logprobs| logprobs.content.as_ref())
|
||||
.is_none_or(|content| {
|
||||
content.iter().all(|entry| {
|
||||
!matches!(entry.token.as_str(), "<think>" | "think" | "</think>")
|
||||
})
|
||||
})
|
||||
})
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chunk_stream_preserves_tool_call_index_and_omits_id_from_arguments_delta() {
|
||||
let stream = stream::iter(vec![
|
||||
@@ -1017,6 +1344,7 @@ mod tests {
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
None,
|
||||
false,
|
||||
false,
|
||||
|
||||
@@ -29,6 +29,8 @@ pub struct PreparedRequest {
|
||||
pub requested_logprobs: bool,
|
||||
/// Whether the caller requested top-level prompt logprobs.
|
||||
pub include_prompt_logprobs: bool,
|
||||
/// Whether to include reasoning content in OpenAI responses.
|
||||
pub include_reasoning: bool,
|
||||
/// Lowered chat request for `vllm-chat`.
|
||||
pub chat_request: ChatRequest,
|
||||
/// Last assistant-role message content to echo back when `echo=true`.
|
||||
@@ -57,6 +59,7 @@ pub(crate) fn prepare_chat_request(
|
||||
.as_ref()
|
||||
.map(|request| request.lora_name.clone())
|
||||
.unwrap_or_else(|| lora_resolution.model_names.first().cloned().unwrap_or_default());
|
||||
let include_reasoning = request.include_reasoning;
|
||||
let echo = request
|
||||
.echo
|
||||
.then(|| extract_last_assistant_content(&request.messages))
|
||||
@@ -146,6 +149,7 @@ pub(crate) fn prepare_chat_request(
|
||||
include_usage,
|
||||
requested_logprobs,
|
||||
include_prompt_logprobs,
|
||||
include_reasoning,
|
||||
chat_request,
|
||||
echo,
|
||||
return_token_ids: request.return_token_ids.unwrap_or(false),
|
||||
@@ -480,6 +484,23 @@ mod tests {
|
||||
assert_eq!(prepared.chat_request.tool_choice, ChatToolChoice::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_chat_request_preserves_include_reasoning_false() {
|
||||
let request = ChatCompletionRequest {
|
||||
include_reasoning: false,
|
||||
..base_request()
|
||||
};
|
||||
|
||||
let prepared = prepare_chat_request(
|
||||
request,
|
||||
&served(&["Qwen/Qwen1.5-0.5B-Chat"]),
|
||||
ResolvedRequestContext::default(),
|
||||
)
|
||||
.expect("request is valid");
|
||||
|
||||
assert!(!prepared.include_reasoning);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_chat_request_preserves_sampling_passthrough_fields() {
|
||||
let request = ChatCompletionRequest {
|
||||
|
||||
@@ -120,12 +120,6 @@ pub(super) fn validate_request_compat(
|
||||
"thinking_token_budget",
|
||||
"thinking_token_budget is not supported.",
|
||||
)?;
|
||||
if !request.include_reasoning {
|
||||
bail_invalid_request!(
|
||||
param = "include_reasoning",
|
||||
"include_reasoning is not supported."
|
||||
);
|
||||
}
|
||||
reject_non_default(
|
||||
request.media_io_kwargs.as_ref(),
|
||||
"media_io_kwargs",
|
||||
@@ -312,6 +306,17 @@ mod tests {
|
||||
.expect("reasoning_effort should be accepted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_request_compat_accepts_include_reasoning_false() {
|
||||
let request = ChatCompletionRequest {
|
||||
include_reasoning: false,
|
||||
..base_request()
|
||||
};
|
||||
|
||||
validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"]))
|
||||
.expect("include_reasoning=false should be accepted");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_request_compat_rejects_top_logprobs_without_logprobs() {
|
||||
let request = ChatCompletionRequest {
|
||||
|
||||
@@ -3492,6 +3492,170 @@ async fn reasoning_blocks_are_mapped_to_reasoning_sse_chunks() {
|
||||
assert!(text.contains("\"content\":\"answer\""), "{text}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[serial]
|
||||
async fn include_reasoning_false_suppresses_reasoning_in_non_stream_chat() {
|
||||
let (app, engine_task) = test_app_with_backend_and_stream_output_specs(
|
||||
Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")),
|
||||
vec![
|
||||
(bytes_to_token_ids(b"<think>think</think>"), None),
|
||||
(
|
||||
bytes_to_token_ids(b"answer"),
|
||||
Some(EngineCoreFinishReason::Length),
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let response = app
|
||||
.clone()
|
||||
.call(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/chat/completions")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
json!({
|
||||
"model": "Qwen/Qwen1.5-0.5B-Chat",
|
||||
"stream": false,
|
||||
"include_reasoning": false,
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
.expect("build request"),
|
||||
)
|
||||
.await
|
||||
.expect("call app");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body");
|
||||
engine_task.await.expect("mock engine task");
|
||||
let text = String::from_utf8(body.to_vec()).expect("utf8 body");
|
||||
let json: serde_json::Value = serde_json::from_str(&text).expect("decode json");
|
||||
|
||||
assert_eq!(json["choices"][0]["message"]["content"], "answer");
|
||||
assert!(
|
||||
json["choices"][0]["message"]
|
||||
.as_object()
|
||||
.is_some_and(|message| !message.contains_key("reasoning")),
|
||||
"{text}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[serial]
|
||||
async fn include_reasoning_false_suppresses_non_stream_output_metadata() {
|
||||
let ipc = IpcNamespace::new().expect("create ipc namespace");
|
||||
let handshake_address = ipc.handshake_endpoint();
|
||||
let engine_id = b"engine-openai-hidden-reasoning-logprobs".to_vec();
|
||||
|
||||
let engine_task = MockEngineTask::new(spawn_mock_engine_task(
|
||||
handshake_address.clone(),
|
||||
engine_id.clone(),
|
||||
|dealer, push| {
|
||||
boxed_test_future(async move {
|
||||
let add = recv_engine_message(dealer).await;
|
||||
let request: EngineCoreRequest =
|
||||
rmp_serde::from_slice(&add[1]).expect("decode request");
|
||||
let reasoning_token_ids = bytes_to_token_ids(b"<think>think</think>");
|
||||
let answer_token_ids = bytes_to_token_ids(b"answer");
|
||||
|
||||
send_outputs(
|
||||
push,
|
||||
EngineCoreOutputs {
|
||||
engine_index: 0,
|
||||
outputs: vec![
|
||||
request_output_with_logprobs(
|
||||
&request.request_id,
|
||||
reasoning_token_ids.clone(),
|
||||
None,
|
||||
None,
|
||||
Some(sample_logprobs_for_tokens(&reasoning_token_ids)),
|
||||
None,
|
||||
),
|
||||
request_output_with_logprobs(
|
||||
&request.request_id,
|
||||
answer_token_ids.clone(),
|
||||
Some(EngineCoreFinishReason::Length),
|
||||
None,
|
||||
Some(sample_logprobs_for_tokens(&answer_token_ids)),
|
||||
None,
|
||||
),
|
||||
],
|
||||
scheduler_stats: None,
|
||||
timestamp: 0.0,
|
||||
utility_output: None,
|
||||
finished_requests: None,
|
||||
wave_complete: None,
|
||||
start_wave: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
})
|
||||
},
|
||||
));
|
||||
|
||||
let client = EngineCoreClient::connect(
|
||||
EngineCoreClientConfig::new_single(handshake_address)
|
||||
.with_model_name("test-model")
|
||||
.with_local_input_output_addresses(
|
||||
Some(ipc.input_endpoint()),
|
||||
Some(ipc.output_endpoint()),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("connect client");
|
||||
let chat = ChatLlm::from_shared_backend(
|
||||
test_llm(client),
|
||||
Arc::new(FakeChatBackend::with_model_id("Qwen/Qwen3-0.6B")),
|
||||
);
|
||||
let mut app = build_router(Arc::new(AppState::new(
|
||||
vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()],
|
||||
chat,
|
||||
)));
|
||||
|
||||
let response = app
|
||||
.call(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/chat/completions")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
json!({
|
||||
"model": "Qwen/Qwen1.5-0.5B-Chat",
|
||||
"stream": false,
|
||||
"include_reasoning": false,
|
||||
"logprobs": true,
|
||||
"return_token_ids": true,
|
||||
"messages": [{"role": "user", "content": "hello"}]
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
.expect("build request"),
|
||||
)
|
||||
.await
|
||||
.expect("call app");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body");
|
||||
engine_task.await.expect("mock engine task");
|
||||
let text = String::from_utf8(body.to_vec()).expect("utf8 body");
|
||||
let json: serde_json::Value = serde_json::from_str(&text).expect("decode json");
|
||||
let choice = json["choices"][0].as_object().expect("choice object");
|
||||
|
||||
assert_eq!(json["choices"][0]["message"]["content"], "answer");
|
||||
assert!(
|
||||
json["choices"][0]["message"]
|
||||
.as_object()
|
||||
.is_some_and(|message| !message.contains_key("reasoning")),
|
||||
"{text}"
|
||||
);
|
||||
assert!(!choice.contains_key("logprobs"), "{text}");
|
||||
assert!(!choice.contains_key("token_ids"), "{text}");
|
||||
assert!(json["prompt_token_ids"].is_array(), "{text}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[serial]
|
||||
async fn tool_calls_are_mapped_to_tool_call_sse_chunks() {
|
||||
|
||||
@@ -115,6 +115,8 @@ impl<T: Tokenizer + ?Sized> IncrementalDecoder for DecodeStream<'_, T> {
|
||||
|
||||
fn next_chunk(&mut self) -> Option<String> {
|
||||
let cutoff = self.cumulative_output.len().saturating_sub(self.min_bytes_to_buffer);
|
||||
// Ensure we split at a utf-8 char boundary.
|
||||
let cutoff = self.cumulative_output.floor_char_boundary(cutoff);
|
||||
(cutoff > self.output_index).then(|| {
|
||||
let chunk = self.cumulative_output[self.output_index..cutoff].to_string();
|
||||
self.output_index = cutoff;
|
||||
@@ -356,4 +358,27 @@ mod tests {
|
||||
assert_eq!(last_chunk.as_deref(), Some("lo!"));
|
||||
assert_eq!(full_text, "Hello!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_chunk_cutoff_respects_char_boundary() {
|
||||
// Regression: next_chunk's cutoff (len - min_bytes_to_buffer) must be
|
||||
// aligned to a UTF-8 char boundary like push_token/flush; otherwise
|
||||
// streaming multi-byte output (CJK/emoji) with a hold-back buffer (set
|
||||
// by a stop string) panics slicing cumulative_output mid-character.
|
||||
let backend = Utf8Backend;
|
||||
let mut decoder = backend.create_decode_stream(&[], false, 2);
|
||||
let mut out = String::new();
|
||||
for byte in "你好A".bytes() {
|
||||
decoder.push_token(u32::from(byte)).unwrap();
|
||||
if let Some(chunk) = decoder.next_chunk() {
|
||||
out.push_str(&chunk);
|
||||
}
|
||||
}
|
||||
let (last_chunk, full_text) = decoder.flush(None).unwrap();
|
||||
if let Some(chunk) = last_chunk {
|
||||
out.push_str(&chunk);
|
||||
}
|
||||
assert_eq!(full_text, "你好A");
|
||||
assert_eq!(out, "你好A");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -565,6 +565,8 @@ def test_size_used_in_multiple_consumer_subgraphs():
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
torch._dynamo.mark_dynamic(y, 0)
|
||||
torch.compile(model_fn, backend=capturing_backend)(x, y)
|
||||
assert captured_graph is not None, "Graph should be captured by backend"
|
||||
assert captured_inputs is not None, "Example inputs should be captured by backend"
|
||||
|
||||
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
@@ -379,6 +380,12 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
|
||||
update_environment_variables(env)
|
||||
local_rank = int(env["LOCAL_RANK"])
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
device_id=torch.device(f"cuda:{local_rank}"),
|
||||
)
|
||||
else:
|
||||
dist.init_process_group(backend="nccl")
|
||||
use_workspace = env.get("USE_WORKSPACE") == "1"
|
||||
if use_workspace:
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.utils import ensure_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
@@ -82,6 +83,13 @@ def test_pynccl():
|
||||
@worker_fn_wrapper
|
||||
def multiple_allreduce_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
# Eager-init path: parent PG has bound_device_id + a CPU backend,
|
||||
# so split_group is supported.
|
||||
group = torch.distributed.split_group(
|
||||
split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
|
||||
)
|
||||
else:
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
|
||||
@@ -339,6 +347,11 @@ def test_pynccl_send_recv():
|
||||
@worker_fn_wrapper
|
||||
def multiple_send_recv_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
group = torch.distributed.split_group(
|
||||
split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
|
||||
)
|
||||
else:
|
||||
groups = [
|
||||
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
|
||||
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
|
||||
|
||||
@@ -9,6 +9,7 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
@@ -397,12 +398,26 @@ def qr_variable_input(rank, world_size):
|
||||
ranks = []
|
||||
for i in range(world_size):
|
||||
ranks.append(i)
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method="tcp://127.0.0.1:29500",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
else:
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://127.0.0.1:29500",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
cpu_group = torch.distributed.split_group(
|
||||
split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
|
||||
)
|
||||
else:
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
|
||||
|
||||
handle = ops.qr_get_handle(_ptr)
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for split_group in GroupCoordinator.
|
||||
|
||||
These tests verify that:
|
||||
1. split_group is used for both device and CPU group creation.
|
||||
2. Multiple subgroups work correctly with split_group.
|
||||
3. Both GPU and CPU all-reduce work on split groups.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import multiprocess as mp
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
# The whole module exercises the split_group code path, which is opt-in
|
||||
# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
|
||||
reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
|
||||
)
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
number_of_processes = world_size
|
||||
processes: list[mp.Process] = []
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12346"
|
||||
# propagate the opt-in flag to the spawned child workers
|
||||
env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
|
||||
p = mp.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
local_rank = os.environ["LOCAL_RANK"]
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.accelerator.set_device_index(device)
|
||||
init_distributed_environment()
|
||||
fn()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
def _verify_device_group(coordinator: GroupCoordinator):
|
||||
"""Verify device group works via all-reduce."""
|
||||
local_rank = torch.distributed.get_rank()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
|
||||
torch.distributed.all_reduce(tensor, group=coordinator.device_group)
|
||||
torch.accelerator.synchronize()
|
||||
expected = coordinator.world_size
|
||||
assert torch.all(tensor == expected).cpu().item(), (
|
||||
f"Device group all-reduce failed: expected {expected}, "
|
||||
f"got {tensor.flatten()[0].item()}"
|
||||
)
|
||||
|
||||
|
||||
def _verify_cpu_group(coordinator: GroupCoordinator):
|
||||
"""Verify CPU group works via all-reduce."""
|
||||
tensor = torch.ones(16, dtype=torch.float32)
|
||||
torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
|
||||
expected = coordinator.world_size
|
||||
assert torch.all(tensor == expected).cpu().item(), (
|
||||
f"CPU group all-reduce failed: expected {expected}, "
|
||||
f"got {tensor.flatten()[0].item()}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Basic split_group path with 2 GPUs
|
||||
# ---------------------------------------------------------------------------
|
||||
@worker_fn_wrapper
|
||||
def split_group_basic_worker():
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
group_ranks = [list(range(world_size))]
|
||||
|
||||
coordinator = GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=rank,
|
||||
torch_distributed_backend="nccl",
|
||||
use_device_communicator=False,
|
||||
group_name="test_split_basic",
|
||||
)
|
||||
|
||||
_verify_device_group(coordinator)
|
||||
_verify_cpu_group(coordinator)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.accelerator.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.",
|
||||
)
|
||||
def test_split_group_basic():
|
||||
"""Test basic GroupCoordinator creation with split_group."""
|
||||
distributed_run(split_group_basic_worker, 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Multiple subgroups with split_group (4 GPUs)
|
||||
# ---------------------------------------------------------------------------
|
||||
@worker_fn_wrapper
|
||||
def split_group_multiple_subgroups_worker():
|
||||
rank = torch.distributed.get_rank()
|
||||
group_ranks = [[0, 1], [2, 3]]
|
||||
|
||||
coordinator = GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=rank,
|
||||
torch_distributed_backend="nccl",
|
||||
use_device_communicator=False,
|
||||
group_name="test_split_multi",
|
||||
)
|
||||
|
||||
assert coordinator.world_size == 2
|
||||
|
||||
_verify_device_group(coordinator)
|
||||
_verify_cpu_group(coordinator)
|
||||
|
||||
if rank in [0, 1]:
|
||||
assert coordinator.ranks == [0, 1]
|
||||
else:
|
||||
assert coordinator.ranks == [2, 3]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.accelerator.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.",
|
||||
)
|
||||
def test_split_group_multiple_subgroups():
|
||||
"""Test GroupCoordinator with multiple independent subgroups."""
|
||||
distributed_run(split_group_multiple_subgroups_worker, 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: split_group contract — every parent rank must enter with the same
|
||||
# ``split_ranks``. NCCL happens to produce
|
||||
# correct subgroups for disjoint partitions because the wrapper hashes
|
||||
# ``my_group`` to derive the comm-split color, but the contract violation is
|
||||
# real and would break under non-partition / non-NCCL backends. This test
|
||||
# captures the actual ``split_ranks`` argument passed on every rank and
|
||||
# asserts they match.
|
||||
# ---------------------------------------------------------------------------
|
||||
@worker_fn_wrapper
|
||||
def split_group_contract_worker():
|
||||
rank = torch.distributed.get_rank()
|
||||
group_ranks = [[0, 1], [2, 3]]
|
||||
|
||||
captured: list[list[list[int]]] = []
|
||||
original_split_group = torch.distributed.split_group
|
||||
|
||||
def capturing_split_group(*args, split_ranks=None, **kwargs):
|
||||
captured.append([list(g) for g in split_ranks])
|
||||
return original_split_group(*args, split_ranks=split_ranks, **kwargs)
|
||||
|
||||
torch.distributed.split_group = capturing_split_group
|
||||
try:
|
||||
GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=rank,
|
||||
torch_distributed_backend="nccl",
|
||||
use_device_communicator=False,
|
||||
group_name="test_split_contract",
|
||||
)
|
||||
finally:
|
||||
torch.distributed.split_group = original_split_group
|
||||
|
||||
# GroupCoordinator builds two subgroups (device + cpu) per coordinator,
|
||||
# so every rank must have made exactly two split_group calls.
|
||||
if len(captured) != 2:
|
||||
raise AssertionError(
|
||||
f"rank {rank} expected 2 split_group calls (device + cpu), "
|
||||
f"got {len(captured)}: {captured}"
|
||||
)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
for call_idx in range(2):
|
||||
gathered: list[Any] = [None] * world_size
|
||||
torch.distributed.all_gather_object(gathered, captured[call_idx])
|
||||
# Normalize for stable comparison (sort each subgroup and the outer list).
|
||||
norm = [
|
||||
sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
|
||||
]
|
||||
reference = norm[0]
|
||||
for r, args in enumerate(norm):
|
||||
if args != reference:
|
||||
raise AssertionError(
|
||||
f"split_group contract violation on call #{call_idx}: "
|
||||
f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
|
||||
f"passed split_ranks={gathered[0]}. PyTorch requires every "
|
||||
"parent rank to enter split_group with the same split_ranks."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.accelerator.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.",
|
||||
)
|
||||
def test_split_group_contract_same_split_ranks_on_all_ranks():
|
||||
"""All parent ranks must call torch.distributed.split_group with the same
|
||||
``split_ranks`` argument. This catches the bug where each rank passed
|
||||
only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
|
||||
disjoint partitions but is a documented contract violation.
|
||||
"""
|
||||
distributed_run(split_group_contract_worker, 4)
|
||||
@@ -5,13 +5,26 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
# Let PyTorch choose the WORLD backend for the current device type.
|
||||
dist.init_process_group()
|
||||
# By default, let PyTorch choose the WORLD backend for the current device
|
||||
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
|
||||
# use the explicit eager-init pattern required by `split_group` (mixed
|
||||
# cpu:gloo,cuda:nccl backend + device_id binding).
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
device_id=torch.device(f"cuda:{local_rank}"),
|
||||
)
|
||||
else:
|
||||
dist.init_process_group()
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
|
||||
@@ -5,13 +5,26 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||
|
||||
# Let PyTorch choose the WORLD backend for the current device type.
|
||||
dist.init_process_group()
|
||||
# By default, let PyTorch choose the WORLD backend for the current device
|
||||
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
|
||||
# use the explicit eager-init pattern required by `split_group` (mixed
|
||||
# cpu:gloo,cuda:nccl backend + device_id binding).
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
device_id=torch.device(f"cuda:{local_rank}"),
|
||||
)
|
||||
else:
|
||||
dist.init_process_group()
|
||||
|
||||
# Create prompts
|
||||
prompts = [
|
||||
|
||||
@@ -808,14 +808,20 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
|
||||
"logprobs": False,
|
||||
}
|
||||
|
||||
chat_completion = await client.chat.completions.create(**request_args)
|
||||
# Use raw HTTP for both endpoints so we compare server responses
|
||||
# directly, without the openai SDK injecting extra fields
|
||||
# (e.g. `moderation` added in newer SDK versions).
|
||||
chat_response = requests.post(
|
||||
server.url_for("v1/chat/completions"), json=request_args
|
||||
)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(
|
||||
server.url_for("invocations"), json=request_args
|
||||
)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_completion.model_dump()
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
|
||||
@@ -5,8 +5,16 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
|
||||
reason="Lightning attention Triton kernels require CUDA/ROCm or XPU.",
|
||||
)
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
HEAD_SIZES = [64]
|
||||
BATCH_SIZES = [1, 2]
|
||||
@@ -121,7 +129,7 @@ def test_linear_decode_forward_triton(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(42)
|
||||
base = 0.01
|
||||
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
@@ -129,16 +137,16 @@ def test_linear_decode_forward_triton(
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
slope_rate = torch.zeros(num_heads, device=DEVICE)
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
slot_idx = torch.arange(batch_size, device=DEVICE)
|
||||
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
@@ -162,7 +170,7 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(42)
|
||||
|
||||
batch_size = 4
|
||||
@@ -172,16 +180,16 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
slope_rate = torch.zeros(num_heads, device="cuda")
|
||||
slope_rate = torch.zeros(num_heads, device=DEVICE)
|
||||
for h in range(num_heads):
|
||||
slope_rate[h] = 0.1 * (h + 1)
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device=DEVICE)
|
||||
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
@@ -224,7 +232,7 @@ def test_lightning_attention_reference(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(42)
|
||||
|
||||
base = 0.01
|
||||
@@ -232,12 +240,12 @@ def test_lightning_attention_reference(
|
||||
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
ed = torch.zeros(num_heads, device=DEVICE)
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device=DEVICE
|
||||
)
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
cleanup_dist_env_and_memory,
|
||||
@@ -60,7 +61,15 @@ def _set_vllm_config(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
cpu_group = torch.distributed.split_group(
|
||||
split_ranks=[list(range(world_size))],
|
||||
group_desc="moe_test_cpu",
|
||||
)
|
||||
else:
|
||||
cpu_group = torch.distributed.new_group(
|
||||
list(range(world_size)), backend="gloo"
|
||||
)
|
||||
return cpu_group
|
||||
|
||||
|
||||
|
||||
@@ -205,7 +205,10 @@ def run_with_expert_maps(
|
||||
w2 = kwargs["w2"]
|
||||
a = kwargs["hidden_states"]
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=w2.shape[0],
|
||||
max_num_tokens=kwargs.get("hidden_states").shape[0],
|
||||
experts_per_token=kwargs.get("topk_ids").shape[1],
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
hidden_dim=w2.shape[1],
|
||||
intermediate_size_per_partition=w2.shape[2],
|
||||
in_dtype=a.dtype,
|
||||
@@ -258,23 +261,27 @@ def run_8_bit(
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
|
||||
kwargs = {
|
||||
"hidden_states": moe_tensors.a,
|
||||
"w1": moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
"w2": moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"global_num_experts": moe_tensors.w1_q.shape[0], # type: ignore[union-attr]
|
||||
"global_num_experts": num_experts,
|
||||
"activation": MoEActivation.SILU,
|
||||
"expert_map": None,
|
||||
"apply_router_weight_on_input": False,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0) # type: ignore[attr-defined]
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
moe_config = make_dummy_moe_config(
|
||||
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
|
||||
max_num_tokens=moe_tensors.a.shape[0],
|
||||
experts_per_token=topk_ids.shape[1],
|
||||
num_experts=num_experts,
|
||||
num_local_experts=num_local_experts,
|
||||
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
|
||||
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
|
||||
in_dtype=moe_tensors.a.dtype,
|
||||
@@ -581,6 +588,7 @@ def test_run_cutlass_moe_fp8(
|
||||
per_out_channel,
|
||||
False,
|
||||
topk_weights,
|
||||
None,
|
||||
)
|
||||
|
||||
workspace13.random_()
|
||||
|
||||
@@ -14,6 +14,7 @@ import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
@@ -375,6 +376,12 @@ def _test_deepep_deepgemm_moe(
|
||||
w1_scale = w1_scale.to(device=device)
|
||||
w2_scale = w2_scale.to(device=device)
|
||||
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
pg = torch.distributed.split_group(
|
||||
split_ranks=[list(range(pgi.world_size))],
|
||||
group_desc="deepep_deepgemm_test",
|
||||
)
|
||||
else:
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, pgi.rank)
|
||||
block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
import torch.distributed
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -375,6 +376,12 @@ def _deep_ep_moe(
|
||||
w1_scale = w1_scale.to(device=device_idx)
|
||||
w2_scale = w2_scale.to(device=device_idx)
|
||||
|
||||
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
|
||||
pg = torch.distributed.split_group(
|
||||
split_ranks=[list(range(pgi.world_size))],
|
||||
group_desc="deepep_test",
|
||||
)
|
||||
else:
|
||||
pg = torch.distributed.new_group(list(range(pgi.world_size)))
|
||||
test_tensors = TestTensors.make(config, low_latency_mode)
|
||||
|
||||
|
||||
@@ -49,10 +49,12 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def make_dummy_moe_config(
|
||||
num_experts: int = 1,
|
||||
num_local_experts: int | None = None,
|
||||
experts_per_token: int = 1,
|
||||
hidden_dim: int = 1,
|
||||
intermediate_size_per_partition: int = 1,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
max_num_tokens: int = 512,
|
||||
) -> FusedMoEConfig:
|
||||
"""
|
||||
This is a dummy config for the mk constructor interface
|
||||
@@ -66,14 +68,16 @@ def make_dummy_moe_config(
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
num_local_experts=num_experts,
|
||||
num_local_experts=num_local_experts
|
||||
if num_local_experts is not None
|
||||
else num_experts,
|
||||
num_logical_experts=num_experts,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation=MoEActivation.SILU,
|
||||
in_dtype=in_dtype,
|
||||
device="cuda",
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
max_num_tokens=512,
|
||||
max_num_tokens=max_num_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,15 @@ from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
device = "cuda"
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() or current_platform.is_xpu()),
|
||||
reason="AWQ Triton kernels require CUDA/ROCm or XPU.",
|
||||
)
|
||||
|
||||
device = current_platform.device_type
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
|
||||
@@ -100,32 +100,6 @@ def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) ->
|
||||
assert output.shape == (2, output_size)
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_handling():
|
||||
# Mock a quant config that supports cache scales
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Condition check in load_weights
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Check if get_cache_scale is called and returns expected value
|
||||
mock_quant_config.get_cache_scale.assert_called_once_with(name)
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
|
||||
|
||||
def test_kv_cache_scale_name_no_scale():
|
||||
# Mock a quant config that returns None for get_cache_scale
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value=None)
|
||||
|
||||
name = "layers.0.mlp.gate_proj.weight"
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
|
||||
# Should return None for weights that don't have cache scales
|
||||
assert scale_name is None
|
||||
|
||||
|
||||
def test_maybe_remap_kv_scale_name():
|
||||
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
|
||||
|
||||
@@ -183,33 +157,3 @@ def test_eagle3_lm_head_receives_quant_config():
|
||||
assert call_kwargs["quant_config"] is mock_quant_config, (
|
||||
"ParallelLMHead must receive the draft model's quant_config"
|
||||
)
|
||||
|
||||
|
||||
def test_load_weights_kv_scale_handling():
|
||||
kv_scale_param = Mock()
|
||||
kv_scale_param.weight_loader = Mock()
|
||||
|
||||
params_dict = {
|
||||
"layers.0.self_attn.kv_scale": kv_scale_param,
|
||||
}
|
||||
|
||||
mock_quant_config = Mock()
|
||||
mock_quant_config.get_cache_scale = Mock(return_value="layers.0.self_attn.kv_scale")
|
||||
|
||||
# Load_weights logic for KV cache scales
|
||||
name = "layers.0.self_attn.k_proj.weight"
|
||||
loaded_weight_tensor = torch.tensor([1.0, 2.0])
|
||||
|
||||
if mock_quant_config is not None:
|
||||
scale_name = mock_quant_config.get_cache_scale(name)
|
||||
if scale_name:
|
||||
param = params_dict[scale_name]
|
||||
assert param is kv_scale_param
|
||||
weight_to_load = (
|
||||
loaded_weight_tensor
|
||||
if loaded_weight_tensor.dim() == 0
|
||||
else loaded_weight_tensor[0]
|
||||
)
|
||||
|
||||
assert scale_name == "layers.0.self_attn.kv_scale"
|
||||
assert weight_to_load == loaded_weight_tensor[0]
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
# -----------------------------------------------------------------------
|
||||
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
|
||||
# -----------------------------------------------------------------------
|
||||
COLBERT_MODELS = {
|
||||
COLBERT_MODELS: dict[str, dict] = {
|
||||
"bert": {
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"colbert_dim": 96,
|
||||
|
||||
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_rocm()
|
||||
and current_platform.supports_fp8()
|
||||
and aiter_available
|
||||
),
|
||||
reason="Requires ROCm + FP8 support + aiter",
|
||||
),
|
||||
pytest.mark.usefixtures("default_vllm_config"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_hipb_mm_kernel(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
yield
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
|
||||
def _make_config(
|
||||
*,
|
||||
weight_quant_key=kFp8StaticChannelSym,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
weight_shape: tuple[int, int] = (512, 4096),
|
||||
) -> FP8ScaledMMLinearLayerConfig:
|
||||
return FP8ScaledMMLinearLayerConfig(
|
||||
weight_quant_key=weight_quant_key,
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_shape=weight_shape,
|
||||
input_dtype=torch.bfloat16,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None:
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
|
||||
with open(path, newline="") as f:
|
||||
reader = csv.DictReader(f, skipinitialspace=True)
|
||||
for row in reader:
|
||||
try:
|
||||
if (
|
||||
int(row.get("m", -1)) == m
|
||||
and int(row.get("n", -1)) == n
|
||||
and int(row.get("k", -1)) == k
|
||||
):
|
||||
return dict(row)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _skip_if_no_hipb_mm_solution(exc: RuntimeError) -> None:
|
||||
if "hipblasLtMatmulAlgoGetHeuristic found 0 valid solutions" in str(exc):
|
||||
pytest.skip(
|
||||
"hipb_mm bpreshuffle path has no valid hipBLASLt solution on "
|
||||
"this ROCm stack."
|
||||
)
|
||||
|
||||
|
||||
def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens: int):
|
||||
import aiter
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda")
|
||||
w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
aiter.hipb_create_extension()
|
||||
x_q, x_scale = aiter.pertoken_quant(x, quant_dtype=current_platform.fp8_dtype())
|
||||
w_q, w_scale = aiter.pertoken_quant(w, quant_dtype=current_platform.fp8_dtype())
|
||||
|
||||
try:
|
||||
aiter.hipb_mm(
|
||||
x_q,
|
||||
shuffle_weight(w_q, layout=(16, 16)).t(),
|
||||
solution_index=-1,
|
||||
out_dtype=torch.bfloat16,
|
||||
scaleA=x_scale,
|
||||
scaleB=w_scale.t().contiguous(),
|
||||
scaleOut=None,
|
||||
bpreshuffle=True,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
_skip_if_no_hipb_mm_solution(exc)
|
||||
raise
|
||||
|
||||
|
||||
def test_hipb_mm_kernel_requires_hipbmm_flag(monkeypatch: pytest.MonkeyPatch):
|
||||
# The kernel rejects when `is_hip_fp8bmm_enabled()` is False. That helper
|
||||
# requires AITER + AITER_LINEAR + MI3xx, so dropping AITER_LINEAR exercises
|
||||
# the rejection branch.
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0")
|
||||
monkeypatch.delenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", raising=False)
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported()
|
||||
|
||||
assert not is_supported
|
||||
assert reason == (
|
||||
"requires setting `VLLM_ROCM_USE_AITER=1`, "
|
||||
"`VLLM_ROCM_USE_AITER_LINEAR=1`, "
|
||||
"and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`."
|
||||
)
|
||||
|
||||
|
||||
def test_hipb_mm_flag_enables_hip_online_tuning(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
import vllm.envs as envs_mod
|
||||
import vllm.platforms.rocm as rocm_mod
|
||||
|
||||
# The rocm.py gate requires all three AITER flags (and MI3xx) to auto-set
|
||||
# HIP_ONLINE_TUNING.
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1")
|
||||
|
||||
try:
|
||||
importlib.reload(envs_mod)
|
||||
importlib.reload(rocm_mod)
|
||||
assert envs_mod.VLLM_ROCM_USE_AITER
|
||||
assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR
|
||||
assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
|
||||
assert os.environ.get("HIP_ONLINE_TUNING") == "1"
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
os.environ.pop("HIP_ONLINE_TUNING", None)
|
||||
importlib.reload(envs_mod)
|
||||
importlib.reload(rocm_mod)
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
|
||||
def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel):
|
||||
can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
|
||||
_make_config()
|
||||
)
|
||||
|
||||
assert can_implement
|
||||
assert reason is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("config", "expected_reason"),
|
||||
[
|
||||
(
|
||||
_make_config(weight_quant_key=kFp8StaticTensorSym),
|
||||
"requires per token activation scales and per channel weight scales.",
|
||||
),
|
||||
(
|
||||
_make_config(out_dtype=torch.float16),
|
||||
"requires bfloat16 output dtype.",
|
||||
),
|
||||
(
|
||||
_make_config(weight_shape=(8, 4090)),
|
||||
"requires N >= 16 and both N and K divisible by 16, "
|
||||
"received N=8 and K=4090.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hipb_mm_kernel_can_implement_rejects_unsupported_configs(
|
||||
enable_hipb_mm_kernel,
|
||||
config: FP8ScaledMMLinearLayerConfig,
|
||||
expected_reason: str,
|
||||
):
|
||||
can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement(
|
||||
config
|
||||
)
|
||||
|
||||
assert not can_implement
|
||||
assert reason == expected_reason
|
||||
|
||||
|
||||
def test_hipb_mm_kernel_process_weights_after_loading_shuffles_weights(
|
||||
enable_hipb_mm_kernel,
|
||||
):
|
||||
weight_shape = (512, 4096)
|
||||
kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
|
||||
_make_config(weight_shape=weight_shape),
|
||||
layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
|
||||
)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
layer.weight = torch.nn.Parameter(
|
||||
torch.rand(weight_shape, device="cuda").to(current_platform.fp8_dtype()).t(),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
torch.rand((weight_shape[0], 1), dtype=torch.float32, device="cuda"),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.input_scale = None
|
||||
layer.input_scale_ub = None
|
||||
|
||||
original_weight = layer.weight.detach().clone()
|
||||
original_weight_scale = layer.weight_scale.detach().clone()
|
||||
|
||||
kernel.process_weights_after_loading(layer)
|
||||
|
||||
# process_weights_after_loading now pre-applies the transposes that used
|
||||
# to live in _rocm_aiter_hipb_mm_fp8_impl, so the stored weight is the
|
||||
# shuffled tensor with a trailing `.t()` view, and the stored weight scale
|
||||
# is its transposed-contiguous form.
|
||||
expected_weight = rocm_aiter_ops.shuffle_weight(
|
||||
original_weight.t().contiguous()
|
||||
).t()
|
||||
torch.testing.assert_close(layer.weight, expected_weight)
|
||||
|
||||
expected_weight_scale = original_weight_scale.t().contiguous()
|
||||
torch.testing.assert_close(layer.weight_scale, expected_weight_scale)
|
||||
|
||||
|
||||
def test_hipb_mm_kernel_forward_matches_raw_aiter_hipb_mm(enable_hipb_mm_kernel):
|
||||
import aiter
|
||||
|
||||
weight_shape = (512, 4096)
|
||||
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=32)
|
||||
|
||||
layer = TestFP8Layer(
|
||||
weight_shape=weight_shape,
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticChannelSym,
|
||||
input_dtype=torch.bfloat16,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=torch.device("cuda"),
|
||||
force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
# hipb_mm uses a transposed-result GEMM internally, so the flattened token
|
||||
# count becomes the effective N dimension passed into hipBLASLt. Keep it
|
||||
# aligned to avoid heuristic failures for tiny N.
|
||||
x = torch.randn(2, 16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
|
||||
bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
try:
|
||||
out = layer(x, bias)
|
||||
except RuntimeError as exc:
|
||||
_skip_if_no_hipb_mm_solution(exc)
|
||||
raise
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
x_q, x_scale = layer.kernel.quant_fp8(
|
||||
x_2d,
|
||||
layer.input_scale,
|
||||
layer.input_scale_ub,
|
||||
)
|
||||
try:
|
||||
# process_weights_after_loading already applies the trailing `.t()` on
|
||||
# the shuffled weight and the `.t().contiguous()` on the weight scale,
|
||||
# so the raw aiter call uses them directly.
|
||||
expected = aiter.hipb_mm(
|
||||
x_q,
|
||||
layer.weight,
|
||||
solution_index=-1,
|
||||
bias=bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
scaleA=x_scale,
|
||||
scaleB=layer.weight_scale,
|
||||
scaleOut=None,
|
||||
bpreshuffle=True,
|
||||
).view(*out.shape)
|
||||
except RuntimeError as exc:
|
||||
_skip_if_no_hipb_mm_solution(exc)
|
||||
raise
|
||||
|
||||
assert isinstance(layer.kernel, AiterHipbMMPerTokenFp8ScaledMMLinearKernel)
|
||||
assert out.shape == (2, 16, weight_shape[0])
|
||||
torch.testing.assert_close(out, expected)
|
||||
|
||||
|
||||
def test_hipb_mm_kernel_forward_accuracy(enable_hipb_mm_kernel):
|
||||
"""Kernel output should match a dequantized fp32 reference within
|
||||
fp8 per-token / per-channel quantization noise."""
|
||||
weight_shape = (512, 4096) # (N, K)
|
||||
num_tokens = 32
|
||||
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=num_tokens)
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
fp8_max = torch.finfo(fp8_dtype).max
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Build a bf16 weight and quantize per output channel (one scale per row).
|
||||
w_bf16 = torch.randn(weight_shape, dtype=torch.bfloat16, device=device)
|
||||
w_amax = w_bf16.abs().amax(dim=1, keepdim=True).to(torch.float32)
|
||||
w_scale = (w_amax / fp8_max).clamp(min=1e-12)
|
||||
w_fp8 = (w_bf16.to(torch.float32) / w_scale).clamp(-fp8_max, fp8_max).to(fp8_dtype)
|
||||
w_dequant = w_fp8.to(torch.float32) * w_scale
|
||||
|
||||
bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device=device)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
# Pre-`process_weights_after_loading` convention: weight stored as the
|
||||
# `[K, N]` view of the fp8 tensor.
|
||||
layer.weight = torch.nn.Parameter(w_fp8.t(), requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(w_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
layer.input_scale_ub = None
|
||||
|
||||
kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel(
|
||||
_make_config(weight_shape=weight_shape),
|
||||
layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"),
|
||||
)
|
||||
kernel.process_weights_after_loading(layer)
|
||||
|
||||
x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device=device)
|
||||
|
||||
try:
|
||||
out = kernel.apply_weights(layer, x, bias)
|
||||
except RuntimeError as exc:
|
||||
_skip_if_no_hipb_mm_solution(exc)
|
||||
raise
|
||||
|
||||
# Reference: quantize x per-token the same way the kernel does, then run
|
||||
# the matmul in fp32 against the dequantized weight. This isolates plumbing
|
||||
# / reduction bugs from inherent fp8 quantization noise.
|
||||
x_amax = x.abs().amax(dim=1, keepdim=True).to(torch.float32)
|
||||
x_scale_ref = (x_amax / fp8_max).clamp(min=1e-12)
|
||||
x_q = (x.to(torch.float32) / x_scale_ref).clamp(-fp8_max, fp8_max).to(fp8_dtype)
|
||||
x_dequant = x_q.to(torch.float32) * x_scale_ref
|
||||
expected = (x_dequant @ w_dequant.t() + bias.to(torch.float32)).to(torch.bfloat16)
|
||||
|
||||
assert out.shape == (num_tokens, weight_shape[0])
|
||||
# K=4096 fp8 reduction leaves room for accumulation order drift and
|
||||
# catastrophic cancellation on near-zero outputs; tolerances are loose
|
||||
# enough to absorb that but tight enough to catch wrong layouts, missing
|
||||
# bias, swapped scales, etc.
|
||||
torch.testing.assert_close(out, expected, atol=5.0, rtol=0.1)
|
||||
|
||||
|
||||
def test_hipb_mm_kernel_online_tuning_writes_csv(
|
||||
enable_hipb_mm_kernel,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
weight_shape = (256, 4096)
|
||||
cache_file = tmp_path / "hip_online_tuning_res.csv"
|
||||
|
||||
_check_bpreshuffle_runtime_support(weight_shape, num_tokens=16)
|
||||
|
||||
monkeypatch.setenv("HIP_ONLINE_TUNING", "1")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
layer = TestFP8Layer(
|
||||
weight_shape=weight_shape,
|
||||
activation_quant_key=kFp8DynamicTokenSym,
|
||||
weight_quant_key=kFp8StaticChannelSym,
|
||||
input_dtype=torch.bfloat16,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=torch.device("cuda"),
|
||||
force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
# The effective heuristic N dimension is the flattened token count.
|
||||
x = torch.randn(16, weight_shape[1], dtype=torch.bfloat16, device="cuda")
|
||||
try:
|
||||
out = layer(x)
|
||||
except RuntimeError as exc:
|
||||
_skip_if_no_hipb_mm_solution(exc)
|
||||
raise
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
assert out.shape == (16, weight_shape[0])
|
||||
assert cache_file.exists()
|
||||
|
||||
# hipb_mm records the internal GEMM dimensions used by hipBLASLt after its
|
||||
# transposed-result transformation.
|
||||
row = _find_csv_row(
|
||||
str(cache_file),
|
||||
m=weight_shape[0],
|
||||
n=x.shape[0],
|
||||
k=weight_shape[1],
|
||||
)
|
||||
assert row is not None
|
||||
@@ -797,11 +797,11 @@ class TestMistralTokenizer:
|
||||
True,
|
||||
(
|
||||
[1, 3, 23325, 2294, 1686, 4, 23325],
|
||||
[1, 3, 22177, 4304, 2662, 4, 22177, 2],
|
||||
[1, 3, 22177, 4304, 2662, 4, 22177],
|
||||
),
|
||||
(
|
||||
"<s>[INST]▁Hello▁world▁![/INST]▁Hello",
|
||||
("<s>[INST]Hello world ![/INST]Hello</s>"),
|
||||
"<s>[INST]Hello world ![/INST]Hello",
|
||||
),
|
||||
),
|
||||
],
|
||||
|
||||
@@ -49,11 +49,13 @@ else
|
||||
KV_EXTRA_CONFIG=''
|
||||
fi
|
||||
|
||||
# Build the kv-transfer-config once
|
||||
# Build the kv-transfer-config for P and D
|
||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
|
||||
KV_CONFIG_P='{"kv_connector":"NixlConnector","kv_role":"kv_producer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
|
||||
KV_CONFIG_D='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
|
||||
else
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
|
||||
KV_CONFIG_P="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
|
||||
KV_CONFIG_D="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
@@ -159,7 +161,7 @@ run_tests_for_model() {
|
||||
--block-size ${PREFILL_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
--kv-transfer-config '$KV_CONFIG_P'"
|
||||
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
|
||||
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
|
||||
for arg in "${extra_args[@]}"; do
|
||||
@@ -207,7 +209,7 @@ run_tests_for_model() {
|
||||
--enforce-eager \
|
||||
--block-size ${DECODE_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
--kv-transfer-config '$KV_CONFIG_D'"
|
||||
if [[ -n "$VLLM_SERVE_EXTRA_ARGS" ]]; then
|
||||
IFS=',' read -r -a extra_args <<< "$VLLM_SERVE_EXTRA_ARGS"
|
||||
for arg in "${extra_args[@]}"; do
|
||||
|
||||
@@ -193,9 +193,11 @@ run_test_for_device() {
|
||||
local kv_device=$1
|
||||
|
||||
if [[ "$kv_device" == "cuda" ]]; then
|
||||
local kv_config='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
local kv_config_p='{"kv_connector":"NixlConnector","kv_role":"kv_producer"}'
|
||||
local kv_config_d='{"kv_connector":"NixlConnector","kv_role":"kv_consumer"}'
|
||||
else
|
||||
local kv_config="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"${kv_device}\"}"
|
||||
local kv_config_p="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"${kv_device}\"}"
|
||||
local kv_config_d="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"${kv_device}\"}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
@@ -248,7 +250,7 @@ run_test_for_device() {
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--kv-transfer-config "$kv_config_p" \
|
||||
--speculative-config "$PREFILL_SPEC_CONFIG" \
|
||||
--attention-backend $ATTENTION_BACKEND \
|
||||
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
|
||||
@@ -287,7 +289,7 @@ run_test_for_device() {
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||
--kv-transfer-config "$kv_config" \
|
||||
--kv-transfer-config "$kv_config_d" \
|
||||
--speculative-config "$DECODE_SPEC_CONFIG" \
|
||||
--attention-backend $ATTENTION_BACKEND \
|
||||
${EXTRA_SERVE_ARGS[@]+"${EXTRA_SERVE_ARGS[@]}"} &
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorHandshakeMetadata,
|
||||
)
|
||||
from vllm.v1.engine import core as engine_core_module
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class _Metadata(KVConnectorHandshakeMetadata):
|
||||
pass
|
||||
|
||||
|
||||
class _FakeExecutor:
|
||||
handshake_metadata_src: (
|
||||
list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None
|
||||
)
|
||||
last_instance: "_FakeExecutor | None" = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: Any,
|
||||
) -> None:
|
||||
del vllm_config
|
||||
self.handshake_metadata = self.handshake_metadata_src
|
||||
self.handshake_calls = 0
|
||||
_FakeExecutor.last_instance = self
|
||||
|
||||
def get_kv_connector_handshake_metadata(
|
||||
self,
|
||||
) -> list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None:
|
||||
self.handshake_calls += 1
|
||||
return self.handshake_metadata
|
||||
|
||||
def init_kv_output_aggregator(self, connector: KVConnectorBase_V1) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _run_engine_core_handshake(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
connector: KVConnectorBase_V1,
|
||||
*,
|
||||
handshake_metadata: (
|
||||
list[dict[tuple[int, int], KVConnectorHandshakeMetadata] | None] | None
|
||||
),
|
||||
) -> _FakeExecutor:
|
||||
class _FakeScheduler:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.connector = connector
|
||||
|
||||
def get_kv_connector(self) -> KVConnectorBase_V1:
|
||||
return connector
|
||||
|
||||
_FakeExecutor.handshake_metadata_src = handshake_metadata
|
||||
_FakeExecutor.last_instance = None
|
||||
|
||||
monkeypatch.setattr("vllm.plugins.load_general_plugins", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
engine_core_module.EngineCore,
|
||||
"_initialize_kv_caches",
|
||||
lambda self, vllm_config: SimpleNamespace(kv_cache_groups=[object()]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine_core_module,
|
||||
"StructuredOutputManager",
|
||||
lambda vllm_config: object(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine_core_module,
|
||||
"resolve_kv_cache_block_sizes",
|
||||
lambda kv_cache_config, vllm_config: (16, 16),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
engine_core_module,
|
||||
"MULTIMODAL_REGISTRY",
|
||||
SimpleNamespace(engine_receiver_cache_from_config=lambda vllm_config: None),
|
||||
)
|
||||
monkeypatch.setattr(engine_core_module, "freeze_gc_heap", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
engine_core_module, "maybe_attach_gc_debug_callback", lambda: None
|
||||
)
|
||||
monkeypatch.setattr(engine_core_module, "enable_envs_cache", lambda: None)
|
||||
monkeypatch.setattr(engine_core_module, "get_hash_fn_by_name", lambda name: None)
|
||||
monkeypatch.setattr(engine_core_module, "init_none_hash", lambda hash_fn: None)
|
||||
monkeypatch.setattr(
|
||||
engine_core_module, "get_request_block_hasher", lambda *args: None
|
||||
)
|
||||
|
||||
vllm_config = SimpleNamespace(
|
||||
parallel_config=SimpleNamespace(data_parallel_rank_local=0),
|
||||
scheduler_config=SimpleNamespace(
|
||||
get_scheduler_cls=lambda: _FakeScheduler,
|
||||
enable_chunked_prefill=False,
|
||||
async_scheduling=False,
|
||||
),
|
||||
speculative_config=None,
|
||||
ec_transfer_config=None,
|
||||
max_concurrent_batches=1,
|
||||
model_config=SimpleNamespace(runner_type="generate"),
|
||||
cache_config=SimpleNamespace(
|
||||
enable_prefix_caching=False,
|
||||
prefix_caching_hash_algo="builtin",
|
||||
),
|
||||
)
|
||||
|
||||
engine_core_module.EngineCore(vllm_config, _FakeExecutor, log_stats=False)
|
||||
assert _FakeExecutor.last_instance is not None
|
||||
return _FakeExecutor.last_instance
|
||||
|
||||
|
||||
class _LegacyConnector(KVConnectorBase_V1):
|
||||
def __init__(self) -> None:
|
||||
self.legacy_metadata: dict[int, KVConnectorHandshakeMetadata] | None = None
|
||||
|
||||
def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: Any,
|
||||
attn_metadata: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
pass
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Any, num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Any, blocks: Any, num_external_tokens: int
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
self.legacy_metadata = metadata
|
||||
|
||||
|
||||
class _PPAwareConnector(_LegacyConnector):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pp_aware_metadata: (
|
||||
dict[tuple[int, int], KVConnectorHandshakeMetadata] | None
|
||||
) = None
|
||||
|
||||
def set_xfer_handshake_metadata_pp_aware(
|
||||
self, metadata: dict[tuple[int, int], KVConnectorHandshakeMetadata]
|
||||
) -> None:
|
||||
self.pp_aware_metadata = metadata
|
||||
|
||||
|
||||
def test_engine_unwraps_handshake_metadata_for_legacy_connector(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Engine core always asks workers for `(pp_rank, tp_rank)`-keyed metadata,
|
||||
then unwraps to `{tp_rank: metadata}` for a connector that has not opted
|
||||
into PP-aware handshake (single-PP producer, all `pp_rank == 0`)."""
|
||||
metadata_0 = _Metadata()
|
||||
metadata_1 = _Metadata()
|
||||
connector = _LegacyConnector()
|
||||
|
||||
executor = _run_engine_core_handshake(
|
||||
monkeypatch,
|
||||
connector,
|
||||
handshake_metadata=[
|
||||
{(0, 0): metadata_0},
|
||||
None,
|
||||
{(0, 1): metadata_1},
|
||||
],
|
||||
)
|
||||
|
||||
assert executor.handshake_calls == 1
|
||||
assert connector.legacy_metadata == {0: metadata_0, 1: metadata_1}
|
||||
|
||||
|
||||
def test_engine_rejects_pp_producer_for_legacy_connector(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A connector that has not opted into PP-aware handshake must not silently
|
||||
drop metadata from `pp_rank > 0`; engine core init raises instead."""
|
||||
connector = _LegacyConnector()
|
||||
|
||||
with pytest.raises(ValueError, match="does not support PP-disaggregated"):
|
||||
_run_engine_core_handshake(
|
||||
monkeypatch,
|
||||
connector,
|
||||
handshake_metadata=[{(0, 0): _Metadata()}, {(1, 0): _Metadata()}],
|
||||
)
|
||||
|
||||
|
||||
def test_engine_passes_handshake_metadata_through_for_pp_aware_connector(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""A PP-aware connector receives the full `(pp_rank, tp_rank)`-keyed dict
|
||||
unchanged."""
|
||||
metadata_0 = _Metadata()
|
||||
metadata_1 = _Metadata()
|
||||
connector = _PPAwareConnector()
|
||||
|
||||
executor = _run_engine_core_handshake(
|
||||
monkeypatch,
|
||||
connector,
|
||||
handshake_metadata=[{(0, 0): metadata_0}, {(1, 0): metadata_1}],
|
||||
)
|
||||
|
||||
assert executor.handshake_calls == 1
|
||||
assert connector.legacy_metadata is None
|
||||
assert connector.pp_aware_metadata == {
|
||||
(0, 0): metadata_0,
|
||||
(1, 0): metadata_1,
|
||||
}
|
||||
@@ -261,11 +261,12 @@ def test_multi_example_connector_consistency():
|
||||
storage1_scheduler_events = _ignore_event_collection(events["storage1-SCHEDULER"])
|
||||
storage2_scheduler_events = _ignore_event_collection(events["storage2-SCHEDULER"])
|
||||
# First event is bind_gpu_block_pool from initialization, then
|
||||
# set_xfer_handshake_metadata, then on_new_request when the request is enqueued,
|
||||
# then get_num_new_matched_tokens and update_state_after_alloc from generate().
|
||||
# set_xfer_handshake_metadata_pp_aware, then on_new_request when the request is
|
||||
# enqueued, then get_num_new_matched_tokens and update_state_after_alloc from
|
||||
# generate().
|
||||
assert storage1_scheduler_events[:6] == [
|
||||
"bind_gpu_block_pool",
|
||||
"set_xfer_handshake_metadata",
|
||||
"set_xfer_handshake_metadata_pp_aware",
|
||||
"on_new_request",
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
@@ -285,7 +286,7 @@ def test_multi_example_connector_consistency():
|
||||
]
|
||||
assert storage2_scheduler_events[:6] == [
|
||||
"bind_gpu_block_pool",
|
||||
"set_xfer_handshake_metadata",
|
||||
"set_xfer_handshake_metadata_pp_aware",
|
||||
"on_new_request",
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
@@ -1057,7 +1058,7 @@ def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch):
|
||||
"connectors": [
|
||||
{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_role": "kv_consumer",
|
||||
},
|
||||
{
|
||||
"kv_connector": "MockConnector",
|
||||
|
||||
@@ -1363,7 +1363,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
timeout = 6
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
kv_role="kv_consumer",
|
||||
kv_connector_extra_config={"kv_lease_duration": timeout},
|
||||
)
|
||||
llm_kwargs = {
|
||||
@@ -2737,3 +2737,50 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
|
||||
f"got {notif!r} (expected {expected_notif!r}, "
|
||||
f"buggy form would be {bad_notif!r})"
|
||||
)
|
||||
|
||||
|
||||
def test_kv_both_deprecation_warning(default_vllm_config, dist_init):
|
||||
"""kv_role='kv_both' should emit a deprecation log warning."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.logger import _print_warning_once
|
||||
|
||||
_print_warning_once.cache_clear()
|
||||
|
||||
vllm_config = create_vllm_config(kv_role="kv_both")
|
||||
|
||||
with patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger"
|
||||
) as mock_logger:
|
||||
mock_logger.warning_once = mock_logger.warning_once
|
||||
NixlConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
make_kv_cache_config(block_size=16),
|
||||
)
|
||||
|
||||
mock_logger.warning_once.assert_called_once()
|
||||
msg = mock_logger.warning_once.call_args[0][0]
|
||||
assert "kv_role='kv_both'" in msg
|
||||
assert "deprecated" in msg
|
||||
|
||||
|
||||
def test_explicit_kv_role_no_deprecation_warning(default_vllm_config, dist_init):
|
||||
"""kv_role='kv_consumer' or 'kv_producer' should NOT emit a warning."""
|
||||
from unittest.mock import patch
|
||||
|
||||
for role in ("kv_consumer", "kv_producer"):
|
||||
vllm_config = create_vllm_config(kv_role=role)
|
||||
with patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl.connector.logger"
|
||||
) as mock_logger:
|
||||
NixlConnector(
|
||||
vllm_config,
|
||||
KVConnectorRole.WORKER,
|
||||
make_kv_cache_config(block_size=16),
|
||||
)
|
||||
|
||||
(
|
||||
mock_logger.warning_once.assert_not_called(),
|
||||
(f"kv_role={role!r} should not emit deprecation warning"),
|
||||
)
|
||||
|
||||
@@ -453,7 +453,7 @@ def test_fewer_blocks_with_hma(monkeypatch, model_name, sw_size):
|
||||
"""
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
kv_role="kv_consumer",
|
||||
)
|
||||
block_size = 16
|
||||
llm_kwargs = {
|
||||
|
||||
@@ -75,7 +75,7 @@ def test_gpu_memory_rixl_hma(model_name, sw_size):
|
||||
"gpu_memory_utilization": 0.5,
|
||||
"kv_transfer_config": KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
kv_role="kv_consumer",
|
||||
),
|
||||
"max_model_len": 2048,
|
||||
"disable_hybrid_kv_cache_manager": False,
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
EngineTransferInfo,
|
||||
TransferTopology,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class _FakeAttentionBackend:
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
||||
|
||||
|
||||
def _make_topology(
|
||||
*,
|
||||
tp_rank: int = 1,
|
||||
tp_size: int = 4,
|
||||
total_num_kv_heads: int = 8,
|
||||
) -> TransferTopology:
|
||||
return TransferTopology(
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
block_size=16,
|
||||
engine_id="local-engine",
|
||||
is_mla=False,
|
||||
is_mamba=False,
|
||||
total_num_kv_heads=total_num_kv_heads,
|
||||
attn_backends=[_FakeAttentionBackend],
|
||||
)
|
||||
|
||||
|
||||
def test_legacy_register_remote_engine_uses_pp_rank_zero() -> None:
|
||||
topology = _make_topology()
|
||||
info = EngineTransferInfo(
|
||||
remote_tp_size=2,
|
||||
remote_block_len=1024,
|
||||
remote_block_size=16,
|
||||
remote_physical_blocks_per_logical=1,
|
||||
)
|
||||
|
||||
registered = topology.register_remote_engine("remote-engine", info)
|
||||
|
||||
assert registered == info
|
||||
assert registered.remote_pp_rank == 0
|
||||
assert topology.get_engine_info("remote-engine") == info
|
||||
assert topology._engines[("remote-engine", 0)] == info
|
||||
assert topology.target_remote_ranks("remote-engine") == [0]
|
||||
|
||||
|
||||
def test_register_remote_engine_stores_pp_ranks_separately() -> None:
|
||||
topology = _make_topology(tp_rank=0, tp_size=2)
|
||||
|
||||
info_0 = EngineTransferInfo(
|
||||
remote_tp_size=2,
|
||||
remote_block_len=1024,
|
||||
remote_block_size=16,
|
||||
remote_physical_blocks_per_logical=1,
|
||||
remote_pp_rank=0,
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
)
|
||||
info_1 = EngineTransferInfo(
|
||||
remote_tp_size=1,
|
||||
remote_block_len=512,
|
||||
remote_block_size=8,
|
||||
remote_physical_blocks_per_logical=2,
|
||||
remote_pp_rank=1,
|
||||
start_layer=16,
|
||||
end_layer=32,
|
||||
)
|
||||
|
||||
registered_0 = topology.register_remote_engine("remote-engine", info_0)
|
||||
registered_1 = topology.register_remote_engine("remote-engine", info_1)
|
||||
|
||||
assert registered_0 == info_0
|
||||
assert registered_1 == info_1
|
||||
assert topology.get_engine_info("remote-engine") == info_0
|
||||
assert topology.get_engine_info("remote-engine", 0) == info_0
|
||||
assert topology.get_engine_info("remote-engine", 1) == info_1
|
||||
assert set(topology._engines) == {
|
||||
("remote-engine", 0),
|
||||
("remote-engine", 1),
|
||||
}
|
||||
|
||||
|
||||
def test_helpers_use_requested_pp_rank() -> None:
|
||||
topology = _make_topology(tp_rank=1, tp_size=2, total_num_kv_heads=2)
|
||||
topology.register_remote_engine(
|
||||
"remote-engine",
|
||||
EngineTransferInfo(
|
||||
remote_tp_size=1,
|
||||
remote_block_len=1024,
|
||||
remote_block_size=16,
|
||||
remote_physical_blocks_per_logical=1,
|
||||
remote_pp_rank=0,
|
||||
start_layer=0,
|
||||
end_layer=8,
|
||||
),
|
||||
)
|
||||
topology.register_remote_engine(
|
||||
"remote-engine",
|
||||
EngineTransferInfo(
|
||||
remote_tp_size=4,
|
||||
remote_block_len=1024,
|
||||
remote_block_size=16,
|
||||
remote_physical_blocks_per_logical=1,
|
||||
remote_pp_rank=1,
|
||||
start_layer=8,
|
||||
end_layer=16,
|
||||
),
|
||||
)
|
||||
|
||||
assert not topology.is_kv_replicated("remote-engine", 0)
|
||||
assert topology.is_kv_replicated("remote-engine", 1)
|
||||
assert topology.replicates_kv_cache("remote-engine", 1)
|
||||
assert topology.target_remote_ranks("remote-engine", 0) == [0]
|
||||
assert topology.target_remote_ranks("remote-engine", 1) == [2, 3]
|
||||
assert "remote_pp=1" in topology.describe("remote-engine", 1)
|
||||
|
||||
|
||||
def test_engine_info_fields_have_backward_compatible_defaults() -> None:
|
||||
topology = _make_topology()
|
||||
info = EngineTransferInfo(
|
||||
remote_tp_size=2,
|
||||
remote_block_len=1024,
|
||||
remote_block_size=16,
|
||||
remote_physical_blocks_per_logical=1,
|
||||
)
|
||||
|
||||
registered = topology.register_remote_engine("remote-engine", info)
|
||||
|
||||
assert topology.get_engine_info("remote-engine") == registered
|
||||
assert registered.remote_pp_rank == 0
|
||||
assert registered.start_layer == 0
|
||||
assert registered.end_layer == 0
|
||||
@@ -103,7 +103,7 @@ def create_vllm_config(
|
||||
kv_load_failure_policy: Literal["recompute", "fail"] = "fail",
|
||||
kv_connector: str = "NixlConnector",
|
||||
kv_connector_module_path: str | None = None,
|
||||
kv_role: str = "kv_both",
|
||||
kv_role: str = "kv_consumer",
|
||||
disable_hybrid_kv_cache_manager: bool | None = None,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
|
||||
@@ -0,0 +1,365 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Mock-based unit tests for ObjectStoreSecondaryTierManager.
|
||||
|
||||
These tests replace the NIXL backend with an in-memory mock so they run
|
||||
without S3 credentials or a live object store. They verify the manager's
|
||||
state machine: job submission, transfer completion polling, and lookup.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.kv_offload.base import OffloadKey, ReqContext, make_offload_key
|
||||
from vllm.v1.kv_offload.tiering.base import JobMetadata, JobResult
|
||||
from vllm.v1.kv_offload.tiering.obj.manager import ObjectStoreSecondaryTierManager
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared stubs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_vllm_config():
|
||||
return SimpleNamespace(
|
||||
model_config=SimpleNamespace(model="test/model"),
|
||||
cache_config=SimpleNamespace(block_size=16, cache_dtype="float16"),
|
||||
parallel_config=SimpleNamespace(
|
||||
tensor_parallel_size=1,
|
||||
pipeline_parallel_size=1,
|
||||
prefill_context_parallel_size=1,
|
||||
decode_context_parallel_size=1,
|
||||
rank=0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_OFFLOADING_SPEC = SimpleNamespace(
|
||||
vllm_config=_make_vllm_config(),
|
||||
kv_cache_config=SimpleNamespace(kv_cache_groups=[]),
|
||||
)
|
||||
|
||||
_STORE_CONFIG = {
|
||||
"bucket": "mock-bucket",
|
||||
"endpoint_override": "mock:9000",
|
||||
"access_key": "mock-access",
|
||||
"secret_key": "mock-secret",
|
||||
}
|
||||
|
||||
_BLOCK_ELEMENTS = 256
|
||||
_DTYPE = torch.float32
|
||||
_RUN_PREFIX = f"test/{uuid.uuid4().hex[:8]}"
|
||||
_CTX = ReqContext(req_id="test-req")
|
||||
|
||||
|
||||
def key(n: int) -> OffloadKey:
|
||||
return make_offload_key(n.to_bytes(8, "big"), 0)
|
||||
|
||||
|
||||
def make_job(
|
||||
job_id: int,
|
||||
keys: list[OffloadKey],
|
||||
block_ids: list[int] | None = None,
|
||||
) -> JobMetadata:
|
||||
if block_ids is None:
|
||||
block_ids = list(range(len(keys)))
|
||||
return JobMetadata(
|
||||
job_id=job_id,
|
||||
keys=keys,
|
||||
block_ids=np.array(block_ids, dtype=np.int64),
|
||||
is_promotion=False,
|
||||
req_context=_CTX,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock NIXL agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockNixlAgent:
|
||||
"""In-memory NIXL agent. Tracks stored object keys and simulates async
|
||||
transfers: transfer() returns PROC, check_xfer_state() returns DONE and
|
||||
commits the write to the in-memory key set.
|
||||
|
||||
The four methods overridden by tests (register_memory, make_prepped_xfer,
|
||||
check_xfer_state, query_memory) are stored as Callable instance attributes
|
||||
so mypy allows reassignment in tests.
|
||||
"""
|
||||
|
||||
# Callable attributes — tests may reassign these on instances.
|
||||
register_memory: Callable
|
||||
make_prepped_xfer: Callable
|
||||
check_xfer_state: Callable
|
||||
query_memory: Callable
|
||||
|
||||
def __init__(self):
|
||||
self._stored_obj_keys: set[str] = set()
|
||||
# handle_id -> (op, [obj_keys])
|
||||
self._pending: dict[int, tuple[str, list[str]]] = {}
|
||||
self._handle_counter = 0
|
||||
self._last_obj_keys: list[str] = []
|
||||
# Bind default implementations as instance attributes.
|
||||
self.register_memory = self._register_memory
|
||||
self.make_prepped_xfer = self._make_prepped_xfer
|
||||
self.check_xfer_state = self._check_xfer_state
|
||||
self.query_memory = self._query_memory
|
||||
|
||||
def create_backend(self, backend_type, params):
|
||||
pass
|
||||
|
||||
def _register_memory(self, descs, mem_type=None, backends=None):
|
||||
mock = MagicMock()
|
||||
mock.trim.return_value = MagicMock()
|
||||
# Capture obj_keys from OBJ 4-tuples: (addr, len, dev_id, obj_key)
|
||||
if mem_type == "OBJ" and descs:
|
||||
self._last_obj_keys = [d[3] for d in descs if d[3]]
|
||||
return mock
|
||||
|
||||
def deregister_memory(self, desc):
|
||||
pass
|
||||
|
||||
def prep_xfer_dlist(self, agent_name, descs, mem_type=None, backends=None):
|
||||
return MagicMock()
|
||||
|
||||
def _make_prepped_xfer(
|
||||
self,
|
||||
op,
|
||||
local_handle,
|
||||
local_indices,
|
||||
remote_handle,
|
||||
remote_indices,
|
||||
notif_msg=b"",
|
||||
backends=None,
|
||||
skip_desc_merge=False,
|
||||
):
|
||||
handle = MagicMock()
|
||||
handle._id = self._handle_counter
|
||||
self._pending[self._handle_counter] = (op, list(self._last_obj_keys))
|
||||
self._handle_counter += 1
|
||||
return handle
|
||||
|
||||
def transfer(self, handle):
|
||||
return "PROC"
|
||||
|
||||
def _check_xfer_state(self, handle):
|
||||
entry = self._pending.pop(handle._id, None)
|
||||
if entry:
|
||||
op, obj_keys = entry
|
||||
if op == "WRITE":
|
||||
self._stored_obj_keys.update(obj_keys)
|
||||
return "DONE"
|
||||
|
||||
def release_xfer_handle(self, handle):
|
||||
pass
|
||||
|
||||
def release_dlist_handle(self, handle):
|
||||
pass
|
||||
|
||||
def _query_memory(self, queries, mem_type, agent_name):
|
||||
return [object() if q[3] in self._stored_obj_keys else None for q in queries]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tier(
|
||||
num_blocks: int = 4,
|
||||
) -> tuple[ObjectStoreSecondaryTierManager, MockNixlAgent]:
|
||||
"""Create a tier backed by a fresh MockNixlAgent."""
|
||||
mock_agent = MockNixlAgent()
|
||||
tensor = torch.zeros((num_blocks, _BLOCK_ELEMENTS), dtype=_DTYPE)
|
||||
view = memoryview(tensor.numpy())
|
||||
with (
|
||||
patch("vllm.v1.kv_offload.tiering.obj.manager.nixl_agent_config"),
|
||||
patch(
|
||||
"vllm.v1.kv_offload.tiering.obj.manager.nixl_agent",
|
||||
return_value=mock_agent,
|
||||
),
|
||||
):
|
||||
tier = ObjectStoreSecondaryTierManager(
|
||||
offloading_spec=_OFFLOADING_SPEC,
|
||||
primary_kv_view=view,
|
||||
tier_type="obj",
|
||||
store_config=_STORE_CONFIG,
|
||||
prefix=_RUN_PREFIX,
|
||||
)
|
||||
return tier, mock_agent
|
||||
|
||||
|
||||
def drain(
|
||||
tier: ObjectStoreSecondaryTierManager, max_rounds: int = 20
|
||||
) -> list[JobResult]:
|
||||
"""Poll get_finished_jobs() until all in-flight jobs resolve."""
|
||||
results: list[JobResult] = []
|
||||
for _ in range(max_rounds):
|
||||
results.extend(tier.get_finished_jobs())
|
||||
if not tier._transfers:
|
||||
break
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockObjTierBasic:
|
||||
def setup_method(self):
|
||||
self.tier, self.agent = _make_tier(num_blocks=4)
|
||||
|
||||
def test_lookup_empty_tier(self):
|
||||
assert self.tier.lookup(key(1), _CTX) is False
|
||||
|
||||
def test_store_and_lookup(self):
|
||||
self.tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
results = drain(self.tier)
|
||||
assert len(results) == 1
|
||||
assert results[0].success
|
||||
assert self.tier.lookup(key(1), _CTX) is True
|
||||
|
||||
def test_lookup_unrelated_key_returns_false(self):
|
||||
self.tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
drain(self.tier)
|
||||
assert self.tier.lookup(key(999), _CTX) is False
|
||||
|
||||
def test_store_then_load_roundtrip(self):
|
||||
self.tier.submit_store(make_job(1, [key(1), key(2)], [0, 1]))
|
||||
results = drain(self.tier)
|
||||
assert results[0].success
|
||||
|
||||
self.tier.submit_load(make_job(2, [key(1), key(2)], [0, 1]))
|
||||
results = drain(self.tier)
|
||||
assert len(results) == 1
|
||||
assert results[0].success
|
||||
|
||||
def test_multiple_jobs_tracked_independently(self):
|
||||
self.tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
self.tier.submit_store(make_job(2, [key(2)], [1]))
|
||||
results = drain(self.tier)
|
||||
assert len(results) == 2
|
||||
assert all(r.success for r in results)
|
||||
|
||||
def test_failed_transfer_reported(self):
|
||||
self.agent.check_xfer_state = lambda h: "ERR"
|
||||
self.tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
results = drain(self.tier)
|
||||
assert len(results) == 1
|
||||
assert not results[0].success
|
||||
|
||||
def test_pending_transfer_not_returned_until_done(self):
|
||||
# First poll returns PROC; second poll returns DONE.
|
||||
call_count = [0]
|
||||
original = self.agent.check_xfer_state
|
||||
|
||||
def delayed(h):
|
||||
call_count[0] += 1
|
||||
return "PROC" if call_count[0] == 1 else original(h)
|
||||
|
||||
self.agent.check_xfer_state = delayed
|
||||
|
||||
self.tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
assert list(self.tier.get_finished_jobs()) == []
|
||||
results = list(self.tier.get_finished_jobs())
|
||||
assert len(results) == 1
|
||||
assert results[0].success
|
||||
|
||||
|
||||
class TestMockObjTierMultiBlock:
|
||||
def test_store_multiple_blocks(self):
|
||||
tier, _ = _make_tier(num_blocks=8)
|
||||
keys = [key(i) for i in range(8)]
|
||||
tier.submit_store(make_job(1, keys, list(range(8))))
|
||||
results = drain(tier)
|
||||
assert len(results) == 1
|
||||
assert results[0].success
|
||||
assert all(tier.lookup(k, _CTX) for k in keys)
|
||||
|
||||
def test_partial_block_lookup(self):
|
||||
tier, _ = _make_tier(num_blocks=4)
|
||||
tier.submit_store(make_job(1, [key(0), key(1)], [0, 1]))
|
||||
drain(tier)
|
||||
assert tier.lookup(key(0), _CTX) is True
|
||||
assert tier.lookup(key(1), _CTX) is True
|
||||
assert tier.lookup(key(2), _CTX) is False
|
||||
|
||||
|
||||
class TestMockObjTierFailures:
|
||||
def test_lookup_exception_returns_false(self):
|
||||
tier, agent = _make_tier(num_blocks=4)
|
||||
agent.query_memory = lambda *a, **k: (_ for _ in ()).throw(
|
||||
RuntimeError("backend error")
|
||||
)
|
||||
assert tier.lookup(key(1), _CTX) is False
|
||||
|
||||
def test_submit_store_register_memory_failure_reported_in_get_finished(self):
|
||||
tier, agent = _make_tier(num_blocks=4)
|
||||
agent.register_memory = lambda *a, **k: None
|
||||
tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
results = list(tier.get_finished_jobs())
|
||||
assert len(results) == 1
|
||||
assert results[0].job_id == 1
|
||||
assert not results[0].success
|
||||
|
||||
def test_submit_load_register_memory_failure_reported_in_get_finished(self):
|
||||
tier, agent = _make_tier(num_blocks=4)
|
||||
agent.register_memory = lambda *a, **k: None
|
||||
tier.submit_load(make_job(2, [key(1)], [0]))
|
||||
results = list(tier.get_finished_jobs())
|
||||
assert len(results) == 1
|
||||
assert results[0].job_id == 2
|
||||
assert not results[0].success
|
||||
|
||||
def test_submit_store_make_prepped_xfer_failure_reported_in_get_finished(self):
|
||||
tier, agent = _make_tier(num_blocks=4)
|
||||
agent.make_prepped_xfer = lambda *a, **k: None
|
||||
tier.submit_store(make_job(3, [key(1)], [0]))
|
||||
results = list(tier.get_finished_jobs())
|
||||
assert len(results) == 1
|
||||
assert results[0].job_id == 3
|
||||
assert not results[0].success
|
||||
|
||||
def test_failure_and_success_both_returned_by_get_finished(self):
|
||||
# One job fails at submission, another succeeds in flight.
|
||||
tier, agent = _make_tier(num_blocks=4)
|
||||
original_register = agent.register_memory
|
||||
call_count = [0]
|
||||
|
||||
def register_once_fail(*a, **k):
|
||||
call_count[0] += 1
|
||||
return None if call_count[0] == 1 else original_register(*a, **k)
|
||||
|
||||
agent.register_memory = register_once_fail
|
||||
|
||||
tier.submit_store(make_job(1, [key(1)], [0])) # fails immediately
|
||||
tier.submit_store(make_job(2, [key(2)], [1])) # succeeds
|
||||
results = drain(tier)
|
||||
assert len(results) == 2
|
||||
by_id = {r.job_id: r for r in results}
|
||||
assert not by_id[1].success
|
||||
assert by_id[2].success
|
||||
|
||||
|
||||
class TestMockObjTierShutdown:
|
||||
def test_shutdown_clears_in_flight_transfers(self):
|
||||
tier, agent = _make_tier(num_blocks=4)
|
||||
# Keep transfer in flight by never completing it
|
||||
agent.check_xfer_state = lambda h: "PROC"
|
||||
tier.submit_store(make_job(1, [key(1)], [0]))
|
||||
assert len(tier._transfers) == 1
|
||||
tier.shutdown()
|
||||
assert len(tier._transfers) == 0
|
||||
assert tier._dram_prepped_handle is None
|
||||
assert tier._primary_reg is None
|
||||
|
||||
def test_shutdown_idempotent(self):
|
||||
tier, _ = _make_tier(num_blocks=4)
|
||||
tier.shutdown()
|
||||
tier.shutdown() # must not raise
|
||||
@@ -544,13 +544,15 @@ def native_sample_recovered_tokens(
|
||||
target_probs: torch.Tensor, # [num_tokens, vocab_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
device: torch.device,
|
||||
use_fp64_gumbel: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size = len(num_draft_tokens)
|
||||
vocab_size = target_probs.shape[-1]
|
||||
q_dtype = torch.float64 if use_fp64_gumbel else torch.float32
|
||||
|
||||
q = torch.empty(
|
||||
(batch_size, vocab_size),
|
||||
dtype=torch.float32,
|
||||
dtype=q_dtype,
|
||||
device=device,
|
||||
)
|
||||
q.exponential_()
|
||||
@@ -935,6 +937,67 @@ def test_sample_recovered_tokens(
|
||||
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
|
||||
|
||||
|
||||
def test_sample_recovered_tokens_uses_fp64_exponential_race_when_requested():
|
||||
batch_size = 2
|
||||
vocab_size = 64
|
||||
max_spec_len = 2
|
||||
num_tokens = batch_size * max_spec_len
|
||||
|
||||
draft_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
target_probs = F.softmax(target_probs, dim=-1)
|
||||
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
|
||||
|
||||
generators = {
|
||||
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
|
||||
}
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
|
||||
generators=generators,
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.reshape(batch_size, max_spec_len).tolist(),
|
||||
target_probs.log(),
|
||||
)
|
||||
|
||||
expected = native_sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
actual = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
########################### Tests for Synthetic Rejection Sampling #########
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,12 @@ from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p_pytorch,
|
||||
random_sample,
|
||||
)
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
@@ -38,6 +43,10 @@ def _flashinfer_topk_topp_supported() -> bool:
|
||||
FLASHINFER_TOPK_TOPP_SUPPORTED = _flashinfer_topk_topp_supported()
|
||||
|
||||
|
||||
def _seed_default_generator(seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_default_device():
|
||||
"""
|
||||
@@ -49,6 +58,35 @@ def reset_default_device():
|
||||
torch.set_default_device(original_device)
|
||||
|
||||
|
||||
def test_sampler_threads_fp64_gumbel_to_topk_topp_sampler():
|
||||
sampler = Sampler(use_fp64_gumbel=True)
|
||||
|
||||
assert sampler.topk_topp_sampler.use_fp64_gumbel
|
||||
|
||||
|
||||
def test_random_sample_uses_fp64_exponential_race_when_requested():
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
probs = torch.tensor(
|
||||
[
|
||||
[0.70, 0.20, 0.10],
|
||||
[0.05, 0.15, 0.80],
|
||||
[0.25, 0.25, 0.50],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
|
||||
q.exponential_()
|
||||
expected = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
actual = random_sample(probs.clone(), {}, use_fp64_gumbel=True)
|
||||
|
||||
assert torch.equal(actual, expected)
|
||||
|
||||
|
||||
def test_topk_impl_equivalence():
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
|
||||
|
||||
@@ -1034,7 +1034,8 @@ def test_propose_stores_probabilistic_draft_probs(monkeypatch):
|
||||
proposer.model = model_mock
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
def fake_compute_probs(logits, sampling_metadata):
|
||||
def fake_compute_probs(logits, sampling_metadata, use_fp64_gumbel):
|
||||
assert use_fp64_gumbel == proposer.use_fp64_gumbel
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return probs.argmax(dim=-1), probs
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.llm_base_proposer import (
|
||||
compute_probs_and_sample_next_token,
|
||||
)
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
def _seed_default_generator(seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
|
||||
def _make_sampling_metadata(batch_size: int) -> SamplingMetadata:
|
||||
return SamplingMetadata(
|
||||
temperature=torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
presence_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
repetition_penalties=torch.empty(0, device=DEVICE_TYPE),
|
||||
output_token_ids=[[] for _ in range(batch_size)],
|
||||
spec_token_ids=[[] for _ in range(batch_size)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
|
||||
def test_compute_probs_and_sample_next_token_uses_fp64_exponential_race():
|
||||
batch_size = 4
|
||||
vocab_size = 32
|
||||
generator = torch.Generator(device=DEVICE_TYPE).manual_seed(11)
|
||||
logits = torch.randn(
|
||||
batch_size,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
generator=generator,
|
||||
)
|
||||
metadata = _make_sampling_metadata(batch_size)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty(probs.shape, dtype=torch.float64, device=probs.device)
|
||||
q.exponential_()
|
||||
expected_ids = q.reciprocal_().mul_(probs).argmax(dim=-1).view(-1)
|
||||
|
||||
_seed_default_generator(12345)
|
||||
actual_ids, actual_probs = compute_probs_and_sample_next_token(
|
||||
logits.clone(),
|
||||
metadata,
|
||||
use_fp64_gumbel=True,
|
||||
)
|
||||
|
||||
assert torch.equal(actual_ids, expected_ids)
|
||||
assert torch.allclose(actual_probs, probs)
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CUDA proof for fp32 exponential-race tail truncation.
|
||||
|
||||
This script is intentionally not a unit test. It is a reproducible, GPU-only
|
||||
statistical proof for the hidden Gumbel-max idiom:
|
||||
|
||||
q.exponential_()
|
||||
sample = (probs / q).argmax()
|
||||
|
||||
For q ~ Exp(1), this is equivalent to argmax(log(probs) + Gumbel). On CUDA,
|
||||
fp32 exponential samples inherit a 24-bit uniform lower-tail cutoff, so very
|
||||
small q values are impossible. The many-tail experiment below chooses a case
|
||||
where a correct sampler should select a low-probability tail token dozens of
|
||||
times, while fp32 q cannot select one.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _seed(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def measure_exponential_lower_tail(
|
||||
*,
|
||||
device: torch.device,
|
||||
samples: int,
|
||||
chunk_size: int,
|
||||
seed: int,
|
||||
) -> None:
|
||||
threshold = 2.0**-24
|
||||
print(f"lower-tail threshold: {threshold:.18e}")
|
||||
for dtype in (torch.float32, torch.float64):
|
||||
_seed(seed)
|
||||
count_below = 0
|
||||
min_q = float("inf")
|
||||
max_q = 0.0
|
||||
start = time.perf_counter()
|
||||
remaining = samples
|
||||
while remaining > 0:
|
||||
n = min(chunk_size, remaining)
|
||||
q = torch.empty((n,), dtype=dtype, device=device)
|
||||
q.exponential_()
|
||||
count_below += int((q < threshold).sum().item())
|
||||
min_q = min(min_q, float(q.min().item()))
|
||||
max_q = max(max_q, float(q.max().item()))
|
||||
remaining -= n
|
||||
torch.accelerator.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
print(
|
||||
f"{dtype}: samples={samples} count(q < 2^-24)={count_below} "
|
||||
f"min={min_q:.18e} max={max_q:.6f} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
def run_many_tail_race(
|
||||
*,
|
||||
device: torch.device,
|
||||
trials: int,
|
||||
num_tail_tokens: int,
|
||||
gap: float,
|
||||
chunk_trials: int,
|
||||
seed: int,
|
||||
) -> None:
|
||||
p_tail = math.exp(-gap)
|
||||
expected_tail_hits = (
|
||||
trials * (num_tail_tokens * p_tail) / (1.0 + num_tail_tokens * p_tail)
|
||||
)
|
||||
print(
|
||||
"many-tail race: "
|
||||
f"trials={trials} num_tail_tokens={num_tail_tokens} gap={gap} "
|
||||
f"expected_tail_hits={expected_tail_hits:.4f}"
|
||||
)
|
||||
|
||||
for dtype in (torch.float32, torch.float64):
|
||||
_seed(seed)
|
||||
hits = 0
|
||||
p0 = torch.tensor(1.0, dtype=dtype, device=device)
|
||||
pt = torch.tensor(p_tail, dtype=dtype, device=device)
|
||||
start = time.perf_counter()
|
||||
remaining = trials
|
||||
while remaining > 0:
|
||||
batch = min(chunk_trials, remaining)
|
||||
q0 = torch.empty((batch,), dtype=dtype, device=device)
|
||||
q0.exponential_()
|
||||
qt = torch.empty((batch, num_tail_tokens), dtype=dtype, device=device)
|
||||
qt.exponential_()
|
||||
head_score = p0 / q0
|
||||
tail_score = (pt / qt).amax(dim=-1)
|
||||
hits += int((tail_score > head_score).sum().item())
|
||||
remaining -= batch
|
||||
torch.accelerator.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
print(f"{dtype}: tail_hits={hits} elapsed={elapsed:.2f}s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lower-tail-samples", type=int, default=200_000_000)
|
||||
parser.add_argument("--lower-tail-chunk-size", type=int, default=10_000_000)
|
||||
parser.add_argument("--race-trials", type=int, default=100_000)
|
||||
parser.add_argument("--race-tail-tokens", type=int, default=262_144)
|
||||
parser.add_argument("--race-gap", type=float, default=20.5)
|
||||
parser.add_argument("--race-chunk-trials", type=int, default=64)
|
||||
parser.add_argument("--seed", type=int, default=2026)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.accelerator.is_available():
|
||||
raise RuntimeError("CUDA is required for this proof.")
|
||||
|
||||
device = torch.accelerator.current_accelerator()
|
||||
if device.type != "cuda":
|
||||
raise RuntimeError("CUDA is required for this proof.")
|
||||
|
||||
print(f"torch={torch.__version__} cuda={torch.version.cuda}")
|
||||
print(f"device={device}")
|
||||
measure_exponential_lower_tail(
|
||||
device=device,
|
||||
samples=args.lower_tail_samples,
|
||||
chunk_size=args.lower_tail_chunk_size,
|
||||
seed=args.seed,
|
||||
)
|
||||
run_many_tail_race(
|
||||
device=device,
|
||||
trials=args.race_trials,
|
||||
num_tail_tokens=args.race_tail_tokens,
|
||||
gap=args.race_gap,
|
||||
chunk_trials=args.race_chunk_trials,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,11 +8,9 @@ on files that have been changed. It groups files into different mypy calls
|
||||
based on their directory to avoid import following issues.
|
||||
|
||||
Usage:
|
||||
python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>
|
||||
python tools/pre_commit/mypy.py <python_version> <changed_files...>
|
||||
|
||||
Args:
|
||||
ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
|
||||
"silent" for the main group of files.
|
||||
python_version: Python version to use (e.g., "3.10") or "local" to use
|
||||
the local Python version.
|
||||
changed_files: List of changed files to check.
|
||||
@@ -98,8 +96,8 @@ def mypy(
|
||||
|
||||
|
||||
def main():
|
||||
python_version = sys.argv[2]
|
||||
file_groups = group_files(sys.argv[3:])
|
||||
python_version = sys.argv[1]
|
||||
file_groups = group_files(sys.argv[2:])
|
||||
|
||||
if python_version == "local":
|
||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
|
||||
@@ -27,6 +27,16 @@ except ImportError:
|
||||
# on ROCm the fp8_dtype always calls is_fp8_fnuz
|
||||
# which is a host op, so we cache it once here.
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
_HIPB_MM_INITIALIZED_DEVICES: set[int] = set()
|
||||
|
||||
|
||||
def _ensure_hipb_mm_extension_initialized() -> None:
|
||||
import aiter
|
||||
|
||||
device = torch.accelerator.current_device_index()
|
||||
if device not in _HIPB_MM_INITIALIZED_DEVICES:
|
||||
aiter.hipb_create_extension()
|
||||
_HIPB_MM_INITIALIZED_DEVICES.add(device)
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
@@ -625,6 +635,43 @@ def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
|
||||
return torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
|
||||
|
||||
def _rocm_aiter_hipb_mm_fp8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import hipb_mm
|
||||
|
||||
_ensure_hipb_mm_extension_initialized()
|
||||
return hipb_mm(
|
||||
A,
|
||||
B,
|
||||
solution_index=-1,
|
||||
bias=bias,
|
||||
out_dtype=output_dtype,
|
||||
scaleA=As,
|
||||
scaleB=Bs,
|
||||
scaleOut=None,
|
||||
bpreshuffle=True,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_hipb_mm_fp8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[1]
|
||||
return torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
|
||||
|
||||
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
@@ -1308,6 +1355,7 @@ class rocm_aiter_ops:
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
_FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
|
||||
_LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
# TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE
|
||||
@@ -1340,6 +1388,7 @@ class rocm_aiter_ops:
|
||||
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM
|
||||
cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM
|
||||
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
@@ -1512,6 +1561,13 @@ class rocm_aiter_ops:
|
||||
|
||||
return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950()
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_hipbmm_enabled(cls) -> bool:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
return cls.is_linear_enabled() and on_mi3xx() and cls._LINEAR_HIPBMM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
||||
@@ -1668,6 +1724,12 @@ class rocm_aiter_ops:
|
||||
fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_hipb_mm_fp8",
|
||||
op_func=_rocm_aiter_hipb_mm_fp8_impl,
|
||||
fake_impl=_rocm_aiter_hipb_mm_fp8_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_triton_gemm_a8w8_blockscale",
|
||||
op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl,
|
||||
@@ -1858,6 +1920,17 @@ class rocm_aiter_ops:
|
||||
A, B, As, Bs, bias, output_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hipb_mm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_hipb_mm_fp8(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def triton_gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
|
||||
+9
-5
@@ -868,9 +868,10 @@ def cutlass_scaled_mm_azp(
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
:param azp_adj: In the per-tensor case, this should include the azp.
|
||||
Args:
|
||||
azp_adj: In the per-tensor case, this should include the azp.
|
||||
Always per-channel.
|
||||
:param azp: Only set in the per-token case. Per-token if set.
|
||||
azp: Only set in the per-token case. Per-token if set.
|
||||
"""
|
||||
assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
|
||||
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
||||
@@ -3886,9 +3887,12 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
Note that sylvester hadamard transforms are also symmetric, which means that
|
||||
this function is also applies the (transpose <=> inverse) transform.
|
||||
|
||||
:param x: value to be transformed inplace
|
||||
:param inplace: modify value in place
|
||||
:return: value after transformation
|
||||
Args:
|
||||
x: value to be transformed inplace
|
||||
inplace: modify value in place
|
||||
|
||||
Returns:
|
||||
value after transformation
|
||||
"""
|
||||
return torch.ops._C.hadacore_transform(x, inplace)
|
||||
|
||||
|
||||
@@ -82,11 +82,12 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
def hash_source(*srcs: str | Any) -> str:
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
|
||||
Args:
|
||||
srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
Results are cached by resolved types to avoid repeated
|
||||
inspect.getsource() calls.
|
||||
:return:
|
||||
"""
|
||||
# Resolve instances to their class for a hashable cache key.
|
||||
cache_key = tuple(
|
||||
@@ -99,7 +100,9 @@ class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
def hash_dict(dict_: dict[Any, Any]) -> str:
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
|
||||
Returns:
|
||||
A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
@@ -276,8 +276,10 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
Replace mutated getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
|
||||
Args:
|
||||
node: The auto-functionalized node
|
||||
mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
@@ -317,9 +319,10 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
Args:
|
||||
graph: Graph to insert the defunctionalized node into
|
||||
node: The auto-functionalized node to defunctionalize
|
||||
args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), (
|
||||
|
||||
@@ -108,9 +108,13 @@ class NoOpEliminationPass(VllmInductorPass):
|
||||
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
Args:
|
||||
dim: The dimension arg to reshape/slice
|
||||
i_dim: The corresponding dimension in the input tensor
|
||||
|
||||
Returns:
|
||||
Are the dimensions equivalent?
|
||||
|
||||
There are two cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
|
||||
@@ -235,9 +235,10 @@ class ModelConfig:
|
||||
temperature and top_k/top_p.
|
||||
"""
|
||||
use_fp64_gumbel: bool = False
|
||||
"""Whether to use FP64 (instead of FP32) for the Gumbel noise used by the
|
||||
sampler. FP64 reduces the chance of ties in Gumbel-max sampling at the cost
|
||||
of significantly lower kernel throughput on most GPUs."""
|
||||
"""Whether to use FP64 (instead of FP32) random noise for Gumbel-max and
|
||||
equivalent exponential-race sampling. FP64 preserves lower-tail sampling
|
||||
events that fp32 uniform/exponential draws can truncate, at the cost of
|
||||
significantly lower throughput on most GPUs."""
|
||||
disable_sliding_window: bool = False
|
||||
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||
window functionality of the model, capping to sliding window size. If the
|
||||
|
||||
@@ -180,7 +180,8 @@ class CuMemAllocator:
|
||||
All data in the memory allocation with the specified tag will be
|
||||
offloaded to CPU memory, and others will be discarded.
|
||||
|
||||
:param offload_tags: The tags of the memory allocation that will be
|
||||
Args:
|
||||
offload_tags: The tags of the memory allocation that will be
|
||||
offloaded. The rest of the memory allocation will be discarded.
|
||||
"""
|
||||
if offload_tags is None:
|
||||
@@ -230,7 +231,8 @@ class CuMemAllocator:
|
||||
All data that is previously offloaded will be loaded back to GPU
|
||||
memory, and the rest of the data will have empty memory.
|
||||
|
||||
:param tags: The tags of the memory allocation that will be loaded
|
||||
Args:
|
||||
tags: The tags of the memory allocation that will be loaded
|
||||
back to GPU memory. If None, all memory allocation will be loaded
|
||||
back to GPU memory.
|
||||
"""
|
||||
@@ -255,7 +257,8 @@ class CuMemAllocator:
|
||||
All memory allocation created inside the context will be allocated
|
||||
in the memory pool, and has the specified tag.
|
||||
|
||||
:param tag: The tag of the memory allocation. If None, the default tag
|
||||
Args:
|
||||
tag: The tag of the memory allocation. If None, the default tag
|
||||
will be used.
|
||||
"""
|
||||
if tag is None:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user