[None][feat] Draft: Save state first pass (#7012)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Izzy Putterman 2025-10-01 15:40:55 -07:00 committed by GitHub
parent e107749a69
commit 1ad7bc4c78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 367 additions and 4 deletions

View File

@ -1110,6 +1110,10 @@ class PyExecutor:
sample_state = self._sample_async(scheduled_batch,
batch_outputs)
if self.drafter is not None:
self.drafter.run_drafter_post(scheduled_batch,
self.resource_manager,
self.is_warmup)
self._update_request_states(scheduled_batch)
self._update_requests(sample_state, self.resource_manager)

View File

@ -3,6 +3,7 @@ from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
from .spec_tree_manager import SpecTreeManager
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
get_spec_decoder, get_spec_drafter, get_spec_metadata,
@ -16,6 +17,7 @@ __all__ = [
"MTPWorker",
"NGramDrafter",
"NGramPoolManager",
"SaveHiddenStatesDrafter",
"SpecMetadata",
"get_num_extra_kv_tokens",
"get_num_spec_layers",

View File

@ -67,3 +67,15 @@ class Drafter(ABC):
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))
def run_drafter_post(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
is_warmup: bool = False,
) -> None:
"""
If draft forward needs to be run directly after the target model forward,
this method can be overridden to do that.
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
"""

View File

@ -126,6 +126,10 @@ class Eagle3SpecMetadata(SpecMetadata):
self.num_layers - 4)
else:
self.layers_to_capture = sorted(list(self.layers_to_capture))
if self.layers_to_capture[0] == -1:
self.layers_to_capture = self.layers_to_capture[1:] + [
self.layers_to_capture.pop(0)
]
self.num_capture_layers = len(self.layers_to_capture)
# Initialize to 0 to avoid reading uninitialized memory during warmup

View File

