[TRTLLM-7384][feat] enable rejection sampling for CDL (#7731)

Signed-off-by: linquanh <linquanh@nvidia.com>
This commit is contained in:
kris1025 2025-10-12 20:38:48 +08:00 committed by GitHub
parent 5798a12199
commit a7ea544dbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 212 additions and 144 deletions

View File

@ -456,6 +456,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.use_draft_model = is_draft
# Whether the request is for the first forward of the draft model.
self.py_is_first_draft = is_first_draft
self.d2t = None
self.py_draft_use_greedy_sampling = False
# Chunked logits parameters
self.py_use_chunked_generation_logits = use_chunked_generation_logits

View File

@ -31,7 +31,6 @@ from ..attention_backend.interface import AttentionRuntimeFeatures
from ..distributed import MPIDist, TorchDist
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
get_spec_resource_manager)
from ..utils import _get_allow_chain_drafter
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla,
validate_feature_combination)
@ -344,13 +343,11 @@ def create_py_executor(
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
draft_spec_config = copy.copy(spec_config)
if _get_allow_chain_drafter():
use_chain_drafter = (
guided_decoding_config is None
and draft_spec_config._allow_greedy_draft_tokens
and pytorch_backend_config.attn_backend == "TRTLLM")
else:
use_chain_drafter = False
use_chain_drafter = (
guided_decoding_config is None
and draft_spec_config._allow_chain_drafter
and draft_spec_config._allow_greedy_draft_tokens
and pytorch_backend_config.attn_backend == "TRTLLM")
logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
if use_chain_drafter:

View File

@ -310,10 +310,15 @@ def greedy_search_sampling_batch(
softmax_indices: Optional[torch.IntTensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
next_tokens = torch.argmax(logits, dim=-1)
index_to_scatter = next_tokens
if softmax_indices is not None:
logits = logits[softmax_indices.to(logits.device, non_blocking=True)]
softmax = torch.softmax(logits, dim=-1)
return next_tokens, softmax
logits = logits[softmax_indices]
index_to_scatter = next_tokens[softmax_indices]
probs = torch.zeros_like(logits)
probs.scatter_(dim=-1,
index=index_to_scatter.unsqueeze(-1),
src=torch.ones_like(logits))
return next_tokens, probs
def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor,
@ -1127,6 +1132,7 @@ class TorchSampler(Sampler):
return new_draft_tokens_host
@torch.inference_mode()
def _process_draft_tokens_rejection_sampling(
self, request: LlmRequest, new_tokens: list[list[list[int]]],
new_tokens_tensor: torch.Tensor) -> int:
@ -1134,13 +1140,30 @@ class TorchSampler(Sampler):
# filtering of vocab_size logits, out of vocab_size in
# total. The 'sample' below should generally be avoided
# by retaining the draft_probs during drafting (TRTLLM-7772).
sampling_strategy = _request_strategy(request, vocab_size=2**31)
draft_sampling_strategy = (
"greedy", None
) if request.py_draft_use_greedy_sampling else _request_strategy(
request, vocab_size=2**31)
generator = self.get_generator(request.py_draft_logits.device)
_, draft_probs = sample(sampling_strategy,
_, draft_probs = sample(draft_sampling_strategy,
request.py_draft_logits,
generator=generator)
draft_probs = draft_probs.squeeze(0)
target_probs = request.py_target_probs
d2t = getattr(request, "d2t", None)
if d2t is not None:
vocab_d = draft_probs.shape[-1]
vocab_t = target_probs.shape[-1]
assert d2t.numel(
) == vocab_d, f"d2t size mismatch: {d2t.numel()} != {vocab_d}"
assert d2t.device == draft_probs.device, f"d2t device mismatch: {d2t.device} != {draft_probs.device}"
aligned_draft_probs = torch.zeros(
(*draft_probs.shape[:-1], vocab_t),
device=draft_probs.device,
dtype=draft_probs.dtype)
source_indices = torch.arange(vocab_d, device=draft_probs.device)
target_indices = (source_indices + d2t) % vocab_t
aligned_draft_probs[..., target_indices] = draft_probs
draft_probs = aligned_draft_probs
rejected_indices = get_rejected_indices(draft_probs, target_probs,
generator,
request.py_draft_tokens)
@ -1181,7 +1204,8 @@ class TorchSampler(Sampler):
new_tokens: list[list[list[int]]],
new_tokens_tensor: torch.Tensor,
resource_manager: Optional[ResourceManager] = None) -> int:
if request.py_draft_logits is None:
if _request_strategy(request, vocab_size=2**
31) == GREEDY or request.py_draft_logits is None:
spec_tree_manager = self.get_spec_tree_manager(resource_manager)
if spec_tree_manager is not None:
num_accepted = self._process_draft_tokens_tree(

View File

@ -116,8 +116,7 @@ class ChainDrafter(torch.nn.Module):
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata,
**kwargs) -> torch.Tensor:
**kwargs) -> dict[str, torch.Tensor]:
logits = self.draft_model.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
@ -126,6 +125,7 @@ class ChainDrafter(torch.nn.Module):
logits = logits[spec_metadata.gather_ids]
new_draft_tokens = [self.sample(logits)]
draft_logits = [logits]
with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs
@ -139,13 +139,17 @@ class ChainDrafter(torch.nn.Module):
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits))
draft_logits.append(logits)
new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.hidden_states_write_indices[:batch_size])
return torch.stack(new_draft_tokens)
return {
"new_draft_tokens": torch.stack(new_draft_tokens),
"draft_logits": torch.stack(draft_logits)
}
def sample(self, logits: torch.Tensor) -> torch.Tensor:
# TODO: inject the sampler here so we can support non-greedy

View File

@ -226,6 +226,12 @@ class ModelDrafter(Drafter):
ScheduledRequests: The prepared draft batch
"""
try:
for req in scheduled_requests.all_requests():
draft_model = self.draft_model_engine.model.draft_model if self.use_static_draft_loop else self.draft_model_engine.model
if hasattr(draft_model.model, "d2t"):
req.d2t = draft_model.model.d2t.data
req.py_draft_use_greedy_sampling = self.use_static_draft_loop
draft_batch = ScheduledRequests()
for request in scheduled_requests.context_requests:
@ -526,7 +532,8 @@ class ModelDrafter(Drafter):
return draft_batch, req_id_to_old_request
def process_static_draft_outputs(
self, outputs: torch.Tensor | SampleState,
self,
outputs: dict[str, torch.Tensor] | tuple[torch.Tensor, SampleState],
draft_batch: ScheduledRequests,
req_id_to_old_request: Dict[int, LlmRequest]) -> None:
"""
@ -537,23 +544,26 @@ class ModelDrafter(Drafter):
draft_batch: The draft batch that was processed
req_id_to_old_request: Mapping from draft request ID to original request
"""
if isinstance(outputs, torch.Tensor):
# For non-overlap scheduler path.
outputs_host = outputs.cpu()
if isinstance(outputs, dict):
draft_tokens_host = outputs["new_draft_tokens"].cpu()
draft_logits = outputs["draft_logits"]
else:
outputs_host = outputs.host.new_tokens
outputs.sampler_event.synchronize()
draft_logits = outputs[0]
draft_tokens_host = outputs[1].host.new_tokens
outputs[1].sampler_event.synchronize()
for token_idx in range(self.max_draft_tokens):
for req_idx, req in enumerate(draft_batch.all_requests()):
target_model_req = req_id_to_old_request[req.py_request_id]
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# Chunked prefill request in progress; no need to append draft tokens
continue
target_req = req_id_to_old_request[req.py_request_id]
target_req.py_draft_tokens.append(
outputs_host[token_idx][req_idx])
for req_idx, req in enumerate(draft_batch.all_requests()):
target_model_req = req_id_to_old_request[req.py_request_id]
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# Chunked prefill request in progress; no need to append draft tokens
continue
py_draft_logits = []
for token_idx in range(self.max_draft_tokens):
target_model_req.py_draft_tokens.append(
draft_tokens_host[token_idx][req_idx])
py_draft_logits.append(draft_logits[token_idx][req_idx])
target_model_req.py_draft_logits = torch.stack(py_draft_logits)
# Clean up draft resources
for req in draft_batch.all_requests():
@ -706,23 +716,26 @@ class ModelDrafter(Drafter):
# Only update target inputs, cleanup will be done in executor loop
self._update_target_inputs_with_draft_tokens(
target_inputs,
outputs,
outputs["new_draft_tokens"],
draft_position=0,
draft_length=self.max_draft_tokens,
draft_batch=draft_batch,
req_id_to_old_request=req_id_to_old_request)
new_tokens_host = outputs.to(device='cpu', non_blocking=True)
new_tokens_host = outputs["new_draft_tokens"].to(device='cpu',
non_blocking=True)
sampler_event = torch.cuda.Event()
sampler_event.record()
outputs = SampleState(
sample_state = SampleState(
scheduled_requests=draft_batch,
device=SampleStateTensors(new_tokens=outputs),
device=SampleStateTensors(
new_tokens=outputs["new_draft_tokens"]),
host=SampleStateTensors(new_tokens=new_tokens_host),
sampler_event=sampler_event)
return target_inputs, outputs, draft_batch
return target_inputs, (outputs["draft_logits"],
sample_state), draft_batch
# Handle guided decoder and sampling for non-static loop
if self.guided_decoder is not None:

View File

@ -308,12 +308,6 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
)
# Development function to control chain drafter feature.
# It's here so that unit tests can mock it and turn it off.
def _get_allow_chain_drafter() -> bool:
return True
def get_device_uuid(device_idx: int) -> str:
"""Get the UUID of a CUDA device using torch cuda api"""

View File

@ -366,6 +366,8 @@ class DecodingBaseConfig(StrictBaseModel):
load_format: Optional[str] = None
# If set, drafting is allowed to use chain drafter.
_allow_chain_drafter: bool = PrivateAttr(True)
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
_allow_greedy_draft_tokens: bool = PrivateAttr(True)

View File

@ -4,7 +4,6 @@ import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
@ -58,11 +57,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
use_one_model: bool, enable_chunked_prefill: bool,
use_chain_drafter: bool, multi_batch: bool,
attention_dp: bool, request):
# Use enforce_single_worker fixture only when use_chain_drafter is False.
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
if not use_chain_drafter:
request.getfixturevalue('enforce_single_worker')
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
@ -72,106 +66,94 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
# Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
if not use_chain_drafter:
patch_context = patch(
'tensorrt_llm._torch.utils._get_allow_chain_drafter',
return_value=False)
else:
patch_context = patch(
'tensorrt_llm._torch.utils._get_allow_chain_drafter',
return_value=True)
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 4 if multi_batch else 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[i for i in range(1, max_batch_size +
1)]) if use_cuda_graph else None
with patch_context:
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
# that ref and spec does not match 100%
max_batch_size = 4 if multi_batch else 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(
batch_sizes=[i for i in range(1, max_batch_size +
1)]) if use_cuda_graph else None
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_attention_dp=attention_dp,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
if enable_chunked_prefill:
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
enable_attention_dp=attention_dp,
# This max_seq_len is larger than the one specified
# in the llama 3 8B eagle's config. We want to make sure
# that the draft model won't go above its max in warmup
# in this test.
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
if enable_chunked_prefill:
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
llm_common_config['max_num_tokens'] = 64
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
)
spec_config._allow_chain_drafter = use_chain_drafter
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
)
# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
# Acceptance rate tests
if enable_chunked_prefill:
# Use a long prompt for chunked prefill tests.
prompts = [
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
]
tok_ids = [llm_spec.tokenizer.encode(prompts[0])]
else:
prompts = [
"The capital of France is",
"The president of the United States is",
]
tok_ids = [llm_spec.tokenizer.encode("The future of AI is")]
if multi_batch:
tok_ids.append(llm_spec.tokenizer.encode(prompts))
sampling_params = SamplingParams(max_tokens=128, temperature=0)
for i in range(len(tok_ids)):
num_tokens = 0
num_drafted = 0
num_accepted = 0
for output in llm_spec.generate_async(tok_ids[i],
sampling_params,
streaming=True):
new_tokens = output.outputs[0].token_ids
num_drafted += max_draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)
accept_rate = num_accepted / num_drafted
assert accept_rate > 0.15
# Output tests
sampling_params = SamplingParams(max_tokens=10, temperature=0)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [
result.outputs[0].text for result in results_spec
# Acceptance rate tests
if enable_chunked_prefill:
# Use a long prompt for chunked prefill tests.
prompts = [
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
]
llm_spec.shutdown()
tok_ids = [llm_spec.tokenizer.encode(prompts[0])]
else:
prompts = [
"The capital of France is",
"The president of the United States is",
]
tok_ids = [llm_spec.tokenizer.encode("The future of AI is")]
if multi_batch:
tok_ids.append(llm_spec.tokenizer.encode(prompts))
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()
sampling_params = SamplingParams(max_tokens=128, temperature=0)
for i in range(len(tok_ids)):
num_tokens = 0
num_drafted = 0
num_accepted = 0
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
for output in llm_spec.generate_async(tok_ids[i],
sampling_params,
streaming=True):
new_tokens = output.outputs[0].token_ids
num_drafted += max_draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)
accept_rate = num_accepted / num_drafted
assert accept_rate > 0.15
# Output tests
sampling_params = SamplingParams(max_tokens=10, temperature=0)
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
llm_spec.shutdown()
llm_ref = LLM(**llm_common_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
llm_ref.shutdown()
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
def test_deepseek_eagle3():
@ -436,5 +418,55 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
llm_spec.shutdown()
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
"""Test CDL sampling with 2 requests and max_batch_size=2."""
attn_backend = "TRTLLM"
enable_block_reuse = False
use_one_model = False
enable_chunked_prefill = False
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
pytest.skip("Not enough memory to load target + draft model")
models_path = llm_models_root()
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
max_batch_size = 1
max_draft_len = 4
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
max_tokens=8192)
cuda_graph_config = CudaGraphConfig(batch_sizes=[1, 2, 4],
enable_padding=True)
llm_common_config = dict(
model=target_model_dir,
attn_backend=attn_backend,
disable_overlap_scheduler=disable_overlap_scheduler,
cuda_graph_config=cuda_graph_config,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
max_seq_len=8192,
enable_chunked_prefill=enable_chunked_prefill,
)
spec_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model_dir,
eagle3_one_model=use_one_model,
)
# Create the LLM instance
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
prompts = ["The president of the United States is"]
sampling_params = SamplingParams(max_tokens=20, temperature=0, top_p=0.9)
llm_spec.generate(prompts, sampling_params)
llm_spec.shutdown()
if __name__ == "__main__":
unittest.main()