mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[refactor] Simplification of Speculative decoding configs (#5639)
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
This commit is contained in:
parent
67a39dbd63
commit
2e3cf42e03
@ -68,7 +68,7 @@ docker run -d --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
|
||||
-p 8000:8000 --gpus=all -e "TRTLLM_ENABLE_PDL=1" \
|
||||
-v /path/to/maverick:/config/models/maverick -v /path/to/eagle:/config/models/eagle \
|
||||
docker.io/<username>/tensorrt_llm:main sh \
|
||||
-c "echo -e 'enable_attention_dp: false\nenable_min_latency: true\nenable_autotuner: false\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n pytorch_weights_path: /config/models/eagle\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
|
||||
-c "echo -e 'enable_attention_dp: false\nenable_min_latency: true\nenable_autotuner: false\ncuda_graph_config:\n max_batch_size: 8\nspeculative_config:\n decoding_type: Eagle\n max_draft_len: 3\n speculative_model_dir: /config/models/eagle\nkv_cache_config:\n enable_block_reuse: false' > c.yaml && \
|
||||
TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL=True \
|
||||
trtllm-serve /config/models/maverick \
|
||||
--host 0.0.0.0 --port 8000 \
|
||||
|
||||
@ -108,7 +108,7 @@ When using EAGLE-2, please enable `--eagle_use_dynamic_tree`, which indicates wh
|
||||
- In EagleNet2, the `N` output nodes of EagleNet1 are expanded, and each node expands `N` new draft tokens. Therefore, this layer also has a total of `N * N` draft tokens. And select the top `N` as the output of this layer.
|
||||
- Etc.
|
||||
|
||||
Finally, after `num_eagle_layer` EagleNets, `N + N * N * (num_eagle_layer - 1)` draft tokens are generated. We will rebuild the final tree based on all draft tokens and their scores. The final generated tree will have `min(N + N * N * (num_eagle_layer - 1), max_draft_tokens)` nodes.
|
||||
Finally, after `num_eagle_layer` EagleNets, `N + N * N * (num_eagle_layer - 1)` draft tokens are generated. We will rebuild the final tree based on all draft tokens and their scores. The final generated tree will have `min(N + N * N * (num_eagle_layer - 1), max_draft_len)` nodes.
|
||||
|
||||
|
||||
|
||||
|
||||
@ -23,12 +23,12 @@ def main():
|
||||
model = "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# The end user can customize the eagle decoding configuration by specifying the
|
||||
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
|
||||
# with the EagleDecodingConfig class
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
max_draft_len=63,
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
|
||||
@ -23,12 +23,12 @@ def main():
|
||||
model = "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# The end user can customize the eagle decoding configuration by specifying the
|
||||
# speculative_model, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# speculative_model_dir, max_draft_len, num_eagle_layers, max_non_leaves_per_layer, eagle_choices
|
||||
# greedy_sampling,posterior_threshold, use_dynamic_tree and dynamic_tree_max_topK
|
||||
# with the EagleDecodingConfig class
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
speculative_model="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model_dir="yuhuili/EAGLE-Vicuna-7B-v1.3",
|
||||
max_draft_len=63,
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
|
||||
@ -48,10 +48,10 @@ def run_medusa_decoding(use_modelopt_ckpt=False, model_dir=None):
|
||||
model = "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# The end user can customize the medusa decoding configuration by specifying the
|
||||
# speculative_model, max_draft_len, medusa heads num and medusa choices
|
||||
# speculative_model_dir, max_draft_len, medusa heads num and medusa choices
|
||||
# with the MedusaDecodingConfig class
|
||||
speculative_config = MedusaDecodingConfig(
|
||||
speculative_model="FasterDecoding/medusa-vicuna-7b-v1.3",
|
||||
speculative_model_dir="FasterDecoding/medusa-vicuna-7b-v1.3",
|
||||
max_draft_len=63,
|
||||
num_medusa_heads=4,
|
||||
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
|
||||
@ -35,7 +35,7 @@ def run_MTP(model: Optional[str] = None):
|
||||
def run_Eagle3():
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=3,
|
||||
pytorch_weights_path="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
speculative_model_dir="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
eagle3_one_model=True)
|
||||
|
||||
llm = LLM(
|
||||
@ -50,7 +50,7 @@ def run_Eagle3():
|
||||
|
||||
def run_ngram():
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=3,
|
||||
max_draft_len=3,
|
||||
max_matching_ngram_size=3,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
|
||||
@ -169,15 +169,15 @@ def setup_llm(args):
|
||||
elif spec_decode_algo == "EAGLE3":
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=args.spec_decode_nextn,
|
||||
pytorch_weights_path=args.draft_model_dir,
|
||||
speculative_model_dir=args.draft_model_dir,
|
||||
eagle3_one_model=args.use_one_model)
|
||||
elif spec_decode_algo == "DRAFT_TARGET":
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=args.spec_decode_nextn,
|
||||
pytorch_weights_path=args.draft_model_dir)
|
||||
speculative_model_dir=args.draft_model_dir)
|
||||
elif spec_decode_algo == "NGRAM":
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=args.spec_decode_nextn,
|
||||
max_draft_len=args.spec_decode_nextn,
|
||||
max_matching_ngram_size=args.max_matching_ngram_size,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
|
||||
@ -261,8 +261,8 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
|
||||
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
|
||||
# some derivative properties
|
||||
max_draft_tokens = (
|
||||
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
|
||||
max_draft_len = (
|
||||
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
|
||||
)
|
||||
|
||||
# initialize model engine
|
||||
@ -299,7 +299,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
# correctly for models as needed.
|
||||
sampler_args = TorchSampler.Args(
|
||||
max_seq_len=ad_config.max_seq_len,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
max_draft_len=max_draft_len,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
enable_mixed_sampler=ad_config.enable_mixed_sampler,
|
||||
@ -317,7 +317,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
|
||||
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
|
||||
max_input_len=ad_config.max_input_len,
|
||||
max_batch_size=ad_config.max_batch_size,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
max_draft_len=max_draft_len,
|
||||
max_beam_width=ad_config.max_beam_width,
|
||||
)
|
||||
return py_executor
|
||||
|
||||
@ -70,7 +70,7 @@ class ModelConfig(Generic[TConfig]):
|
||||
# to support mixed quantization.
|
||||
skip_create_weights_in_init: bool = False
|
||||
|
||||
spec_config: Optional["SpecConfig"] = None
|
||||
spec_config: Optional["DecodingBaseConfig"] = None
|
||||
lora_config: Optional["LoraConfig"] = None
|
||||
|
||||
is_generation: bool = True
|
||||
|
||||
@ -340,7 +340,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
model_config, 'spec_config', None
|
||||
) and model_config.spec_config.spec_dec_mode.use_one_engine():
|
||||
draft_config = ModelConfig.from_pretrained(
|
||||
model_config.spec_config.draft_model_path,
|
||||
model_config.spec_config.speculative_model_dir,
|
||||
trust_remote_code=True,
|
||||
attn_backend=model_config.attn_backend,
|
||||
moe_backend=model_config.moe_backend,
|
||||
|
||||
@ -157,10 +157,10 @@ class KvCacheCreator:
|
||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
||||
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_tokens
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_len
|
||||
|
||||
if spec_cfg is not None:
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_tokens
|
||||
num_extra_tokens_per_seq += spec_cfg.max_draft_len
|
||||
num_extra_tokens_per_seq += spec_cfg.num_extra_kv_tokens
|
||||
for req in self._dummy_reqs:
|
||||
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
|
||||
@ -538,7 +538,7 @@ def create_py_executor_instance(
|
||||
disable_overlap_scheduler,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
max_draft_tokens=spec_config.max_draft_tokens
|
||||
max_draft_len=spec_config.max_draft_len
|
||||
if spec_config is not None else 0,
|
||||
kv_cache_transceiver=kv_cache_transceiver,
|
||||
draft_model_engine=draft_model_engine,
|
||||
@ -549,11 +549,11 @@ def create_py_executor_instance(
|
||||
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
|
||||
*, max_seq_len: int, enable_mixed_sampler: bool):
|
||||
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
|
||||
max_draft_tokens = (0 if executor_config.speculative_config is None else
|
||||
executor_config.speculative_config.max_draft_tokens)
|
||||
max_draft_len = (0 if executor_config.speculative_config is None else
|
||||
executor_config.speculative_config.max_draft_len)
|
||||
return TorchSampler.Args(
|
||||
max_seq_len=max_seq_len,
|
||||
max_draft_tokens=max_draft_tokens,
|
||||
max_draft_len=max_draft_len,
|
||||
max_num_sequences=max_num_sequences,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
enable_mixed_sampler=enable_mixed_sampler,
|
||||
|
||||
@ -8,7 +8,6 @@ from ...llmapi.llm_args import LoadFormat
|
||||
from ...logger import logger
|
||||
from ...mapping import Mapping
|
||||
from ..model_config import MoeLoadBalancerConfig
|
||||
from ..speculative import SpecConfig
|
||||
from .resource_manager import BaseResourceManager
|
||||
|
||||
|
||||
@ -110,7 +109,7 @@ def update_executor_config(
|
||||
pytorch_backend_config: Optional[PyTorchConfig] = None,
|
||||
mapping: Optional[Mapping] = None,
|
||||
build_config: Optional[BuildConfig] = None,
|
||||
speculative_config: Optional[SpecConfig] = None,
|
||||
speculative_config: Optional["DecodingBaseConfig"] = None,
|
||||
hf_model_dir: Optional[str] = None,
|
||||
max_input_len: Optional[int] = None,
|
||||
max_seq_len: Optional[int] = None):
|
||||
|
||||
@ -53,7 +53,7 @@ class DecodingCUDAGraphRunner:
|
||||
# [CUDA graph spec decode padding]
|
||||
# We pad input IDs/position IDs to the maximum draft length (token per request).
|
||||
# We're forced to do this because we cannot reallocate inputs over many graph runs.
|
||||
token_per_request = spec_metadata.max_draft_tokens + 1 if spec_metadata is not None else 1
|
||||
token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1
|
||||
|
||||
# Using ones instead of zeros prevents NaNs in e.g. Deepseek
|
||||
self.input_ids = torch.ones((batch_size * token_per_request, ),
|
||||
|
||||
@ -51,7 +51,7 @@ from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode,
|
||||
run_concurrently, timing)
|
||||
from ..modules.fused_moe.moe_load_balancer import (
|
||||
MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer)
|
||||
from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata
|
||||
from ..speculative import SpecMetadata, get_spec_metadata
|
||||
from ..utils import (get_model_extra_attrs, set_torch_compiling,
|
||||
with_model_extra_attrs)
|
||||
from .config import LoadFormat, PyTorchConfig
|
||||
@ -353,7 +353,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
mapping: Optional[Mapping] = None,
|
||||
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
|
||||
dist: Optional[MPIDist] = None,
|
||||
spec_config: Optional[SpecConfig] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
is_draft_model: bool = False,
|
||||
@ -456,7 +456,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if self.is_spec_decode:
|
||||
self.spec_metadata = None
|
||||
self.spec_config.update_from_model_config(self.model.config)
|
||||
max_num_draft_tokens = self.spec_config.max_draft_tokens * batch_size
|
||||
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size
|
||||
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
@ -472,7 +472,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
device='cuda')
|
||||
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
|
||||
)
|
||||
self.max_draft_len = spec_config.max_draft_tokens
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
else:
|
||||
self.without_logits = False
|
||||
self.max_draft_len = 0
|
||||
@ -858,6 +858,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if no_cache:
|
||||
return get_spec_metadata(
|
||||
self.spec_config,
|
||||
self.model.config,
|
||||
self.batch_size,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
spec_resource_manager=spec_resource_manager,
|
||||
@ -867,6 +868,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
return self.spec_metadata
|
||||
self.spec_metadata = get_spec_metadata(
|
||||
self.spec_config,
|
||||
self.model.config,
|
||||
self.batch_size,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
spec_resource_manager=spec_resource_manager,
|
||||
@ -951,7 +953,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
def _maybe_get_cuda_graph(
|
||||
self,
|
||||
batch: ScheduledRequests,
|
||||
spec_config: Optional[SpecConfig] = None
|
||||
spec_config: Optional["DecodingBaseConfig"] = None
|
||||
) -> Optional[DecodingCUDAGraphRunner]:
|
||||
"""
|
||||
Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled
|
||||
@ -961,7 +963,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if ExpertStatistic.set_iter(self.iter_counter):
|
||||
return None
|
||||
|
||||
spec_max_draft_tokens = spec_config.max_draft_tokens if self.is_spec_decode else 0
|
||||
spec_max_draft_tokens = spec_config.max_draft_len if self.is_spec_decode else 0
|
||||
can_run_cuda_graph = batch.can_run_cuda_graph
|
||||
batch_size = len(batch.generation_requests)
|
||||
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
|
||||
@ -1078,7 +1080,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights(
|
||||
):
|
||||
weights = load_weights(self.spec_config.draft_model_path)
|
||||
weights = load_weights(
|
||||
self.spec_config.speculative_model_dir)
|
||||
model.load_draft_weights(weights)
|
||||
|
||||
elif load_format == LoadFormat.DUMMY:
|
||||
@ -1261,9 +1264,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
extend_requests += extend_dummy_requests
|
||||
|
||||
if not self._disable_overlap_scheduler and self.is_spec_decode:
|
||||
spec_dec_mode = self.spec_config.spec_dec_mode
|
||||
assert spec_dec_mode.support_overlap_scheduler(
|
||||
), f"{self.spec_config.spec_dec_name} does not support overlap scheduler"
|
||||
assert self.spec_config.spec_dec_mode.support_overlap_scheduler(
|
||||
), f"{self.spec_config.decoding_type} does not support overlap scheduler"
|
||||
|
||||
# will contain previous batch indices of generation requests
|
||||
previous_batch_indices = []
|
||||
@ -2078,7 +2080,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(),
|
||||
spec_metadata.is_spec_dec_tree,
|
||||
spec_metadata.is_spec_dec_dynamic_tree,
|
||||
spec_metadata.max_draft_tokens)
|
||||
spec_metadata.max_draft_len)
|
||||
else:
|
||||
spec_metadata = None
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ class PyExecutor:
|
||||
max_input_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_tokens: int = 0,
|
||||
max_draft_len: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
draft_model_engine: Optional[ModelEngine] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
@ -207,7 +207,7 @@ class PyExecutor:
|
||||
self.active = True
|
||||
self.next_req_id = max_batch_size # The first max_batch_size request IDs are reserved for dummy requests
|
||||
self.max_beam_width = max_beam_width
|
||||
self.max_draft_tokens = max_draft_tokens
|
||||
self.max_draft_len = max_draft_len
|
||||
self.print_log = model_engine.pytorch_backend_config.print_iter_log
|
||||
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
|
||||
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
|
||||
@ -979,7 +979,7 @@ class PyExecutor:
|
||||
LlmRequestState.DISAGG_GENERATION_INIT):
|
||||
continue
|
||||
req.py_last_draft_tokens = req.py_draft_tokens
|
||||
max_draft_len = self.model_engine.spec_config.max_draft_tokens
|
||||
max_draft_len = self.model_engine.spec_config.max_draft_len
|
||||
|
||||
if max_draft_len > 0:
|
||||
req.py_draft_tokens = [0] * max_draft_len
|
||||
@ -1523,7 +1523,7 @@ class PyExecutor:
|
||||
request_ids=[0],
|
||||
is_gen=not self.has_context_request,
|
||||
prepare_resource=not self.has_context_request,
|
||||
max_num_draft_tokens=self.max_draft_tokens,
|
||||
max_num_draft_tokens=self.max_draft_len,
|
||||
)[0]
|
||||
llm_request.is_attention_dp_dummy = True
|
||||
spec_resource_manager = self.resource_manager.get_resource_manager(
|
||||
@ -1871,15 +1871,15 @@ class PyExecutor:
|
||||
# this? Just needs proper kernel support.
|
||||
def _pad_to_max_draft_tokens():
|
||||
for req in scheduled_requests.generation_requests:
|
||||
max_draft_tokens = self.max_draft_tokens
|
||||
max_draft_len = self.max_draft_len
|
||||
num_draft_tokens = len(req.py_draft_tokens)
|
||||
req.py_draft_tokens.extend(
|
||||
0 for _ in range(max_draft_tokens - num_draft_tokens))
|
||||
0 for _ in range(max_draft_len - num_draft_tokens))
|
||||
|
||||
draft_batch.generation_requests = draft_batch.context_requests + draft_batch.generation_requests
|
||||
draft_batch.context_requests = []
|
||||
|
||||
for i in range(self.max_draft_tokens - 1):
|
||||
for i in range(self.max_draft_len - 1):
|
||||
if len(draft_batch.generation_requests) == 0:
|
||||
break
|
||||
|
||||
|
||||
@ -247,10 +247,10 @@ def create_py_executor(
|
||||
draft_spec_config = copy.copy(spec_config)
|
||||
# The draft model won't have any draft tokens attached to
|
||||
# generation requests when we invoke it autoregressively
|
||||
draft_spec_config.max_draft_tokens = 0
|
||||
draft_spec_config.max_draft_len = 0
|
||||
|
||||
draft_model_engine = PyTorchModelEngine(
|
||||
model_path=spec_config.draft_model_path,
|
||||
model_path=spec_config.speculative_model_dir,
|
||||
pytorch_backend_config=pytorch_backend_config,
|
||||
batch_size=executor_config.max_batch_size,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
@ -276,11 +276,11 @@ def create_py_executor(
|
||||
if not pytorch_backend_config.disable_overlap_scheduler:
|
||||
max_seq_len = model_engine.max_seq_len + 1
|
||||
if spec_config is not None:
|
||||
max_seq_len += spec_config.max_draft_tokens
|
||||
max_seq_len += spec_config.max_draft_len
|
||||
|
||||
if spec_config is not None:
|
||||
max_seq_len += spec_config.num_extra_kv_tokens
|
||||
max_seq_len += spec_config.max_draft_tokens
|
||||
max_seq_len += spec_config.max_draft_len
|
||||
|
||||
executor_config.max_seq_len = max_seq_len
|
||||
executor_config.max_num_tokens = model_engine.max_num_tokens
|
||||
|
||||
@ -2,7 +2,7 @@ import enum
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -22,9 +22,6 @@ if ENABLE_MULTI_DEVICE:
|
||||
|
||||
from tensorrt_llm._utils import mpi_comm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..speculative.interface import SpecConfig
|
||||
|
||||
BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
|
||||
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
|
||||
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
|
||||
@ -83,7 +80,7 @@ class BaseResourceManager(ABC):
|
||||
def get_pp_layers(
|
||||
num_layers: int,
|
||||
mapping: Mapping,
|
||||
spec_config: Optional["SpecConfig"] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
) -> Tuple[List[int], int]:
|
||||
from ..speculative.utils import get_num_spec_layers
|
||||
@ -127,7 +124,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
max_batch_size: int,
|
||||
mapping: Mapping,
|
||||
dtype: DataType = DataType.HALF,
|
||||
spec_config: Optional["SpecConfig"] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
max_num_tokens: int = 8192,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
@ -902,7 +899,7 @@ class MambaHybridCacheManager(KVCacheManager, MambaCacheManager):
|
||||
max_batch_size: int,
|
||||
mapping: Mapping,
|
||||
dtype: DataType = DataType.HALF,
|
||||
spec_config: Optional["SpecConfig"] = None,
|
||||
spec_config: Optional["DecodingBaseConfig"] = None,
|
||||
) -> None:
|
||||
|
||||
# mamba hybrid cache requires block reuse to be disabled in KV cache config
|
||||
|
||||
@ -220,7 +220,7 @@ class TorchSampler(Sampler):
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class Args:
|
||||
max_seq_len: int
|
||||
max_draft_tokens: int
|
||||
max_draft_len: int
|
||||
max_num_sequences: int
|
||||
max_beam_width: int
|
||||
enable_mixed_sampler: bool
|
||||
@ -228,7 +228,7 @@ class TorchSampler(Sampler):
|
||||
def __init__(self, args: Args):
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.enable_mixed_sampler = args.enable_mixed_sampler
|
||||
self.max_tokens = args.max_draft_tokens + 1
|
||||
self.max_tokens = args.max_draft_len + 1
|
||||
assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1"
|
||||
self.num_seq_slots = args.max_num_sequences
|
||||
|
||||
|
||||
@ -1,27 +1,20 @@
|
||||
from .draft_target import DraftTargetConfig
|
||||
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
|
||||
from .eagle3 import Eagle3SpecMetadata
|
||||
from .interface import SpecConfig, SpecMetadata
|
||||
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .ngram import NGramConfig, NGramDrafter, NGramPoolManager
|
||||
from .user_provided import UserProvidedConfig
|
||||
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .ngram import NGramDrafter, NGramPoolManager
|
||||
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_drafter,
|
||||
get_spec_metadata, get_spec_resource_manager,
|
||||
get_spec_worker)
|
||||
|
||||
__all__ = [
|
||||
"DraftTargetConfig",
|
||||
"Eagle3Config",
|
||||
"Eagle3SpecMetadata",
|
||||
"MTPConfig",
|
||||
"MTPEagleWorker",
|
||||
"MTPSpecMetadata",
|
||||
"MTPWorker",
|
||||
"NGramConfig",
|
||||
"NGramDrafter",
|
||||
"NGramPoolManager",
|
||||
"SpecConfig",
|
||||
"SpecMetadata",
|
||||
"UserProvidedConfig",
|
||||
"get_num_spec_layers",
|
||||
"get_spec_decoder",
|
||||
"get_spec_drafter",
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class DraftTargetConfig(SpecConfig):
|
||||
spec_dec_name: str = "DRAFT_TARGET"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.draft_model_path is None:
|
||||
raise ValueError("Path to Draft weights must be specified.")
|
||||
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
self.num_extra_kv_tokens = 0
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
pass
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
return input_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class DraftTargetSpecMetadata(SpecMetadata):
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
def prepare(self):
|
||||
pass
|
||||
@ -4,7 +4,6 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -12,42 +11,10 @@ from ..pyexecutor.llm_request import LlmRequest
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
||||
from ..pyexecutor.sampler import TorchSampler
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
||||
from .interface import SpecMetadata
|
||||
from .mtp import MTPSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3Config(SpecConfig):
|
||||
spec_dec_name: str = "EAGLE3"
|
||||
num_layers: int = 0
|
||||
hidden_size: int = 0
|
||||
eagle3_one_model: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.draft_model_path is None:
|
||||
raise ValueError("Path to EAGLE3 weights must be specified.")
|
||||
|
||||
if self.eagle3_one_model:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.EAGLE3_ONE_MODEL
|
||||
self.num_extra_kv_tokens = self.max_draft_tokens - 1
|
||||
else:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
logger.info(f"EAGLE3 Config: {self}")
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
self.num_layers = model_config.num_hidden_layers
|
||||
self.hidden_size = model_config.hidden_size
|
||||
self.dtype = model_config.torch_dtype
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Eagle3 always throws away the first token when processing draft inputs
|
||||
"""
|
||||
return input_tokens[1:]
|
||||
|
||||
|
||||
class Eagle3ResourceManager(BaseResourceManager):
|
||||
"""
|
||||
Eagle3 needs to save the hidden states for the draft model. When using
|
||||
@ -55,11 +22,11 @@ class Eagle3ResourceManager(BaseResourceManager):
|
||||
and one for the draft model. Use this class to manage the hidden states.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Eagle3Config, dtype: torch.dtype,
|
||||
def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype,
|
||||
hidden_size: int, max_num_requests: int, max_seq_len: int,
|
||||
max_num_tokens: int):
|
||||
self.dtype = dtype
|
||||
self.max_draft_tokens = config.max_draft_tokens
|
||||
self.max_draft_len = config.max_draft_len
|
||||
self.hidden_size = hidden_size
|
||||
self.max_num_requests = max_num_requests
|
||||
self.max_seq_len = max_seq_len
|
||||
@ -268,7 +235,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
|
||||
pin_memory=True)
|
||||
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
|
||||
non_blocking=True)
|
||||
self.num_tokens -= (self.num_generations) * self.max_draft_tokens
|
||||
self.num_tokens -= (self.num_generations) * self.max_draft_len
|
||||
|
||||
def maybe_capture_hidden_states(
|
||||
self,
|
||||
@ -288,15 +255,15 @@ class Eagle3OneModelSpecMetadata(SpecMetadata):
|
||||
class Eagle3OneModelSampler(MTPSampler):
|
||||
|
||||
def __init__(self, args: TorchSampler.Args):
|
||||
super().__init__(args, nextn=args.max_draft_tokens)
|
||||
super().__init__(args, nextn=args.max_draft_len)
|
||||
|
||||
|
||||
class Eagle3OneModelWorker(nn.Module):
|
||||
|
||||
def __init__(self, spec_config: Eagle3Config, mapping: Mapping):
|
||||
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
|
||||
super().__init__()
|
||||
self.spec_config = spec_config
|
||||
self.max_draft_tokens = self.spec_config.max_draft_tokens
|
||||
self.max_draft_len = self.spec_config.max_draft_len
|
||||
self.mapping = mapping
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs")
|
||||
@ -333,7 +300,7 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
|
||||
# Predict draft tokens
|
||||
next_draft_tokens = []
|
||||
for i in range(self.max_draft_tokens):
|
||||
for i in range(self.max_draft_len):
|
||||
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
|
||||
|
||||
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
|
||||
@ -344,7 +311,7 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
attn_metadata.use_spec_decoding = False
|
||||
if i == 0:
|
||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.max_draft_tokens + 1)).long()
|
||||
(self.max_draft_len + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
@ -374,8 +341,7 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.max_draft_tokens -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
self.max_draft_len - num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
@ -428,7 +394,7 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
logits = logits.unsqueeze(0)
|
||||
|
||||
# The return buffer
|
||||
accepted_tokens = torch.empty((batch_size, (self.max_draft_tokens + 1)),
|
||||
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
num_accepted_tokens = torch.ones(batch_size,
|
||||
@ -442,13 +408,13 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
|
||||
# generation
|
||||
gen_target_tokens = target_tokens[num_contexts:].reshape(
|
||||
num_gens, self.max_draft_tokens + 1)
|
||||
num_gens, self.max_draft_len + 1)
|
||||
accepted_tokens[num_contexts:, :] = gen_target_tokens
|
||||
draft_tokens = spec_metadata.draft_tokens.reshape(
|
||||
num_gens, self.max_draft_tokens)
|
||||
num_accepted_tokens[num_contexts:] += torch.cumprod((
|
||||
draft_tokens == gen_target_tokens[:, :self.max_draft_tokens]).int(),
|
||||
dim=-1).sum(1)
|
||||
num_gens, self.max_draft_len)
|
||||
num_accepted_tokens[num_contexts:] += torch.cumprod(
|
||||
(draft_tokens == gen_target_tokens[:, :self.max_draft_len]).int(),
|
||||
dim=-1).sum(1)
|
||||
return accepted_tokens, num_accepted_tokens
|
||||
|
||||
def draft_decoder(
|
||||
@ -468,7 +434,7 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
|
||||
Returns:
|
||||
draft_tokens: torch.Tensor
|
||||
[batch_size * max_draft_tokens]
|
||||
[batch_size * max_draft_len]
|
||||
Draft token ids. Flattened.
|
||||
'''
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
|
||||
from ..model_config import TConfig
|
||||
|
||||
|
||||
class SpeculativeDecodingMode(IntEnum):
|
||||
@ -110,32 +109,11 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
class SpecConfig:
|
||||
"""
|
||||
Configuration for speculative decoding.
|
||||
This class is deprecated, but thread-leak of pytest raises flaky error if removing it.
|
||||
TODO: remove this class safely.
|
||||
"""
|
||||
# The name of speculative decoding.
|
||||
spec_dec_name = None
|
||||
# The mode of speculative decoding.
|
||||
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
|
||||
# The max number of draft tokens
|
||||
max_draft_tokens: int = 1024
|
||||
# The path to the draft model
|
||||
draft_model_path: Optional[str] = None
|
||||
# The number of extra kv tokens
|
||||
num_extra_kv_tokens: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
|
||||
def update_from_model_config(self, model_config: TConfig):
|
||||
pass
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Override for spec dec modes that need to preprocess prompt
|
||||
tokens before passing them to the draft model.
|
||||
"""
|
||||
return input_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -146,7 +124,7 @@ class SpecMetadata:
|
||||
# The max number of requests in a single batch.
|
||||
max_num_requests: int
|
||||
# The max number of draft tokens.
|
||||
max_draft_tokens: int
|
||||
max_draft_len: int
|
||||
# The number of gen-phase sequences in the batch.
|
||||
num_generations: int = 0
|
||||
# Whether CUDA graph is enabled.
|
||||
@ -180,7 +158,7 @@ class SpecMetadata:
|
||||
# Some speculative decoding methods need to use different kv lengths for the
|
||||
# draft/target layers. But KVCacheManager can only support kv caches with the
|
||||
# same kv lengths for different layers. Add extra kv token in kv cache manager
|
||||
# to haddle this issue.
|
||||
# to handle this issue.
|
||||
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
|
||||
# The number of layers
|
||||
num_layers: int = 0
|
||||
|
||||
@ -10,7 +10,7 @@ from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager
|
||||
from ..pyexecutor.sampler import (SampleState, SampleStateTensors, TorchSampler,
|
||||
add_token, int_tensor)
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
||||
from .interface import SpecMetadata
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -25,52 +25,10 @@ class SampleStateMTP(SampleState):
|
||||
host: SampleStateTensorsMTP
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTPConfig(SpecConfig):
|
||||
"""
|
||||
Configuration for MTP.
|
||||
"""
|
||||
# The name of speculative decoding.
|
||||
spec_dec_name = "MTP"
|
||||
# The number of MTP modules
|
||||
num_nextn_predict_layers: int = 1
|
||||
# The number of max batch size
|
||||
max_batch_size: int = 8
|
||||
|
||||
# Whether to use relaxed acceptance during thinking phase for reasoning model
|
||||
use_relaxed_acceptance_for_thinking: bool = False
|
||||
# The top-N tokens are sampled from logits to obtain a candidate set.
|
||||
relaxed_topk: int = 1
|
||||
# The threshold to further filter the candidate set.
|
||||
# Filter out tokens with a large probability gap between the top-1 token's log probability.
|
||||
relaxed_delta: float = 0.
|
||||
|
||||
# Whether to use vanilla MTP
|
||||
use_mtp_vanilla: bool = False
|
||||
|
||||
# TODO: Hard code for DeepSeek R1
|
||||
# When encounter <think>, start thinking phase.
|
||||
# When encounter </think>, end thinking phase.
|
||||
# <think> [thinking phase] </think> [real output]
|
||||
BEGIN_THINKING_PHASE_TOKEN: int = 128798
|
||||
END_THINKING_PHASE_TOKEN: int = 128799
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
self.max_draft_tokens = self.num_nextn_predict_layers
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
assert self.num_nextn_predict_layers > 0
|
||||
if model_config.num_nextn_predict_layers == 1 and not self.use_mtp_vanilla:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.MTP_EAGLE
|
||||
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
|
||||
|
||||
|
||||
class MTPHiddenStatesManager(BaseResourceManager):
|
||||
|
||||
def __init__(self, config: MTPConfig, dtype: torch.dtype, hidden_size: int,
|
||||
max_num_requests: int):
|
||||
def __init__(self, config: "MTPDecodingConfig", dtype: torch.dtype,
|
||||
hidden_size: int, max_num_requests: int):
|
||||
self.dtype = dtype
|
||||
self.num_nextn_predict_layers = config.num_nextn_predict_layers
|
||||
self.hidden_size = hidden_size
|
||||
@ -199,8 +157,8 @@ class MTPSpecMetadata(SpecMetadata):
|
||||
pin_memory=True)
|
||||
self.batch_indices_cuda[:num_seqs].copy_(batch_indices,
|
||||
non_blocking=True)
|
||||
# MTP vanilla worker uses total max_draft_tokens input tokens in generation phase,
|
||||
# while MTP Eagle worker uses (max_draft_tokens + 1) input tokens in the 1st draft
|
||||
# MTP vanilla worker uses total max_draft_len input tokens in generation phase,
|
||||
# while MTP Eagle worker uses (max_draft_len + 1) input tokens in the 1st draft
|
||||
# forward and only one input token in the following draft forward.
|
||||
# This num_tokens is used to set the all_rank_num_tokens for attention dp.
|
||||
if not self.spec_dec_mode.is_mtp_eagle():
|
||||
@ -353,7 +311,7 @@ class MTPSampler(TorchSampler):
|
||||
|
||||
class MTPWorker(nn.Module):
|
||||
|
||||
def __init__(self, spec_config: MTPConfig):
|
||||
def __init__(self, spec_config: "MTPDecodingConfig"):
|
||||
super().__init__()
|
||||
self.spec_config = spec_config
|
||||
self.is_thop = False
|
||||
@ -742,7 +700,7 @@ class MTPWorker(nn.Module):
|
||||
|
||||
Returns:
|
||||
accepted_tokens: torch.Tensor
|
||||
[batch_size, (max_draft_tokens + 1)]
|
||||
[batch_size, (max_draft_len + 1)]
|
||||
Accepted token ids. Flattened.
|
||||
|
||||
num_accepted_tokens: torch.Tensor
|
||||
@ -942,7 +900,7 @@ class MTPWorker(nn.Module):
|
||||
Target model's hidden states.
|
||||
|
||||
accepted_tokens: torch.Tensor
|
||||
[batch_size, max_draft_tokens + 1]
|
||||
[batch_size, max_draft_len + 1]
|
||||
Accepted token ids. Flattened.
|
||||
|
||||
num_accepted_tokens: torch.Tensor
|
||||
@ -1080,7 +1038,7 @@ class MTPWorker(nn.Module):
|
||||
|
||||
Returns:
|
||||
draft_tokens: torch.Tensor
|
||||
[batch_size * max_draft_tokens]
|
||||
[batch_size * max_draft_len]
|
||||
Draft token ids. Flattened.
|
||||
'''
|
||||
|
||||
@ -1090,7 +1048,7 @@ class MTPWorker(nn.Module):
|
||||
|
||||
class MTPEagleWorker(MTPWorker):
|
||||
|
||||
def __init__(self, spec_config: MTPConfig):
|
||||
def __init__(self, spec_config: "MTPDecodingConfig"):
|
||||
super().__init__(spec_config)
|
||||
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
|
||||
from ordered_set import OrderedSet
|
||||
@ -9,34 +8,6 @@ from ..pyexecutor.llm_request import *
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .drafter import Drafter
|
||||
from .interface import SpecConfig, SpeculativeDecodingMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class NGramConfig(SpecConfig):
|
||||
"""
|
||||
Configuration for NGram drafter.
|
||||
"""
|
||||
# The name of speculative decoding.
|
||||
spec_dec_name = "NGRAM"
|
||||
|
||||
num_extra_kv_tokens: int = 0
|
||||
max_draft_tokens: int = 0
|
||||
|
||||
prompt_lookup_num_tokens: int = 5
|
||||
max_matching_ngram_size: int = 5
|
||||
end_id: int = -1
|
||||
is_keep_all: bool = True
|
||||
is_use_oldest: bool = True
|
||||
is_public_pool: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
self.max_draft_tokens = self.prompt_lookup_num_tokens
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
pass
|
||||
|
||||
|
||||
class NGramPoolManager(BaseResourceManager):
|
||||
@ -52,7 +23,7 @@ class NGramPoolManager(BaseResourceManager):
|
||||
`matches` is a list of candidate draft token ids attaching to a pattern.
|
||||
|
||||
Arguments:
|
||||
prompt_lookup_num_tokens: int
|
||||
max_draft_len: int
|
||||
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
|
||||
|
||||
max_matching_ngram_size: int
|
||||
@ -76,8 +47,9 @@ class NGramPoolManager(BaseResourceManager):
|
||||
It maps from request ID to the index of the prompt to update the pool in the next step.
|
||||
"""
|
||||
|
||||
def __init__(self, spec_config: SpecConfig, max_num_requests: int):
|
||||
self.prompt_lookup_num_tokens = spec_config.prompt_lookup_num_tokens
|
||||
def __init__(self, spec_config: "NGramDecodingConfig",
|
||||
max_num_requests: int):
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
self.max_matching_ngram_size = spec_config.max_matching_ngram_size
|
||||
self.is_keep_all = spec_config.is_keep_all
|
||||
self.is_use_oldest = spec_config.is_use_oldest # TODO: remove this if updating strategy is supported
|
||||
@ -133,7 +105,7 @@ class NGramPoolManager(BaseResourceManager):
|
||||
-1):
|
||||
# Find each possible pattern-match combination, and use tuple for hash
|
||||
for l in range(len(sequence) - size):
|
||||
r = min(l + size + self.prompt_lookup_num_tokens, len(sequence))
|
||||
r = min(l + size + self.max_draft_len, len(sequence))
|
||||
pattern = tuple(sequence[l:l + size])
|
||||
new_match = tuple(sequence[l + size:r])
|
||||
if pattern not in pool or \
|
||||
@ -165,7 +137,7 @@ class NGramPoolManager(BaseResourceManager):
|
||||
# Update start_index
|
||||
self.start_index[request_id] = max(
|
||||
0, prefix_len -
|
||||
(self.prompt_lookup_num_tokens + self.max_matching_ngram_size - 1))
|
||||
(self.max_draft_len + self.max_matching_ngram_size - 1))
|
||||
|
||||
return draft_tokens
|
||||
|
||||
@ -191,12 +163,12 @@ class NGramDrafter(Drafter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_config: SpecConfig,
|
||||
spec_config: "NGramDecodingConfig",
|
||||
ngram_pool_manager: NGramPoolManager = None,
|
||||
):
|
||||
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
|
||||
super().__init__(spec_resource_manager=ngram_pool_manager)
|
||||
self.max_num_draft_tokens = spec_config.max_draft_tokens
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
|
||||
def prepare_draft_tokens(
|
||||
self,
|
||||
@ -220,8 +192,8 @@ class NGramDrafter(Drafter):
|
||||
request.py_end_id,
|
||||
request.py_orig_prompt_len + request.py_max_new_tokens,
|
||||
)
|
||||
# Pad length to `self.max_num_draft_tokens`
|
||||
# Pad length to `self.max_draft_len`
|
||||
if len(draft_tokens) > 0:
|
||||
pad_length = self.max_num_draft_tokens - len(draft_tokens)
|
||||
pad_length = self.max_draft_len - len(draft_tokens)
|
||||
draft_tokens.extend([request.py_end_id] * pad_length)
|
||||
request.py_draft_tokens = draft_tokens
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from tensorrt_llm._torch.speculative.drafter import Drafter
|
||||
|
||||
from .interface import SpecConfig, SpeculativeDecodingMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserProvidedConfig(SpecConfig):
|
||||
"""
|
||||
Configuration for user provided speculative decoding.
|
||||
"""
|
||||
# The name of speculative decoding.
|
||||
spec_dec_name = "USER_PROVIDED"
|
||||
|
||||
num_extra_kv_tokens: int = 0
|
||||
max_draft_tokens: int = 0
|
||||
drafter: Optional[Drafter] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
pass
|
||||
@ -1,7 +1,6 @@
|
||||
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
|
||||
from tensorrt_llm._torch.speculative.interface import SpecConfig, SpecMetadata
|
||||
from tensorrt_llm._torch.speculative.interface import SpecMetadata
|
||||
|
||||
from .draft_target import DraftTargetSpecMetadata
|
||||
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
|
||||
Eagle3OneModelWorker, Eagle3ResourceManager,
|
||||
Eagle3SpecMetadata)
|
||||
@ -11,13 +10,14 @@ from .ngram import NGramDrafter, NGramPoolManager
|
||||
|
||||
|
||||
def get_spec_metadata(spec_config,
|
||||
model_config,
|
||||
max_num_requests,
|
||||
max_num_tokens,
|
||||
spec_resource_manager=None,
|
||||
is_draft_model=False):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return MTPSpecMetadata(
|
||||
max_draft_tokens=spec_config.max_draft_tokens,
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
mtp_num_modules=spec_config.num_nextn_predict_layers,
|
||||
max_num_requests=max_num_requests,
|
||||
@ -25,35 +25,30 @@ def get_spec_metadata(spec_config,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_eagle3():
|
||||
return Eagle3SpecMetadata(
|
||||
max_draft_tokens=spec_config.max_draft_tokens,
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=spec_config.num_layers,
|
||||
hidden_size=spec_config.hidden_size,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
hidden_size=model_config.hidden_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
dtype=spec_config.dtype,
|
||||
dtype=model_config.torch_dtype,
|
||||
is_draft_model=is_draft_model,
|
||||
eagle3_resource_manager=spec_resource_manager,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelSpecMetadata(
|
||||
max_draft_tokens=spec_config.max_draft_tokens,
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=spec_config.num_layers,
|
||||
hidden_size=spec_config.hidden_size,
|
||||
num_layers=model_config.num_hidden_layers,
|
||||
hidden_size=model_config.hidden_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_draft_target():
|
||||
return DraftTargetSpecMetadata(
|
||||
max_draft_tokens=spec_config.max_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_ngram(
|
||||
) or spec_config.spec_dec_mode.is_user_provided():
|
||||
if spec_config.spec_dec_mode.is_draft_target() or \
|
||||
spec_config.spec_dec_mode.is_ngram() or \
|
||||
spec_config.spec_dec_mode.is_user_provided():
|
||||
return SpecMetadata(
|
||||
max_draft_tokens=spec_config.max_draft_tokens,
|
||||
max_draft_len=spec_config.max_draft_len,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
)
|
||||
@ -104,7 +99,8 @@ def get_spec_resource_manager(model_engine,
|
||||
return None
|
||||
|
||||
|
||||
def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
|
||||
def get_spec_decoder(sampler_args: TorchSampler.Args,
|
||||
spec_config: "DecodingBaseConfig"):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return MTPSampler(sampler_args,
|
||||
nextn=spec_config.num_nextn_predict_layers)
|
||||
@ -133,18 +129,16 @@ def get_spec_drafter(model_engine):
|
||||
def get_num_spec_layers(spec_config):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return spec_config.num_nextn_predict_layers
|
||||
elif spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
def get_spec_worker(spec_config, mapping):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return MTPWorker(spec_config)
|
||||
elif spec_config.spec_dec_mode.is_mtp_eagle():
|
||||
if spec_config.spec_dec_mode.is_mtp_eagle():
|
||||
return MTPEagleWorker(spec_config)
|
||||
elif spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelWorker(spec_config, mapping)
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@ -222,7 +223,8 @@ class _ModelFormatKind(Enum):
|
||||
|
||||
class DecodingBaseConfig(BaseModel):
|
||||
max_draft_len: Optional[int] = None
|
||||
speculative_model: Optional[Union[str, Path]] = None
|
||||
speculative_model_dir: Optional[Union[str, Path]] = None
|
||||
num_extra_kv_tokens: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
@ -247,6 +249,35 @@ class DecodingBaseConfig(BaseModel):
|
||||
def _check_fields(self):
|
||||
pass
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
"""
|
||||
Override if the speculation algorithm does not support
|
||||
a subset of the possible backends.
|
||||
"""
|
||||
return True
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Do any additional error checking here.
|
||||
"""
|
||||
|
||||
@functools.cached_property
|
||||
def spec_dec_mode(self):
|
||||
# spec_dec_mode has more functionality than the raw decoding_mode string.
|
||||
# Use an alias for the import here to avoid name collisions with the one for the
|
||||
# TRT backend.
|
||||
from tensorrt_llm._torch.speculative.interface import \
|
||||
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
|
||||
return TorchSpeculativeDecodingMode.from_string(
|
||||
self.decoding_type.upper())
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
pass
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
return input_tokens
|
||||
|
||||
|
||||
class MedusaDecodingConfig(DecodingBaseConfig):
|
||||
medusa_choices: Optional[List[List[int]]] = None
|
||||
@ -258,6 +289,9 @@ class MedusaDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
decoding_type: ClassVar[str] = "Medusa"
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend not in ("pytorch", "_autodeploy")
|
||||
|
||||
|
||||
class EagleDecodingConfig(DecodingBaseConfig):
|
||||
eagle_choices: Optional[List[List[int]]] = None
|
||||
@ -267,7 +301,6 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
dynamic_tree_max_topK: Optional[int] = None
|
||||
num_eagle_layers: Optional[int] = None
|
||||
max_non_leaves_per_layer: Optional[int] = None
|
||||
pytorch_weights_path: Optional[str] = None
|
||||
eagle3_one_model: Optional[bool] = True
|
||||
|
||||
@classmethod
|
||||
@ -276,6 +309,25 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
decoding_type: ClassVar[str] = "Eagle"
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.speculative_model_dir is None:
|
||||
raise ValueError("Draft model must be provided for EAGLE")
|
||||
|
||||
@functools.cached_property
|
||||
def spec_dec_mode(self):
|
||||
from tensorrt_llm._torch.speculative.interface import \
|
||||
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
|
||||
if self.eagle3_one_model:
|
||||
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
|
||||
return TorchSpeculativeDecodingMode.EAGLE3
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Eagle3 always throws away the first token when processing draft inputs
|
||||
"""
|
||||
return input_tokens[1:]
|
||||
|
||||
|
||||
class UserProvidedDecodingConfig(DecodingBaseConfig):
|
||||
# Type should be Drafter, but it leads to circular import
|
||||
@ -285,7 +337,7 @@ class UserProvidedDecodingConfig(DecodingBaseConfig):
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
decoding_type: ClassVar[str] = "UserProvided"
|
||||
decoding_type: ClassVar[str] = "User_Provided"
|
||||
|
||||
|
||||
class NGramDecodingConfig(DecodingBaseConfig):
|
||||
@ -293,7 +345,7 @@ class NGramDecodingConfig(DecodingBaseConfig):
|
||||
Configuration for NGram drafter speculative decoding.
|
||||
|
||||
Arguments:
|
||||
prompt_lookup_num_tokens: int
|
||||
max_draft_len: int
|
||||
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
|
||||
|
||||
max_matching_ngram_size: int
|
||||
@ -309,7 +361,6 @@ class NGramDecodingConfig(DecodingBaseConfig):
|
||||
Whether to use a common pool for all requests, or the pool is private for each request if False.
|
||||
"""
|
||||
|
||||
prompt_lookup_num_tokens: int = 2
|
||||
max_matching_ngram_size: int = 4
|
||||
is_keep_all: bool = True
|
||||
is_use_oldest: bool = True
|
||||
@ -321,15 +372,20 @@ class NGramDecodingConfig(DecodingBaseConfig):
|
||||
|
||||
decoding_type: ClassVar[str] = "NGram"
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend == "pytorch"
|
||||
|
||||
|
||||
class DraftTargetDecodingConfig(DecodingBaseConfig):
|
||||
pytorch_weights_path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
decoding_type: ClassVar[str] = "DraftTarget"
|
||||
decoding_type: ClassVar[str] = "Draft_Target"
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend == "pytorch"
|
||||
|
||||
|
||||
class MTPDecodingConfig(DecodingBaseConfig):
|
||||
@ -339,12 +395,39 @@ class MTPDecodingConfig(DecodingBaseConfig):
|
||||
relaxed_delta: Optional[float] = 0.
|
||||
use_mtp_vanilla: Optional[bool] = False
|
||||
|
||||
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
|
||||
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
|
||||
num_nextn_predict_layers_from_model_config: Optional[int] = 1
|
||||
|
||||
# TODO: Hard code for DeepSeek R1
|
||||
# When encounter <think>, start thinking phase.
|
||||
# When encounter </think>, end thinking phase.
|
||||
# <think> [thinking phase] </think> [real output]
|
||||
BEGIN_THINKING_PHASE_TOKEN: int = 128798
|
||||
END_THINKING_PHASE_TOKEN: int = 128799
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
decoding_type: ClassVar[str] = "MTP"
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend == "pytorch"
|
||||
|
||||
@functools.cached_property
|
||||
def spec_dec_mode(self):
|
||||
from tensorrt_llm._torch.speculative.interface import \
|
||||
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
|
||||
if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla:
|
||||
return TorchSpeculativeDecodingMode.MTP_EAGLE
|
||||
return TorchSpeculativeDecodingMode.MTP
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
assert self.num_nextn_predict_layers > 0
|
||||
if model_config.num_nextn_predict_layers == 1 and not self.use_mtp_vanilla:
|
||||
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
|
||||
|
||||
|
||||
class PybindMirror(ABC):
|
||||
''' A class containing the utilities for mirroring Python classes to
|
||||
@ -635,6 +718,9 @@ class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror):
|
||||
self.max_ngram_size,
|
||||
self.max_verification_set_size)
|
||||
|
||||
def supports_backend(self, backend: str) -> bool:
|
||||
return backend not in ("pytorch", "_autodeploy")
|
||||
|
||||
decoding_type: ClassVar[str] = "Lookahead"
|
||||
|
||||
|
||||
@ -1037,7 +1123,7 @@ class BaseLlmArgs(BaseModel):
|
||||
return self._model_format
|
||||
|
||||
@property
|
||||
def speculative_model(self) -> Optional[_ModelFormatKind]:
|
||||
def speculative_model_dir(self) -> Optional[_ModelFormatKind]:
|
||||
return self._speculative_model
|
||||
|
||||
@property
|
||||
@ -1314,33 +1400,40 @@ class BaseLlmArgs(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def validate_speculative_config(self):
|
||||
if self.speculative_config:
|
||||
if isinstance(self.speculative_config, LookaheadDecodingConfig):
|
||||
lookahead_config = self.speculative_config
|
||||
# Update the build config
|
||||
_, _, max_draft_tokens, _ = lookahead_config.calculate_speculative_resource(
|
||||
)
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.LOOKAHEAD_DECODING
|
||||
if max_draft_tokens > self.build_config.max_draft_len:
|
||||
self.build_config.max_draft_len = max_draft_tokens
|
||||
if not self.speculative_config.supports_backend(self.backend):
|
||||
raise ValueError(
|
||||
f"Speculation type {self.speculative_config.decoding_type} does not "
|
||||
f"support backend {self.backend}")
|
||||
|
||||
# Below, we only need to set speculative_decoding_mode/decoding_config for speculation
|
||||
# on the TRT backend.
|
||||
if isinstance(self.speculative_config, LookaheadDecodingConfig):
|
||||
max_draft_len = self.speculative_config.calculate_speculative_resource(
|
||||
)[2]
|
||||
assert max_draft_len > 0
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.LOOKAHEAD_DECODING
|
||||
self.build_config.max_draft_len = max(
|
||||
self.build_config.max_draft_len, max_draft_len)
|
||||
self.decoding_config = DecodingConfig(
|
||||
decoding_mode=DecodingMode.Lookahead(),
|
||||
lookahead_decoding_config=PybindMirror.maybe_to_pybind(
|
||||
lookahead_config))
|
||||
elif isinstance(self.speculative_config, MedusaDecodingConfig):
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.MEDUSA
|
||||
self.speculative_config))
|
||||
|
||||
elif isinstance(self.speculative_config, MedusaDecodingConfig):
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.MEDUSA
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
self.decoding_config = DecodingConfig(
|
||||
decoding_mode=DecodingMode.Medusa(),
|
||||
medusa_choices=self.speculative_config.medusa_choices)
|
||||
|
||||
elif isinstance(self.speculative_config, EagleDecodingConfig):
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
|
||||
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
|
||||
if self.speculative_config.eagle3_one_model:
|
||||
self.num_extra_kv_tokens = self.max_draft_len - 1
|
||||
if self.backend not in ['pytorch', '_autodeploy']:
|
||||
eagle_config = _EagleConfig(
|
||||
self.speculative_config.eagle_choices,
|
||||
@ -1351,68 +1444,39 @@ class BaseLlmArgs(BaseModel):
|
||||
self.decoding_config = DecodingConfig(
|
||||
decoding_mode=DecodingMode.Eagle(),
|
||||
eagle_config=eagle_config)
|
||||
else:
|
||||
from tensorrt_llm._torch.speculative import Eagle3Config
|
||||
self.speculative_config = Eagle3Config(
|
||||
max_draft_tokens=self.speculative_config.max_draft_len,
|
||||
draft_model_path=self.speculative_config.
|
||||
pytorch_weights_path,
|
||||
eagle3_one_model=self.speculative_config.
|
||||
eagle3_one_model)
|
||||
|
||||
elif isinstance(self.speculative_config, NGramDecodingConfig):
|
||||
assert self.backend in ['pytorch', '_autodeploy']
|
||||
assert self.speculative_config.prompt_lookup_num_tokens > 0 and self.speculative_config.max_matching_ngram_size > 0
|
||||
assert self.speculative_config.max_draft_len > 0 and self.speculative_config.max_matching_ngram_size > 0
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.NGRAM
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
from tensorrt_llm._torch.speculative import NGramConfig
|
||||
self.speculative_config = NGramConfig(
|
||||
prompt_lookup_num_tokens=self.speculative_config.
|
||||
prompt_lookup_num_tokens,
|
||||
max_matching_ngram_size=self.speculative_config.
|
||||
max_matching_ngram_size,
|
||||
is_keep_all=self.speculative_config.is_keep_all,
|
||||
is_use_oldest=self.speculative_config.is_use_oldest,
|
||||
is_public_pool=self.speculative_config.is_public_pool,
|
||||
)
|
||||
|
||||
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
|
||||
assert self.backend == 'pytorch'
|
||||
assert self.backend in ['pytorch']
|
||||
assert self.speculative_config.max_draft_len > 0
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
from tensorrt_llm._torch.speculative import DraftTargetConfig
|
||||
self.speculative_config = DraftTargetConfig(
|
||||
max_draft_tokens=self.speculative_config.max_draft_len,
|
||||
draft_model_path=self.speculative_config.
|
||||
pytorch_weights_path)
|
||||
|
||||
elif isinstance(self.speculative_config, MTPDecodingConfig):
|
||||
from tensorrt_llm._torch.speculative import MTPConfig
|
||||
self.speculative_config = MTPConfig(
|
||||
num_nextn_predict_layers=self.speculative_config.
|
||||
num_nextn_predict_layers,
|
||||
max_batch_size=self.build_config.max_batch_size,
|
||||
use_relaxed_acceptance_for_thinking=self.speculative_config.
|
||||
use_relaxed_acceptance_for_thinking,
|
||||
relaxed_topk=self.speculative_config.relaxed_topk,
|
||||
relaxed_delta=self.speculative_config.relaxed_delta,
|
||||
use_mtp_vanilla=self.speculative_config.use_mtp_vanilla)
|
||||
assert self.speculative_config.num_nextn_predict_layers > 0
|
||||
self.speculative_config.max_draft_len = self.speculative_config.num_nextn_predict_layers
|
||||
|
||||
elif isinstance(self.speculative_config,
|
||||
UserProvidedDecodingConfig):
|
||||
assert self.backend in ['pytorch', '_autodeploy']
|
||||
from tensorrt_llm._torch.speculative import UserProvidedConfig
|
||||
self.speculative_config = UserProvidedConfig(
|
||||
max_draft_tokens=self.speculative_config.max_draft_len,
|
||||
drafter=self.speculative_config.drafter)
|
||||
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.USER_PROVIDED
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_tokens
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Speculative config type not recognized: {self.speculative_config}"
|
||||
f"Unrecognized speculative config type {type(self.speculative_config)}"
|
||||
)
|
||||
|
||||
else:
|
||||
self.decoding_config = None
|
||||
|
||||
self._speculative_model = getattr(self.speculative_config,
|
||||
"speculative_model", None)
|
||||
"speculative_model_dir", None)
|
||||
speculative_model_obj = _ModelWrapper(
|
||||
self._speculative_model
|
||||
) if self._speculative_model is not None else None
|
||||
|
||||
@ -110,8 +110,8 @@ class ModelLoader:
|
||||
|
||||
self.model_obj = _ModelWrapper(self.llm_args.model)
|
||||
self.speculative_model_obj = _ModelWrapper(
|
||||
self.llm_args.speculative_model
|
||||
) if self.llm_args.speculative_model is not None else None
|
||||
self.llm_args.speculative_model_dir
|
||||
) if self.llm_args.speculative_model_dir is not None else None
|
||||
|
||||
if isinstance(self.llm_args, TrtLlmArgs):
|
||||
self.convert_checkpoint_options = self.llm_args._convert_checkpoint_options
|
||||
@ -440,8 +440,8 @@ class ModelLoader:
|
||||
model_cls = AutoModelForCausalLM.get_trtllm_model_class(
|
||||
self._model_dir, self.llm_args.trust_remote_code,
|
||||
self.llm_args.decoding_config.decoding_mode
|
||||
if hasattr(self.llm_args, "speculative_model")
|
||||
and self.llm_args.speculative_model else None)
|
||||
if hasattr(self.llm_args, "speculative_model_dir")
|
||||
and self.llm_args.speculative_model_dir else None)
|
||||
|
||||
prequantized = self._update_from_hf_quant_config()
|
||||
|
||||
@ -484,7 +484,7 @@ class ModelLoader:
|
||||
load_model_on_cpu=
|
||||
True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and chose best location
|
||||
trust_remote_code=self.llm_args.trust_remote_code,
|
||||
speculative_model=self._speculative_model_dir,
|
||||
speculative_model_dir=self._speculative_model_dir,
|
||||
speculative_config=self.llm_args.speculative_config
|
||||
if not isinstance(self.llm_args.speculative_config,
|
||||
LookaheadDecodingConfig) else None,
|
||||
|
||||
@ -59,7 +59,7 @@ class EagleConfig(LLaMAConfig):
|
||||
**kwargs):
|
||||
import transformers
|
||||
trust_remote_code = kwargs.pop('trust_remote_code', True)
|
||||
speculative_config_or_dir = kwargs.pop('speculative_model', None)
|
||||
speculative_config_or_dir = kwargs.pop('speculative_model_dir', None)
|
||||
|
||||
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):
|
||||
hf_config = hf_config_or_dir
|
||||
|
||||
@ -948,11 +948,11 @@ class EagleForCausalLM(LLaMAForCausalLM):
|
||||
spec_decoding_position_offsets: [bs, max_gen_tokens]
|
||||
spec_decoding_packed_mask: [bs, max_draft_len, packed_length] **
|
||||
eagle_temperature: [bs]
|
||||
rand_data_validation: [bs, max_draft_tokens]
|
||||
rand_data_validation: [bs, max_draft_len]
|
||||
|
||||
** The mask is tricky since the boolean mask will need to be
|
||||
packed in runtime. So, the last dim will be:
|
||||
packed_length = ceil((max_draft_tokens+1)/32)
|
||||
packed_length = ceil((max_draft_len+1)/32)
|
||||
"""
|
||||
default_range = GenerationMixin.default_range
|
||||
remove_input_padding = default_net().plugin_config.remove_input_padding
|
||||
@ -1228,7 +1228,7 @@ class EagleForCausalLM(LLaMAForCausalLM):
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
**kwargs):
|
||||
assert hf_model_or_dir is not None
|
||||
speculative_model_dir = kwargs.get('speculative_model', None)
|
||||
speculative_model_dir = kwargs.get('speculative_model_dir', None)
|
||||
tllm_config = EagleConfig.from_hugging_face(hf_model_or_dir,
|
||||
dtype=dtype,
|
||||
mapping=mapping,
|
||||
|
||||
@ -70,7 +70,7 @@ class MedusaConfig(PretrainedConfig):
|
||||
import transformers
|
||||
|
||||
trust_remote_code = kwargs.pop('trust_remote_code', True)
|
||||
speculative_config_or_dir = kwargs.pop('speculative_model', None)
|
||||
speculative_config_or_dir = kwargs.pop('speculative_model_dir', None)
|
||||
speculative_config = kwargs.pop("speculative_config", None)
|
||||
|
||||
if isinstance(hf_config_or_dir, transformers.PretrainedConfig):
|
||||
|
||||
@ -191,7 +191,7 @@ class MedusaForCausalLm(PretrainedModel):
|
||||
import transformers
|
||||
|
||||
assert hf_model_or_dir is not None
|
||||
speculative_model_dir = kwargs.get('speculative_model', None)
|
||||
speculative_model_dir = kwargs.get('speculative_model_dir', None)
|
||||
|
||||
use_preloading = isinstance(hf_model_or_dir,
|
||||
transformers.PreTrainedModel)
|
||||
|
||||
@ -179,15 +179,15 @@ class ReDrafterMixin:
|
||||
bb_range = default_range(max_batch_size)
|
||||
bb0_range = default_range(max_batch_size, min_range=0, opt_offset=1)
|
||||
num_beam_tokens = self.num_beams * self.beam_length
|
||||
max_draft_tokens = num_beam_tokens - self.num_beams # ignore the true token
|
||||
max_gen_token_len = 1 + max_draft_tokens # for the true token
|
||||
max_draft_len = num_beam_tokens - self.num_beams # ignore the true token
|
||||
max_gen_token_len = 1 + max_draft_len # for the true token
|
||||
max_gen_token_len_range = default_range(max_gen_token_len)
|
||||
bb_max_gen_token_len_range = default_range(max_gen_token_len *
|
||||
max_batch_size,
|
||||
min_range=0)
|
||||
|
||||
kwargs['speculative_decoding_draft_tokens_external'] = False
|
||||
kwargs['max_draft_len'] = max_draft_tokens
|
||||
kwargs['max_draft_len'] = max_draft_len
|
||||
kwargs['spec_decoding_is_generation_length_variable'] = True
|
||||
inputs = super().prepare_inputs(*args, **kwargs)
|
||||
assert inputs['spec_decoding_params'] is not None
|
||||
|
||||
@ -157,6 +157,8 @@ class AccuracyTask:
|
||||
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
|
||||
spec_dec_algo = llm.args.speculative_config.decoding_type
|
||||
elif isinstance(llm.args.speculative_config, SpecConfig):
|
||||
# This branch is deprecated, but thread-leak of pytest raises flaky error if removing it.
|
||||
# TODO: remove this branch safely.
|
||||
spec_dec_algo = llm.args.speculative_config.spec_dec_name
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@ -214,7 +214,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
def test_ngram(self):
|
||||
speculative_decoding_config = {
|
||||
"decoding_type": "NGram",
|
||||
"prompt_lookup_num_tokens": 4,
|
||||
"max_draft_len": 4,
|
||||
"max_matching_ngram_size": 4,
|
||||
"is_keep_all": True,
|
||||
"is_use_oldest": True,
|
||||
|
||||
@ -393,7 +393,7 @@ class TestEagleVicuna_7B_v1_3(LlmapiAccuracyTestHarness):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
@ -419,7 +419,7 @@ class TestEagle2Vicuna_7B_v1_3(LlmapiAccuracyTestHarness):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
speculative_model_dir=f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3",
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
use_dynamic_tree=True,
|
||||
|
||||
@ -243,7 +243,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
pytorch_weights_path=eagle_model_dir)
|
||||
speculative_model_dir=eagle_model_dir)
|
||||
|
||||
llm = LLM(model=target_model_dir,
|
||||
**pytorch_config,
|
||||
@ -262,7 +262,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
|
||||
|
||||
draft_len = 4
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=draft_len,
|
||||
max_draft_len=draft_len,
|
||||
max_matching_ngram_size=draft_len,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
|
||||
@ -18,7 +18,7 @@ generation_servers:
|
||||
- "localhost:8002"
|
||||
speculative_config:
|
||||
decoding_type: NGram
|
||||
prompt_lookup_num_tokens: 4
|
||||
max_draft_len: 4
|
||||
max_matching_ngram_size: 4
|
||||
is_keep_all: True
|
||||
is_use_oldest: True
|
||||
|
||||
@ -22,45 +22,43 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
|
||||
pytest.skip("Not enough memory to load target model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False, max_tokens=2080)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=32,
|
||||
temperature=0,
|
||||
)
|
||||
max_batch_size = 1
|
||||
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
draft_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
max_batch_size = 2
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
backend='pytorch',
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=2048,
|
||||
)
|
||||
|
||||
draft_len = 4
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=draft_len, pytorch_weights_path=draft_model_dir)
|
||||
llm_spec = LLM(model=target_model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=CudaGraphConfig(
|
||||
cuda_graph_batch_sizes=[1]) if use_cuda_graph else None,
|
||||
attn_backend=attn_backend,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config)
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=draft_model_dir,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is", "The president of the United States is"
|
||||
"The capital of France is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=32)
|
||||
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [result.outputs[0].text for result in results_spec]
|
||||
llm_spec.shutdown()
|
||||
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
disable_overlap_scheduler=True,
|
||||
cuda_graph_config=CudaGraphConfig(
|
||||
cuda_graph_batch_sizes=[1]) if use_cuda_graph else None,
|
||||
attn_backend=attn_backend,
|
||||
kv_cache_config=kv_cache_config)
|
||||
|
||||
llm_ref = LLM(**llm_common_config)
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
llm_ref.shutdown()
|
||||
|
||||
@ -32,81 +32,70 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
# Only create a single CUDA graph to prevent OOM in CI
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, )
|
||||
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=draft_len,
|
||||
pytorch_weights_path=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model)
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 1
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
free_gpu_memory_fraction=0.5)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
llm_spec = LLM(
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
**pytorch_config,
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size=1,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
# This max_seq_len is larger than the one specified
|
||||
# in the llama 3 8B eagle's config. We want to make sure
|
||||
# that the draft model won't go above its max in warmup
|
||||
# in this test.
|
||||
max_seq_len=8192,
|
||||
kv_cache_config=kv_cache_config,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
speculative_config=spec_config)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# First make sure the acceptance rate is reasonable.
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
# Acceptance rate tests
|
||||
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||
num_tokens = 0
|
||||
|
||||
num_drafted = 0
|
||||
num_accepted = 0
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=128, temperature=0)
|
||||
for output in llm_spec.generate_async(tok_ids,
|
||||
SamplingParams(max_tokens=128,
|
||||
temperature=0),
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
beam = output.outputs[0]
|
||||
new_tokens = beam.token_ids
|
||||
|
||||
num_drafted += draft_len
|
||||
new_tokens = output.outputs[0].token_ids
|
||||
num_drafted += max_draft_len
|
||||
num_accepted += len(new_tokens) - num_tokens - 1
|
||||
|
||||
num_tokens = len(new_tokens)
|
||||
|
||||
accept_rate = num_accepted / num_drafted
|
||||
assert accept_rate > 0.15
|
||||
|
||||
# Output tests
|
||||
prompts = [
|
||||
"The capital of France is", "The president of the United States is"
|
||||
"The capital of France is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0)
|
||||
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [result.outputs[0].text for result in results_spec]
|
||||
llm_spec.shutdown()
|
||||
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
**pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
llm_ref = LLM(**llm_common_config)
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
llm_ref.shutdown()
|
||||
|
||||
@ -6,9 +6,9 @@ from parameterized import parameterized
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.speculative.mtp import (MTPConfig,
|
||||
MTPHiddenStatesManager,
|
||||
from tensorrt_llm._torch.speculative.mtp import (MTPHiddenStatesManager,
|
||||
MTPSpecMetadata, MTPWorker)
|
||||
from tensorrt_llm.llmapi import MTPDecodingConfig
|
||||
|
||||
|
||||
def unittest_name_func(testcase_func, param_num, param):
|
||||
@ -40,15 +40,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
device="cuda") # [num_tokens, vocab_size]
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
[], dtype=torch.int, device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([0], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 0]], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([1],
|
||||
dtype=torch.int,
|
||||
@ -74,15 +73,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
device="cuda") # [num_tokens, vocab_size]
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
[], dtype=torch.int, device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([0, 0, 0, 0], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 0], [3, 0], [3, 0], [6, 0]], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([1, 1, 1, 1],
|
||||
dtype=torch.int,
|
||||
@ -111,14 +109,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[1, 3, 4], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([3], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 3, 2, 0]], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([3],
|
||||
dtype=torch.int,
|
||||
@ -147,14 +145,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[1, 5], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([1, 1], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 3], [4, 0]], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([2, 1],
|
||||
dtype=torch.int,
|
||||
@ -187,14 +185,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[1, 3, 4, 4, 7, 3], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([3, 3], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 3, 2, 0], [4, 6, 0, 0]], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([3, 2],
|
||||
dtype=torch.int,
|
||||
@ -231,7 +229,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[1, 3, 5, 4, 6, 5, 5, 7, 4], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([3, 3, 3], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
@ -239,7 +237,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 3, 2, 0], [4, 6, 5, 2], [4, 0, 0, 0]],
|
||||
dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([3, 4, 1],
|
||||
dtype=torch.int,
|
||||
@ -267,15 +265,14 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
device="cuda") # [num_tokens, vocab_size]
|
||||
|
||||
draft_tokens = torch.tensor(
|
||||
[4], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
[4], dtype=torch.int, device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
draft_len = torch.tensor([0, 1], dtype=torch.int,
|
||||
device="cuda") # [batch_size]
|
||||
|
||||
ref_accepted_tokens = torch.tensor(
|
||||
[[1, 0], [4, 6]], dtype=torch.int,
|
||||
device="cuda") # [batch_size * max_draft_tokens]
|
||||
device="cuda") # [batch_size * max_draft_len]
|
||||
|
||||
ref_num_accepted_tokens = torch.tensor([1, 2],
|
||||
dtype=torch.int,
|
||||
@ -297,7 +294,8 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
ref_accepted_tokens,
|
||||
ref_num_accepted_tokens):
|
||||
batch_size = len(draft_len)
|
||||
spec_config = MTPConfig(num_nextn_predict_layers=mtp_num_modules)
|
||||
spec_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=mtp_num_modules)
|
||||
|
||||
# attention metedata
|
||||
attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size,
|
||||
@ -310,7 +308,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
|
||||
# speculative decoding metadata
|
||||
spec_metadata = MTPSpecMetadata(max_num_requests=32,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_draft_tokens=mtp_num_modules,
|
||||
max_draft_len=mtp_num_modules,
|
||||
mtp_num_modules=mtp_num_modules)
|
||||
spec_metadata.draft_tokens = draft_tokens
|
||||
|
||||
@ -871,7 +869,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase):
|
||||
batch_size = len(request_ids)
|
||||
batch_size - num_context_request
|
||||
hidden_size = hidden_states.shape[1]
|
||||
spec_config = MTPConfig(
|
||||
spec_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=num_nextn_predict_layers)
|
||||
|
||||
attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size,
|
||||
@ -892,7 +890,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase):
|
||||
spec_metadata = MTPSpecMetadata(
|
||||
max_num_requests=32,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_draft_tokens=num_nextn_predict_layers,
|
||||
max_draft_len=num_nextn_predict_layers,
|
||||
mtp_num_modules=num_nextn_predict_layers,
|
||||
mtp_hidden_states_manager=spec_manager)
|
||||
spec_metadata.request_ids = request_ids
|
||||
@ -1362,7 +1360,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase):
|
||||
hidden_size = previous_layer_hidden_states.shape[1]
|
||||
else:
|
||||
hidden_size = 10
|
||||
spec_config = MTPConfig(
|
||||
spec_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=num_nextn_predict_layers)
|
||||
|
||||
if attn_metadata is None:
|
||||
@ -1387,7 +1385,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase):
|
||||
spec_metadata = MTPSpecMetadata(
|
||||
max_num_requests=32,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_draft_tokens=num_nextn_predict_layers,
|
||||
max_draft_len=num_nextn_predict_layers,
|
||||
mtp_num_modules=num_nextn_predict_layers,
|
||||
mtp_hidden_states_manager=spec_manager)
|
||||
spec_metadata.request_ids = request_ids
|
||||
|
||||
@ -24,6 +24,10 @@ def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool,
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 20:
|
||||
pytest.skip("Not enough memory to load target model")
|
||||
|
||||
max_batch_size = 2
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
@ -33,13 +37,13 @@ def test_llama_ngram(disable_overlap_scheduler: bool, use_cuda_graph: bool,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=4,
|
||||
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=2048,
|
||||
)
|
||||
|
||||
spec_config = NGramDecodingConfig(
|
||||
prompt_lookup_num_tokens=4,
|
||||
max_draft_len=max_draft_len,
|
||||
max_matching_ngram_size=2,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
|
||||
@ -6,9 +6,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.speculative.ngram import (NGramConfig, NGramDrafter,
|
||||
NGramPoolManager)
|
||||
from tensorrt_llm._torch.speculative.ngram import NGramDrafter, NGramPoolManager
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, KvCacheConfig,
|
||||
NGramDecodingConfig,
|
||||
UserProvidedDecodingConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
@ -25,11 +25,13 @@ def test_llama_user_provided(disable_overlap_scheduler: bool,
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 20:
|
||||
pytest.skip("Not enough memory to load target model")
|
||||
|
||||
max_batch_size = 2
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
max_batch_size = 4
|
||||
|
||||
llm_common_config = dict( \
|
||||
model=llm_models_root() / "llama-3.1-model" /"Meta-Llama-3.1-8B",
|
||||
backend='pytorch',
|
||||
@ -37,27 +39,32 @@ def test_llama_user_provided(disable_overlap_scheduler: bool,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_num_tokens=2048,
|
||||
)
|
||||
|
||||
draft_len = 4
|
||||
|
||||
ngram_config = NGramConfig(
|
||||
prompt_lookup_num_tokens=draft_len,
|
||||
ngram_config = NGramDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
max_matching_ngram_size=2,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
)
|
||||
|
||||
drafter = NGramDrafter(spec_config=ngram_config,
|
||||
ngram_pool_manager=NGramPoolManager(
|
||||
spec_config=ngram_config,
|
||||
max_num_requests=max_batch_size))
|
||||
ngram_pool_manager = NGramPoolManager(
|
||||
spec_config=ngram_config,
|
||||
max_num_requests=max_batch_size,
|
||||
)
|
||||
|
||||
spec_config = UserProvidedDecodingConfig(max_draft_len=draft_len,
|
||||
drafter=drafter)
|
||||
drafter = NGramDrafter(
|
||||
spec_config=ngram_config,
|
||||
ngram_pool_manager=ngram_pool_manager,
|
||||
)
|
||||
|
||||
spec_config = UserProvidedDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
drafter=drafter,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
|
||||
@ -1128,7 +1128,7 @@ def test_llm_api_medusa():
|
||||
|
||||
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
|
||||
max_draft_len=63,
|
||||
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
||||
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
||||
@ -1167,7 +1167,7 @@ def test_llm_api_medusa_tp2():
|
||||
|
||||
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
|
||||
max_draft_len=63,
|
||||
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
speculative_model_dir=get_model_path("medusa-vicuna-7b-v1.3"),
|
||||
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
||||
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
||||
@ -1205,7 +1205,7 @@ def test_llm_api_eagle(**llm_kwargs):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
||||
@ -1252,7 +1252,7 @@ def test_llm_api_eagle2(**llm_kwargs):
|
||||
|
||||
speculative_config = EagleDecodingConfig(
|
||||
max_draft_len=63,
|
||||
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
speculative_model_dir=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
||||
num_eagle_layers=4,
|
||||
max_non_leaves_per_layer=10,
|
||||
use_dynamic_tree=True,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user