[None][feat] Draft: Save state first pass (#7012)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Izzy Putterman 2025-10-01 15:40:55 -07:00 committed by GitHub
parent e107749a69
commit 1ad7bc4c78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 367 additions and 4 deletions

View File

@ -1110,6 +1110,10 @@ class PyExecutor:
sample_state = self._sample_async(scheduled_batch, sample_state = self._sample_async(scheduled_batch,
batch_outputs) batch_outputs)
if self.drafter is not None:
self.drafter.run_drafter_post(scheduled_batch,
self.resource_manager,
self.is_warmup)
self._update_request_states(scheduled_batch) self._update_request_states(scheduled_batch)
self._update_requests(sample_state, self.resource_manager) self._update_requests(sample_state, self.resource_manager)

View File

@ -3,6 +3,7 @@ from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata from .interface import SpecMetadata
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
from .spec_tree_manager import SpecTreeManager from .spec_tree_manager import SpecTreeManager
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers, from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
get_spec_decoder, get_spec_drafter, get_spec_metadata, get_spec_decoder, get_spec_drafter, get_spec_metadata,
@ -16,6 +17,7 @@ __all__ = [
"MTPWorker", "MTPWorker",
"NGramDrafter", "NGramDrafter",
"NGramPoolManager", "NGramPoolManager",
"SaveHiddenStatesDrafter",
"SpecMetadata", "SpecMetadata",
"get_num_extra_kv_tokens", "get_num_extra_kv_tokens",
"get_num_spec_layers", "get_num_spec_layers",

View File

@ -67,3 +67,15 @@ class Drafter(ABC):
num_draft_tokens = get_draft_token_length(req) num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend( req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens)) 0 for _ in range(max_draft_tokens - num_draft_tokens))
def run_drafter_post(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
is_warmup: bool = False,
) -> None:
"""
If draft forward needs to be run directly after the target model forward,
this method can be overridden to do that.
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
"""

View File

@ -126,6 +126,10 @@ class Eagle3SpecMetadata(SpecMetadata):
self.num_layers - 4) self.num_layers - 4)
else: else:
self.layers_to_capture = sorted(list(self.layers_to_capture)) self.layers_to_capture = sorted(list(self.layers_to_capture))
if self.layers_to_capture[0] == -1:
self.layers_to_capture = self.layers_to_capture[1:] + [
self.layers_to_capture.pop(0)
]
self.num_capture_layers = len(self.layers_to_capture) self.num_capture_layers = len(self.layers_to_capture)
# Initialize to 0 to avoid reading uninitialized memory during warmup # Initialize to 0 to avoid reading uninitialized memory during warmup

View File

@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum):
NGRAM = auto() NGRAM = auto()
DRAFT_TARGET = auto() DRAFT_TARGET = auto()
USER_PROVIDED = auto() USER_PROVIDED = auto()
SAVE_HIDDEN_STATES = auto()
NONE = auto() NONE = auto()
AUTO = auto() AUTO = auto()
@ -55,6 +56,9 @@ class SpeculativeDecodingMode(IntEnum):
def is_draft_target(self): def is_draft_target(self):
return self == SpeculativeDecodingMode.DRAFT_TARGET return self == SpeculativeDecodingMode.DRAFT_TARGET
def is_save_hidden_states(self):
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES
def without_logits(self): def without_logits(self):
return self.is_mtp_one_model() or self.is_eagle3_one_model() return self.is_mtp_one_model() or self.is_eagle3_one_model()
@ -95,8 +99,9 @@ class SpeculativeDecodingMode(IntEnum):
) or self.is_eagle3_one_model() ) or self.is_eagle3_one_model()
def has_spec_drafter(self): def has_spec_drafter(self):
return self.is_eagle3() or self.is_draft_target() or self.is_ngram( return self.is_eagle3(
) or self.is_user_provided() or self.is_mtp_eagle() ) or self.is_draft_target() or self.is_ngram() or self.is_user_provided(
) or self.is_mtp_eagle() or self.is_save_hidden_states()
def extend_ctx(self, attention_backend: Type[AttentionBackend]): def extend_ctx(self, attention_backend: Type[AttentionBackend]):
""" """

