From 4e2eba28beec9972445c338e8ad2080b3cab3246 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Fri, 22 May 2026 18:23:08 -0400 Subject: [PATCH] [Perf] Optimize hidden state extraction logic (#37374) Signed-off-by: Benjamin Chislett Signed-off-by: Benjamin Chislett Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../benchmark_hidden_state_extraction.py | 415 ++++++++++++++++++ docs/features/speculative_decoding/README.md | 1 + .../extract_hidden_states.md | 86 ++++ .../extract_hidden_states_offline.py | 41 +- .../test_extraction.py | 2 + .../spec_decode/test_extract_hidden_states.py | 2 + vllm/config/vllm.py | 11 + .../v1/example_hidden_states_connector.py | 250 +++++++++-- vllm/v1/spec_decode/extract_hidden_states.py | 29 +- 9 files changed, 772 insertions(+), 65 deletions(-) create mode 100644 benchmarks/benchmark_hidden_state_extraction.py create mode 100644 docs/features/speculative_decoding/extract_hidden_states.md diff --git a/benchmarks/benchmark_hidden_state_extraction.py b/benchmarks/benchmark_hidden_state_extraction.py new file mode 100644 index 00000000000..6056fcdd072 --- /dev/null +++ b/benchmarks/benchmark_hidden_state_extraction.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark hidden state extraction throughput. + +Measures two modes: + 1. Baseline: bulk inference with max_tokens=1, no extraction. + 2. Extract: async hidden state extraction via ExampleHiddenStatesConnector + with N concurrent clients, each consuming hidden states as + soon as their request finishes (overlapping I/O with generation). + +Reports tokens/s and prompts/s for each mode. + +Usage: + python benchmarks/benchmark_hidden_state_extraction.py \ + --model Qwen/Qwen3-0.6B \ + --num-prompts 64 \ + --num-clients 8 \ + --prompt-len 8192 \ + --layers 1 2 3 4 +""" + +import argparse +import asyncio +import time +from concurrent.futures import ThreadPoolExecutor + +import torch +from transformers import AutoConfig + +from vllm import LLM, SamplingParams +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + example_hidden_states_connector, +) +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import RequestOutputKind +from vllm.v1.engine.async_llm import AsyncLLM + + +def _make_profiler_config(profile_dir: str) -> dict: + """Build a profiler_config dict for torch profiling.""" + return { + "profiler": "torch", + "torch_profiler_dir": profile_dir, + "torch_profiler_with_stack": True, + } + + +def make_random_prompts( + num_prompts: int, prompt_len: int, vocab_size: int, seed: int = 42 +) -> list[list[int]]: + """Generate lists of random token IDs.""" + # Set seed for reproducibility + torch.manual_seed(seed) + return [ + torch.randint(0, vocab_size, (prompt_len,)).tolist() for _ in range(num_prompts) + ] + + +def consume_hidden_states(path: str) -> float: + """Load hidden states from disk and compute per-position mean. + + Returns a single float: the grand mean of all hidden state values. + This forces the benchmark to actually read and reduce the data. + + Uses :func:`load_hidden_states` which acquires a shared flock, + blocking (without polling) until the async writer releases its + exclusive lock. + """ + obj = example_hidden_states_connector.load_hidden_states(path) + hs = obj["hidden_states"] + total = hs.mean().item() + + example_hidden_states_connector.cleanup_hidden_states(path) + + return total + + +def run_baseline( + model: str, + prompts: list[list[int]], + extra_args: dict, + profile_dir: str | None = None, +) -> dict: + """Baseline: bulk inference, no hidden state extraction.""" + if profile_dir: + extra_args = { + **extra_args, + "profiler_config": _make_profiler_config(profile_dir), + } + llm = LLM( + model=model, + enable_prefix_caching=False, + enable_chunked_prefill=False, + **extra_args, + ) + sampling_params = SamplingParams(max_tokens=1) + prompt_inputs = [{"prompt_token_ids": p} for p in prompts] + + # Warmup + llm.generate(prompt_inputs[:4], sampling_params, use_tqdm=False) + + if profile_dir: + llm.start_profile() + + t0 = time.perf_counter() + outputs = llm.generate(prompt_inputs, sampling_params, use_tqdm=True) + elapsed = time.perf_counter() - t0 + + if profile_dir: + llm.stop_profile() + + total_prompt_tokens = sum(len(o.prompt_token_ids) for o in outputs) + num_prompts = len(outputs) + + del llm + torch.accelerator.empty_cache() + + return { + "mode": "baseline", + "elapsed_s": elapsed, + "num_prompts": num_prompts, + "total_prompt_tokens": total_prompt_tokens, + "tokens_per_s": total_prompt_tokens / elapsed, + "prompts_per_s": num_prompts / elapsed, + } + + +# ---- Async extraction benchmark ---- + + +async def _client_loop( + engine: AsyncLLM, + prompt_queue: asyncio.Queue, + consume_pool: ThreadPoolExecutor, + results: list[dict], + client_id: int, +): + """A single async client: pulls prompts, submits to engine, consumes + hidden states as soon as each request finishes.""" + loop = asyncio.get_event_loop() + while True: + item = await prompt_queue.get() + if item is None: + prompt_queue.task_done() + break + idx, token_ids = item + + request_id = f"req-{idx}" + sampling_params = SamplingParams( + max_tokens=1, + output_kind=RequestOutputKind.FINAL_ONLY, + ) + + final_output = None + async for output in engine.generate( + request_id=request_id, + prompt={"prompt_token_ids": token_ids}, + sampling_params=sampling_params, + ): + if output.finished: + final_output = output + + # Consume hidden states on a thread (disk I/O) + path = final_output.kv_transfer_params["hidden_states_path"] + mean_val = await loop.run_in_executor(consume_pool, consume_hidden_states, path) + num_tokens = len(final_output.prompt_token_ids) + + results.append( + { + "request_id": request_id, + "num_prompt_tokens": num_tokens, + "mean_hidden_value": mean_val, + } + ) + prompt_queue.task_done() + + +async def _run_extraction_async( + model: str, + prompts: list[list[int]], + num_clients: int, + layers: list[int], + tmpdir: str, + extra_args: dict, + profile_dir: str | None = None, +) -> dict: + if profile_dir: + extra_args = { + **extra_args, + "profiler_config": _make_profiler_config(profile_dir), + } + engine_args = AsyncEngineArgs( + model=model, + enable_prefix_caching=False, + enable_chunked_prefill=False, + max_num_batched_tokens=40960, + max_model_len=40960, + speculative_config={ + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": { + "eagle_aux_hidden_state_layer_ids": layers, + }, + }, + }, + kv_transfer_config=KVTransferConfig( + kv_connector="ExampleHiddenStatesConnector", + kv_role="kv_producer", + kv_connector_extra_config={ + "shared_storage_path": tmpdir, + }, + ), + **extra_args, + ) + engine = AsyncLLM.from_engine_args(engine_args) + + try: + # Warmup: run a few prompts sequentially, cleaning up generated files + for i in range(min(4, len(prompts))): + sp = SamplingParams(max_tokens=1, output_kind=RequestOutputKind.FINAL_ONLY) + final_output = None + async for output in engine.generate( + request_id=f"warmup-{i}", + prompt={"prompt_token_ids": prompts[i]}, + sampling_params=sp, + ): + if output.finished: + final_output = output + if final_output and final_output.kv_transfer_params: + path = final_output.kv_transfer_params.get("hidden_states_path") + if path: + example_hidden_states_connector.cleanup_hidden_states(path) + + if profile_dir: + await engine.start_profile() + + # Fill prompt queue + prompt_queue: asyncio.Queue = asyncio.Queue() + for idx, token_ids in enumerate(prompts): + prompt_queue.put_nowait((idx, token_ids)) + # Sentinel per client + for _ in range(num_clients): + prompt_queue.put_nowait(None) + + results: list[dict] = [] + consume_pool = ThreadPoolExecutor(max_workers=num_clients) + + t0 = time.perf_counter() + tasks = [ + asyncio.create_task( + _client_loop(engine, prompt_queue, consume_pool, results, i) + ) + for i in range(num_clients) + ] + await asyncio.gather(*tasks) + elapsed = time.perf_counter() - t0 + + consume_pool.shutdown(wait=True) + + if profile_dir: + await engine.stop_profile() + + total_prompt_tokens = sum(r["num_prompt_tokens"] for r in results) + num_prompts = len(results) + mean_hidden = sum(r["mean_hidden_value"] for r in results) / max( + len(results), 1 + ) + + return { + "mode": "extract", + "elapsed_s": elapsed, + "num_prompts": num_prompts, + "total_prompt_tokens": total_prompt_tokens, + "tokens_per_s": total_prompt_tokens / elapsed, + "prompts_per_s": num_prompts / elapsed, + "mean_hidden_value": mean_hidden, + } + finally: + engine.shutdown() + + +def run_extraction( + model: str, + prompts: list[list[int]], + num_clients: int, + layers: list[int], + extra_args: dict, + profile_dir: str | None = None, +) -> dict: + return asyncio.run( + _run_extraction_async( + model, + prompts, + num_clients, + layers, + "/dev/shm", + extra_args, + profile_dir=profile_dir, + ) + ) + + +def print_results(results: dict): + mode = results["mode"] + print(f"\n{'=' * 60}") + print(f" {mode.upper()} RESULTS") + print(f"{'=' * 60}") + print(f" Prompts: {results['num_prompts']}") + print(f" Total prompt tokens: {results['total_prompt_tokens']:,}") + print(f" Wall time: {results['elapsed_s']:.2f}s") + print(f" Tokens/s: {results['tokens_per_s']:,.0f}") + print(f" Prompts/s: {results['prompts_per_s']:.2f}") + if mode == "extract": + print(f" Mean hidden value: {results['mean_hidden_value']:.6f}") + print(f"{'=' * 60}\n") + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark hidden state extraction throughput" + ) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--num-prompts", type=int, default=64) + parser.add_argument("--num-clients", type=int, default=8) + parser.add_argument("--prompt-len", type=int, default=8192) + parser.add_argument("--layers", type=int, nargs="+", default=[1, 2, 3, 4]) + parser.add_argument("--skip-baseline", action="store_true") + parser.add_argument("--skip-extract", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--max-num-batched-tokens", type=int, default=None) + parser.add_argument("--max-cudagraph-capture-size", type=int, default=None) + parser.add_argument("--max-model-len", type=int, default=None) + parser.add_argument("--enforce-eager", action="store_true") + parser.add_argument("--load-format", type=str, default=None) + parser.add_argument( + "--profile", + action="store_true", + help="Enable torch profiler for both baseline and extraction runs.", + ) + parser.add_argument( + "--torch-profiler-dir", + type=str, + default="./vllm_profile", + help="Directory to save torch profiler traces (default: ./vllm_profile).", + ) + parser.add_argument( + "--enable-flashinfer-autotune", + action="store_true", + default=False, + help="Enable FlashInfer autotuning (can be slow).", + ) + args = parser.parse_args() + + extra_args = { + "gpu_memory_utilization": args.gpu_memory_utilization, + } + if args.max_model_len is not None: + extra_args["max_model_len"] = args.max_model_len + if args.max_num_batched_tokens is not None: + extra_args["max_num_batched_tokens"] = args.max_num_batched_tokens + if args.max_model_len and args.max_num_batched_tokens < args.max_model_len: + raise ValueError( + "max_num_batched_tokens must be >= max_model_len since chunked prefill" + " is not supported by hidden state extraction." + ) + if args.enforce_eager: + extra_args["enforce_eager"] = True + if args.load_format is not None: + extra_args["load_format"] = args.load_format + if args.max_cudagraph_capture_size is not None: + extra_args["max_cudagraph_capture_size"] = args.max_cudagraph_capture_size + extra_args["enable_flashinfer_autotune"] = args.enable_flashinfer_autotune + + # Get vocab size from HF config without loading the full model + hf_config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) + vocab_size = hf_config.vocab_size + prompts = make_random_prompts(args.num_prompts, args.prompt_len, vocab_size) + print( + f"Generated {args.num_prompts} prompts, " + f"{args.prompt_len} tokens each (vocab {vocab_size})" + ) + + profile_dir = args.torch_profiler_dir if args.profile else None + if profile_dir: + print(f"Torch profiler enabled, traces will be saved to {profile_dir}/") + + if not args.skip_baseline: + baseline_profile_dir = f"{profile_dir}/baseline" if profile_dir else None + baseline = run_baseline( + args.model, prompts, extra_args, profile_dir=baseline_profile_dir + ) + print_results(baseline) + + if not args.skip_extract: + extract_profile_dir = f"{profile_dir}/extract" if profile_dir else None + extract = run_extraction( + args.model, + prompts, + args.num_clients, + args.layers, + extra_args, + profile_dir=extract_profile_dir, + ) + print_results(extract) + + if not args.skip_baseline and not args.skip_extract: + slowdown = baseline["tokens_per_s"] / extract["tokens_per_s"] + print("Extraction slowdown factor: {:.2f}x".format(slowdown)) + + +if __name__ == "__main__": + main() diff --git a/docs/features/speculative_decoding/README.md b/docs/features/speculative_decoding/README.md index 056a5e96a99..768e9f78d40 100644 --- a/docs/features/speculative_decoding/README.md +++ b/docs/features/speculative_decoding/README.md @@ -15,6 +15,7 @@ vLLM supports a variety of methods of speculative decoding. Model-based methods - [Multi-Layer Perceptron](mlp.md) - [N-Gram](n_gram.md) - [Suffix Decoding](suffix.md) +- [Hidden State Extraction](extract_hidden_states.md) - [Custom Proposer Backend (Experimental)](#custom-proposer-backend-experimental) ## Method Selection at a Glance diff --git a/docs/features/speculative_decoding/extract_hidden_states.md b/docs/features/speculative_decoding/extract_hidden_states.md new file mode 100644 index 00000000000..2184a71f489 --- /dev/null +++ b/docs/features/speculative_decoding/extract_hidden_states.md @@ -0,0 +1,86 @@ +# Hidden State Extraction + +The Hidden State Extraction feature allows vLLM to save intermediate layer activations from a target model during inference. This is useful for training [EAGLE](eagle.md)-style draft models, knowledge distillation, or offline analysis of model internals. + +!!! note + It is possible to save the last-layer's output hidden states by passing `num_hidden_layers` as a layer id. Note that these are _not_ normalized using the output norm. + +## Offline Example + +```python +import tempfile + +from vllm import LLM, SamplingParams +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + example_hidden_states_connector, +) + +with tempfile.TemporaryDirectory() as tmpdir: + llm = LLM( + model="Qwen/Qwen3-8B", + enable_chunked_prefill=False, + speculative_config={ + "method": "extract_hidden_states", + "num_speculative_tokens": 1, + "draft_model_config": { + "hf_config": { + "eagle_aux_hidden_state_layer_ids": [1, 2, 3, 4], + }, + }, + }, + kv_transfer_config=KVTransferConfig( + kv_connector="ExampleHiddenStatesConnector", + kv_role="kv_producer", + kv_connector_extra_config={ + "shared_storage_path": tmpdir, + }, + ), + ) + + outputs = llm.generate( + ["The future of AI is"], + SamplingParams(max_tokens=1), + ) + + for output in outputs: + path = output.kv_transfer_params["hidden_states_path"] + obj = example_hidden_states_connector.load_hidden_states(path) + print(f"token_ids: {obj['token_ids'].shape}") + print(f"hidden_states: {obj['hidden_states'].shape}") +``` + +A complete example is available at [`examples/features/speculative_decoding/extract_hidden_states_offline.py`](../../../examples/features/speculative_decoding/extract_hidden_states_offline.py). + +## Online Example + +For improved performance, it is recommended to use a RAM-mounted file system such as `/dev/shm/` for online usage in which the client cleans up the files soon after they are generated. + +```bash +vllm serve Qwen/Qwen3-8B \ + --speculative_config '{"method": "extract_hidden_states", "num_speculative_tokens": 1, "draft_model_config": {"hf_config": {"eagle_aux_hidden_state_layer_ids": [1, 2, 3, 4]}}}' \ + --kv_transfer_config '{"kv_connector": "ExampleHiddenStatesConnector", "kv_role": "kv_producer", "kv_connector_extra_config": {"shared_storage_path": "/dev/shm/hidden_states"}}' \ + --no-enable-chunked-prefill +``` + +## Configuration + +The `kv_connector_extra_config` dict accepts these options: + +| Parameter | Default | Description | +| --- | --- | --- | +| `shared_storage_path` | `/tmp` | Directory where hidden state files are saved | +| `num_writer_threads` | `8` | Thread pool size for async disk writes | +| `use_synchronization_lock` | `True` | Use file locks so concurrent readers block until writes complete. Can be disabled for batch generation where synchronization is not needed. | + +## Output Format + +Each request produces a `.safetensors` file containing: + +- **`hidden_states`** — shape `[num_tokens, num_extracted_layers, hidden_size]` +- **`token_ids`** — shape `[num_tokens]` + +The file path is returned in `output.kv_transfer_params["hidden_states_path"]`. Use `load_hidden_states()` from the connector module to read the file with proper synchronization. + +!!! note + Chunked prefill is not compatible with this feature and must be disabled. diff --git a/examples/features/speculative_decoding/extract_hidden_states_offline.py b/examples/features/speculative_decoding/extract_hidden_states_offline.py index 551f137614f..f8909566f40 100644 --- a/examples/features/speculative_decoding/extract_hidden_states_offline.py +++ b/examples/features/speculative_decoding/extract_hidden_states_offline.py @@ -2,9 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile -from safetensors import safe_open - from vllm import LLM, SamplingParams +from vllm.config.kv_transfer import KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + example_hidden_states_connector, +) + +# NOTE: If changing the interface of the ExampleHiddenStatesConnector, please also +# update the benchmark in benchmarks/benchmark_hidden_state_extraction.py +# and the docs in docs/features/speculative_decoding/extract_hidden_states.md # Example: Using the custom "extract_hidden_states" speculator method and # ExampleHiddenStatesConnector to extract and save hidden states from vllm @@ -12,6 +18,7 @@ from vllm import LLM, SamplingParams with tempfile.TemporaryDirectory() as tmpdirname: llm = LLM( model="Qwen/Qwen3-8B", # Your target model + enable_chunked_prefill=False, # required speculative_config={ "method": "extract_hidden_states", "num_speculative_tokens": 1, @@ -23,16 +30,16 @@ with tempfile.TemporaryDirectory() as tmpdirname: 3, 4, ], - } + }, }, }, - kv_transfer_config={ - "kv_connector": "ExampleHiddenStatesConnector", - "kv_role": "kv_producer", - "kv_connector_extra_config": { + kv_transfer_config=KVTransferConfig( + kv_connector="ExampleHiddenStatesConnector", + kv_role="kv_producer", + kv_connector_extra_config={ "shared_storage_path": tmpdirname, }, - }, + ), ) prompts = ["Generate a sentence with hidden states", "Write a python function"] @@ -47,12 +54,14 @@ with tempfile.TemporaryDirectory() as tmpdirname: assert hidden_states_path is not None print("Prompt hidden states path:", hidden_states_path) - with safe_open(hidden_states_path, "pt") as f: - token_ids = f.get_tensor("token_ids") - hidden_states = f.get_tensor("hidden_states") + obj = example_hidden_states_connector.load_hidden_states(hidden_states_path) + token_ids = obj["token_ids"] + hidden_states = obj["hidden_states"] - print("Extracted token ids:", token_ids) # Matches prompt token ids - print( - "Extracted hidden states shape:", hidden_states.shape - ) # [prompt len, num_hidden_layers, hidden size] - print("Extracted hidden states:", hidden_states) + print("Extracted token ids:", token_ids) # Matches prompt token ids + print( + "Extracted hidden states shape:", hidden_states.shape + ) # [prompt_len, num_extracted_layers, hidden_size] + print("Extracted hidden states:", hidden_states) + + example_hidden_states_connector.cleanup_hidden_states(hidden_states_path) diff --git a/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py b/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py index 1426beb9d06..5cc19247f51 100644 --- a/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py +++ b/tests/v1/kv_connector/extract_hidden_states_integration/test_extraction.py @@ -120,6 +120,7 @@ def test_extract_hidden_states_with_predictable_dummy_model( }, max_model_len=128, enforce_eager=True, + enable_chunked_prefill=False, trust_remote_code=True, load_format="dummy", # Don't try to load real weights ) @@ -184,6 +185,7 @@ def test_extract_hidden_states_qwen35_hybrid_smoke(tmp_path): }, max_model_len=256, enforce_eager=True, + enable_chunked_prefill=False, gpu_memory_utilization=0.4, load_format="dummy", ) diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py index 2a67257b091..b568d0b204f 100644 --- a/tests/v1/spec_decode/test_extract_hidden_states.py +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -69,6 +69,7 @@ def _create_proposer( scheduler_config=SchedulerConfig( max_model_len=model_config.max_model_len, is_encoder_decoder=model_config.is_encoder_decoder, + enable_chunked_prefill=False, ), attention_config=AttentionConfig(), ) @@ -119,6 +120,7 @@ def test_proposer_initialization_missing_layer_ids(): scheduler_config=SchedulerConfig( max_model_len=model_config.max_model_len, is_encoder_decoder=model_config.is_encoder_decoder, + enable_chunked_prefill=False, ), attention_config=AttentionConfig(), ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 2c020e9cffd..f009dd6f154 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -734,6 +734,17 @@ class VllmConfig: Right now, this function reads the offloading settings from CacheConfig and configures the KVTransferConfig accordingly. """ + # Check if KV connector requires chunked prefill to be disabled. + if ( + self.kv_transfer_config is not None + and self.kv_transfer_config.kv_connector == "ExampleHiddenStatesConnector" + and self.scheduler_config.enable_chunked_prefill + ): + raise ValueError( + "ExampleHiddenStatesConnector does not support chunked prefill. " + "Please disable chunked prefill (--no-enable-chunked-prefill)." + ) + # KV offloading is only activated when kv_offloading_size is set. if (kv_offloading_size := self.cache_config.kv_offloading_size) is None: return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py index 98b03c4ebb1..3e4e6750858 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -1,11 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import fcntl import os +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field +from functools import partial +from importlib.metadata import version from typing import TYPE_CHECKING, Any -import safetensors import torch +from packaging.version import Version +from safetensors.torch import load_file, save_file from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorRole, SupportsHMA, ) +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput @@ -36,6 +42,39 @@ def extract_from_kv_cache( return kv_cache[slot_mapping // block_size, slot_mapping % block_size][:num_tokens] +def load_hidden_states(path: str) -> dict[str, torch.Tensor]: + """Load hidden states written by ExampleHiddenStatesConnector. + + Blocks (without polling) until the async write is complete by + acquiring a shared flock on the companion lock file. The kernel + puts the caller to sleep until the writer releases its exclusive lock. + + Args: + path: The file path returned in kv_transfer_params["hidden_states_path"]. + + Returns: + Dict with "hidden_states" and "token_ids" tensors. + """ + lock_path = path + ".lock" + with open(lock_path) as lf: + fcntl.flock(lf, fcntl.LOCK_SH) # sleeps until writer releases LOCK_EX + data = load_file(path, device="cpu") + return data + + +def cleanup_hidden_states(path: str, keep_hidden_states: bool = False) -> None: + """Clean up hidden states file and lock file after loading. + + If keep_hidden_states is True, only removes the lock file + and keeps the hidden states file. + """ + lock_path = path + ".lock" + if os.path.exists(lock_path): + os.remove(lock_path) + if not keep_hidden_states and os.path.exists(path): + os.remove(path) + + @dataclass class ReqMeta: # Request ID @@ -112,6 +151,13 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): logger.info(self._kv_transfer_config) logger.info("Shared storage path is %s", self._storage_path) + if Version(version("safetensors")) < Version("0.8.0"): + logger.warning( + "safetensors < 0.8.0 holds the GIL during save_file, which " + "serializes the writer thread pool and hurts throughput. " + "Upgrade to safetensors >= 0.8.0 for better performance." + ) + assert self._vllm_config.speculative_config is not None, ( "ExampleHiddenStatesConnector only works when using " "'extract_hidden_states' speculative method" @@ -125,17 +171,97 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): self._active_requests: dict[str, NewRequestData] = {} self._req_blocks: dict[str, list[int]] = {} + # Async write infrastructure (worker-side). + # Dedicated CUDA stream for DtoH copies so they don't block + # the default stream (model forward). Thread pool for disk writes. + self._copy_stream: torch.cuda.Stream | None = None # lazy init + self._executor = ThreadPoolExecutor( + max_workers=self._kv_transfer_config.get_from_extra_config( + "num_writer_threads", 8 + ), + thread_name_prefix="vllm-hs-save", + ) + # Whether to use a filesystem lock when writing files to shared storage. + # This is necessary for online transfer clients to avoid incomplete reads, + # but can be disabled for offline tasks that run tasks in batches to completion + self.use_lock = self._kv_transfer_config.get_from_extra_config( + "use_synchronization_lock", True + ) + # (tensors_dict, copy_done_event, filename, req_id) queued by + # save_kv_layer, submitted to thread pool by wait_for_save. + self._pending_copies: list[ + tuple[dict[str, torch.Tensor], torch.cuda.Event, str, str] + ] = [] + # req_id → in-flight disk-write Future for that req_id. + self._req_futures: dict[str, Future] = {} + # req_id → CUDA event marking completion of the DtoH copy. Once + # this event is complete the request is considered "done sending" + # by get_finished; clients block on the per-file flock to wait for + # the disk write itself. + self._req_copy_events: dict[str, torch.cuda.Event] = {} + # req_ids reported as finished-generating by the scheduler, + # accumulated across get_finished calls. + self._accumulated_finished_req_ids: set[str] = set() + + def _get_copy_stream(self) -> torch.cuda.Stream: + """Lazily create the copy stream (CUDA must be initialized).""" + if self._copy_stream is None: + self._copy_stream = torch.cuda.Stream() + return self._copy_stream + # ============================== # Worker-side methods # ============================== def start_load_kv(self, *args, **kwargs: Any) -> None: - pass # Empty implementation of abstract method + pass # Store-only connector — nothing to load def wait_for_layer_load(self, layer_name: str) -> None: - pass # Empty implementation of abstract method + pass # Store-only connector — nothing to load def wait_for_save(self): - pass # Empty implementation of abstract method + """Submit pending async copies to the thread pool for disk write. + + For each pending write we acquire an exclusive flock on a + companion ``.lock`` file **before** submitting to the thread pool. + The thread worker releases the lock after the data file is fully + written. Clients call :func:`load_hidden_states` which takes a + shared flock — the kernel sleeps the client until the writer is + done. Because ``wait_for_save`` runs before the worker returns + output to the scheduler, the lock file is guaranteed to exist + (and be held) by the time the client receives the path. + + The lock can be disabled via the "use_synchronization_lock" extra config. + """ + for tensors, event, filename, req_id in self._pending_copies: + prior = self._req_futures.get(req_id) + assert prior is None, "Found another KV transfer request with same req_id!" + + lock_fd = None + if self.use_lock: + # Create/open the lock file and acquire an exclusive lock. + # The lock is held by this fd; the thread worker will close + # the fd after writing, which releases the lock. + lock_path = filename + ".lock" + lock_fd = os.open( + lock_path, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o644 + ) + fcntl.flock(lock_fd, fcntl.LOCK_EX) + + future = self._executor.submit( + self._write_tensors, tensors, event, filename, lock_fd + ) + self._req_copy_events[req_id] = event + self._req_futures[req_id] = future + future.add_done_callback(partial(self._on_write_done, req_id)) + self._pending_copies.clear() + + def _on_write_done(self, req_id: str, future: Future) -> None: + """Surface any exception from the disk-write thread and drop the + completed future from the in-flight tracking dict.""" + self._req_futures.pop(req_id, None) + exc = future.exception() + if exc is not None: + logger.error("Hidden-states write failed for req_id=%s: %r", req_id, exc) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): from vllm.model_executor.models.extract_hidden_states import ( @@ -151,6 +277,26 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}" ) + @staticmethod + def _write_tensors( + tensors: dict[str, torch.Tensor], + event: torch.cuda.Event, + filename: str, + lock_fd: int | None, + ) -> None: + """Thread worker: wait for async DtoH copy, write to disk, release lock. + + ``lock_fd`` is an open file descriptor on the companion ``.lock`` + file with ``LOCK_EX`` already held. Closing it releases the lock, + which unblocks any client sleeping on ``LOCK_SH``. + """ + try: + event.synchronize() + save_file(tensors, filename) + finally: + if lock_fd is not None: + os.close(lock_fd) # releases LOCK_EX + def save_kv_layer( self, layer_name: str, @@ -161,6 +307,10 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. + Launches an async DtoH copy on a dedicated CUDA stream. The + actual disk write is deferred to wait_for_save() which submits + it to a thread pool. + Args: layer_name (str): the name of the layer. kv_layer (torch.Tensor): the paged KV buffer of the current @@ -184,21 +334,46 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): os.makedirs(self._storage_path, exist_ok=True) - slot_mapping = attn_metadata.slot_mapping + copy_stream = self._get_copy_stream() + + # Ensure the copy stream sees all prior writes on the default stream. + ready_event = torch.cuda.Event() + ready_event.record() + copy_stream.wait_event(ready_event) + + slot_mapping = get_forward_context().slot_mapping[layer_name] # type: ignore offset = 0 for request in connector_metadata.requests: num_tokens = request.token_ids.shape[0] - req_slot_mapping = slot_mapping[offset : offset + num_tokens] - offset += num_tokens + with torch.cuda.stream(copy_stream): + req_slot_mapping_gpu = slot_mapping[offset : offset + num_tokens] + assert req_slot_mapping_gpu.device == kv_layer.device + offset += num_tokens - hidden_states = extract_from_kv_cache( - kv_layer, req_slot_mapping, num_tokens + hidden_states_gpu = extract_from_kv_cache( + kv_layer, req_slot_mapping_gpu, num_tokens + ) + # Async DtoH copy into pinned host memory. + pinned_hs = torch.empty_like( + hidden_states_gpu, device="cpu", pin_memory=True + ) + pinned_hs.copy_(hidden_states_gpu, non_blocking=True) + + # Record completion of this copy on the copy stream. + copy_done = torch.cuda.Event() + copy_done.record(copy_stream) + + # token_ids is already on CPU (created in ReqMeta.make_meta). + assert not request.token_ids.is_cuda, ( + "Expected token_ids on CPU, got CUDA tensor" ) tensors = { - "hidden_states": hidden_states.detach().cpu(), - "token_ids": request.token_ids.detach().cpu(), + "hidden_states": pinned_hs, + "token_ids": request.token_ids.clone(), } - safetensors.torch.save_file(tensors, request.filename) + self._pending_copies.append( + (tensors, copy_done, request.filename, request.req_id) + ) # ============================== # Scheduler-side methods @@ -258,31 +433,6 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): self._active_requests[new_req.req_id] = new_req self._req_blocks[new_req.req_id] = list(new_req.block_ids[0]) - cached_reqs = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(cached_reqs.req_ids): - if req_id not in self._active_requests: - continue - - new_block_ids = cached_reqs.new_block_ids[i] - - cached_req = self._active_requests[req_id] - req_block_ids = self._req_blocks[req_id] - - if new_block_ids is None: - continue - - block_ids = new_block_ids[0] - - req_block_ids.extend(block_ids) - filename = os.path.join(self._storage_path, f"{req_id}.safetensors") - - meta.add_request( - req_id=req_id, - filename=filename, - token_ids=cached_req.prompt_token_ids or [], - new_req=False, - ) - return meta def request_finished( @@ -309,7 +459,31 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1, SupportsHMA): _ = self._active_requests.pop(req_id, None) _ = self._req_blocks.pop(req_id, None) - return False, {"hidden_states_path": req_filename} + return True, {"hidden_states_path": req_filename} + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """Poll DtoH-copy completion for requests that finished generating. + + The scheduler passes finished_req_ids to tell the worker which + requests are done generating. We accumulate these across calls + and return a request as "finished sending" once its DtoH copy + event is complete (or if it never had a pending copy). The + subsequent disk write may still be in flight; clients block on + the per-file flock to wait for it. + """ + self._accumulated_finished_req_ids.update(finished_req_ids) + + done_sending: set[str] = set() + for req_id in list(self._accumulated_finished_req_ids): + event = self._req_copy_events.get(req_id) + if event is None or event.query(): + self._req_copy_events.pop(req_id, None) + done_sending.add(req_id) + self._accumulated_finished_req_ids.discard(req_id) + + return done_sending or None, None def request_finished_all_groups( self, diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index 26ba0352e2a..c3cb3c8aaea 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -12,8 +12,10 @@ from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model +from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -50,6 +52,14 @@ class ExtractHiddenStatesProposer: vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size ) + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None) if not layer_ids: @@ -303,18 +313,15 @@ class ExtractHiddenStatesProposer: (if valid and not discarded) or a backup token from the request state. """ num_reqs = gpu_input_batch.num_reqs - device = sampled_token_ids.device - # Compute backup tokens for discarded / invalid requests - seq_lens_list = (gpu_input_batch.num_tokens_no_spec[:num_reqs] - 1).tolist() - backup_tokens_gpu = torch.tensor( - [ - requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) - for i in range(num_reqs) - ], - dtype=torch.int32, - device=device, - ) + # Precompute backup token IDs for discarded requests. + num_reqs = gpu_input_batch.num_reqs + for i in range(num_reqs): + self.backup_next_token_ids.np[i] = requests[ + gpu_input_batch.req_ids[i] + ].get_token_id(gpu_input_batch.num_tokens_no_spec[i] - 1) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + backup_tokens_gpu = self.backup_next_token_ids.gpu[:num_reqs] assert discard_request_mask.dtype == torch.bool