[Perf] Optimize hidden state extraction logic (#37374)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Benjamin Chislett
2026-05-22 18:23:08 -04:00
committed by GitHub
parent f743254143
commit 4e2eba28be
9 changed files with 772 additions and 65 deletions
@@ -2,9 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from safetensors import safe_open
from vllm import LLM, SamplingParams
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1 import (
example_hidden_states_connector,
)
# NOTE: If changing the interface of the ExampleHiddenStatesConnector, please also
# update the benchmark in benchmarks/benchmark_hidden_state_extraction.py
# and the docs in docs/features/speculative_decoding/extract_hidden_states.md
# Example: Using the custom "extract_hidden_states" speculator method and
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
@@ -12,6 +18,7 @@ from vllm import LLM, SamplingParams
with tempfile.TemporaryDirectory() as tmpdirname:
llm = LLM(
model="Qwen/Qwen3-8B", # Your target model
enable_chunked_prefill=False, # required
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
@@ -23,16 +30,16 @@ with tempfile.TemporaryDirectory() as tmpdirname:
3,
4,
],
}
},
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
kv_transfer_config=KVTransferConfig(
kv_connector="ExampleHiddenStatesConnector",
kv_role="kv_producer",
kv_connector_extra_config={
"shared_storage_path": tmpdirname,
},
},
),
)
prompts = ["Generate a sentence with hidden states", "Write a python function"]
@@ -47,12 +54,14 @@ with tempfile.TemporaryDirectory() as tmpdirname:
assert hidden_states_path is not None
print("Prompt hidden states path:", hidden_states_path)
with safe_open(hidden_states_path, "pt") as f:
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")
obj = example_hidden_states_connector.load_hidden_states(hidden_states_path)
token_ids = obj["token_ids"]
hidden_states = obj["hidden_states"]
print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [prompt len, num_hidden_layers, hidden size]
print("Extracted hidden states:", hidden_states)
print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [prompt_len, num_extracted_layers, hidden_size]
print("Extracted hidden states:", hidden_states)
example_hidden_states_connector.cleanup_hidden_states(hidden_states_path)