mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Perf] Optimize hidden state extraction logic (#37374)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f743254143
commit
4e2eba28be
@@ -2,9 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
example_hidden_states_connector,
|
||||
)
|
||||
|
||||
# NOTE: If changing the interface of the ExampleHiddenStatesConnector, please also
|
||||
# update the benchmark in benchmarks/benchmark_hidden_state_extraction.py
|
||||
# and the docs in docs/features/speculative_decoding/extract_hidden_states.md
|
||||
|
||||
# Example: Using the custom "extract_hidden_states" speculator method and
|
||||
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
|
||||
@@ -12,6 +18,7 @@ from vllm import LLM, SamplingParams
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen3-8B", # Your target model
|
||||
enable_chunked_prefill=False, # required
|
||||
speculative_config={
|
||||
"method": "extract_hidden_states",
|
||||
"num_speculative_tokens": 1,
|
||||
@@ -23,16 +30,16 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
3,
|
||||
4,
|
||||
],
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
kv_transfer_config={
|
||||
"kv_connector": "ExampleHiddenStatesConnector",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_connector_extra_config": {
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="ExampleHiddenStatesConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": tmpdirname,
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
prompts = ["Generate a sentence with hidden states", "Write a python function"]
|
||||
@@ -47,12 +54,14 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
assert hidden_states_path is not None
|
||||
print("Prompt hidden states path:", hidden_states_path)
|
||||
|
||||
with safe_open(hidden_states_path, "pt") as f:
|
||||
token_ids = f.get_tensor("token_ids")
|
||||
hidden_states = f.get_tensor("hidden_states")
|
||||
obj = example_hidden_states_connector.load_hidden_states(hidden_states_path)
|
||||
token_ids = obj["token_ids"]
|
||||
hidden_states = obj["hidden_states"]
|
||||
|
||||
print("Extracted token ids:", token_ids) # Matches prompt token ids
|
||||
print(
|
||||
"Extracted hidden states shape:", hidden_states.shape
|
||||
) # [prompt len, num_hidden_layers, hidden size]
|
||||
print("Extracted hidden states:", hidden_states)
|
||||
print("Extracted token ids:", token_ids) # Matches prompt token ids
|
||||
print(
|
||||
"Extracted hidden states shape:", hidden_states.shape
|
||||
) # [prompt_len, num_extracted_layers, hidden_size]
|
||||
print("Extracted hidden states:", hidden_states)
|
||||
|
||||
example_hidden_states_connector.cleanup_hidden_states(hidden_states_path)
|
||||
|
||||
Reference in New Issue
Block a user