mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5517404][fix] Use the correct cuda graph for dynamic spec dec (#7728)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
4509d97780
commit
897c4dd23b
@ -40,9 +40,6 @@ class CUDAGraphRunner:
|
||||
self.max_beam_width = engine.max_beam_width
|
||||
self.spec_config = engine.spec_config
|
||||
|
||||
self.max_possible_draft_len = (self.spec_config.max_draft_len
|
||||
if self.enable_spec_decode else 0)
|
||||
|
||||
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
|
||||
self.graph_outputs: Dict[Tuple[int, int],
|
||||
Callable[[], Optional[torch.Tensor]]] = {}
|
||||
@ -58,7 +55,7 @@ class CUDAGraphRunner:
|
||||
"""Allocates static tensors sized for the largest possible batch."""
|
||||
engine = self._get_engine()
|
||||
|
||||
token_per_request = self.max_possible_draft_len + 1
|
||||
token_per_request = self.draft_len + 1
|
||||
max_total_tokens = (self.max_supported_batch_size *
|
||||
self.max_beam_width * token_per_request)
|
||||
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
|
||||
@ -78,7 +75,7 @@ class CUDAGraphRunner:
|
||||
|
||||
@property
|
||||
def enable_spec_decode(self):
|
||||
return self._get_engine().is_spec_decode
|
||||
return self._get_engine().enable_spec_decode
|
||||
|
||||
@property
|
||||
def draft_len(self):
|
||||
@ -174,7 +171,7 @@ class CUDAGraphRunner:
|
||||
# [CUDA graph spec decode padding]
|
||||
# We pad input IDs/position IDs to the maximum draft length (token per request).
|
||||
# We're forced to do this because we cannot reallocate inputs over many graph runs.
|
||||
token_per_request = self.max_possible_draft_len + 1
|
||||
token_per_request = self.draft_len + 1
|
||||
num_tokens_for_capture = (batch_size * self.max_beam_width *
|
||||
token_per_request)
|
||||
|
||||
|
||||
@ -1511,7 +1511,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
prompt_lengths.append(1 + self.runtime_draft_len)
|
||||
else:
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
|
||||
for request in generation_requests:
|
||||
request_ids.append(request.py_request_id)
|
||||
beam_width = request.sampling_config.beam_width
|
||||
@ -1534,7 +1533,6 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if beam == first_beam:
|
||||
previous_batch_indices.append(request.py_batch_idx)
|
||||
past_seen_token_num = request.max_beam_num_tokens
|
||||
|
||||
position_ids.append(past_seen_token_num)
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
|
||||
@ -1198,13 +1198,18 @@ class PyExecutor:
|
||||
previous_tensors = self.previous_batch and self.previous_batch.sample_state
|
||||
target_inputs = None
|
||||
draft_outputs = None
|
||||
if self.drafter is not None and self.use_spec_decode:
|
||||
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
|
||||
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
|
||||
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
|
||||
use_previous_draft_tokens = self.has_previous_draft_tokens
|
||||
if self.drafter is not None and (self.use_spec_decode or
|
||||
use_previous_draft_tokens):
|
||||
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
|
||||
scheduled_batch, previous_tensors)
|
||||
|
||||
# Use the draft_model's outputs if we've launched the draft model.
|
||||
# Otherwise, use the previous batch's outputs.
|
||||
if target_inputs is not None:
|
||||
if target_inputs is not None or use_previous_draft_tokens:
|
||||
previous_tensors_device = target_inputs
|
||||
else:
|
||||
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
|
||||
@ -1215,7 +1220,7 @@ class PyExecutor:
|
||||
if target_inputs is not None:
|
||||
self._process_draft_results(scheduled_batch,
|
||||
draft_outputs, draft_batch)
|
||||
elif self.previous_batch is not None:
|
||||
elif self.previous_batch is not None and not use_previous_draft_tokens:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
@ -1968,19 +1973,21 @@ class PyExecutor:
|
||||
self.inflight_req_ids.erase(req.request_id)
|
||||
|
||||
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
|
||||
with request_context(is_draft=True, scheduled_requests=scheduled_batch):
|
||||
with request_context(is_draft=self.draft_model_engine is not None,
|
||||
scheduled_requests=scheduled_batch):
|
||||
# Do an early checking to see if we need to forward the draft model.
|
||||
# If needed, the overlap should happen between the target requests and the draft requests.
|
||||
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
|
||||
has_draft_batch = (
|
||||
self.previous_batch is not None
|
||||
self.previous_batch is not None and self.use_spec_decode
|
||||
and self.drafter.should_forward_draft_model(scheduled_batch))
|
||||
|
||||
if has_draft_batch:
|
||||
if has_draft_batch or self.has_previous_draft_tokens:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
if self.has_previous_draft_tokens:
|
||||
self._prepare_draft_requests()
|
||||
|
||||
if has_draft_batch:
|
||||
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
|
||||
scheduled_batch, self.resource_manager,
|
||||
previous_tensors.device if previous_tensors else None)
|
||||
@ -1997,26 +2004,27 @@ class PyExecutor:
|
||||
"""
|
||||
Append the draft tokens to the target requests, and clean up the draft resources.
|
||||
"""
|
||||
req_id_to_old_request = {
|
||||
req.py_request_id: req
|
||||
for req in scheduled_batch.all_requests()
|
||||
}
|
||||
with request_context(is_draft=self.draft_model_engine is not None,
|
||||
scheduled_requests=scheduled_batch):
|
||||
req_id_to_old_request = {
|
||||
req.py_request_id: req
|
||||
for req in scheduled_batch.all_requests()
|
||||
}
|
||||
|
||||
if self.drafter.use_static_draft_loop:
|
||||
self.drafter.process_static_draft_outputs(draft_outputs,
|
||||
draft_batch,
|
||||
req_id_to_old_request)
|
||||
elif draft_outputs is not None:
|
||||
self.drafter.process_dynamic_draft_outputs(draft_outputs,
|
||||
req_id_to_old_request)
|
||||
if self.drafter.use_static_draft_loop:
|
||||
self.drafter.process_static_draft_outputs(
|
||||
draft_outputs, draft_batch, req_id_to_old_request)
|
||||
elif draft_outputs is not None:
|
||||
self.drafter.process_dynamic_draft_outputs(
|
||||
draft_outputs, req_id_to_old_request)
|
||||
|
||||
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
|
||||
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
|
||||
# add_batch must be called again to restore to target requests with updated draft tokens.
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.add_batch(scheduled_batch)
|
||||
if hasattr(self.drafter, "guided_decoder"):
|
||||
self.guided_decoder.rollback_draft_tokens()
|
||||
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
|
||||
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
|
||||
# add_batch must be called again to restore to target requests with updated draft tokens.
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.add_batch(scheduled_batch)
|
||||
if hasattr(self.drafter, "guided_decoder"):
|
||||
self.guided_decoder.rollback_draft_tokens()
|
||||
|
||||
|
||||
class DisaggPPTerminationHandler:
|
||||
|
||||
@ -41,11 +41,10 @@ from .model_engine import PyTorchModelEngine
|
||||
from .py_executor import PyExecutor
|
||||
|
||||
|
||||
# Development flag to control chain drafter feature
|
||||
# Development function to control chain drafter feature.
|
||||
# It's here so that unit tests can mock it and turn it off.
|
||||
def _get_allow_chain_drafter() -> bool:
|
||||
"""Get the chain drafter flag from environment variable."""
|
||||
# Use environment variable for cross-process compatibility
|
||||
return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") == "1"
|
||||
return True
|
||||
|
||||
|
||||
class _ExecutorCreationStage(enum.Enum):
|
||||
|
||||
@ -563,6 +563,9 @@ class TorchSampler(Sampler):
|
||||
if get_draft_token_length(req) > 0:
|
||||
req.py_num_accepted_draft_tokens = num_accepted
|
||||
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
|
||||
else:
|
||||
req.py_num_accepted_draft_tokens = 0
|
||||
req.py_rewind_len = 0
|
||||
processed += num_accepted
|
||||
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
|
||||
req.py_decoding_iter += 1
|
||||
|
||||
@ -396,6 +396,7 @@ class ModelDrafter(Drafter):
|
||||
new_tokens_lens = None
|
||||
next_draft_tokens = None
|
||||
has_draft_tokens = False
|
||||
batch_size = new_tokens.shape[1]
|
||||
# Iterate through generation requests and copy tokens based on accepted draft tokens
|
||||
for request in scheduled_batch.all_requests():
|
||||
idx = request.py_seq_slot
|
||||
@ -411,9 +412,8 @@ class ModelDrafter(Drafter):
|
||||
|
||||
if has_draft_tokens:
|
||||
# We already updated the target state, so the new_tokens_lens should be all ones.
|
||||
new_tokens_lens = torch.ones(scheduled_batch.batch_size,
|
||||
device=device)
|
||||
next_draft_tokens = torch.zeros(scheduled_batch.batch_size,
|
||||
new_tokens_lens = torch.ones(batch_size, device=device)
|
||||
next_draft_tokens = torch.zeros(batch_size,
|
||||
self.max_draft_tokens,
|
||||
device=device)
|
||||
|
||||
@ -438,15 +438,15 @@ class ModelDrafter(Drafter):
|
||||
Update target inputs with new draft tokens from sample state.
|
||||
"""
|
||||
if draft_tensors is not None:
|
||||
for request in draft_batch.all_requests():
|
||||
for req_idx, request in enumerate(draft_batch.all_requests()):
|
||||
# Skip prefill requests
|
||||
if target_inputs.next_draft_tokens is None:
|
||||
continue
|
||||
|
||||
# Get the index of the draft/target tokens in the device tensor
|
||||
draft_idx = request.py_seq_slot
|
||||
draft_idx = req_idx if self.use_static_draft_loop else request.py_batch_idx
|
||||
target_idx = req_id_to_old_request[
|
||||
request.py_request_id].py_seq_slot
|
||||
request.py_request_id].py_batch_idx
|
||||
target_inputs.new_tokens[draft_position + 1:draft_position +
|
||||
draft_length + 1, target_idx,
|
||||
0] = draft_tensors[0:draft_length,
|
||||
|
||||
@ -188,6 +188,7 @@ def create_mock_engine(batch_size: int):
|
||||
max_beam_width=1,
|
||||
max_num_tokens=8192,
|
||||
is_spec_decode=False,
|
||||
enable_spec_decode=False,
|
||||
spec_config=None,
|
||||
_cuda_graph_mem_pool=None,
|
||||
use_mrope=False,
|
||||
|
||||
@ -1,22 +1,32 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
|
||||
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
KvCacheConfig)
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def enforce_single_worker(monkeypatch):
|
||||
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
|
||||
@pytest.mark.high_cuda_memory
|
||||
def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
|
||||
def test_dynamic_spec_decode(enforce_single_worker,
|
||||
disable_overlap_scheduler: bool):
|
||||
# mock_should_use_spec_decode doesn't work with multiple processes,
|
||||
# so we use the enforce_single_worker fixture to set the environment variable.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
@ -51,32 +61,42 @@ def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
|
||||
eagle3_one_model=False,
|
||||
)
|
||||
|
||||
# Mock should_use_spec_decode to return True for first two calls, then False
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
sampling_params = SamplingParams(max_tokens=128, temperature=0)
|
||||
|
||||
# Output tests
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=20, temperature=0)
|
||||
|
||||
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
|
||||
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
|
||||
max_draft_len):
|
||||
if not hasattr(mock_should_use_spec_decode, 'call_count'):
|
||||
mock_should_use_spec_decode.call_count = 0
|
||||
mock_should_use_spec_decode.call_count += 1
|
||||
return mock_should_use_spec_decode.call_count <= 2
|
||||
for req in requests:
|
||||
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
continue
|
||||
|
||||
mock_should_use_spec_decode.call_count += 1
|
||||
# Turn off spec decode when we've called it 5 times.
|
||||
# In the current case, at the 5th call, there are 2 accepted draft tokens,
|
||||
# so we can have better coverage for the switching between spec decode on and off.
|
||||
if mock_should_use_spec_decode.call_count > 5:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Create a Mock object with the mock function as side_effect
|
||||
mock_should_use_spec_decode = Mock(side_effect=mock_should_use_spec_decode)
|
||||
# Reset mock state before using it
|
||||
mock_should_use_spec_decode.reset_mock()
|
||||
mock_should_use_spec_decode.call_count = 0
|
||||
|
||||
with patch(
|
||||
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
|
||||
side_effect=mock_should_use_spec_decode):
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
sampling_params = SamplingParams(max_tokens=128, temperature=0)
|
||||
|
||||
# Output tests
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=10, temperature=0)
|
||||
|
||||
mock_should_use_spec_decode):
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [
|
||||
result.outputs[0].text for result in results_spec
|
||||
]
|
||||
llm_spec.shutdown()
|
||||
generated_text_spec = [result.outputs[0].text for result in results_spec]
|
||||
llm_spec.shutdown()
|
||||
|
||||
llm_ref = LLM(**llm_common_config)
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
|
||||
@ -4,6 +4,7 @@ import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -16,35 +17,50 @@ from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def enforce_single_worker(monkeypatch):
|
||||
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter",
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch",
|
||||
[
|
||||
[True, "TRTLLM", True, False, False, False, True],
|
||||
[True, "TRTLLM", True, False, False, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, True],
|
||||
[False, "TRTLLM", True, False, False, False, False],
|
||||
[True, "FLASHINFER", True, False, False, False, True],
|
||||
[False, "FLASHINFER", True, False, False, False, True],
|
||||
[False, "TRTLLM", False, True, True, False, True],
|
||||
[True, "TRTLLM", False, True, True, False, True],
|
||||
[True, "TRTLLM", True, False, True, True, True],
|
||||
[True, "TRTLLM", True, False, True, False, True],
|
||||
[True, "TRTLLM", True, False, False, False, True, False],
|
||||
[True, "TRTLLM", True, False, False, False, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, True, False],
|
||||
[False, "TRTLLM", True, False, False, False, False, False],
|
||||
[True, "FLASHINFER", True, False, False, False, True, False],
|
||||
[False, "FLASHINFER", True, False, False, False, True, False],
|
||||
[False, "TRTLLM", False, True, True, False, True, False],
|
||||
[True, "TRTLLM", False, True, True, False, True, False],
|
||||
[True, "TRTLLM", True, False, True, True, True, False],
|
||||
[True, "TRTLLM", True, False, True, False, True, False],
|
||||
# TODO: nvbugs/5461761
|
||||
# [True, "TRTLLM", True, False, False, True, True],
|
||||
[True, "TRTLLM", False, False, False, False, True],
|
||||
[False, "TRTLLM", False, False, False, False, True],
|
||||
[True, "TRTLLM", False, False, False, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, True],
|
||||
[True, "TRTLLM", False, False, False, True, False],
|
||||
[True, "FLASHINFER", False, False, False, False, True],
|
||||
[False, "FLASHINFER", False, False, False, False, True],
|
||||
# [True, "TRTLLM", True, False, False, True, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True],
|
||||
[False, "TRTLLM", False, False, False, False, False, True],
|
||||
[True, "TRTLLM", False, False, False, False, True, True],
|
||||
[False, "TRTLLM", False, False, False, False, True, True],
|
||||
[True, "TRTLLM", False, False, False, False, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, True, False],
|
||||
[True, "TRTLLM", False, False, False, True, False, False],
|
||||
[True, "FLASHINFER", False, False, False, False, True, False],
|
||||
[False, "FLASHINFER", False, False, False, False, True, False],
|
||||
])
|
||||
@pytest.mark.high_cuda_memory
|
||||
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
disable_overlap_scheduler: bool, enable_block_reuse: bool,
|
||||
use_one_model: bool, enable_chunked_prefill: bool,
|
||||
use_chain_drafter: bool):
|
||||
use_chain_drafter: bool, multi_batch: bool, request):
|
||||
# Use enforce_single_worker fixture only when use_chain_drafter is False.
|
||||
# Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
|
||||
if not use_chain_drafter:
|
||||
request.getfixturevalue('enforce_single_worker')
|
||||
|
||||
# Eagle3 one model works with overlap scheduler and block reuse.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
@ -54,46 +70,52 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 1
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
# Mock _get_allow_chain_drafter to return False when use_chain_drafter is False
|
||||
if not use_chain_drafter:
|
||||
patch_context = patch(
|
||||
'tensorrt_llm._torch.pyexecutor.py_executor_creator._get_allow_chain_drafter',
|
||||
return_value=False)
|
||||
else:
|
||||
patch_context = patch(
|
||||
'tensorrt_llm._torch.pyexecutor.py_executor_creator._get_allow_chain_drafter',
|
||||
return_value=True)
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
# This max_seq_len is larger than the one specified
|
||||
# in the llama 3 8B eagle's config. We want to make sure
|
||||
# that the draft model won't go above its max in warmup
|
||||
# in this test.
|
||||
max_seq_len=8192,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
if enable_chunked_prefill:
|
||||
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
|
||||
llm_common_config['max_num_tokens'] = 64
|
||||
with patch_context:
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 4 if multi_batch else 1
|
||||
max_draft_len = 4
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
kv_cache_config=kv_cache_config,
|
||||
# This max_seq_len is larger than the one specified
|
||||
# in the llama 3 8B eagle's config. We want to make sure
|
||||
# that the draft model won't go above its max in warmup
|
||||
# in this test.
|
||||
max_seq_len=8192,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
if enable_chunked_prefill:
|
||||
# Use a small max_num_tokens so that the chunked prefill path gets exercised.
|
||||
llm_common_config['max_num_tokens'] = 64
|
||||
|
||||
# Set the development flag to control use_chain_drafter behavior
|
||||
original_env_value = os.environ.get("TRTLLM_ALLOW_CHAIN_DRAFTER", "0")
|
||||
try:
|
||||
os.environ[
|
||||
"TRTLLM_ALLOW_CHAIN_DRAFTER"] = "1" if use_chain_drafter else "0"
|
||||
# Create the LLM instance with the mocked flag controlling use_chain_drafter
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
)
|
||||
|
||||
# Create the LLM instance
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
# Acceptance rate tests
|
||||
@ -142,9 +164,6 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
|
||||
# The spec decode algorithm currently guarantees identical results
|
||||
assert text_spec == text_ref
|
||||
finally:
|
||||
# Restore the original environment variable value
|
||||
os.environ["TRTLLM_ALLOW_CHAIN_DRAFTER"] = original_env_value
|
||||
|
||||
|
||||
def test_deepseek_eagle3():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user