[None][feat] Implement sampling on 1-model EAGLE3 (#9885)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
Signed-off-by: Mike Iovine <miovine@nvidia.com>
This commit is contained in:
Mike Iovine 2025-12-13 10:38:22 -05:00 committed by GitHub
parent 079ef8ae77
commit 383b13e0e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 248 additions and 5 deletions

View File

@ -143,6 +143,9 @@ def add_llm_args(parser):
default=False,
action='store_true')
parser.add_argument('--dynamic_tree_max_topK', type=int, default=None)
parser.add_argument('--allow_advanced_sampling',
default=False,
action='store_true')
# Relaxed acceptance
parser.add_argument('--use_relaxed_acceptance_for_thinking',
@ -210,7 +213,9 @@ def setup_llm(args, **kwargs):
eagle3_one_model=args.use_one_model,
eagle_choices=args.eagle_choices,
use_dynamic_tree=args.use_dynamic_tree,
dynamic_tree_max_topK=args.dynamic_tree_max_topK)
dynamic_tree_max_topK=args.dynamic_tree_max_topK,
allow_advanced_sampling=args.allow_advanced_sampling)
elif spec_decode_algo == "DRAFT_TARGET":
spec_config = DraftTargetDecodingConfig(
max_draft_len=args.spec_decode_max_draft_len,

View File

@ -48,7 +48,8 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
get_spec_metadata,
update_spec_config_from_model_config)
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
Eagle3ResourceManager, Eagle3SpecMetadata)
from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.utils import SpecDecodingTensor
from ..utils import (get_model_extra_attrs,
@ -2093,6 +2094,9 @@ class PyTorchModelEngine(ModelEngine):
num_accepted_draft_tokens)]
if isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.request_accepted_path = request_accepted_path
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
spec_metadata.populate_sampling_params_for_one_model(
scheduled_requests.all_requests())
spec_metadata.prepare()
inputs['spec_metadata'] = spec_metadata

View File

@ -281,6 +281,17 @@ def create_py_executor(
)
llm_args.disable_overlap_scheduler = True
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
if not spec_config.allow_advanced_sampling:
logger.warning(
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
)
elif spec_config.spec_dec_mode.is_mtp_one_model():
logger.warning(
"Advanced sampling is not supported for MTP yet - this will be added soon."
)
if mm_encoder_only:
llm_args.mm_encoder_only = True
llm_args.disable_overlap_scheduler = True

View File

