mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8160][feat] Add max_total_draft_tokens (#8366)
Signed-off-by: Yue Weng <25103990+yweng0828@users.noreply.github.com>
This commit is contained in:
parent
a0024f4d34
commit
8dc4aac5b6
@ -320,13 +320,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
|
||||
max_draft_len = (
|
||||
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
|
||||
)
|
||||
max_total_draft_tokens = 0
|
||||
if ad_config.speculative_config is None:
|
||||
max_total_draft_tokens = 0
|
||||
elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"):
|
||||
max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens
|
||||
else:
|
||||
max_total_draft_tokens = max_draft_len
|
||||
max_total_draft_tokens = (
|
||||
0
|
||||
if ad_config.speculative_config is None
|
||||
else ad_config.speculative_config.max_total_draft_tokens
|
||||
)
|
||||
|
||||
# initialize model engine
|
||||
engine = ADEngine.build_from_config(ad_config=ad_config)
|
||||
@ -417,6 +415,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
|
||||
max_input_len=ad_config.max_input_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
max_draft_len=max_draft_len,
|
||||
max_total_draft_tokens=max_total_draft_tokens,
|
||||
max_beam_width=ad_config.max_beam_width,
|
||||
)
|
||||
return py_executor
|
||||
|
||||
@ -510,7 +510,7 @@ class DeepseekV3Attention(MLA):
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
|
||||
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
|
||||
super().__init__(hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
|
||||
@ -250,10 +250,10 @@ class KvCacheCreator:
|
||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
||||
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_len
|
||||
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
|
||||
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_len
|
||||
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
|
||||
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
|
||||
|
||||
if self._dummy_reqs is None:
|
||||
@ -808,6 +808,8 @@ def create_py_executor_instance(
|
||||
max_beam_width=max_beam_width,
|
||||
max_draft_len=spec_config.max_draft_len
|
||||
if spec_config is not None else 0,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens
|
||||
if spec_config is not None else 0,
|
||||
kv_cache_transceiver=kv_cache_transceiver,
|
||||
guided_decoder=guided_decoder,
|
||||
start_worker=start_worker,
|
||||
@ -824,13 +826,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
|
||||
max_num_sequences = max_batch_size * mapping.pp_size
|
||||
max_draft_len = (0 if speculative_config is None else
|
||||
speculative_config.max_draft_len)
|
||||
max_total_draft_tokens = 0
|
||||
if speculative_config is None:
|
||||
max_total_draft_tokens = 0
|
||||
elif hasattr(speculative_config, 'max_total_draft_tokens'):
|
||||
max_total_draft_tokens = speculative_config.max_total_draft_tokens
|
||||
else:
|
||||
max_total_draft_tokens = max_draft_len
|
||||
max_total_draft_tokens = (0 if speculative_config is None else
|
||||
speculative_config.max_total_draft_tokens)
|
||||
|
||||
return TorchSampler.Args(
|
||||
max_seq_len=max_seq_len,
|
||||
|
||||
@ -93,7 +93,8 @@ class CUDAGraphRunner:
|
||||
@property
|
||||
def max_possible_draft_len(self):
|
||||
engine = self._get_engine()
|
||||
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
|
||||
return (engine.original_max_total_draft_tokens
|
||||
if self.enable_spec_decode else 0)
|
||||
|
||||
def get_graph_key(
|
||||
self,
|
||||
@ -102,10 +103,12 @@ class CUDAGraphRunner:
|
||||
engine = self._get_engine()
|
||||
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
|
||||
spec_resource_manager, Eagle3ResourceManager):
|
||||
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
|
||||
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
|
||||
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
|
||||
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
|
||||
else:
|
||||
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
|
||||
draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0
|
||||
key = (batch_size, draft_len, False)
|
||||
return key
|
||||
|
||||
|
||||
@ -93,11 +93,11 @@ class ModelEngine(ABC):
|
||||
|
||||
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
|
||||
max_batch_size: int, max_num_tokens: int,
|
||||
max_draft_len: int,
|
||||
max_total_draft_tokens: int,
|
||||
enable_padding: bool) -> list[int]:
|
||||
# This is the largest possible batch size for a pure decoding batch.
|
||||
max_cuda_graph_bs = min(max_batch_size,
|
||||
int(max_num_tokens / (1 + max_draft_len)))
|
||||
int(max_num_tokens / (1 + max_total_draft_tokens)))
|
||||
|
||||
result = []
|
||||
# This function assumes cuda_graph_batch_sizes is sorted
|
||||
@ -162,11 +162,13 @@ class PyTorchModelEngine(ModelEngine):
|
||||
ExpertStatistic.create(self.dist.rank)
|
||||
self.pytorch_backend_config = pytorch_backend_config
|
||||
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
|
||||
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
|
||||
|
||||
# The draft model won't have any draft tokens attached to
|
||||
# generation requests when we invoke it autoregressively
|
||||
if spec_config is not None and is_draft_model:
|
||||
spec_config.max_draft_len = 0
|
||||
spec_config.max_total_draft_tokens = 0
|
||||
self.spec_config = spec_config
|
||||
self.is_spec_decode = spec_config is not None
|
||||
self.sparse_attention_config = sparse_attention_config
|
||||
@ -277,7 +279,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.spec_metadata = None
|
||||
update_spec_config_from_model_config(self.spec_config,
|
||||
self.model.config)
|
||||
max_num_draft_tokens = self.original_max_draft_len * batch_size
|
||||
max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size
|
||||
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
@ -297,9 +299,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
|
||||
) or self.model_is_wrapped
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
|
||||
else:
|
||||
self.without_logits = False
|
||||
self.max_draft_len = 0
|
||||
self.max_total_draft_tokens = 0
|
||||
|
||||
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
|
||||
|
||||
@ -320,7 +324,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
|
||||
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
|
||||
self.max_num_tokens, self.original_max_draft_len,
|
||||
self.max_num_tokens, self.original_max_total_draft_tokens,
|
||||
self._cuda_graph_padding_enabled
|
||||
) if pytorch_backend_config.cuda_graph_batch_sizes else []
|
||||
|
||||
@ -364,7 +368,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
@property
|
||||
def runtime_draft_len(self):
|
||||
return self.max_draft_len if self.enable_spec_decode else 0
|
||||
return self.max_total_draft_tokens if self.enable_spec_decode else 0
|
||||
|
||||
def set_lora_model_config(self,
|
||||
lora_target_modules: list[str],
|
||||
@ -585,20 +589,20 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
|
||||
spec_resource_manager, Eagle3ResourceManager):
|
||||
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
|
||||
draft_lengths.append(self.original_max_draft_len)
|
||||
draft_lengths.append(self.original_max_total_draft_tokens)
|
||||
else:
|
||||
draft_lengths.append(self.max_draft_len)
|
||||
draft_lengths.append(self.max_total_draft_tokens)
|
||||
else:
|
||||
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
|
||||
# so that when we disable spec decode at runtime, we can still run the captured graph.
|
||||
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
|
||||
if (self.max_draft_len > 0
|
||||
if (self.max_total_draft_tokens > 0
|
||||
and not self.spec_config.spec_dec_mode.use_one_engine()
|
||||
# Assume that speculation is always on if the user didn't give us a max_concurrency
|
||||
# value. This will save on memory.
|
||||
and self.spec_config.max_concurrency is not None):
|
||||
draft_lengths.append(0)
|
||||
draft_lengths = [self.max_draft_len]
|
||||
draft_lengths = [self.max_total_draft_tokens]
|
||||
|
||||
for bs in cuda_graph_batch_sizes:
|
||||
if bs > self.batch_size:
|
||||
@ -757,7 +761,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_ctx_requests + num_gen_tokens)),
|
||||
token_nums=[1] * num_gen_tokens,
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
max_num_draft_tokens=self.max_total_draft_tokens,
|
||||
use_mrope=self.use_mrope)
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(request_ids=list(
|
||||
@ -830,7 +834,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
def _get_cuda_graph_draft_lengths(
|
||||
self, resource_manager: ResourceManager) -> List[int]:
|
||||
"""Determines the draft lengths for which to capture CUDA graphs."""
|
||||
draft_lengths = [self.max_draft_len]
|
||||
draft_lengths = [self.max_total_draft_tokens]
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
|
||||
@ -1027,7 +1031,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
"""
|
||||
if self.enable_spec_decode and not self._disable_overlap_scheduler:
|
||||
# When enabling overlap scheduler, the kv cache for draft tokens will
|
||||
# be prepared in advance by using the max_draft_len. But we need to use
|
||||
# be prepared in advance by using the max_total_draft_tokens. But we need to use
|
||||
# new_tokens_lens_device to get the real past kv lengths and the
|
||||
# correct position ids. And to avoid blocking the async data transfer,
|
||||
# we need to preprocess the inputs in forward to update the position_ids and
|
||||
@ -2252,7 +2256,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
|
||||
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
|
||||
spec_resource_manager, self.is_draft_model, self.attn_backend,
|
||||
self.model_is_wrapped)
|
||||
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
|
||||
attn_metadata.update_spec_dec_param(
|
||||
is_spec_dec_mode, spec_metadata.is_spec_dec_tree,
|
||||
spec_metadata.is_spec_dec_dynamic_tree,
|
||||
|
||||
@ -160,6 +160,7 @@ class PyExecutor:
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_len: int = 0,
|
||||
max_total_draft_tokens: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
@ -195,6 +196,7 @@ class PyExecutor:
|
||||
self.active = True
|
||||
self.max_beam_width = max_beam_width
|
||||
self.max_draft_len = max_draft_len
|
||||
self.max_total_draft_tokens = max_total_draft_tokens
|
||||
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
|
||||
self.print_log = model_engine.pytorch_backend_config.print_iter_log
|
||||
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
|
||||
@ -1040,7 +1042,7 @@ class PyExecutor:
|
||||
self.use_spec_decode = self.drafter.should_use_spec_decode(
|
||||
self.active_requests, self.max_batch_size,
|
||||
self.model_engine.max_num_tokens,
|
||||
self.model_engine.spec_config.max_draft_len)
|
||||
self.model_engine.spec_config.max_total_draft_tokens)
|
||||
logger.debug(f"Use spec decode: {self.use_spec_decode}")
|
||||
self.model_engine.enable_spec_decode = self.use_spec_decode
|
||||
|
||||
@ -1050,10 +1052,10 @@ class PyExecutor:
|
||||
LlmRequestState.GENERATION_IN_PROGRESS,
|
||||
LlmRequestState.DISAGG_GENERATION_INIT):
|
||||
continue
|
||||
max_draft_len = self.model_engine.spec_config.max_draft_len
|
||||
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
|
||||
request.draft_tokens = [
|
||||
0
|
||||
] * max_draft_len if max_draft_len > 0 else []
|
||||
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
|
||||
|
||||
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
|
||||
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
|
||||
@ -1223,11 +1225,11 @@ class PyExecutor:
|
||||
continue
|
||||
|
||||
req.py_last_draft_tokens = req.py_draft_tokens
|
||||
max_draft_len = self.model_engine.spec_config.max_draft_len
|
||||
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
|
||||
|
||||
if max_draft_len > 0 and self.use_spec_decode:
|
||||
req.py_draft_tokens = [0] * max_draft_len
|
||||
req.py_draft_pages_allocated = max_draft_len
|
||||
if max_total_draft_tokens > 0 and self.use_spec_decode:
|
||||
req.py_draft_tokens = [0] * max_total_draft_tokens
|
||||
req.py_draft_pages_allocated = max_total_draft_tokens
|
||||
else:
|
||||
req.py_draft_tokens = []
|
||||
req.py_draft_pages_allocated = 0
|
||||
@ -1615,7 +1617,7 @@ class PyExecutor:
|
||||
request_ids=[0],
|
||||
is_gen=True,
|
||||
prepare_resource=True,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
max_num_draft_tokens=self.max_total_draft_tokens,
|
||||
)[0]
|
||||
llm_request.is_attention_dp_dummy = True
|
||||
spec_resource_manager = self.resource_manager.get_resource_manager(
|
||||
|
||||
@ -357,7 +357,9 @@ def create_py_executor(
|
||||
from tensorrt_llm._torch.speculative.drafting_loops import \
|
||||
ChainDrafter
|
||||
|
||||
return ChainDrafter(spec_config.max_draft_len, model)
|
||||
return ChainDrafter(spec_config.max_draft_len,
|
||||
spec_config.max_total_draft_tokens,
|
||||
model)
|
||||
else:
|
||||
drafting_loop_wrapper = None
|
||||
|
||||
@ -397,11 +399,11 @@ def create_py_executor(
|
||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
||||
model_engine_max_seq_len = model_engine.max_seq_len + 1
|
||||
if spec_config is not None:
|
||||
model_engine_max_seq_len += spec_config.max_draft_len
|
||||
model_engine_max_seq_len += spec_config.max_total_draft_tokens
|
||||
|
||||
if spec_config is not None:
|
||||
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
|
||||
model_engine_max_seq_len += spec_config.max_draft_len
|
||||
model_engine_max_seq_len += spec_config.max_total_draft_tokens
|
||||
|
||||
max_seq_len = model_engine_max_seq_len
|
||||
max_num_tokens = model_engine.max_num_tokens
|
||||
@ -471,7 +473,8 @@ def create_py_executor(
|
||||
"vocab_size_padded": model_engine.model.vocab_size_padded
|
||||
}
|
||||
if spec_config is not None:
|
||||
kwargs["max_num_draft_tokens"] = spec_config.max_draft_len
|
||||
kwargs[
|
||||
"max_num_draft_tokens"] = spec_config.max_total_draft_tokens
|
||||
|
||||
if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
|
||||
):
|
||||
|
||||
@ -589,7 +589,7 @@ class TorchSampler(Sampler):
|
||||
|
||||
def __init__(self, args: Args):
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.max_tokens = args.max_draft_len + 1
|
||||
self.max_tokens = args.max_total_draft_tokens + 1
|
||||
assert args.max_beam_width == self.MAX_BEAM_WIDTH, (
|
||||
"TorchSampler only supports beam_width = 1"
|
||||
)
|
||||
@ -738,9 +738,9 @@ class TorchSampler(Sampler):
|
||||
we can find the longest match by comparing all the paths.
|
||||
Args:
|
||||
request: LlmRequest. The request with draft tokens.
|
||||
new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer.
|
||||
new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer.
|
||||
The tokens generated by the target model
|
||||
The relationship between [max_draft_len + 1] and the draft token tree:
|
||||
The relationship between [max_total_draft_tokens + 1] and the draft token tree:
|
||||
If the current node is accepted, what is the NEXT token_id that the target model will generate?
|
||||
For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root
|
||||
node in the draft token tree if it is accepted.
|
||||
|
||||
@ -29,14 +29,14 @@ class Drafter(ABC):
|
||||
@final
|
||||
def should_use_spec_decode(self, requests: List[LlmRequest],
|
||||
max_batch_size: int, max_num_tokens: int,
|
||||
max_draft_len: int) -> bool:
|
||||
max_total_draft_tokens: int) -> bool:
|
||||
"""
|
||||
You probably don't want to override this. ModelEngine
|
||||
assumes that speculation is always on if max_concurrency
|
||||
is not specified by the user's spec config.
|
||||
"""
|
||||
|
||||
# Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0
|
||||
# Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_total_draft_tokens>=0
|
||||
|
||||
if self.max_concurrency is None:
|
||||
return True
|
||||
@ -45,7 +45,7 @@ class Drafter(ABC):
|
||||
if not requests or max_batch_size <= 0 or max_num_tokens <= 0:
|
||||
return False
|
||||
|
||||
tokens_per_request = 1 + max_draft_len
|
||||
tokens_per_request = 1 + max_total_draft_tokens
|
||||
token_cap = max_num_tokens // tokens_per_request
|
||||
if token_cap <= 0:
|
||||
return False
|
||||
@ -63,7 +63,7 @@ class Drafter(ABC):
|
||||
scheduled_requests: The scheduled requests to pad
|
||||
"""
|
||||
for req in scheduled_requests.generation_requests:
|
||||
max_draft_tokens = self.max_draft_tokens
|
||||
max_draft_tokens = self.max_draft_len
|
||||
num_draft_tokens = get_draft_token_length(req)
|
||||
req.py_draft_tokens.extend(
|
||||
0 for _ in range(max_draft_tokens - num_draft_tokens))
|
||||
|
||||
@ -107,12 +107,14 @@ def prepare_for_generation(attn_metadata: AttentionMetadata,
|
||||
|
||||
class ChainDrafter(torch.nn.Module):
|
||||
|
||||
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
|
||||
def __init__(self, max_draft_len: int, max_total_draft_tokens: int,
|
||||
draft_model: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.draft_model = draft_model
|
||||
self.config = self.draft_model.config
|
||||
self.model_config = self.draft_model.model_config
|
||||
self.max_draft_len = max_draft_len
|
||||
self.max_total_draft_tokens = max_total_draft_tokens
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata,
|
||||
|
||||
@ -131,6 +131,7 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
is_draft_model: bool,
|
||||
attention_backend: Type[AttentionBackend],
|
||||
use_chain_drafter: bool,
|
||||
is_spec_dec_tree: bool,
|
||||
):
|
||||
"""
|
||||
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
|
||||
@ -154,8 +155,10 @@ class SpecMetadata:
|
||||
"""
|
||||
# The max number of requests in a single batch.
|
||||
max_num_requests: int
|
||||
# The max number of draft tokens.
|
||||
# The number of draft layers. (Also the number of draft tokens for the linear tree.)
|
||||
max_draft_len: int
|
||||
# The max number of draft tokens for the static tree and dynamic tree .
|
||||
max_total_draft_tokens: int
|
||||
# The number of gen-phase sequences in the batch.
|
||||
num_generations: int = 0
|
||||
# Whether CUDA graph is enabled.
|
||||
@ -191,9 +194,13 @@ class SpecMetadata:
|
||||
# The number of layers
|
||||
num_layers: int = 0
|
||||
|
||||
# if spec-dec tree is a tree or a chain (linear tree)
|
||||
is_spec_dec_tree: bool = False
|
||||
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
|
||||
# NOTE: For the linear tree, though it can be treated as a special case of static tree.
|
||||
# NOTE: But we do not set `is_spec_dec_tree` to True for this cases.
|
||||
# NOTE: i.e., for the linear tree, is_spec_dec_tree == False and is_spec_dec_dynamic_tree == False.
|
||||
# whether the spec-dec mode is a tree (can be static tree or dynamic tree).
|
||||
is_spec_dec_tree: bool = False
|
||||
# whether the spec-dec mode is a dynamic tree.
|
||||
is_spec_dec_dynamic_tree: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@ -51,7 +51,8 @@ class ModelDrafter(Drafter):
|
||||
self,
|
||||
spec_config: "DecodingBaseConfig",
|
||||
draft_model_engine: "ModelEngine",
|
||||
max_draft_tokens: int,
|
||||
max_draft_len: int,
|
||||
max_total_draft_tokens: int,
|
||||
draft_seq_slot_manager: SeqSlotManager,
|
||||
sampler: Sampler,
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None,
|
||||
@ -62,8 +63,11 @@ class ModelDrafter(Drafter):
|
||||
# Validate required parameters
|
||||
if draft_model_engine is None:
|
||||
raise ValueError("draft_model_engine cannot be None")
|
||||
if max_draft_tokens < 0:
|
||||
raise ValueError("max_draft_tokens must be >= 0")
|
||||
if max_draft_len < 0:
|
||||
raise ValueError("max_draft_len must be >= 0")
|
||||
if max_total_draft_tokens < 0:
|
||||
raise ValueError("max_total_draft_tokens must be >= 0")
|
||||
assert max_draft_len <= max_total_draft_tokens
|
||||
|
||||
# Model and resource management
|
||||
self.draft_model_engine = draft_model_engine
|
||||
@ -72,7 +76,8 @@ class ModelDrafter(Drafter):
|
||||
|
||||
# Configuration
|
||||
self.spec_config = spec_config
|
||||
self.max_draft_tokens = max_draft_tokens
|
||||
self.max_draft_len = max_draft_len
|
||||
self.max_total_draft_tokens = max_total_draft_tokens
|
||||
# Sampling
|
||||
self.sampler = sampler
|
||||
self.guided_decoder = guided_decoder
|
||||
@ -153,9 +158,11 @@ class ModelDrafter(Drafter):
|
||||
Create a chunked context request for accepted tokens.
|
||||
Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3)
|
||||
"""
|
||||
# Pad input_tokens to max_draft_tokens
|
||||
# Pad input_tokens to max_draft_len
|
||||
# We use max_draft_len instead of max_total_draft_tokens here,
|
||||
# because at most max_draft_len draft tokens are accepted.
|
||||
input_tokens.extend(
|
||||
0 for _ in range(self.max_draft_tokens - num_accepted_tokens))
|
||||
0 for _ in range(self.max_draft_len - num_accepted_tokens))
|
||||
new_request = self._create_draft_request(request, input_tokens)
|
||||
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
new_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens
|
||||
@ -469,7 +476,7 @@ class ModelDrafter(Drafter):
|
||||
# We already updated the target state, so the new_tokens_lens should be all ones.
|
||||
new_tokens_lens = torch.ones(batch_size, device=device)
|
||||
next_draft_tokens = torch.zeros(batch_size,
|
||||
self.max_draft_tokens,
|
||||
self.max_draft_len,
|
||||
device=device)
|
||||
|
||||
# Create a new SampleStateTensorsMTP object with the additional fields
|
||||
@ -563,7 +570,7 @@ class ModelDrafter(Drafter):
|
||||
# Chunked prefill request in progress; no need to append draft tokens
|
||||
continue
|
||||
py_draft_logits = []
|
||||
for token_idx in range(self.max_draft_tokens):
|
||||
for token_idx in range(self.max_draft_len):
|
||||
target_model_req.py_draft_tokens.append(
|
||||
draft_tokens_host[token_idx][req_idx])
|
||||
py_draft_logits.append(draft_logits[token_idx][req_idx])
|
||||
@ -646,7 +653,7 @@ class ModelDrafter(Drafter):
|
||||
previous_draft_state = initial_draft_state
|
||||
|
||||
# Generate remaining draft tokens iteratively
|
||||
for i in range(self.max_draft_tokens - 1):
|
||||
for i in range(self.max_draft_len - 1):
|
||||
if len(draft_batch.generation_requests) == 0:
|
||||
break
|
||||
|
||||
@ -722,7 +729,7 @@ class ModelDrafter(Drafter):
|
||||
target_inputs,
|
||||
outputs["new_draft_tokens"],
|
||||
draft_position=0,
|
||||
draft_length=self.max_draft_tokens,
|
||||
draft_length=self.max_draft_len,
|
||||
draft_batch=draft_batch,
|
||||
req_id_to_old_request=req_id_to_old_request)
|
||||
|
||||
|
||||
@ -170,7 +170,7 @@ class NGramDrafter(Drafter):
|
||||
super().__init__(spec_config.max_concurrency)
|
||||
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
|
||||
self.spec_config = spec_config
|
||||
self.max_draft_tokens = spec_config.max_draft_len
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
self.spec_resource_manager = ngram_pool_manager
|
||||
|
||||
def prepare_draft_tokens(
|
||||
|
||||
@ -21,6 +21,7 @@ class SaveHiddenStatesDrafter(Drafter):
|
||||
super().__init__(spec_config.max_concurrency)
|
||||
self.spec_config = spec_config
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
|
||||
self._iter = 1
|
||||
self._output_directory = spec_config.output_directory
|
||||
self._file_prefix = spec_config.file_prefix
|
||||
|
||||
@ -23,6 +23,7 @@ def get_spec_metadata(spec_config,
|
||||
if spec_config.spec_dec_mode.is_mtp_one_model():
|
||||
return MTPSpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
mtp_num_modules=spec_config.num_nextn_predict_layers,
|
||||
max_num_requests=max_num_requests,
|
||||
@ -31,6 +32,7 @@ def get_spec_metadata(spec_config,
|
||||
if spec_config.spec_dec_mode.is_mtp_eagle():
|
||||
return Eagle3SpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
@ -45,6 +47,7 @@ def get_spec_metadata(spec_config,
|
||||
if spec_config.spec_dec_mode.is_eagle3():
|
||||
return Eagle3SpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
@ -55,14 +58,15 @@ def get_spec_metadata(spec_config,
|
||||
eagle3_resource_manager=spec_resource_manager,
|
||||
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||
is_mtp_eagle=False,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens,
|
||||
eagle_choices=spec_config.eagle_choices,
|
||||
is_spec_dec_tree=spec_config.eagle_choices is not None,
|
||||
is_spec_dec_tree=spec_config.eagle_choices is not None
|
||||
or spec_config.use_dynamic_tree,
|
||||
is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelSpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
@ -78,6 +82,7 @@ def get_spec_metadata(spec_config,
|
||||
}
|
||||
return Eagle3SpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
max_total_draft_tokens=1,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
@ -87,13 +92,13 @@ def get_spec_metadata(spec_config,
|
||||
is_draft_model=is_draft_model,
|
||||
eagle3_resource_manager=spec_resource_manager,
|
||||
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||
max_total_draft_tokens=1,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_draft_target() or \
|
||||
spec_config.spec_dec_mode.is_ngram() or \
|
||||
spec_config.spec_dec_mode.is_user_provided():
|
||||
return SpecMetadata(
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
max_total_draft_tokens=spec_config.max_total_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
)
|
||||
@ -186,6 +191,7 @@ def get_spec_drafter(model_engine,
|
||||
return ModelDrafter(spec_config,
|
||||
draft_model_engine,
|
||||
spec_config.max_draft_len,
|
||||
spec_config.max_total_draft_tokens,
|
||||
SeqSlotManager(max_num_requests),
|
||||
sampler,
|
||||
spec_resource_manager=spec_resource_manager,
|
||||
|
||||
@ -413,7 +413,14 @@ class _ModelFormatKind(Enum):
|
||||
|
||||
|
||||
class DecodingBaseConfig(StrictBaseModel):
|
||||
# The number of the drafter layers.
|
||||
max_draft_len: Optional[int] = None
|
||||
# The number of draft tokens in the draft tokens tree.
|
||||
# If it's a linear tree, each draft layer will only generate one draft token.
|
||||
# In this case, max_draft_len == max_total_draft_tokens.
|
||||
# If it's a static or dynamic tree, each draft layer may generate more than one draft token.
|
||||
# In this case, max_total_draft_tokens >= max_draft_len.
|
||||
max_total_draft_tokens: Optional[int] = None
|
||||
speculative_model_dir: Optional[Union[str, Path]] = None
|
||||
|
||||
# PyTorch only.
|
||||
@ -526,6 +533,10 @@ class MedusaDecodingConfig(DecodingBaseConfig):
|
||||
medusa_choices: Optional[List[List[int]]] = None
|
||||
num_medusa_heads: Optional[int] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_total_draft_tokens = self.max_draft_len # Current Medusa only support linear tree
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
@ -544,8 +555,6 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
use_dynamic_tree: Optional[bool] = False
|
||||
# The topK value for each layer when enable dynamic tree.
|
||||
dynamic_tree_max_topK: Optional[int] = None
|
||||
# The number of draft tokens in the draft tokens tree.
|
||||
max_total_draft_tokens: Optional[int] = None
|
||||
# The number of eagle layer. will not be used in pytorch flow, just for compatibility with TRT flow
|
||||
num_eagle_layers: Optional[int] = None
|
||||
# The number of non-leaves in each layer.
|
||||
@ -580,7 +589,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
# Checks whether the input eagle choices is valid
|
||||
# and reset the max_draft_len and num_eagle_layers if necessary
|
||||
if self.eagle_choices is not None:
|
||||
# If eagle_choices is provided, use_dynamic_tree will not be used
|
||||
# If eagle_choices is provided, use_dynamic_tree should not be used
|
||||
assert not self.use_dynamic_tree, "If eagle_choices is provided, use_dynamic_tree need to be False"
|
||||
|
||||
# Get num_eagle_layers from eagle_choices
|
||||
@ -651,6 +660,12 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
return len(self.eagle3_layers_to_capture)
|
||||
return 3
|
||||
|
||||
@functools.cached_property
|
||||
def is_linear_tree(self) -> bool:
|
||||
if self.eagle_choices is None and self.use_dynamic_tree is False:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
|
||||
output_directory: str
|
||||
@ -703,6 +718,10 @@ class UserProvidedDecodingConfig(DecodingBaseConfig):
|
||||
drafter: object # Type is Drafter
|
||||
resource_manager: object = None # Type is Optional[ResourceManager]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_total_draft_tokens = self.max_draft_len # Current UserProvided only support linear tree
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
@ -735,6 +754,10 @@ class NGramDecodingConfig(DecodingBaseConfig):
|
||||
is_use_oldest: bool = True
|
||||
is_public_pool: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_total_draft_tokens = self.max_draft_len # Current NGram only support linear tree
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
@ -747,6 +770,10 @@ class NGramDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
class DraftTargetDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_total_draft_tokens = self.max_draft_len # Current DraftTarget only support linear tree
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
@ -776,10 +803,18 @@ class MTPDecodingConfig(DecodingBaseConfig):
|
||||
BEGIN_THINKING_PHASE_TOKEN: int = 128798
|
||||
END_THINKING_PHASE_TOKEN: int = 128799
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if 'num_nextn_predict_layers' in kwargs:
|
||||
self.max_draft_len = kwargs['num_nextn_predict_layers']
|
||||
self.max_total_draft_tokens = kwargs[
|
||||
'num_nextn_predict_layers'] # Current MTP only support linear tree
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
out = cls(**data)
|
||||
out.max_draft_len = out.num_nextn_predict_layers
|
||||
out.max_total_draft_tokens = out.num_nextn_predict_layers # Current MTP only support linear tree
|
||||
return out
|
||||
|
||||
decoding_type: ClassVar[str] = "MTP"
|
||||
@ -814,6 +849,10 @@ class AutoDecodingConfig(DecodingBaseConfig):
|
||||
Attributes that are inherited from the base class are ignored.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.max_total_draft_tokens = self.max_draft_len # Current Auto only support linear tree
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
@ -1156,6 +1195,7 @@ class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror):
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.max_total_draft_tokens = self.max_draft_len # Current Lookahead only support linear tree
|
||||
self._check_fields()
|
||||
|
||||
def calculate_speculative_resource(self):
|
||||
|
||||
@ -73,7 +73,7 @@ def test_dynamic_spec_decode(enforce_single_worker,
|
||||
|
||||
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
|
||||
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
|
||||
max_draft_len):
|
||||
max_total_draft_tokens):
|
||||
for req in requests:
|
||||
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
continue
|
||||
@ -198,40 +198,42 @@ def test_should_use_spec_decode():
|
||||
assert drafter.should_use_spec_decode(active_requests,
|
||||
max_batch_size=8,
|
||||
max_num_tokens=4096 * 8,
|
||||
max_draft_len=4)
|
||||
max_total_draft_tokens=4)
|
||||
|
||||
# Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True
|
||||
active_requests = [object()] * 12
|
||||
assert drafter.should_use_spec_decode(active_requests,
|
||||
max_batch_size=5,
|
||||
max_num_tokens=4096 * 8,
|
||||
max_draft_len=4)
|
||||
max_total_draft_tokens=4)
|
||||
|
||||
# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(12, 8, 5) = 5 <= 6 → True
|
||||
active_requests = [object()] * 12
|
||||
assert drafter.should_use_spec_decode(active_requests,
|
||||
max_batch_size=8,
|
||||
max_num_tokens=28,
|
||||
max_draft_len=4)
|
||||
max_total_draft_tokens=4)
|
||||
|
||||
# Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False
|
||||
active_requests = [object()] * 12
|
||||
assert not drafter.should_use_spec_decode(active_requests,
|
||||
max_batch_size=8,
|
||||
max_num_tokens=4096 * 8,
|
||||
max_draft_len=4)
|
||||
max_total_draft_tokens=4)
|
||||
|
||||
# Edge case - None active requests OFF case
|
||||
active_requests = []
|
||||
assert not drafter.should_use_spec_decode(active_requests,
|
||||
max_batch_size=8,
|
||||
max_num_tokens=4096 * 8,
|
||||
max_draft_len=4)
|
||||
max_total_draft_tokens=4)
|
||||
|
||||
# Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False
|
||||
active_requests = [object()] * 12
|
||||
assert not drafter.should_use_spec_decode(
|
||||
active_requests, max_batch_size=8, max_num_tokens=4, max_draft_len=4)
|
||||
assert not drafter.should_use_spec_decode(active_requests,
|
||||
max_batch_size=8,
|
||||
max_num_tokens=4,
|
||||
max_total_draft_tokens=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -309,6 +309,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
spec_metadata = MTPSpecMetadata(max_num_requests=32,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_draft_len=mtp_num_modules,
|
||||
max_total_draft_tokens=mtp_num_modules,
|
||||
mtp_num_modules=mtp_num_modules)
|
||||
spec_metadata.draft_tokens = draft_tokens
|
||||
|
||||
@ -891,6 +892,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase):
|
||||
max_num_requests=32,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_draft_len=num_nextn_predict_layers,
|
||||
max_total_draft_tokens=num_nextn_predict_layers,
|
||||
mtp_num_modules=num_nextn_predict_layers,
|
||||
mtp_hidden_states_manager=spec_manager)
|
||||
spec_metadata.request_ids = request_ids
|
||||
@ -1386,6 +1388,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase):
|
||||
max_num_requests=32,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_draft_len=num_nextn_predict_layers,
|
||||
max_total_draft_tokens=num_nextn_predict_layers,
|
||||
mtp_num_modules=num_nextn_predict_layers,
|
||||
mtp_hidden_states_manager=spec_manager)
|
||||
spec_metadata.request_ids = request_ids
|
||||
|
||||
Loading…
Reference in New Issue
Block a user