mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Draft: Save state first pass (#7012)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
parent
e107749a69
commit
1ad7bc4c78
@ -1110,6 +1110,10 @@ class PyExecutor:
|
|||||||
|
|
||||||
sample_state = self._sample_async(scheduled_batch,
|
sample_state = self._sample_async(scheduled_batch,
|
||||||
batch_outputs)
|
batch_outputs)
|
||||||
|
if self.drafter is not None:
|
||||||
|
self.drafter.run_drafter_post(scheduled_batch,
|
||||||
|
self.resource_manager,
|
||||||
|
self.is_warmup)
|
||||||
|
|
||||||
self._update_request_states(scheduled_batch)
|
self._update_request_states(scheduled_batch)
|
||||||
self._update_requests(sample_state, self.resource_manager)
|
self._update_requests(sample_state, self.resource_manager)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from .eagle3 import Eagle3SpecMetadata
|
|||||||
from .interface import SpecMetadata
|
from .interface import SpecMetadata
|
||||||
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||||
from .ngram import NGramDrafter, NGramPoolManager
|
from .ngram import NGramDrafter, NGramPoolManager
|
||||||
|
from .save_hidden_state import SaveHiddenStatesDrafter
|
||||||
from .spec_tree_manager import SpecTreeManager
|
from .spec_tree_manager import SpecTreeManager
|
||||||
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
|
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
|
||||||
get_spec_decoder, get_spec_drafter, get_spec_metadata,
|
get_spec_decoder, get_spec_drafter, get_spec_metadata,
|
||||||
@ -16,6 +17,7 @@ __all__ = [
|
|||||||
"MTPWorker",
|
"MTPWorker",
|
||||||
"NGramDrafter",
|
"NGramDrafter",
|
||||||
"NGramPoolManager",
|
"NGramPoolManager",
|
||||||
|
"SaveHiddenStatesDrafter",
|
||||||
"SpecMetadata",
|
"SpecMetadata",
|
||||||
"get_num_extra_kv_tokens",
|
"get_num_extra_kv_tokens",
|
||||||
"get_num_spec_layers",
|
"get_num_spec_layers",
|
||||||
|
|||||||
@ -67,3 +67,15 @@ class Drafter(ABC):
|
|||||||
num_draft_tokens = get_draft_token_length(req)
|
num_draft_tokens = get_draft_token_length(req)
|
||||||
req.py_draft_tokens.extend(
|
req.py_draft_tokens.extend(
|
||||||
0 for _ in range(max_draft_tokens - num_draft_tokens))
|
0 for _ in range(max_draft_tokens - num_draft_tokens))
|
||||||
|
|
||||||
|
def run_drafter_post(
|
||||||
|
self,
|
||||||
|
scheduled_requests: ScheduledRequests,
|
||||||
|
resource_manager: Optional[ResourceManager] = None,
|
||||||
|
is_warmup: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
If draft forward needs to be run directly after the target model forward,
|
||||||
|
this method can be overridden to do that.
|
||||||
|
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
|
||||||
|
"""
|
||||||
|
|||||||
@ -126,6 +126,10 @@ class Eagle3SpecMetadata(SpecMetadata):
|
|||||||
self.num_layers - 4)
|
self.num_layers - 4)
|
||||||
else:
|
else:
|
||||||
self.layers_to_capture = sorted(list(self.layers_to_capture))
|
self.layers_to_capture = sorted(list(self.layers_to_capture))
|
||||||
|
if self.layers_to_capture[0] == -1:
|
||||||
|
self.layers_to_capture = self.layers_to_capture[1:] + [
|
||||||
|
self.layers_to_capture.pop(0)
|
||||||
|
]
|
||||||
self.num_capture_layers = len(self.layers_to_capture)
|
self.num_capture_layers = len(self.layers_to_capture)
|
||||||
|
|
||||||
# Initialize to 0 to avoid reading uninitialized memory during warmup
|
# Initialize to 0 to avoid reading uninitialized memory during warmup
|
||||||
|
|||||||
@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum):
|
|||||||
NGRAM = auto()
|
NGRAM = auto()
|
||||||
DRAFT_TARGET = auto()
|
DRAFT_TARGET = auto()
|
||||||
USER_PROVIDED = auto()
|
USER_PROVIDED = auto()
|
||||||
|
SAVE_HIDDEN_STATES = auto()
|
||||||
NONE = auto()
|
NONE = auto()
|
||||||
AUTO = auto()
|
AUTO = auto()
|
||||||
|
|
||||||
@ -55,6 +56,9 @@ class SpeculativeDecodingMode(IntEnum):
|
|||||||
def is_draft_target(self):
|
def is_draft_target(self):
|
||||||
return self == SpeculativeDecodingMode.DRAFT_TARGET
|
return self == SpeculativeDecodingMode.DRAFT_TARGET
|
||||||
|
|
||||||
|
def is_save_hidden_states(self):
|
||||||
|
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||||
|
|
||||||
def without_logits(self):
|
def without_logits(self):
|
||||||
return self.is_mtp_one_model() or self.is_eagle3_one_model()
|
return self.is_mtp_one_model() or self.is_eagle3_one_model()
|
||||||
|
|
||||||
@ -95,8 +99,9 @@ class SpeculativeDecodingMode(IntEnum):
|
|||||||
) or self.is_eagle3_one_model()
|
) or self.is_eagle3_one_model()
|
||||||
|
|
||||||
def has_spec_drafter(self):
|
def has_spec_drafter(self):
|
||||||
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
|
return self.is_eagle3(
|
||||||
) or self.is_user_provided() or self.is_mtp_eagle()
|
) or self.is_draft_target() or self.is_ngram() or self.is_user_provided(
|
||||||
|
) or self.is_mtp_eagle() or self.is_save_hidden_states()
|
||||||
|
|
||||||
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
|
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
99
tensorrt_llm/_torch/speculative/save_hidden_state.py
Normal file
99
tensorrt_llm/_torch/speculative/save_hidden_state.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._utils import local_mpi_rank
|
||||||
|
|
||||||
|
from ..pyexecutor.llm_request import LlmRequest
|
||||||
|
from ..pyexecutor.resource_manager import ResourceManager
|
||||||
|
from ..pyexecutor.scheduler import ScheduledRequests
|
||||||
|
from .drafter import Drafter
|
||||||
|
|
||||||
|
|
||||||
|
class SaveHiddenStatesDrafter(Drafter):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spec_config: "SaveHiddenStatesDecodingConfig",
|
||||||
|
spec_resource_manager,
|
||||||
|
):
|
||||||
|
super().__init__(spec_config.max_concurrency)
|
||||||
|
self.spec_config = spec_config
|
||||||
|
self.max_draft_len = spec_config.max_draft_len
|
||||||
|
self._iter = 1
|
||||||
|
self._output_directory = spec_config.output_directory
|
||||||
|
self._file_prefix = spec_config.file_prefix
|
||||||
|
self._write_interval = spec_config.write_interval
|
||||||
|
self._saved_state = []
|
||||||
|
self.spec_resource_manager = spec_resource_manager
|
||||||
|
os.makedirs(self._output_directory, exist_ok=True)
|
||||||
|
|
||||||
|
def _process_request(self, request: LlmRequest, resource_manager) -> None:
|
||||||
|
out_dict = {}
|
||||||
|
if local_mpi_rank() == 0:
|
||||||
|
input_ids = torch.tensor(list(request.get_tokens(0)),
|
||||||
|
dtype=torch.long,
|
||||||
|
device='cpu')
|
||||||
|
hidden_size = resource_manager.hidden_size
|
||||||
|
num_tokens = input_ids.shape[0]
|
||||||
|
hidden_states = resource_manager.hidden_states[:num_tokens,
|
||||||
|
-hidden_size:].cpu(
|
||||||
|
).clone()
|
||||||
|
|
||||||
|
out_dict = {
|
||||||
|
"id": self._iter,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"hidden_state": hidden_states,
|
||||||
|
}
|
||||||
|
if len(self.spec_config.eagle3_layers_to_capture) > 1:
|
||||||
|
if self.spec_config._last_hidden_in_save:
|
||||||
|
out_dict[
|
||||||
|
"aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu(
|
||||||
|
).clone()
|
||||||
|
else:
|
||||||
|
out_dict[
|
||||||
|
"aux_hidden_states"] = resource_manager.hidden_states[:
|
||||||
|
num_tokens, :
|
||||||
|
-hidden_size].cpu(
|
||||||
|
).clone(
|
||||||
|
)
|
||||||
|
|
||||||
|
self._saved_state.append(out_dict)
|
||||||
|
|
||||||
|
def _write_to_file(self) -> None:
|
||||||
|
if local_mpi_rank() == 0:
|
||||||
|
output_path = os.path.join(self._output_directory,
|
||||||
|
f"{self._file_prefix}_{self._iter}.pt")
|
||||||
|
torch.save(self._saved_state, output_path)
|
||||||
|
self._saved_state = []
|
||||||
|
|
||||||
|
def prepare_draft_tokens(
|
||||||
|
self,
|
||||||
|
scheduled_requests: ScheduledRequests,
|
||||||
|
resource_manager: Optional[ResourceManager] = None,
|
||||||
|
) -> None:
|
||||||
|
for request in sorted(
|
||||||
|
scheduled_requests.context_requests,
|
||||||
|
key=lambda r:
|
||||||
|
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
|
||||||
|
):
|
||||||
|
request.py_max_new_tokens = 1
|
||||||
|
|
||||||
|
def run_drafter_post(
|
||||||
|
self,
|
||||||
|
scheduled_requests: ScheduledRequests,
|
||||||
|
resource_manager: Optional[ResourceManager] = None,
|
||||||
|
is_warmup: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if is_warmup:
|
||||||
|
return
|
||||||
|
for request in sorted(
|
||||||
|
scheduled_requests.context_requests,
|
||||||
|
key=lambda r:
|
||||||
|
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
|
||||||
|
):
|
||||||
|
self._process_request(request, self.spec_resource_manager)
|
||||||
|
if self._iter % self._write_interval == 0:
|
||||||
|
self._write_to_file()
|
||||||
|
self._iter += 1
|
||||||
@ -11,6 +11,7 @@ from .model_drafter import ModelDrafter
|
|||||||
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
|
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
|
||||||
MTPSpecMetadata, MTPWorker)
|
MTPSpecMetadata, MTPWorker)
|
||||||
from .ngram import NGramDrafter, NGramPoolManager
|
from .ngram import NGramDrafter, NGramPoolManager
|
||||||
|
from .save_hidden_state import SaveHiddenStatesDrafter
|
||||||
|
|
||||||
|
|
||||||
def get_spec_metadata(spec_config,
|
def get_spec_metadata(spec_config,
|
||||||
@ -55,6 +56,25 @@ def get_spec_metadata(spec_config,
|
|||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||||
)
|
)
|
||||||
|
if spec_config.spec_dec_mode.is_save_hidden_states():
|
||||||
|
if spec_config.eagle3_layers_to_capture is None:
|
||||||
|
spec_config.eagle3_layers_to_capture = {
|
||||||
|
1, model_config.num_hidden_layers // 2 - 1,
|
||||||
|
model_config.num_hidden_layers - 4, -1
|
||||||
|
}
|
||||||
|
return Eagle3SpecMetadata(
|
||||||
|
max_draft_len=spec_config.max_draft_len,
|
||||||
|
spec_dec_mode=spec_config.spec_dec_mode,
|
||||||
|
max_num_requests=max_num_requests,
|
||||||
|
num_layers=model_config.num_hidden_layers,
|
||||||
|
hidden_size=model_config.hidden_size,
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
dtype=model_config.torch_dtype,
|
||||||
|
is_draft_model=is_draft_model,
|
||||||
|
eagle3_resource_manager=spec_resource_manager,
|
||||||
|
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||||
|
max_total_draft_tokens=1,
|
||||||
|
)
|
||||||
if spec_config.spec_dec_mode.is_draft_target() or \
|
if spec_config.spec_dec_mode.is_draft_target() or \
|
||||||
spec_config.spec_dec_mode.is_ngram() or \
|
spec_config.spec_dec_mode.is_ngram() or \
|
||||||
spec_config.spec_dec_mode.is_user_provided():
|
spec_config.spec_dec_mode.is_user_provided():
|
||||||
@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
max_num_tokens,
|
max_num_tokens,
|
||||||
)
|
)
|
||||||
|
if spec_dec_mode.is_save_hidden_states():
|
||||||
|
return Eagle3ResourceManager(
|
||||||
|
spec_config,
|
||||||
|
model_engine.model.config.torch_dtype,
|
||||||
|
model_config.hidden_size,
|
||||||
|
max_num_requests,
|
||||||
|
max_seq_len,
|
||||||
|
max_num_tokens,
|
||||||
|
)
|
||||||
if spec_dec_mode.is_ngram():
|
if spec_dec_mode.is_ngram():
|
||||||
return NGramPoolManager(spec_config, max_num_requests)
|
return NGramPoolManager(spec_config, max_num_requests)
|
||||||
if spec_dec_mode.is_user_provided():
|
if spec_dec_mode.is_user_provided():
|
||||||
@ -151,6 +180,9 @@ def get_spec_drafter(model_engine,
|
|||||||
if spec_config.spec_dec_mode.is_ngram():
|
if spec_config.spec_dec_mode.is_ngram():
|
||||||
return NGramDrafter(spec_config, spec_resource_manager)
|
return NGramDrafter(spec_config, spec_resource_manager)
|
||||||
|
|
||||||
|
if spec_config.spec_dec_mode.is_save_hidden_states():
|
||||||
|
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
|
|||||||
DynamicBatchConfig, EagleDecodingConfig,
|
DynamicBatchConfig, EagleDecodingConfig,
|
||||||
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
|
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
|
||||||
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
|
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
|
||||||
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
|
MTPDecodingConfig, NGramDecodingConfig,
|
||||||
|
SaveHiddenStatesDecodingConfig, SchedulerConfig,
|
||||||
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
|
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
|
||||||
UserProvidedDecodingConfig)
|
UserProvidedDecodingConfig)
|
||||||
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
|
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
|
||||||
@ -59,4 +60,5 @@ __all__ = [
|
|||||||
'AutoDecodingConfig',
|
'AutoDecodingConfig',
|
||||||
'AttentionDpConfig',
|
'AttentionDpConfig',
|
||||||
'LoRARequest',
|
'LoRARequest',
|
||||||
|
'SaveHiddenStatesDecodingConfig',
|
||||||
]
|
]
|
||||||
|
|||||||
@ -380,6 +380,7 @@ class DecodingBaseConfig(StrictBaseModel):
|
|||||||
"Lookahead": LookaheadDecodingConfig,
|
"Lookahead": LookaheadDecodingConfig,
|
||||||
"NGram": NGramDecodingConfig,
|
"NGram": NGramDecodingConfig,
|
||||||
"DraftTarget": DraftTargetDecodingConfig,
|
"DraftTarget": DraftTargetDecodingConfig,
|
||||||
|
"SaveState": SaveHiddenStatesDecodingConfig,
|
||||||
"UserProvided": UserProvidedDecodingConfig,
|
"UserProvided": UserProvidedDecodingConfig,
|
||||||
"AUTO": AutoDecodingConfig,
|
"AUTO": AutoDecodingConfig,
|
||||||
}
|
}
|
||||||
@ -562,6 +563,52 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
|||||||
return 3
|
return 3
|
||||||
|
|
||||||
|
|
||||||
|
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
|
||||||
|
output_directory: str
|
||||||
|
write_interval: int = 20
|
||||||
|
file_prefix: str = "data"
|
||||||
|
eagle3_layers_to_capture: Optional[Set[int]] = None
|
||||||
|
|
||||||
|
max_total_draft_tokens: Optional[int] = Field(default=1, init=False)
|
||||||
|
eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False)
|
||||||
|
|
||||||
|
def model_post_init(self, __context):
|
||||||
|
self._last_hidden_in_save = True
|
||||||
|
if self.eagle3_layers_to_capture is None:
|
||||||
|
self._last_hidden_in_save = False
|
||||||
|
elif -1 not in self.eagle3_layers_to_capture:
|
||||||
|
self._last_hidden_in_save = False
|
||||||
|
self.eagle3_layers_to_capture.add(-1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
decoding_type: ClassVar[str] = "SaveState"
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
if self.output_directory is None or not self.eagle3_layers_to_capture:
|
||||||
|
raise ValueError(
|
||||||
|
"Save directory and layers to capture must be provided")
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def spec_dec_mode(self):
|
||||||
|
from tensorrt_llm._torch.speculative.interface import \
|
||||||
|
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
|
||||||
|
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def num_capture_layers(self):
|
||||||
|
"""
|
||||||
|
Returns the number of layers to capture of the target model.
|
||||||
|
If eagle3_layers_to_capture is not None, return the length of the set.
|
||||||
|
Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
|
||||||
|
"""
|
||||||
|
if self.eagle3_layers_to_capture is None:
|
||||||
|
return 4
|
||||||
|
return len(self.eagle3_layers_to_capture)
|
||||||
|
|
||||||
|
|
||||||
class UserProvidedDecodingConfig(DecodingBaseConfig):
|
class UserProvidedDecodingConfig(DecodingBaseConfig):
|
||||||
# Cannot use real type annotations due to circular imports
|
# Cannot use real type annotations due to circular imports
|
||||||
drafter: object # Type is Drafter
|
drafter: object # Type is Drafter
|
||||||
@ -1050,6 +1097,7 @@ SpeculativeConfig: TypeAlias = Optional[Union[
|
|||||||
MTPDecodingConfig,
|
MTPDecodingConfig,
|
||||||
NGramDecodingConfig,
|
NGramDecodingConfig,
|
||||||
UserProvidedDecodingConfig,
|
UserProvidedDecodingConfig,
|
||||||
|
SaveHiddenStatesDecodingConfig,
|
||||||
AutoDecodingConfig,
|
AutoDecodingConfig,
|
||||||
]]
|
]]
|
||||||
|
|
||||||
@ -1869,6 +1917,20 @@ class BaseLlmArgs(StrictBaseModel):
|
|||||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
|
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
|
||||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||||
|
|
||||||
|
elif isinstance(self.speculative_config,
|
||||||
|
SaveHiddenStatesDecodingConfig):
|
||||||
|
assert self.backend in ['pytorch']
|
||||||
|
logger.warning(
|
||||||
|
"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
|
||||||
|
)
|
||||||
|
self.build_config.max_batch_size = 1
|
||||||
|
self.max_batch_size = 1
|
||||||
|
self.disable_overlap_scheduler = True
|
||||||
|
self.cuda_graph_config = None
|
||||||
|
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||||
|
self.build_config.max_draft_len = 1
|
||||||
|
self.speculative_config.max_draft_len = 1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized speculative config type {type(self.speculative_config)}"
|
f"Unrecognized speculative config type {type(self.speculative_config)}"
|
||||||
|
|||||||
@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
|
|||||||
EAGLE = auto()
|
EAGLE = auto()
|
||||||
NGRAM = auto()
|
NGRAM = auto()
|
||||||
USER_PROVIDED = auto()
|
USER_PROVIDED = auto()
|
||||||
|
SAVE_HIDDEN_STATES = auto()
|
||||||
AUTO = auto()
|
AUTO = auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -120,6 +121,8 @@ class SpeculativeDecodingMode(IntFlag):
|
|||||||
return SpeculativeDecodingMode.USER_PROVIDED
|
return SpeculativeDecodingMode.USER_PROVIDED
|
||||||
elif args.speculative_decoding_mode == "auto":
|
elif args.speculative_decoding_mode == "auto":
|
||||||
return SpeculativeDecodingMode.AUTO
|
return SpeculativeDecodingMode.AUTO
|
||||||
|
elif args.speculative_decoding_mode == "save_hidden_states":
|
||||||
|
return SpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||||
else:
|
else:
|
||||||
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
|
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
|
||||||
|
|
||||||
|
|||||||
138
tests/unittest/_torch/speculative/test_save_state.py
Normal file
138
tests/unittest/_torch/speculative/test_save_state.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from utils.llm_data import llm_models_root
|
||||||
|
|
||||||
|
from tensorrt_llm import LLM, SamplingParams
|
||||||
|
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||||
|
SaveHiddenStatesDecodingConfig)
|
||||||
|
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_save_state():
|
||||||
|
use_cuda_graph = True
|
||||||
|
attn_backend = "TRTLLM"
|
||||||
|
disable_overlap_scheduler = False
|
||||||
|
enable_block_reuse = False
|
||||||
|
enable_chunked_prefill = False
|
||||||
|
layers_to_capture = {10, 11, 12}
|
||||||
|
|
||||||
|
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||||
|
if total_mem_gb < 80:
|
||||||
|
pytest.skip("Not enough memory to load target + draft model")
|
||||||
|
|
||||||
|
models_path = llm_models_root()
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
|
||||||
|
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
|
max_batch_size = 16
|
||||||
|
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||||
|
free_gpu_memory_fraction=0.5)
|
||||||
|
cuda_graph_config = CudaGraphConfig(
|
||||||
|
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
|
||||||
|
|
||||||
|
llm_common_config = dict(
|
||||||
|
model=target_model_dir,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||||
|
cuda_graph_config=cuda_graph_config,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
)
|
||||||
|
spec_config = SaveHiddenStatesDecodingConfig(
|
||||||
|
output_directory=temp_dir,
|
||||||
|
write_interval=1,
|
||||||
|
file_prefix="data",
|
||||||
|
eagle3_layers_to_capture=layers_to_capture)
|
||||||
|
|
||||||
|
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||||
|
|
||||||
|
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=32, temperature=0)
|
||||||
|
for output in llm_spec.generate_async(tok_ids,
|
||||||
|
sampling_params,
|
||||||
|
streaming=True):
|
||||||
|
pass
|
||||||
|
llm_spec.shutdown()
|
||||||
|
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
|
||||||
|
# Read in .pt file
|
||||||
|
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
|
||||||
|
|
||||||
|
assert saved_data["aux_hidden_states"].shape == (len(tok_ids), 2048 *
|
||||||
|
len(layers_to_capture))
|
||||||
|
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
|
||||||
|
assert saved_data["input_ids"].tolist() == tok_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("layers_to_capture", [{-1}, None])
|
||||||
|
def test_save_state(layers_to_capture):
|
||||||
|
use_cuda_graph = True
|
||||||
|
attn_backend = "TRTLLM"
|
||||||
|
disable_overlap_scheduler = False
|
||||||
|
enable_block_reuse = False
|
||||||
|
enable_chunked_prefill = False
|
||||||
|
|
||||||
|
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||||
|
if total_mem_gb < 80:
|
||||||
|
pytest.skip("Not enough memory to load target + draft model")
|
||||||
|
|
||||||
|
models_path = llm_models_root()
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
|
||||||
|
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
|
max_batch_size = 16
|
||||||
|
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||||
|
free_gpu_memory_fraction=0.5)
|
||||||
|
cuda_graph_config = CudaGraphConfig(
|
||||||
|
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
|
||||||
|
|
||||||
|
llm_common_config = dict(
|
||||||
|
model=target_model_dir,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||||
|
cuda_graph_config=cuda_graph_config,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
)
|
||||||
|
spec_config = SaveHiddenStatesDecodingConfig(
|
||||||
|
output_directory=temp_dir,
|
||||||
|
write_interval=1,
|
||||||
|
file_prefix="data",
|
||||||
|
eagle3_layers_to_capture=layers_to_capture)
|
||||||
|
|
||||||
|
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||||
|
|
||||||
|
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(max_tokens=32, temperature=0)
|
||||||
|
for output in llm_spec.generate_async(tok_ids,
|
||||||
|
sampling_params,
|
||||||
|
streaming=True):
|
||||||
|
pass
|
||||||
|
llm_spec.shutdown()
|
||||||
|
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
|
||||||
|
# Read in .pt file
|
||||||
|
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
|
||||||
|
if layers_to_capture is None:
|
||||||
|
assert saved_data["aux_hidden_states"].shape == (len(tok_ids),
|
||||||
|
2048 * 3)
|
||||||
|
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
|
||||||
|
assert saved_data["input_ids"].tolist() == tok_ids
|
||||||
|
else:
|
||||||
|
assert "aux_hidden_states" not in saved_data
|
||||||
|
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
|
||||||
|
assert saved_data["input_ids"].tolist() == tok_ids
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@ -59,7 +59,7 @@ methods:
|
|||||||
default: null
|
default: null
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
speculative_config:
|
speculative_config:
|
||||||
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType]
|
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, tensorrt_llm.llmapi.llm_args.SaveHiddenStatesDecodingConfig, NoneType]
|
||||||
default: null
|
default: null
|
||||||
# generation constraints
|
# generation constraints
|
||||||
max_batch_size:
|
max_batch_size:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user