@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum):
NGRAM = auto()
DRAFT_TARGET = auto()
USER_PROVIDED = auto()
SAVE_HIDDEN_STATES = auto()
NONE = auto()
AUTO = auto()
@ -55,6 +56,9 @@ class SpeculativeDecodingMode(IntEnum):
def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET
def is_save_hidden_states(self):
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES
def without_logits(self):
return self.is_mtp_one_model() or self.is_eagle3_one_model()
@ -95,8 +99,9 @@ class SpeculativeDecodingMode(IntEnum):
) or self.is_eagle3_one_model()
def has_spec_drafter(self):
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
) or self.is_user_provided() or self.is_mtp_eagle()
return self.is_eagle3(
) or self.is_draft_target() or self.is_ngram() or self.is_user_provided(
) or self.is_mtp_eagle() or self.is_save_hidden_states()
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
"""

View File

@ -0,0 +1,99 @@
import os
from typing import Optional
import torch
from tensorrt_llm._utils import local_mpi_rank
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .drafter import Drafter
class SaveHiddenStatesDrafter(Drafter):
def __init__(
self,
spec_config: "SaveHiddenStatesDecodingConfig",
spec_resource_manager,
):
super().__init__(spec_config.max_concurrency)
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
self._iter = 1
self._output_directory = spec_config.output_directory
self._file_prefix = spec_config.file_prefix
self._write_interval = spec_config.write_interval
self._saved_state = []
self.spec_resource_manager = spec_resource_manager
os.makedirs(self._output_directory, exist_ok=True)
def _process_request(self, request: LlmRequest, resource_manager) -> None:
out_dict = {}
if local_mpi_rank() == 0:
input_ids = torch.tensor(list(request.get_tokens(0)),
dtype=torch.long,
device='cpu')
hidden_size = resource_manager.hidden_size
num_tokens = input_ids.shape[0]
hidden_states = resource_manager.hidden_states[:num_tokens,
-hidden_size:].cpu(
).clone()
out_dict = {
"id": self._iter,
"input_ids": input_ids,
"hidden_state": hidden_states,
}
if len(self.spec_config.eagle3_layers_to_capture) > 1:
if self.spec_config._last_hidden_in_save:
out_dict[
"aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu(
).clone()
else:
out_dict[
"aux_hidden_states"] = resource_manager.hidden_states[:
num_tokens, :
-hidden_size].cpu(
).clone(
)
self._saved_state.append(out_dict)
def _write_to_file(self) -> None:
if local_mpi_rank() == 0:
output_path = os.path.join(self._output_directory,
f"{self._file_prefix}_{self._iter}.pt")
torch.save(self._saved_state, output_path)
self._saved_state = []
def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
for request in sorted(
scheduled_requests.context_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
request.py_max_new_tokens = 1
def run_drafter_post(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
is_warmup: bool = False,
) -> None:
if is_warmup:
return
for request in sorted(
scheduled_requests.context_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
self._process_request(request, self.spec_resource_manager)
if self._iter % self._write_interval == 0:
self._write_to_file()
self._iter += 1

View File

@ -11,6 +11,7 @@ from .model_drafter import ModelDrafter
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
MTPSpecMetadata, MTPWorker)
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
def get_spec_metadata(spec_config,
@ -55,6 +56,25 @@ def get_spec_metadata(spec_config,
max_num_tokens=max_num_tokens,
layers_to_capture=spec_config.eagle3_layers_to_capture,
)
if spec_config.spec_dec_mode.is_save_hidden_states():
if spec_config.eagle3_layers_to_capture is None:
spec_config.eagle3_layers_to_capture = {
1, model_config.num_hidden_layers // 2 - 1,
model_config.num_hidden_layers - 4, -1
}
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
max_total_draft_tokens=1,
)
if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \
spec_config.spec_dec_mode.is_user_provided():
@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_save_hidden_states():
return Eagle3ResourceManager(
spec_config,
model_engine.model.config.torch_dtype,
model_config.hidden_size,
max_num_requests,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
if spec_dec_mode.is_user_provided():
@ -151,6 +180,9 @@ def get_spec_drafter(model_engine,
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config, spec_resource_manager)
if spec_config.spec_dec_mode.is_save_hidden_states():
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
return None

View File

@ -11,7 +11,8 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
MTPDecodingConfig, NGramDecodingConfig,
SaveHiddenStatesDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig)
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
@ -59,4 +60,5 @@ __all__ = [
'AutoDecodingConfig',
'AttentionDpConfig',
'LoRARequest',
'SaveHiddenStatesDecodingConfig',
]

View File

@ -380,6 +380,7 @@ class DecodingBaseConfig(StrictBaseModel):
"Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig,
"SaveState": SaveHiddenStatesDecodingConfig,
"UserProvided": UserProvidedDecodingConfig,
"AUTO": AutoDecodingConfig,
}
@ -562,6 +563,52 @@ class EagleDecodingConfig(DecodingBaseConfig):
return 3
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
output_directory: str
write_interval: int = 20
file_prefix: str = "data"
eagle3_layers_to_capture: Optional[Set[int]] = None
max_total_draft_tokens: Optional[int] = Field(default=1, init=False)
eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False)
def model_post_init(self, __context):
self._last_hidden_in_save = True
if self.eagle3_layers_to_capture is None:
self._last_hidden_in_save = False
elif -1 not in self.eagle3_layers_to_capture:
self._last_hidden_in_save = False
self.eagle3_layers_to_capture.add(-1)
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "SaveState"
def validate(self) -> None:
if self.output_directory is None or not self.eagle3_layers_to_capture:
raise ValueError(
"Save directory and layers to capture must be provided")
@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES
@functools.cached_property
def num_capture_layers(self):
"""
Returns the number of layers to capture of the target model.
If eagle3_layers_to_capture is not None, return the length of the set.
Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
"""
if self.eagle3_layers_to_capture is None:
return 4
return len(self.eagle3_layers_to_capture)
class UserProvidedDecodingConfig(DecodingBaseConfig):
# Cannot use real type annotations due to circular imports
drafter: object # Type is Drafter
@ -1050,6 +1097,7 @@ SpeculativeConfig: TypeAlias = Optional[Union[
MTPDecodingConfig,
NGramDecodingConfig,
UserProvidedDecodingConfig,
SaveHiddenStatesDecodingConfig,
AutoDecodingConfig,
]]
@ -1869,6 +1917,20 @@ class BaseLlmArgs(StrictBaseModel):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
self.build_config.max_draft_len = self.speculative_config.max_draft_len
elif isinstance(self.speculative_config,
SaveHiddenStatesDecodingConfig):
assert self.backend in ['pytorch']
logger.warning(
"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
)
self.build_config.max_batch_size = 1
self.max_batch_size = 1
self.disable_overlap_scheduler = True
self.cuda_graph_config = None
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
self.build_config.max_draft_len = 1
self.speculative_config.max_draft_len = 1
else:
raise ValueError(
f"Unrecognized speculative config type {type(self.speculative_config)}"

View File

@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
EAGLE = auto()
NGRAM = auto()
USER_PROVIDED = auto()
SAVE_HIDDEN_STATES = auto()
AUTO = auto()
@staticmethod
@ -120,6 +121,8 @@ class SpeculativeDecodingMode(IntFlag):
return SpeculativeDecodingMode.USER_PROVIDED
elif args.speculative_decoding_mode == "auto":
return SpeculativeDecodingMode.AUTO
elif args.speculative_decoding_mode == "save_hidden_states":
return SpeculativeDecodingMode.SAVE_HIDDEN_STATES
else:
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode

View File

@ -0,0 +1,138 @@
import os
import sys
import tempfile
import unittest
import pytest
import torch
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
SaveHiddenStatesDecodingConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
def test_multi_save_state():
use_cuda_graph = True
attn_backend = "TRTLLM"
disable_overlap_scheduler = False
enable_block_reuse = False
enable_chunked_prefill = False
layers_to_capture = {10, 11, 12}
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 80:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
with tempfile.TemporaryDirectory() as temp_dir:
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
max_batch_size = 16
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
free_gpu_memory_fraction=0.5)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = SaveHiddenStatesDecodingConfig(
output_directory=temp_dir,
write_interval=1,
file_prefix="data",
eagle3_layers_to_capture=layers_to_capture)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
sampling_params = SamplingParams(max_tokens=32, temperature=0)
for output in llm_spec.generate_async(tok_ids,
sampling_params,
streaming=True):
pass
llm_spec.shutdown()
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
# Read in .pt file
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
assert saved_data["aux_hidden_states"].shape == (len(tok_ids), 2048 *
len(layers_to_capture))
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
assert saved_data["input_ids"].tolist() == tok_ids
@pytest.mark.parametrize("layers_to_capture", [{-1}, None])
def test_save_state(layers_to_capture):
use_cuda_graph = True
attn_backend = "TRTLLM"
disable_overlap_scheduler = False
enable_block_reuse = False
enable_chunked_prefill = False
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 80:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
with tempfile.TemporaryDirectory() as temp_dir:
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
max_batch_size = 16
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
free_gpu_memory_fraction=0.5)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = SaveHiddenStatesDecodingConfig(
output_directory=temp_dir,
write_interval=1,
file_prefix="data",
eagle3_layers_to_capture=layers_to_capture)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
sampling_params = SamplingParams(max_tokens=32, temperature=0)
for output in llm_spec.generate_async(tok_ids,
sampling_params,
streaming=True):
pass
llm_spec.shutdown()
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
# Read in .pt file
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
if layers_to_capture is None:
assert saved_data["aux_hidden_states"].shape == (len(tok_ids),
2048 * 3)
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
assert saved_data["input_ids"].tolist() == tok_ids
else:
assert "aux_hidden_states" not in saved_data
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
assert saved_data["input_ids"].tolist() == tok_ids
if __name__ == "__main__":
unittest.main()

View File

@ -59,7 +59,7 @@ methods:
default: null
# Speculative decoding
speculative_config:
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType]
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, tensorrt_llm.llmapi.llm_args.SaveHiddenStatesDecodingConfig, NoneType]
default: null
# generation constraints
max_batch_size: