mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-7384][feat] enable rejection sampling for CDL (#7731)
Signed-off-by: linquanh <linquanh@nvidia.com>
This commit is contained in:
parent
5798a12199
commit
a7ea544dbe
@ -456,6 +456,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.use_draft_model = is_draft
|
||||
# Whether the request is for the first forward of the draft model.
|
||||
self.py_is_first_draft = is_first_draft
|
||||
self.d2t = None
|
||||
self.py_draft_use_greedy_sampling = False
|
||||
|
||||
# Chunked logits parameters
|
||||
self.py_use_chunked_generation_logits = use_chunked_generation_logits
|
||||
|
||||
@ -31,7 +31,6 @@ from ..attention_backend.interface import AttentionRuntimeFeatures
|
||||
from ..distributed import MPIDist, TorchDist
|
||||
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
|
||||
get_spec_resource_manager)
|
||||
from ..utils import _get_allow_chain_drafter
|
||||
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
|
||||
create_py_executor_instance, instantiate_sampler, is_mla,
|
||||
validate_feature_combination)
|
||||
@ -344,13 +343,11 @@ def create_py_executor(
|
||||
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
|
||||
draft_spec_config = copy.copy(spec_config)
|
||||
|
||||
if _get_allow_chain_drafter():
|
||||
use_chain_drafter = (
|
||||
guided_decoding_config is None
|
||||
and draft_spec_config._allow_greedy_draft_tokens
|
||||
and pytorch_backend_config.attn_backend == "TRTLLM")
|
||||
else:
|
||||
use_chain_drafter = False
|
||||
use_chain_drafter = (
|
||||
guided_decoding_config is None
|
||||
and draft_spec_config._allow_chain_drafter
|
||||
and draft_spec_config._allow_greedy_draft_tokens
|
||||
and pytorch_backend_config.attn_backend == "TRTLLM")
|
||||
|
||||
logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
|
||||
if use_chain_drafter:
|
||||
|
||||
@ -310,10 +310,15 @@ def greedy_search_sampling_batch(
|
||||
softmax_indices: Optional[torch.IntTensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
next_tokens = torch.argmax(logits, dim=-1)
|
||||
index_to_scatter = next_tokens
|
||||
if softmax_indices is not None:
|
||||
logits = logits[softmax_indices.to(logits.device, non_blocking=True)]
|
||||
softmax = torch.softmax(logits, dim=-1)
|
||||
return next_tokens, softmax
|
||||
logits = logits[softmax_indices]
|
||||
index_to_scatter = next_tokens[softmax_indices]
|
||||
probs = torch.zeros_like(logits)
|
||||
probs.scatter_(dim=-1,
|
||||
index=index_to_scatter.unsqueeze(-1),
|
||||
src=torch.ones_like(logits))
|
||||
return next_tokens, probs
|
||||
|
||||
|
||||
def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor,
|
||||
@ -1127,6 +1132,7 @@ class TorchSampler(Sampler):
|
||||
|
||||
return new_draft_tokens_host
|
||||
|
||||
@torch.inference_mode()
|
||||
def _process_draft_tokens_rejection_sampling(
|
||||
self, request: LlmRequest, new_tokens: list[list[list[int]]],
|
||||
new_tokens_tensor: torch.Tensor) -> int:
|
||||
@ -1134,13 +1140,30 @@ class TorchSampler(Sampler):
|
||||
# filtering of vocab_size logits, out of vocab_size in
|
||||
# total. The 'sample' below should generally be avoided
|
||||
# by retaining the draft_probs during drafting (TRTLLM-7772).
|
||||
sampling_strategy = _request_strategy(request, vocab_size=2**31)
|
||||
draft_sampling_strategy = (
|
||||
"greedy", None
|
||||
) if request.py_draft_use_greedy_sampling else _request_strategy(
|
||||
request, vocab_size=2**31)
|
||||
generator = self.get_generator(request.py_draft_logits.device)
|
||||
_, draft_probs = sample(sampling_strategy,
|
||||
_, draft_probs = sample(draft_sampling_strategy,
|
||||
request.py_draft_logits,
|
||||
generator=generator)
|
||||
draft_probs = draft_probs.squeeze(0)
|
||||
target_probs = request.py_target_probs
|
||||
d2t = getattr(request, "d2t", None)
|
||||
if d2t is not None:
|
||||
vocab_d = draft_probs.shape[-1]
|
||||
vocab_t = target_probs.shape[-1]
|
||||
assert d2t.numel(
|
||||
) == vocab_d, f"d2t size mismatch: {d2t.numel()} != {vocab_d}"
|
||||
assert d2t.device == draft_probs.device, f"d2t device mismatch: {d2t.device} != {draft_probs.device}"
|
||||
aligned_draft_probs = torch.zeros(
|
||||
(*draft_probs.shape[:-1], vocab_t),
|
||||
device=draft_probs.device,
|
||||
dtype=draft_probs.dtype)
|
||||
source_indices = torch.arange(vocab_d, device=draft_probs.device)
|
||||
target_indices = (source_indices + d2t) % vocab_t
|
||||
aligned_draft_probs[..., target_indices] = draft_probs
|
||||
draft_probs = aligned_draft_probs
|
||||
rejected_indices = get_rejected_indices(draft_probs, target_probs,
|
||||
generator,
|
||||
request.py_draft_tokens)
|
||||
@ -1181,7 +1204,8 @@ class TorchSampler(Sampler):
|
||||
new_tokens: list[list[list[int]]],
|
||||
new_tokens_tensor: torch.Tensor,
|
||||
resource_manager: Optional[ResourceManager] = None) -> int:
|
||||
if request.py_draft_logits is None:
|
||||
if _request_strategy(request, vocab_size=2**
|
||||
31) == GREEDY or request.py_draft_logits is None:
|
||||
spec_tree_manager = self.get_spec_tree_manager(resource_manager)
|
||||
if spec_tree_manager is not None:
|
||||
num_accepted = self._process_draft_tokens_tree(
|
||||
|
||||
@ -116,8 +116,7 @@ class ChainDrafter(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
**kwargs) -> dict[str, torch.Tensor]:
|
||||
logits = self.draft_model.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -126,6 +125,7 @@ class ChainDrafter(torch.nn.Module):
|
||||
logits = logits[spec_metadata.gather_ids]
|
||||
|
||||
new_draft_tokens = [self.sample(logits)]
|
||||
draft_logits = [logits]
|
||||
with save_metadata_state(attn_metadata, spec_metadata):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
|
||||
@ -139,13 +139,17 @@ class ChainDrafter(torch.nn.Module):
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
new_draft_tokens.append(self.sample(logits))
|
||||
draft_logits.append(logits)
|
||||
new_position_ids += 1
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
|
||||
spec_metadata.hidden_states_write_indices[:batch_size])
|
||||
|
||||
return torch.stack(new_draft_tokens)
|
||||
return {
|
||||
"new_draft_tokens": torch.stack(new_draft_tokens),
|
||||
"draft_logits": torch.stack(draft_logits)
|
||||
}
|
||||
|
||||
def sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: inject the sampler here so we can support non-greedy
|
||||
|
||||
@ -226,6 +226,12 @@ class ModelDrafter(Drafter):
|
||||
ScheduledRequests: The prepared draft batch
|
||||
"""
|
||||
try:
|
||||
for req in scheduled_requests.all_requests():
|
||||
draft_model = self.draft_model_engine.model.draft_model if self.use_static_draft_loop else self.draft_model_engine.model
|
||||
if hasattr(draft_model.model, "d2t"):
|
||||
req.d2t = draft_model.model.d2t.data
|
||||
req.py_draft_use_greedy_sampling = self.use_static_draft_loop
|
||||
|
||||
draft_batch = ScheduledRequests()
|
||||
|
||||
for request in scheduled_requests.context_requests:
|
||||
@ -526,7 +532,8 @@ class ModelDrafter(Drafter):
|
||||
return draft_batch, req_id_to_old_request
|
||||
|
||||
def process_static_draft_outputs(
|
||||
self, outputs: torch.Tensor | SampleState,
|
||||
self,
|
||||
outputs: dict[str, torch.Tensor] | tuple[torch.Tensor, SampleState],
|
||||
draft_batch: ScheduledRequests,
|
||||
req_id_to_old_request: Dict[int, LlmRequest]) -> None:
|
||||
"""
|
||||
@ -537,23 +544,26 @@ class ModelDrafter(Drafter):
|
||||
draft_batch: The draft batch that was processed
|
||||
req_id_to_old_request: Mapping from draft request ID to original request
|
||||
"""
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
# For non-overlap scheduler path.
|
||||
outputs_host = outputs.cpu()
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
draft_tokens_host = outputs["new_draft_tokens"].cpu()
|
||||
draft_logits = outputs["draft_logits"]
|
||||
else:
|
||||
outputs_host = outputs.host.new_tokens
|
||||
outputs.sampler_event.synchronize()
|
||||
draft_logits = outputs[0]
|
||||
draft_tokens_host = outputs[1].host.new_tokens
|
||||
outputs[1].sampler_event.synchronize()
|
||||
|
||||
for token_idx in range(self.max_draft_tokens):
|
||||
for req_idx, req in enumerate(draft_batch.all_requests()):
|
||||
target_model_req = req_id_to_old_request[req.py_request_id]
|
||||
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
# Chunked prefill request in progress; no need to append draft tokens
|
||||
continue
|
||||
|
||||
target_req = req_id_to_old_request[req.py_request_id]
|
||||
target_req.py_draft_tokens.append(
|
||||
outputs_host[token_idx][req_idx])
|
||||
for req_idx, req in enumerate(draft_batch.all_requests()):
|
||||
target_model_req = req_id_to_old_request[req.py_request_id]
|
||||
if target_model_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
# Chunked prefill request in progress; no need to append draft tokens
|
||||
continue
|
||||
py_draft_logits = []
|
||||
for token_idx in range(self.max_draft_tokens):
|
||||
target_model_req.py_draft_tokens.append(
|
||||
draft_tokens_host[token_idx][req_idx])
|
||||
py_draft_logits.append(draft_logits[token_idx][req_idx])
|
||||
target_model_req.py_draft_logits = torch.stack(py_draft_logits)
|
||||
|
||||
# Clean up draft resources
|
||||
for req in draft_batch.all_requests():
|
||||
@ -706,23 +716,26 @@ class ModelDrafter(Drafter):
|
||||
# Only update target inputs, cleanup will be done in executor loop
|
||||
self._update_target_inputs_with_draft_tokens(
|
||||
target_inputs,
|
||||
outputs,
|
||||
outputs["new_draft_tokens"],
|
||||
draft_position=0,
|
||||
draft_length=self.max_draft_tokens,
|
||||
draft_batch=draft_batch,
|
||||
req_id_to_old_request=req_id_to_old_request)
|
||||
|
||||
new_tokens_host = outputs.to(device='cpu', non_blocking=True)
|
||||
new_tokens_host = outputs["new_draft_tokens"].to(device='cpu',
|
||||
non_blocking=True)
|
||||
sampler_event = torch.cuda.Event()
|
||||
sampler_event.record()
|
||||
|
||||
outputs = SampleState(
|
||||
sample_state = SampleState(
|
||||
scheduled_requests=draft_batch,
|
||||
device=SampleStateTensors(new_tokens=outputs),
|
||||
device=SampleStateTensors(
|
||||
new_tokens=outputs["new_draft_tokens"]),
|
||||
host=SampleStateTensors(new_tokens=new_tokens_host),
|
||||
sampler_event=sampler_event)
|
||||
|
||||
return target_inputs, outputs, draft_batch
|
||||
return target_inputs, (outputs["draft_logits"],
|
||||
sample_state), draft_batch
|
||||
|
||||
# Handle guided decoder and sampling for non-static loop
|
||||
if self.guided_decoder is not None:
|
||||
|
||||
@ -308,12 +308,6 @@ def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
|
||||
)
|
||||
|
||||
|
||||
# Development function to control chain drafter feature.
|
||||
# It's here so that unit tests can mock it and turn it off.
|
||||
def _get_allow_chain_drafter() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_device_uuid(device_idx: int) -> str:
|
||||
"""Get the UUID of a CUDA device using torch cuda api"""
|
||||
|
||||
|
||||
@ -366,6 +366,8 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
|
||||
load_format: Optional[str] = None
|
||||
|
||||
# If set, drafting is allowed to use chain drafter.
|
||||
_allow_chain_drafter: bool = PrivateAttr(True)
|
||||
# If set, drafting uses greedy sampling, irrespective of sampling parameters.
|
||||
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -58,11 +57,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
use_one_model: bool, enable_chunked_prefill: bool,
|
||||
use_chain_drafter: bool, multi_batch: bool,
|
||||
attention_dp: bool, request):
|
||||
# Use enforce_single_worker fixture only when use_chain_drafter is False.
|
||||
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
|
||||
if not use_chain_drafter:
|
||||
request.getfixturevalue('enforce_single_worker')
|
||||
|
||||
# Eagle3 one model works with overlap scheduler and block reuse.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
@ -72,106 +66,94 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
|
||||
if not use_chain_drafter:
|
||||
patch_context = patch(
|
||||
'tensorrt_llm._torch.utils._get_allow_chain_drafter',
|
||||
return_value=False)
|
||||
else:
|
||||
patch_context = patch(
|
||||
'tensorrt_llm._torch.utils._get_allow_chain_drafter',
|
||||
return_value=True)
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 4 if multi_batch else 1
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[i for i in range(1, max_batch_size +
|
||||
1)]) if use_cuda_graph else None
|
||||
|
||||
with patch_context:
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 4 if multi_batch else 1
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[i for i in range(1, max_batch_size +
|
||||
1)]) if use_cuda_graph else None
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
# This max_seq_len is larger than the one specified
|
||||
# in the llama 3 8B eagle's config. We want to make sure
|
||||
# that the draft model won't go above its max in warmup
|
||||
# in this test.
|
||||
max_seq_len=8192,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
if enable_chunked_prefill:
|
||||
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
|
||||
llm_common_config['max_num_tokens'] = 64
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
# This max_seq_len is larger than the one specified
|
||||
# in the llama 3 8B eagle's config. We want to make sure
|
||||
# that the draft model won't go above its max in warmup
|
||||
# in this test.
|
||||
max_seq_len=8192,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
if enable_chunked_prefill:
|
||||
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
|
||||
llm_common_config['max_num_tokens'] = 64
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
spec_config._allow_chain_drafter = use_chain_drafter
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
# Create the LLM instance
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
# Create the LLM instance
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
# Acceptance rate tests
|
||||
if enable_chunked_prefill:
|
||||
# Use a long prompt for chunked prefill tests.
|
||||
prompts = [
|
||||
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
|
||||
]
|
||||
tok_ids = [llm_spec.tokenizer.encode(prompts[0])]
|
||||
else:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
tok_ids = [llm_spec.tokenizer.encode("The future of AI is")]
|
||||
if multi_batch:
|
||||
tok_ids.append(llm_spec.tokenizer.encode(prompts))
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=128, temperature=0)
|
||||
for i in range(len(tok_ids)):
|
||||
num_tokens = 0
|
||||
num_drafted = 0
|
||||
num_accepted = 0
|
||||
|
||||
for output in llm_spec.generate_async(tok_ids[i],
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
new_tokens = output.outputs[0].token_ids
|
||||
num_drafted += max_draft_len
|
||||
num_accepted += len(new_tokens) - num_tokens - 1
|
||||
num_tokens = len(new_tokens)
|
||||
|
||||
accept_rate = num_accepted / num_drafted
|
||||
assert accept_rate > 0.15
|
||||
|
||||
# Output tests
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0)
|
||||
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [
|
||||
result.outputs[0].text for result in results_spec
|
||||
# Acceptance rate tests
|
||||
if enable_chunked_prefill:
|
||||
# Use a long prompt for chunked prefill tests.
|
||||
prompts = [
|
||||
"The capital of France is a city of romance, art, fashion, and cuisine. Paris is a must-visit destination for anyone who loves history, architecture, and culture. From the iconic Eiffel Tower to the world-famous Louvre Museum, Paris has something to offer for every interest and age.\nThe city is divided into 20 arrondissements, each with its own unique character and charm. The Latin Quarter is a popular area for students and young travelers, while the Champs-Élysées is a hub for shopping and dining. The Montmartre neighborhood is famous for its bohemian vibe and stunning views of the city.\nParis is also known for its beautiful parks and gardens, such as the Luxembourg Gardens and the Tuileries Garden. The city has a rich history, with landmarks like the Notre-Dame Cathedral and the Arc de Triomphe. Visitors can also explore the city's many museums, including the Musée d'Orsay and the Musée Rodin.\nIn addition to its cultural and historical attractions, Paris is also a great destination for foodies. The city is famous for its cuisine, including croissants, baguettes, and cheese. Visitors can sample the city's famous dishes at one of the many restaurants, cafes, and "
|
||||
]
|
||||
llm_spec.shutdown()
|
||||
tok_ids = [llm_spec.tokenizer.encode(prompts[0])]
|
||||
else:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
tok_ids = [llm_spec.tokenizer.encode("The future of AI is")]
|
||||
if multi_batch:
|
||||
tok_ids.append(llm_spec.tokenizer.encode(prompts))
|
||||
|
||||
llm_ref = LLM(**llm_common_config)
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
llm_ref.shutdown()
|
||||
sampling_params = SamplingParams(max_tokens=128, temperature=0)
|
||||
for i in range(len(tok_ids)):
|
||||
num_tokens = 0
|
||||
num_drafted = 0
|
||||
num_accepted = 0
|
||||
|
||||
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
|
||||
# The spec decode algorithm currently guarantees identical results
|
||||
assert text_spec == text_ref
|
||||
for output in llm_spec.generate_async(tok_ids[i],
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
new_tokens = output.outputs[0].token_ids
|
||||
num_drafted += max_draft_len
|
||||
num_accepted += len(new_tokens) - num_tokens - 1
|
||||
num_tokens = len(new_tokens)
|
||||
|
||||
accept_rate = num_accepted / num_drafted
|
||||
assert accept_rate > 0.15
|
||||
|
||||
# Output tests
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0)
|
||||
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [result.outputs[0].text for result in results_spec]
|
||||
llm_spec.shutdown()
|
||||
|
||||
llm_ref = LLM(**llm_common_config)
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
llm_ref.shutdown()
|
||||
|
||||
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
|
||||
# The spec decode algorithm currently guarantees identical results
|
||||
assert text_spec == text_ref
|
||||
|
||||
|
||||
def test_deepseek_eagle3():
|
||||
@ -436,5 +418,55 @@ def test_eagle3_cuda_graph_padding(disable_overlap_scheduler: bool):
|
||||
llm_spec.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
|
||||
def test_eagle3_cdl_sampling(disable_overlap_scheduler: bool):
|
||||
"""Test CDL sampling with 2 requests and max_batch_size=2."""
|
||||
attn_backend = "TRTLLM"
|
||||
enable_block_reuse = False
|
||||
use_one_model = False
|
||||
enable_chunked_prefill = False
|
||||
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
max_batch_size = 1
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
cuda_graph_config = CudaGraphConfig(batch_sizes=[1, 2, 4],
|
||||
enable_padding=True)
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_seq_len=8192,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
|
||||
# Create the LLM instance
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
prompts = ["The president of the United States is"]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=20, temperature=0, top_p=0.9)
|
||||
llm_spec.generate(prompts, sampling_params)
|
||||
llm_spec.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user