View File

@ -0,0 +1,99 @@
import os
from typing import Optional
import torch
from tensorrt_llm._utils import local_mpi_rank
from ..pyexecutor.llm_request import LlmRequest
from ..pyexecutor.resource_manager import ResourceManager
from ..pyexecutor.scheduler import ScheduledRequests
from .drafter import Drafter
class SaveHiddenStatesDrafter(Drafter):
def __init__(
self,
spec_config: "SaveHiddenStatesDecodingConfig",
spec_resource_manager,
):
super().__init__(spec_config.max_concurrency)
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
self._iter = 1
self._output_directory = spec_config.output_directory
self._file_prefix = spec_config.file_prefix
self._write_interval = spec_config.write_interval
self._saved_state = []
self.spec_resource_manager = spec_resource_manager
os.makedirs(self._output_directory, exist_ok=True)
def _process_request(self, request: LlmRequest, resource_manager) -> None:
out_dict = {}
if local_mpi_rank() == 0:
input_ids = torch.tensor(list(request.get_tokens(0)),
dtype=torch.long,
device='cpu')
hidden_size = resource_manager.hidden_size
num_tokens = input_ids.shape[0]
hidden_states = resource_manager.hidden_states[:num_tokens,
-hidden_size:].cpu(
).clone()
out_dict = {
"id": self._iter,
"input_ids": input_ids,
"hidden_state": hidden_states,
}
if len(self.spec_config.eagle3_layers_to_capture) > 1:
if self.spec_config._last_hidden_in_save:
out_dict[
"aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu(
).clone()
else:
out_dict[
"aux_hidden_states"] = resource_manager.hidden_states[:
num_tokens, :
-hidden_size].cpu(
).clone(
)
self._saved_state.append(out_dict)
def _write_to_file(self) -> None:
if local_mpi_rank() == 0:
output_path = os.path.join(self._output_directory,
f"{self._file_prefix}_{self._iter}.pt")
torch.save(self._saved_state, output_path)
self._saved_state = []
def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
for request in sorted(
scheduled_requests.context_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
request.py_max_new_tokens = 1
def run_drafter_post(
self,
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
is_warmup: bool = False,
) -> None:
if is_warmup:
return
for request in sorted(
scheduled_requests.context_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
self._process_request(request, self.spec_resource_manager)
if self._iter % self._write_interval == 0:
self._write_to_file()
self._iter += 1

View File

@ -11,6 +11,7 @@ from .model_drafter import ModelDrafter
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
MTPSpecMetadata, MTPWorker) MTPSpecMetadata, MTPWorker)
from .ngram import NGramDrafter, NGramPoolManager from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
def get_spec_metadata(spec_config, def get_spec_metadata(spec_config,
@ -55,6 +56,25 @@ def get_spec_metadata(spec_config,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
layers_to_capture=spec_config.eagle3_layers_to_capture, layers_to_capture=spec_config.eagle3_layers_to_capture,
) )
if spec_config.spec_dec_mode.is_save_hidden_states():
if spec_config.eagle3_layers_to_capture is None:
spec_config.eagle3_layers_to_capture = {
1, model_config.num_hidden_layers // 2 - 1,
model_config.num_hidden_layers - 4, -1
}
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
max_total_draft_tokens=1,
)
if spec_config.spec_dec_mode.is_draft_target() or \ if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \ spec_config.spec_dec_mode.is_ngram() or \
spec_config.spec_dec_mode.is_user_provided(): spec_config.spec_dec_mode.is_user_provided():
@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
max_seq_len, max_seq_len,
max_num_tokens, max_num_tokens,
) )
if spec_dec_mode.is_save_hidden_states():
return Eagle3ResourceManager(
spec_config,
model_engine.model.config.torch_dtype,
model_config.hidden_size,
max_num_requests,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram(): if spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests) return NGramPoolManager(spec_config, max_num_requests)
if spec_dec_mode.is_user_provided(): if spec_dec_mode.is_user_provided():
@ -151,6 +180,9 @@ def get_spec_drafter(model_engine,
if spec_config.spec_dec_mode.is_ngram(): if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config, spec_resource_manager) return NGramDrafter(spec_config, spec_resource_manager)
if spec_config.spec_dec_mode.is_save_hidden_states():
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
return None return None

