mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Implement sampling on 1-model EAGLE3 (#9885)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Mike Iovine <miovine@nvidia.com>
This commit is contained in:
parent
079ef8ae77
commit
383b13e0e5
@ -143,6 +143,9 @@ def add_llm_args(parser):
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
|
||||
parser.add_argument('--allow_advanced_sampling',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
# Relaxed acceptance
|
||||
parser.add_argument('--use_relaxed_acceptance_for_thinking',
|
||||
@ -210,7 +213,9 @@ def setup_llm(args, **kwargs):
|
||||
eagle3_one_model=args.use_one_model,
|
||||
eagle_choices=args.eagle_choices,
|
||||
use_dynamic_tree=args.use_dynamic_tree,
|
||||
dynamic_tree_max_topK=args.dynamic_tree_max_topK)
|
||||
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
|
||||
allow_advanced_sampling=args.allow_advanced_sampling)
|
||||
|
||||
elif spec_decode_algo == "DRAFT_TARGET":
|
||||
spec_config = DraftTargetDecodingConfig(
|
||||
max_draft_len=args.spec_decode_max_draft_len,
|
||||
|
||||
@ -48,7 +48,8 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
|
||||
get_spec_metadata,
|
||||
update_spec_config_from_model_config)
|
||||
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
|
||||
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
|
||||
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
|
||||
Eagle3ResourceManager, Eagle3SpecMetadata)
|
||||
from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..speculative.utils import SpecDecodingTensor
|
||||
from ..utils import (get_model_extra_attrs,
|
||||
@ -2093,6 +2094,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_accepted_draft_tokens)]
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.request_accepted_path = request_accepted_path
|
||||
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
|
||||
spec_metadata.populate_sampling_params_for_one_model(
|
||||
scheduled_requests.all_requests())
|
||||
spec_metadata.prepare()
|
||||
inputs['spec_metadata'] = spec_metadata
|
||||
|
||||
|
||||
@ -281,6 +281,17 @@ def create_py_executor(
|
||||
)
|
||||
llm_args.disable_overlap_scheduler = True
|
||||
|
||||
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
|
||||
if not spec_config.allow_advanced_sampling:
|
||||
logger.warning(
|
||||
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
|
||||
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
|
||||
)
|
||||
elif spec_config.spec_dec_mode.is_mtp_one_model():
|
||||
logger.warning(
|
||||
"Advanced sampling is not supported for MTP yet - this will be added soon."
|
||||
)
|
||||
|
||||
if mm_encoder_only:
|
||||
llm_args.mm_encoder_only = True
|
||||
llm_args.disable_overlap_scheduler = True
|
||||
|
||||
@ -14,6 +14,7 @@ from ..pyexecutor.sampler import TorchSampler
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
from .interface import SpecMetadata, get_force_num_accepted_tokens
|
||||
from .mtp import MTPSampler
|
||||
from .one_model_sampler import sampling_batch_spec_dec_one_model
|
||||
from .spec_tree_manager import SpecTreeManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -493,6 +494,40 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
'next_new_tokens': next_new_tokens,
|
||||
}
|
||||
|
||||
def _sample_tokens_for_batch(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
spec_metadata: Eagle3OneModelSpecMetadata,
|
||||
num_contexts: int,
|
||||
batch_size: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample tokens from logits using per-request sampling parameters.
|
||||
Supports both greedy and non-greedy sampling.
|
||||
|
||||
Args:
|
||||
logits: [num_tokens, vocab_size] - Logits to sample from
|
||||
spec_metadata: Metadata containing sampling parameters
|
||||
batch_size: Number of requests in the batch
|
||||
|
||||
Returns:
|
||||
sampled_tokens: [num_tokens] - Sampled token ids
|
||||
"""
|
||||
if spec_metadata.allow_advanced_sampling:
|
||||
num_gens = batch_size - num_contexts
|
||||
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
|
||||
|
||||
temperatures = spec_metadata.temperatures[:num_tokens]
|
||||
top_ks = spec_metadata.top_ks[:num_tokens]
|
||||
top_ps = spec_metadata.top_ps[:num_tokens]
|
||||
|
||||
sampled_tokens = sampling_batch_spec_dec_one_model(
|
||||
logits, temperatures, top_ks, top_ps)
|
||||
else:
|
||||
sampled_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
return sampled_tokens
|
||||
|
||||
def sample_and_accept_draft_tokens(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@ -514,8 +549,9 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
|
||||
# Do greedy sampling for the input logits
|
||||
target_tokens = torch.argmax(logits, dim=-1)
|
||||
# Sample tokens using per-request sampling parameters
|
||||
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
|
||||
num_contexts, batch_size)
|
||||
# context
|
||||
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
|
||||
|
||||
@ -557,6 +593,9 @@ class Eagle3OneModelWorker(nn.Module):
|
||||
Draft token ids. Flattened.
|
||||
'''
|
||||
|
||||
# Note: using greedy for draft tokens is a bit easier to implement and
|
||||
# faster. It doesn't affect the final output and seems to have a negligible
|
||||
# impact on AR.
|
||||
draft_tokens = torch.argmax(logits, dim=-1)
|
||||
|
||||
# Apply d2t (offsets between draft model dictionary and main model dictionary).
|
||||
|
||||
@ -229,6 +229,13 @@ class SpecMetadata:
|
||||
# whether the spec-dec mode is a dynamic tree.
|
||||
is_spec_dec_dynamic_tree: bool = False
|
||||
|
||||
# For non-greedy sampling on 1-model.
|
||||
allow_advanced_sampling: bool = False
|
||||
# Sampling parameters for non-greedy sampling (per-request)
|
||||
temperatures: Optional[torch.Tensor] = None
|
||||
top_ks: Optional[torch.Tensor] = None
|
||||
top_ps: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
|
||||
@ -264,3 +271,83 @@ class SpecMetadata:
|
||||
Some spec decode algorithms require hidden states from the target
|
||||
model. Use this method to record them. By default, does nothing.
|
||||
"""
|
||||
|
||||
def populate_sampling_params_for_one_model(
|
||||
self, requests: list["LlmRequest"]) -> None:
|
||||
"""
|
||||
Set up topp/topk/temperatures for 1-model sampler.
|
||||
"""
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine(
|
||||
):
|
||||
return
|
||||
|
||||
if self.temperatures is None:
|
||||
# Ensures determinism across ranks.
|
||||
torch.manual_seed(0)
|
||||
|
||||
temperatures = []
|
||||
top_ks = []
|
||||
top_ps = []
|
||||
|
||||
# Need to use a very small value for temperature when disabled to avoid division by 0
|
||||
DISABLE_TEMP_VAL = 1e-5
|
||||
# Very large values disable topk.
|
||||
DISABLE_TOPK_VAL = torch.iinfo(torch.int32).max
|
||||
DISABLE_TOPP_VAL = 1.0
|
||||
|
||||
for request in requests:
|
||||
sampling_config = request.sampling_config
|
||||
temp = sampling_config.temperature
|
||||
temp_val = temp[0] if temp is not None and len(temp) > 0 else None
|
||||
|
||||
tk = sampling_config.top_k
|
||||
tk_val = tk[0] if tk is not None and len(tk) > 0 else None
|
||||
|
||||
tp = sampling_config.top_p
|
||||
tp_val = tp[0] if tp is not None and len(tp) > 0 else None
|
||||
|
||||
# Context requests have no draft tokens yet.
|
||||
num_tokens = 1 + self.max_draft_len if request.state == LlmRequestState.GENERATION_IN_PROGRESS else 1
|
||||
|
||||
is_greedy = SamplingParams.params_imply_greedy_decoding(
|
||||
temperature=temp_val,
|
||||
top_k=tk_val,
|
||||
top_p=tp_val,
|
||||
use_beam_search=False)
|
||||
|
||||
temp_val = DISABLE_TEMP_VAL if is_greedy or temp_val is None or temp_val == 0 else temp_val
|
||||
tk_val = DISABLE_TOPK_VAL if is_greedy or tk_val is None or tk_val <= 0 else tk_val
|
||||
tp_val = DISABLE_TOPP_VAL if is_greedy or tp_val is None else tp_val
|
||||
|
||||
temperatures.extend(temp_val for _ in range(num_tokens))
|
||||
top_ks.extend(tk_val for _ in range(num_tokens))
|
||||
top_ps.extend(tp_val for _ in range(num_tokens))
|
||||
|
||||
if self.temperatures is None:
|
||||
self.temperatures = torch.ones(
|
||||
(self.max_draft_len + 1) * self.max_num_requests,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
self.top_ks = torch.zeros(
|
||||
(self.max_draft_len + 1) * self.max_num_requests,
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
self.top_ps = torch.ones(
|
||||
(self.max_draft_len + 1) * self.max_num_requests,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
|
||||
self.temperatures[:len(temperatures)].copy_(torch.tensor(
|
||||
temperatures, dtype=torch.float32, pin_memory=True),
|
||||
non_blocking=True)
|
||||
self.top_ks[:len(top_ks)].copy_(torch.tensor(top_ks,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True),
|
||||
non_blocking=True)
|
||||
self.top_ps[:len(top_ps)].copy_(torch.tensor(top_ps,
|
||||
dtype=torch.float32,
|
||||
pin_memory=True),
|
||||
non_blocking=True)
|
||||
|
||||
91
tensorrt_llm/_torch/speculative/one_model_sampler.py
Normal file
91
tensorrt_llm/_torch/speculative/one_model_sampler.py
Normal file
@ -0,0 +1,91 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def forward_native(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs)
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs).exponential_()
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
top_k_mask = top_k_mask.clamp(min=0)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def sampling_batch_spec_dec_one_model(
|
||||
logits: torch.Tensor,
|
||||
temperatures: torch.Tensor,
|
||||
top_k: torch.Tensor,
|
||||
top_p: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
CUDA-graph compatible sampling. Supports mixed sampling params.
|
||||
|
||||
We can't do dynamic kernel selection inside graphs, so this might
|
||||
be slower than a torch.argmax for greedy requests. This is why advanced
|
||||
sampling is opt-in for now.
|
||||
"""
|
||||
logits = apply_temperature(logits, temperatures)
|
||||
random_sampled = forward_native(logits, top_k, top_p)
|
||||
return random_sampled
|
||||
@ -76,6 +76,7 @@ def get_spec_metadata(spec_config,
|
||||
hidden_size=model_config.hidden_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
layers_to_capture=spec_config.eagle3_layers_to_capture,
|
||||
allow_advanced_sampling=spec_config.allow_advanced_sampling,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_save_hidden_states():
|
||||
if spec_config.eagle3_layers_to_capture is None:
|
||||
|
||||
@ -619,6 +619,10 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
# (N = acceptance_window) drops below this value.
|
||||
acceptance_length_threshold: Optional[float] = None
|
||||
|
||||
# Prototype. If true, allows non-greedy sampling when speculation is used. Only applicable
|
||||
# to 1-model code paths; non-greedy sampling is always enabled on 2-model paths.
|
||||
allow_advanced_sampling: bool = False
|
||||
|
||||
# Validate acceptance controls at field level so they run on model creation
|
||||
@field_validator('acceptance_window')
|
||||
@classmethod
|
||||
|
||||
@ -4307,7 +4307,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
draft_len = 3
|
||||
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=one_model)
|
||||
eagle3_one_model=one_model,
|
||||
allow_advanced_sampling=True)
|
||||
|
||||
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
|
||||
llm = LLM(self.MODEL_PATH,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user