diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 5aa7f7ce70..f028d41e55 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -143,6 +143,9 @@ def add_llm_args(parser): default=False, action='store_true') parser.add_argument('--dynamic_tree_max_topK', type=int, default=None) + parser.add_argument('--allow_advanced_sampling', + default=False, + action='store_true') # Relaxed acceptance parser.add_argument('--use_relaxed_acceptance_for_thinking', @@ -210,7 +213,9 @@ def setup_llm(args, **kwargs): eagle3_one_model=args.use_one_model, eagle_choices=args.eagle_choices, use_dynamic_tree=args.use_dynamic_tree, - dynamic_tree_max_topK=args.dynamic_tree_max_topK) + dynamic_tree_max_topK=args.dynamic_tree_max_topK, + allow_advanced_sampling=args.allow_advanced_sampling) + elif spec_decode_algo == "DRAFT_TARGET": spec_config = DraftTargetDecodingConfig( max_draft_len=args.spec_decode_max_draft_len, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index be6ae4bf3c..5da64a5569 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -48,7 +48,8 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) from ..speculative.drafting_loops import BaseDraftingLoopWrapper -from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata +from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata, + Eagle3ResourceManager, Eagle3SpecMetadata) from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.utils import SpecDecodingTensor from ..utils import (get_model_extra_attrs, @@ -2093,6 +2094,9 @@ class PyTorchModelEngine(ModelEngine): num_accepted_draft_tokens)] if isinstance(spec_metadata, Eagle3SpecMetadata): spec_metadata.request_accepted_path = request_accepted_path + if isinstance(spec_metadata, Eagle3OneModelSpecMetadata): + spec_metadata.populate_sampling_params_for_one_model( + scheduled_requests.all_requests()) spec_metadata.prepare() inputs['spec_metadata'] = spec_metadata diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 3fc0027d63..a908ba251f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -281,6 +281,17 @@ def create_py_executor( ) llm_args.disable_overlap_scheduler = True + if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(): + if not spec_config.allow_advanced_sampling: + logger.warning( + f"Falling back to greedy decoding for {spec_config.decoding_type}. If you " + "want to use non-greedy sampling, please set allow_advanced_sampling=True." + ) + elif spec_config.spec_dec_mode.is_mtp_one_model(): + logger.warning( + "Advanced sampling is not supported for MTP yet - this will be added soon." + ) + if mm_encoder_only: llm_args.mm_encoder_only = True llm_args.disable_overlap_scheduler = True diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 89b1ff0ff1..18052f617c 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -14,6 +14,7 @@ from ..pyexecutor.sampler import TorchSampler from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecMetadata, get_force_num_accepted_tokens from .mtp import MTPSampler +from .one_model_sampler import sampling_batch_spec_dec_one_model from .spec_tree_manager import SpecTreeManager if TYPE_CHECKING: @@ -493,6 +494,40 @@ class Eagle3OneModelWorker(nn.Module): 'next_new_tokens': next_new_tokens, } + def _sample_tokens_for_batch( + self, + logits: torch.Tensor, + spec_metadata: Eagle3OneModelSpecMetadata, + num_contexts: int, + batch_size: int, + ) -> torch.Tensor: + """ + Sample tokens from logits using per-request sampling parameters. + Supports both greedy and non-greedy sampling. + + Args: + logits: [num_tokens, vocab_size] - Logits to sample from + spec_metadata: Metadata containing sampling parameters + batch_size: Number of requests in the batch + + Returns: + sampled_tokens: [num_tokens] - Sampled token ids + """ + if spec_metadata.allow_advanced_sampling: + num_gens = batch_size - num_contexts + num_tokens = num_contexts + num_gens * (self.max_draft_len + 1) + + temperatures = spec_metadata.temperatures[:num_tokens] + top_ks = spec_metadata.top_ks[:num_tokens] + top_ps = spec_metadata.top_ps[:num_tokens] + + sampled_tokens = sampling_batch_spec_dec_one_model( + logits, temperatures, top_ks, top_ps) + else: + sampled_tokens = torch.argmax(logits, dim=-1) + + return sampled_tokens + def sample_and_accept_draft_tokens( self, logits: torch.Tensor, @@ -514,8 +549,9 @@ class Eagle3OneModelWorker(nn.Module): dtype=torch.int, device=logits.device) - # Do greedy sampling for the input logits - target_tokens = torch.argmax(logits, dim=-1) + # Sample tokens using per-request sampling parameters + target_tokens = self._sample_tokens_for_batch(logits, spec_metadata, + num_contexts, batch_size) # context accepted_tokens[:num_contexts, 0] = target_tokens[:num_contexts] @@ -557,6 +593,9 @@ class Eagle3OneModelWorker(nn.Module): Draft token ids. Flattened. ''' + # Note: using greedy for draft tokens is a bit easier to implement and + # faster. It doesn't affect the final output and seems to have a negligible + # impact on AR. draft_tokens = torch.argmax(logits, dim=-1) # Apply d2t (offsets between draft model dictionary and main model dictionary). diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index a02640f420..9bf262b3cb 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -229,6 +229,13 @@ class SpecMetadata: # whether the spec-dec mode is a dynamic tree. is_spec_dec_dynamic_tree: bool = False + # For non-greedy sampling on 1-model. + allow_advanced_sampling: bool = False + # Sampling parameters for non-greedy sampling (per-request) + temperatures: Optional[torch.Tensor] = None + top_ks: Optional[torch.Tensor] = None + top_ps: Optional[torch.Tensor] = None + def __post_init__(self): pass @@ -264,3 +271,83 @@ class SpecMetadata: Some spec decode algorithms require hidden states from the target model. Use this method to record them. By default, does nothing. """ + + def populate_sampling_params_for_one_model( + self, requests: list["LlmRequest"]) -> None: + """ + Set up topp/topk/temperatures for 1-model sampler. + """ + from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState + from tensorrt_llm.sampling_params import SamplingParams + + if not self.allow_advanced_sampling or not self.spec_dec_mode.use_one_engine( + ): + return + + if self.temperatures is None: + # Ensures determinism across ranks. + torch.manual_seed(0) + + temperatures = [] + top_ks = [] + top_ps = [] + + # Need to use a very small value for temperature when disabled to avoid division by 0 + DISABLE_TEMP_VAL = 1e-5 + # Very large values disable topk. + DISABLE_TOPK_VAL = torch.iinfo(torch.int32).max + DISABLE_TOPP_VAL = 1.0 + + for request in requests: + sampling_config = request.sampling_config + temp = sampling_config.temperature + temp_val = temp[0] if temp is not None and len(temp) > 0 else None + + tk = sampling_config.top_k + tk_val = tk[0] if tk is not None and len(tk) > 0 else None + + tp = sampling_config.top_p + tp_val = tp[0] if tp is not None and len(tp) > 0 else None + + # Context requests have no draft tokens yet. + num_tokens = 1 + self.max_draft_len if request.state == LlmRequestState.GENERATION_IN_PROGRESS else 1 + + is_greedy = SamplingParams.params_imply_greedy_decoding( + temperature=temp_val, + top_k=tk_val, + top_p=tp_val, + use_beam_search=False) + + temp_val = DISABLE_TEMP_VAL if is_greedy or temp_val is None or temp_val == 0 else temp_val + tk_val = DISABLE_TOPK_VAL if is_greedy or tk_val is None or tk_val <= 0 else tk_val + tp_val = DISABLE_TOPP_VAL if is_greedy or tp_val is None else tp_val + + temperatures.extend(temp_val for _ in range(num_tokens)) + top_ks.extend(tk_val for _ in range(num_tokens)) + top_ps.extend(tp_val for _ in range(num_tokens)) + + if self.temperatures is None: + self.temperatures = torch.ones( + (self.max_draft_len + 1) * self.max_num_requests, + dtype=torch.float32, + device='cuda') + self.top_ks = torch.zeros( + (self.max_draft_len + 1) * self.max_num_requests, + dtype=torch.int32, + device='cuda') + self.top_ps = torch.ones( + (self.max_draft_len + 1) * self.max_num_requests, + dtype=torch.float32, + device='cuda') + + self.temperatures[:len(temperatures)].copy_(torch.tensor( + temperatures, dtype=torch.float32, pin_memory=True), + non_blocking=True) + self.top_ks[:len(top_ks)].copy_(torch.tensor(top_ks, + dtype=torch.int32, + pin_memory=True), + non_blocking=True) + self.top_ps[:len(top_ps)].copy_(torch.tensor(top_ps, + dtype=torch.float32, + pin_memory=True), + non_blocking=True) diff --git a/tensorrt_llm/_torch/speculative/one_model_sampler.py b/tensorrt_llm/_torch/speculative/one_model_sampler.py new file mode 100644 index 0000000000..ca48c03f28 --- /dev/null +++ b/tensorrt_llm/_torch/speculative/one_model_sampler.py @@ -0,0 +1,91 @@ +from typing import Optional + +import torch + + +def forward_native( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + PyTorch-native implementation of top-k and top-p sampling. + + The logits tensor may be updated in-place. + """ + logits = apply_top_k_top_p(logits, k, p) + probs = logits.softmax(dim=-1, dtype=torch.float32) + return random_sample(probs) + + +def random_sample( + probs: torch.Tensor, +) -> torch.Tensor: + """Randomly sample from the probabilities. + + We use this function instead of torch.multinomial because torch.multinomial + causes CPU-GPU synchronization. + """ + q = torch.empty_like(probs).exponential_() + return probs.div_(q).argmax(dim=-1).view(-1) + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply top-k and top-p masks to the logits. + + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. + """ + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + if k is not None: + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B + top_k_mask = top_k_mask.clamp(min=0) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + if p is not None: + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + # Re-sort the probabilities. + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def apply_temperature( + logits: torch.Tensor, + temp: torch.Tensor, +) -> torch.Tensor: + return logits.div_(temp.unsqueeze(dim=1)) + + +@torch.compile(options={"max-autotune": True}) +def sampling_batch_spec_dec_one_model( + logits: torch.Tensor, + temperatures: torch.Tensor, + top_k: torch.Tensor, + top_p: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + CUDA-graph compatible sampling. Supports mixed sampling params. + + We can't do dynamic kernel selection inside graphs, so this might + be slower than a torch.argmax for greedy requests. This is why advanced + sampling is opt-in for now. + """ + logits = apply_temperature(logits, temperatures) + random_sampled = forward_native(logits, top_k, top_p) + return random_sampled diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 4ef4ff4296..139787df44 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -76,6 +76,7 @@ def get_spec_metadata(spec_config, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, layers_to_capture=spec_config.eagle3_layers_to_capture, + allow_advanced_sampling=spec_config.allow_advanced_sampling, ) if spec_config.spec_dec_mode.is_save_hidden_states(): if spec_config.eagle3_layers_to_capture is None: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index cd7858c6f4..b790dc141d 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -619,6 +619,10 @@ class DecodingBaseConfig(StrictBaseModel): # (N = acceptance_window) drops below this value. acceptance_length_threshold: Optional[float] = None + # Prototype. If true, allows non-greedy sampling when speculation is used. Only applicable + # to 1-model code paths; non-greedy sampling is always enabled on 2-model paths. + allow_advanced_sampling: bool = False + # Validate acceptance controls at field level so they run on model creation @field_validator('acceptance_window') @classmethod diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 3b667b15c9..2f27c5dc18 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4307,7 +4307,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness): draft_len = 3 spec_config = EagleDecodingConfig(max_draft_len=draft_len, speculative_model_dir=eagle_model_dir, - eagle3_one_model=one_model) + eagle3_one_model=one_model, + allow_advanced_sampling=True) max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN llm = LLM(self.MODEL_PATH,