View File

@ -11,7 +11,8 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
DynamicBatchConfig, EagleDecodingConfig, DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs, ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig, LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig, MTPDecodingConfig, NGramDecodingConfig,
SaveHiddenStatesDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig) UserProvidedDecodingConfig)
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
@ -59,4 +60,5 @@ __all__ = [
'AutoDecodingConfig', 'AutoDecodingConfig',
'AttentionDpConfig', 'AttentionDpConfig',
'LoRARequest', 'LoRARequest',
'SaveHiddenStatesDecodingConfig',
] ]

View File

@ -380,6 +380,7 @@ class DecodingBaseConfig(StrictBaseModel):
"Lookahead": LookaheadDecodingConfig, "Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig, "NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig, "DraftTarget": DraftTargetDecodingConfig,
"SaveState": SaveHiddenStatesDecodingConfig,
"UserProvided": UserProvidedDecodingConfig, "UserProvided": UserProvidedDecodingConfig,
"AUTO": AutoDecodingConfig, "AUTO": AutoDecodingConfig,
} }
@ -562,6 +563,52 @@ class EagleDecodingConfig(DecodingBaseConfig):
return 3 return 3
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
output_directory: str
write_interval: int = 20
file_prefix: str = "data"
eagle3_layers_to_capture: Optional[Set[int]] = None
max_total_draft_tokens: Optional[int] = Field(default=1, init=False)
eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False)
def model_post_init(self, __context):
self._last_hidden_in_save = True
if self.eagle3_layers_to_capture is None:
self._last_hidden_in_save = False
elif -1 not in self.eagle3_layers_to_capture:
self._last_hidden_in_save = False
self.eagle3_layers_to_capture.add(-1)
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "SaveState"
def validate(self) -> None:
if self.output_directory is None or not self.eagle3_layers_to_capture:
raise ValueError(
"Save directory and layers to capture must be provided")
@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES
@functools.cached_property
def num_capture_layers(self):
"""
Returns the number of layers to capture of the target model.
If eagle3_layers_to_capture is not None, return the length of the set.
Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
"""
if self.eagle3_layers_to_capture is None:
return 4
return len(self.eagle3_layers_to_capture)
class UserProvidedDecodingConfig(DecodingBaseConfig): class UserProvidedDecodingConfig(DecodingBaseConfig):
# Cannot use real type annotations due to circular imports # Cannot use real type annotations due to circular imports
drafter: object # Type is Drafter drafter: object # Type is Drafter
@ -1050,6 +1097,7 @@ SpeculativeConfig: TypeAlias = Optional[Union[
MTPDecodingConfig, MTPDecodingConfig,
NGramDecodingConfig, NGramDecodingConfig,
UserProvidedDecodingConfig, UserProvidedDecodingConfig,
SaveHiddenStatesDecodingConfig,
AutoDecodingConfig, AutoDecodingConfig,
]] ]]
@ -1869,6 +1917,20 @@ class BaseLlmArgs(StrictBaseModel):
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
self.build_config.max_draft_len = self.speculative_config.max_draft_len self.build_config.max_draft_len = self.speculative_config.max_draft_len
elif isinstance(self.speculative_config,
SaveHiddenStatesDecodingConfig):
assert self.backend in ['pytorch']
logger.warning(
"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
)
self.build_config.max_batch_size = 1
self.max_batch_size = 1
self.disable_overlap_scheduler = True
self.cuda_graph_config = None
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
self.build_config.max_draft_len = 1
self.speculative_config.max_draft_len = 1
else: else:
raise ValueError( raise ValueError(
f"Unrecognized speculative config type {type(self.speculative_config)}" f"Unrecognized speculative config type {type(self.speculative_config)}"

View File

@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
EAGLE = auto() EAGLE = auto()
NGRAM = auto() NGRAM = auto()
USER_PROVIDED = auto() USER_PROVIDED = auto()
SAVE_HIDDEN_STATES = auto()
AUTO = auto() AUTO = auto()
@staticmethod @staticmethod
@ -120,6 +121,8 @@ class SpeculativeDecodingMode(IntFlag):
return SpeculativeDecodingMode.USER_PROVIDED return SpeculativeDecodingMode.USER_PROVIDED
elif args.speculative_decoding_mode == "auto": elif args.speculative_decoding_mode == "auto":
return SpeculativeDecodingMode.AUTO return SpeculativeDecodingMode.AUTO
elif args.speculative_decoding_mode == "save_hidden_states":
return SpeculativeDecodingMode.SAVE_HIDDEN_STATES
else: else:
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode

View File

@ -0,0 +1,138 @@
import os
import sys
import tempfile
import unittest
import pytest
import torch
from utils.llm_data import llm_models_root
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
SaveHiddenStatesDecodingConfig)
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
def test_multi_save_state():
use_cuda_graph = True
attn_backend = "TRTLLM"
disable_overlap_scheduler = False
enable_block_reuse = False
enable_chunked_prefill = False
layers_to_capture = {10, 11, 12}
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 80:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
with tempfile.TemporaryDirectory() as temp_dir:
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
max_batch_size = 16
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
free_gpu_memory_fraction=0.5)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = SaveHiddenStatesDecodingConfig(
output_directory=temp_dir,
write_interval=1,
file_prefix="data",
eagle3_layers_to_capture=layers_to_capture)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
sampling_params = SamplingParams(max_tokens=32, temperature=0)
for output in llm_spec.generate_async(tok_ids,
sampling_params,
streaming=True):
pass
llm_spec.shutdown()
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
# Read in .pt file
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
assert saved_data["aux_hidden_states"].shape == (len(tok_ids), 2048 *
len(layers_to_capture))
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
assert saved_data["input_ids"].tolist() == tok_ids
@pytest.mark.parametrize("layers_to_capture", [{-1}, None])
def test_save_state(layers_to_capture):
use_cuda_graph = True
attn_backend = "TRTLLM"
disable_overlap_scheduler = False
enable_block_reuse = False
enable_chunked_prefill = False
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 80:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
with tempfile.TemporaryDirectory() as temp_dir:
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
max_batch_size = 16
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
free_gpu_memory_fraction=0.5)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = SaveHiddenStatesDecodingConfig(
output_directory=temp_dir,
write_interval=1,
file_prefix="data",
eagle3_layers_to_capture=layers_to_capture)
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
sampling_params = SamplingParams(max_tokens=32, temperature=0)
for output in llm_spec.generate_async(tok_ids,
sampling_params,
streaming=True):
pass
llm_spec.shutdown()
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
# Read in .pt file
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
if layers_to_capture is None:
assert saved_data["aux_hidden_states"].shape == (len(tok_ids),
2048 * 3)
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
assert saved_data["input_ids"].tolist() == tok_ids
else:
assert "aux_hidden_states" not in saved_data
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
assert saved_data["input_ids"].tolist() == tok_ids
if __name__ == "__main__":
unittest.main()

View File

@ -59,7 +59,7 @@ methods:
default: null default: null
# Speculative decoding # Speculative decoding
speculative_config: speculative_config:
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType] annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, tensorrt_llm.llmapi.llm_args.SaveHiddenStatesDecodingConfig, NoneType]
default: null default: null
# generation constraints # generation constraints
max_batch_size: max_batch_size: