mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Draft: Save state first pass (#7012)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
parent
e107749a69
commit
1ad7bc4c78
@ -1110,6 +1110,10 @@ class PyExecutor:
|
||||
|
||||
sample_state = self._sample_async(scheduled_batch,
|
||||
batch_outputs)
|
||||
if self.drafter is not None:
|
||||
self.drafter.run_drafter_post(scheduled_batch,
|
||||
self.resource_manager,
|
||||
self.is_warmup)
|
||||
|
||||
self._update_request_states(scheduled_batch)
|
||||
self._update_requests(sample_state, self.resource_manager)
|
||||
|
||||
@ -3,6 +3,7 @@ from .eagle3 import Eagle3SpecMetadata
|
||||
from .interface import SpecMetadata
|
||||
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .ngram import NGramDrafter, NGramPoolManager
|
||||
from .save_hidden_state import SaveHiddenStatesDrafter
|
||||
from .spec_tree_manager import SpecTreeManager
|
||||
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
|
||||
get_spec_decoder, get_spec_drafter, get_spec_metadata,
|
||||
@ -16,6 +17,7 @@ __all__ = [
|
||||
"MTPWorker",
|
||||
"NGramDrafter",
|
||||
"NGramPoolManager",
|
||||
"SaveHiddenStatesDrafter",
|
||||
"SpecMetadata",
|
||||
"get_num_extra_kv_tokens",
|
||||
"get_num_spec_layers",
|
||||
|
||||
@ -67,3 +67,15 @@ class Drafter(ABC):
|
||||
num_draft_tokens = get_draft_token_length(req)
|
||||
req.py_draft_tokens.extend(
|
||||
0 for _ in range(max_draft_tokens - num_draft_tokens))
|
||||
|
||||
def run_drafter_post(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
is_warmup: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
If draft forward needs to be run directly after the target model forward,
|
||||
this method can be overridden to do that.
|
||||
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
|
||||
"""
|
||||
|
||||
@ -126,6 +126,10 @@ class Eagle3SpecMetadata(SpecMetadata):
|
||||
self.num_layers - 4)
|
||||
else:
|
||||
self.layers_to_capture = sorted(list(self.layers_to_capture))
|
||||
if self.layers_to_capture[0] == -1:
|
||||
self.layers_to_capture = self.layers_to_capture[1:] + [
|
||||
self.layers_to_capture.pop(0)
|
||||
]
|
||||
self.num_capture_layers = len(self.layers_to_capture)
|
||||
|
||||
# Initialize to 0 to avoid reading uninitialized memory during warmup
|
||||
|
||||
@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
NGRAM = auto()
|
||||
DRAFT_TARGET = auto()
|
||||
USER_PROVIDED = auto()
|
||||
SAVE_HIDDEN_STATES = auto()
|
||||
NONE = auto()
|
||||
AUTO = auto()
|
||||
|
||||
@ -55,6 +56,9 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
def is_draft_target(self):
|
||||
return self == SpeculativeDecodingMode.DRAFT_TARGET
|
||||
|
||||
def is_save_hidden_states(self):
|
||||
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||
|
||||
def without_logits(self):
|
||||
return self.is_mtp_one_model() or self.is_eagle3_one_model()
|
||||
|
||||
@ -95,8 +99,9 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
) or self.is_eagle3_one_model()
|
||||
|
||||
def has_spec_drafter(self):
|
||||
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
|
||||
) or self.is_user_provided() or self.is_mtp_eagle()
|
||||
return self.is_eagle3(
|
||||
) or self.is_draft_target() or self.is_ngram() or self.is_user_provided(
|
||||
) or self.is_mtp_eagle() or self.is_save_hidden_states()
|
||||
|
||||
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
|
||||
"""
|
||||
|
||||
99
tensorrt_llm/_torch/speculative/save_hidden_state.py
Normal file
99
tensorrt_llm/_torch/speculative/save_hidden_state.py
Normal file
@ -0,0 +1,99 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import local_mpi_rank
|
||||
|
||||
from ..pyexecutor.llm_request import LlmRequest
|
||||
from ..pyexecutor.resource_manager import ResourceManager
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .drafter import Drafter
|
||||
|
||||
|
||||
class SaveHiddenStatesDrafter(Drafter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_config: "SaveHiddenStatesDecodingConfig",
|
||||
spec_resource_manager,
|
||||
):
|
||||
super().__init__(spec_config.max_concurrency)
|
||||
self.spec_config = spec_config
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
self._iter = 1
|
||||
self._output_directory = spec_config.output_directory
|
||||
self._file_prefix = spec_config.file_prefix
|
||||
self._write_interval = spec_config.write_interval
|
||||
self._saved_state = []
|
||||
self.spec_resource_manager = spec_resource_manager
|
||||
os.makedirs(self._output_directory, exist_ok=True)
|
||||
|
||||
def _process_request(self, request: LlmRequest, resource_manager) -> None:
|
||||
out_dict = {}
|
||||
if local_mpi_rank() == 0:
|
||||
input_ids = torch.tensor(list(request.get_tokens(0)),
|
||||
dtype=torch.long,
|
||||
device='cpu')
|
||||
hidden_size = resource_manager.hidden_size
|
||||
num_tokens = input_ids.shape[0]
|
||||
hidden_states = resource_manager.hidden_states[:num_tokens,
|
||||
-hidden_size:].cpu(
|
||||
).clone()
|
||||
|
||||
out_dict = {
|
||||
"id": self._iter,
|
||||
"input_ids": input_ids,
|
||||
"hidden_state": hidden_states,
|
||||
}
|
||||
if len(self.spec_config.eagle3_layers_to_capture) > 1:
|
||||
if self.spec_config._last_hidden_in_save:
|
||||
out_dict[
|
||||
"aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu(
|
||||
).clone()
|
||||
else:
|
||||
out_dict[
|
||||
"aux_hidden_states"] = resource_manager.hidden_states[:
|
||||
num_tokens, :
|
||||
-hidden_size].cpu(
|
||||
).clone(
|
||||
)
|
||||
|
||||
self._saved_state.append(out_dict)
|
||||
|
||||
def _write_to_file(self) -> None:
|
||||
if local_mpi_rank() == 0:
|
||||
output_path = os.path.join(self._output_directory,
|
||||
f"{self._file_prefix}_{self._iter}.pt")
|
||||
torch.save(self._saved_state, output_path)
|
||||
self._saved_state = []
|
||||
|
||||
def prepare_draft_tokens(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> None:
|
||||
for request in sorted(
|
||||
scheduled_requests.context_requests,
|
||||
key=lambda r:
|
||||
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
|
||||
):
|
||||
request.py_max_new_tokens = 1
|
||||
|
||||
def run_drafter_post(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
is_warmup: bool = False,
|
||||
) -> None:
|
||||
if is_warmup:
|
||||
return
|
||||
for request in sorted(
|
||||
scheduled_requests.context_requests,
|
||||
key=lambda r:
|
||||
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
|
||||
):
|
||||
self._process_request(request, self.spec_resource_manager)
|
||||
if self._iter % self._write_interval == 0:
|
||||
self._write_to_file()
|
||||
self._iter += 1
|
||||
@ -11,6 +11,7 @@ from .model_drafter import ModelDrafter
|
||||
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
|
||||
MTPSpecMetadata, MTPWorker)
|
||||
from .ngram import NGramDrafter, NGramPoolManager
|
||||
from .save_hidden_state import SaveHiddenStatesDrafter
|
||||
|
||||
|
||||
def get_spec_metadata(spec_config,
|
||||
@ -55,6 +56,25 @@ def get_spec_metadata(spec_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_save_hidden_states():
|
||||
if spec_config.eagle3_layers_to_capture is None:
|
||||
spec_config.eagle3_layers_to_capture = {
|
||||
1, model_config.num_hidden_layers // 2 - 1,
|
||||
model_config.num_hidden_layers - 4, -1
|
||||
}
|
||||
return Eagle3SpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
hidden_size=model_config.hidden_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
dtype=model_config.torch_dtype,
|
||||
is_draft_model=is_draft_model,
|
||||
eagle3_resource_manager=spec_resource_manager,
|
||||
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||
max_total_draft_tokens=1,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_draft_target() or \
|
||||
spec_config.spec_dec_mode.is_ngram() or \
|
||||
spec_config.spec_dec_mode.is_user_provided():
|
||||
@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
|
||||
max_seq_len,
|
||||
max_num_tokens,
|
||||
)
|
||||
if spec_dec_mode.is_save_hidden_states():
|
||||
return Eagle3ResourceManager(
|
||||
spec_config,
|
||||
model_engine.model.config.torch_dtype,
|
||||
model_config.hidden_size,
|
||||
max_num_requests,
|
||||
max_seq_len,
|
||||
max_num_tokens,
|
||||
)
|
||||
if spec_dec_mode.is_ngram():
|
||||
return NGramPoolManager(spec_config, max_num_requests)
|
||||
if spec_dec_mode.is_user_provided():
|
||||
@ -151,6 +180,9 @@ def get_spec_drafter(model_engine,
|
||||
if spec_config.spec_dec_mode.is_ngram():
|
||||
return NGramDrafter(spec_config, spec_resource_manager)
|
||||
|
||||
if spec_config.spec_dec_mode.is_save_hidden_states():
|
||||
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,8 @@ from .llm_args import (AttentionDpConfig, AutoDecodingConfig, BatchingType,
|
||||
DynamicBatchConfig, EagleDecodingConfig,
|
||||
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
|
||||
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
SaveHiddenStatesDecodingConfig, SchedulerConfig,
|
||||
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
|
||||
UserProvidedDecodingConfig)
|
||||
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
|
||||
@ -59,4 +60,5 @@ __all__ = [
|
||||
'AutoDecodingConfig',
|
||||
'AttentionDpConfig',
|
||||
'LoRARequest',
|
||||
'SaveHiddenStatesDecodingConfig',
|
||||
]
|
||||
|
||||
@ -380,6 +380,7 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
"Lookahead": LookaheadDecodingConfig,
|
||||
"NGram": NGramDecodingConfig,
|
||||
"DraftTarget": DraftTargetDecodingConfig,
|
||||
"SaveState": SaveHiddenStatesDecodingConfig,
|
||||
"UserProvided": UserProvidedDecodingConfig,
|
||||
"AUTO": AutoDecodingConfig,
|
||||
}
|
||||
@ -562,6 +563,52 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
return 3
|
||||
|
||||
|
||||
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
|
||||
output_directory: str
|
||||
write_interval: int = 20
|
||||
file_prefix: str = "data"
|
||||
eagle3_layers_to_capture: Optional[Set[int]] = None
|
||||
|
||||
max_total_draft_tokens: Optional[int] = Field(default=1, init=False)
|
||||
eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False)
|
||||
|
||||
def model_post_init(self, __context):
|
||||
self._last_hidden_in_save = True
|
||||
if self.eagle3_layers_to_capture is None:
|
||||
self._last_hidden_in_save = False
|
||||
elif -1 not in self.eagle3_layers_to_capture:
|
||||
self._last_hidden_in_save = False
|
||||
self.eagle3_layers_to_capture.add(-1)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
decoding_type: ClassVar[str] = "SaveState"
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.output_directory is None or not self.eagle3_layers_to_capture:
|
||||
raise ValueError(
|
||||
"Save directory and layers to capture must be provided")
|
||||
|
||||
@functools.cached_property
|
||||
def spec_dec_mode(self):
|
||||
from tensorrt_llm._torch.speculative.interface import \
|
||||
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
|
||||
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||
|
||||
@functools.cached_property
|
||||
def num_capture_layers(self):
|
||||
"""
|
||||
Returns the number of layers to capture of the target model.
|
||||
If eagle3_layers_to_capture is not None, return the length of the set.
|
||||
Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
|
||||
"""
|
||||
if self.eagle3_layers_to_capture is None:
|
||||
return 4
|
||||
return len(self.eagle3_layers_to_capture)
|
||||
|
||||
|
||||
class UserProvidedDecodingConfig(DecodingBaseConfig):
|
||||
# Cannot use real type annotations due to circular imports
|
||||
drafter: object # Type is Drafter
|
||||
@ -1050,6 +1097,7 @@ SpeculativeConfig: TypeAlias = Optional[Union[
|
||||
MTPDecodingConfig,
|
||||
NGramDecodingConfig,
|
||||
UserProvidedDecodingConfig,
|
||||
SaveHiddenStatesDecodingConfig,
|
||||
AutoDecodingConfig,
|
||||
]]
|
||||
|
||||
@ -1869,6 +1917,20 @@ class BaseLlmArgs(StrictBaseModel):
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
|
||||
elif isinstance(self.speculative_config,
|
||||
SaveHiddenStatesDecodingConfig):
|
||||
assert self.backend in ['pytorch']
|
||||
logger.warning(
|
||||
"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
|
||||
)
|
||||
self.build_config.max_batch_size = 1
|
||||
self.max_batch_size = 1
|
||||
self.disable_overlap_scheduler = True
|
||||
self.cuda_graph_config = None
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||
self.build_config.max_draft_len = 1
|
||||
self.speculative_config.max_draft_len = 1
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized speculative config type {type(self.speculative_config)}"
|
||||
|
||||
@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
|
||||
EAGLE = auto()
|
||||
NGRAM = auto()
|
||||
USER_PROVIDED = auto()
|
||||
SAVE_HIDDEN_STATES = auto()
|
||||
AUTO = auto()
|
||||
|
||||
@staticmethod
|
||||
@ -120,6 +121,8 @@ class SpeculativeDecodingMode(IntFlag):
|
||||
return SpeculativeDecodingMode.USER_PROVIDED
|
||||
elif args.speculative_decoding_mode == "auto":
|
||||
return SpeculativeDecodingMode.AUTO
|
||||
elif args.speculative_decoding_mode == "save_hidden_states":
|
||||
return SpeculativeDecodingMode.SAVE_HIDDEN_STATES
|
||||
else:
|
||||
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
|
||||
|
||||
|
||||
138
tests/unittest/_torch/speculative/test_save_state.py
Normal file
138
tests/unittest/_torch/speculative/test_save_state.py
Normal file
@ -0,0 +1,138 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||
SaveHiddenStatesDecodingConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
def test_multi_save_state():
|
||||
use_cuda_graph = True
|
||||
attn_backend = "TRTLLM"
|
||||
disable_overlap_scheduler = False
|
||||
enable_block_reuse = False
|
||||
enable_chunked_prefill = False
|
||||
layers_to_capture = {10, 11, 12}
|
||||
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 80:
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
||||
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
|
||||
|
||||
max_batch_size = 16
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
free_gpu_memory_fraction=0.5)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
spec_config = SaveHiddenStatesDecodingConfig(
|
||||
output_directory=temp_dir,
|
||||
write_interval=1,
|
||||
file_prefix="data",
|
||||
eagle3_layers_to_capture=layers_to_capture)
|
||||
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=32, temperature=0)
|
||||
for output in llm_spec.generate_async(tok_ids,
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
pass
|
||||
llm_spec.shutdown()
|
||||
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
|
||||
# Read in .pt file
|
||||
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
|
||||
|
||||
assert saved_data["aux_hidden_states"].shape == (len(tok_ids), 2048 *
|
||||
len(layers_to_capture))
|
||||
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
|
||||
assert saved_data["input_ids"].tolist() == tok_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layers_to_capture", [{-1}, None])
|
||||
def test_save_state(layers_to_capture):
|
||||
use_cuda_graph = True
|
||||
attn_backend = "TRTLLM"
|
||||
disable_overlap_scheduler = False
|
||||
enable_block_reuse = False
|
||||
enable_chunked_prefill = False
|
||||
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 80:
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
||||
target_model_dir = f"{models_path}/llama-3.2-models/Llama-3.2-1B-Instruct"
|
||||
|
||||
max_batch_size = 16
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
free_gpu_memory_fraction=0.5)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1, 2, 4]) if use_cuda_graph else None
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
spec_config = SaveHiddenStatesDecodingConfig(
|
||||
output_directory=temp_dir,
|
||||
write_interval=1,
|
||||
file_prefix="data",
|
||||
eagle3_layers_to_capture=layers_to_capture)
|
||||
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=32, temperature=0)
|
||||
for output in llm_spec.generate_async(tok_ids,
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
pass
|
||||
llm_spec.shutdown()
|
||||
assert os.path.exists(os.path.join(temp_dir, "data_1.pt"))
|
||||
# Read in .pt file
|
||||
saved_data = torch.load(os.path.join(temp_dir, "data_1.pt"))[0]
|
||||
if layers_to_capture is None:
|
||||
assert saved_data["aux_hidden_states"].shape == (len(tok_ids),
|
||||
2048 * 3)
|
||||
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
|
||||
assert saved_data["input_ids"].tolist() == tok_ids
|
||||
else:
|
||||
assert "aux_hidden_states" not in saved_data
|
||||
assert saved_data["hidden_state"].shape == (len(tok_ids), 2048)
|
||||
assert saved_data["input_ids"].tolist() == tok_ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -59,7 +59,7 @@ methods:
|
||||
default: null
|
||||
# Speculative decoding
|
||||
speculative_config:
|
||||
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, NoneType]
|
||||
annotation: Union[tensorrt_llm.llmapi.llm_args.DraftTargetDecodingConfig, tensorrt_llm.llmapi.llm_args.EagleDecodingConfig,tensorrt_llm.llmapi.llm_args.LookaheadDecodingConfig, tensorrt_llm.llmapi.llm_args.MedusaDecodingConfig, tensorrt_llm.llmapi.llm_args.MTPDecodingConfig, tensorrt_llm.llmapi.llm_args.NGramDecodingConfig, tensorrt_llm.llmapi.llm_args.UserProvidedDecodingConfig, tensorrt_llm.llmapi.llm_args.AutoDecodingConfig, tensorrt_llm.llmapi.llm_args.SaveHiddenStatesDecodingConfig, NoneType]
|
||||
default: null
|
||||
# generation constraints
|
||||
max_batch_size:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user