@ -14,6 +14,7 @@ from ..pyexecutor.sampler import TorchSampler
from ..pyexecutor.scheduler import ScheduledRequests
from .interface import SpecMetadata, get_force_num_accepted_tokens
from .mtp import MTPSampler
from .one_model_sampler import sampling_batch_spec_dec_one_model
from .spec_tree_manager import SpecTreeManager
if TYPE_CHECKING:
@ -493,6 +494,40 @@ class Eagle3OneModelWorker(nn.Module):
'next_new_tokens': next_new_tokens,
}
def _sample_tokens_for_batch(
self,
logits: torch.Tensor,
spec_metadata: Eagle3OneModelSpecMetadata,
num_contexts: int,
batch_size: int,
) -> torch.Tensor:
"""
Sample tokens from logits using per-request sampling parameters.
Supports both greedy and non-greedy sampling.
Args:
logits: [num_tokens, vocab_size] - Logits to sample from
spec_metadata: Metadata containing sampling parameters
batch_size: Number of requests in the batch
Returns:
sampled_tokens: [num_tokens] - Sampled token ids
"""
if spec_metadata.allow_advanced_sampling:
num_gens = batch_size - num_contexts
num_tokens = num_contexts + num_gens * (self.max_draft_len + 1)
temperatures = spec_metadata.temperatures[:num_tokens]
top_ks = spec_metadata.top_ks[:num_tokens]
top_ps = spec_metadata.top_ps[:num_tokens]
sampled_tokens = sampling_batch_spec_dec_one_model(
logits, temperatures, top_ks, top_ps)
else:
sampled_tokens = torch.argmax(logits, dim=-1)
return sampled_tokens
def sample_and_accept_draft_tokens(
self,
logits: torch.Tensor,
@ -514,8 +549,9 @@ class Eagle3OneModelWorker(nn.Module):
dtype=torch.int,
device=logits.device)
# Do greedy sampling for the input logits
target_tokens = torch.argmax(logits, dim=-1)
# Sample tokens using per-request sampling parameters
target_tokens = self._sample_tokens_for_batch(logits, spec_metadata,
num_contexts, batch_size)
# context
accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts]
@ -557,6 +593,9 @@ class Eagle3OneModelWorker(nn.Module):
Draft token ids. Flattened.
'''
# Note: using greedy for draft tokens is a bit easier to implement and
# faster. It doesn't affect the final output and seems to have a negligible
# impact on AR.
draft_tokens = torch.argmax(logits, dim=-1)
# Apply d2t (offsets between draft model dictionary and main model dictionary).

View File

@ -229,6 +229,13 @@ class SpecMetadata:
# whether the spec-dec mode is a dynamic tree.
is_spec_dec_dynamic_tree: bool = False
# For non-greedy sampling on 1-model.
allow_advanced_sampling: bool = False
# Sampling parameters for non-greedy sampling (per-request)
temperatures: Optional[torch.Tensor] = None
top_ks: Optional[torch.Tensor] = None
top_ps: Optional[torch.Tensor] = None
def __post_init__(self):
pass
@ -264,3 +271,83 @@ class SpecMetadata:
Some spec decode algorithms require hidden states from the target
model. Use this method to record them. By default, does nothing.
"""
def populate_sampling_params_for_one_model(
self, requests: list["LlmRequest"]) -> None:
"""
Set up topp/topk/temperatures for 1-model sampler.
"""
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
from tensorrt_llm.sampling_params import SamplingParams
if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine(
):
return
if self.temperatures is None:
# Ensures determinism across ranks.
torch.manual_seed(0)
temperatures = []
top_ks = []
top_ps = []
# Need to use a very small value for temperature when disabled to avoid division by 0
DISABLE_TEMP_VAL = 1e-5
# Very large values disable topk.
DISABLE_TOPK_VAL = torch.iinfo(torch.int32).max
DISABLE_TOPP_VAL = 1.0
for request in requests:
sampling_config = request.sampling_config
temp = sampling_config.temperature
temp_val = temp[0] if temp is not None and len(temp) > 0 else None
tk = sampling_config.top_k
tk_val = tk[0] if tk is not None and len(tk) > 0 else None
tp = sampling_config.top_p
tp_val = tp[0] if tp is not None and len(tp) > 0 else None
# Context requests have no draft tokens yet.
num_tokens = 1 + self.max_draft_len if request.state == LlmRequestState.GENERATION_IN_PROGRESS else 1
is_greedy = SamplingParams.params_imply_greedy_decoding(
temperature=temp_val,
top_k=tk_val,
top_p=tp_val,
use_beam_search=False)
temp_val = DISABLE_TEMP_VAL if is_greedy or temp_val is None or temp_val == 0 else temp_val
tk_val = DISABLE_TOPK_VAL if is_greedy or tk_val is None or tk_val <= 0 else tk_val
tp_val = DISABLE_TOPP_VAL if is_greedy or tp_val is None else tp_val
temperatures.extend(temp_val for _ in range(num_tokens))
top_ks.extend(tk_val for _ in range(num_tokens))
top_ps.extend(tp_val for _ in range(num_tokens))
if self.temperatures is None:
self.temperatures = torch.ones(
(self.max_draft_len + 1) * self.max_num_requests,
dtype=torch.float32,
device='cuda')
self.top_ks = torch.zeros(
(self.max_draft_len + 1) * self.max_num_requests,
dtype=torch.int32,
device='cuda')
self.top_ps = torch.ones(
(self.max_draft_len + 1) * self.max_num_requests,
dtype=torch.float32,
device='cuda')
self.temperatures[:len(temperatures)].copy_(torch.tensor(
temperatures, dtype=torch.float32, pin_memory=True),
non_blocking=True)
self.top_ks[:len(top_ks)].copy_(torch.tensor(top_ks,
dtype=torch.int32,
pin_memory=True),
non_blocking=True)
self.top_ps[:len(top_ps)].copy_(torch.tensor(top_ps,
dtype=torch.float32,
pin_memory=True),
non_blocking=True)

View File

@ -0,0 +1,91 @@
from typing import Optional
import torch
def forward_native(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs)
def random_sample(
probs: torch.Tensor,
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs).exponential_()
return probs.div_(q).argmax(dim=-1).view(-1)
def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
top_k_mask = top_k_mask.clamp(min=0)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_temperature(
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
return logits.div_(temp.unsqueeze(dim=1))
@torch.compile(options={"max-autotune": True})
def sampling_batch_spec_dec_one_model(
logits: torch.Tensor,
temperatures: torch.Tensor,
top_k: torch.Tensor,
top_p: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
CUDA-graph compatible sampling. Supports mixed sampling params.
We can't do dynamic kernel selection inside graphs, so this might
be slower than a torch.argmax for greedy requests. This is why advanced
sampling is opt-in for now.
"""
logits = apply_temperature(logits, temperatures)
random_sampled = forward_native(logits, top_k, top_p)
return random_sampled

View File

@ -76,6 +76,7 @@ def get_spec_metadata(spec_config,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
layers_to_capture=spec_config.eagle3_layers_to_capture,
allow_advanced_sampling=spec_config.allow_advanced_sampling,
)
if spec_config.spec_dec_mode.is_save_hidden_states():
if spec_config.eagle3_layers_to_capture is None:

View File

@ -619,6 +619,10 @@ class DecodingBaseConfig(StrictBaseModel):
# (N = acceptance_window) drops below this value.
acceptance_length_threshold: Optional[float] = None
# Prototype. If true, allows non-greedy sampling when speculation is used. Only applicable
# to 1-model code paths; non-greedy sampling is always enabled on 2-model paths.
allow_advanced_sampling: bool = False
# Validate acceptance controls at field level so they run on model creation
@field_validator('acceptance_window')
@classmethod

View File

@ -4307,7 +4307,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
draft_len = 3
spec_config = EagleDecodingConfig(max_draft_len=draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=one_model)
eagle3_one_model=one_model,
allow_advanced_sampling=True)
max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
llm = LLM(self.MODEL_PATH,