TensorRT-LLMs/tests/integration/defs/examples/test_ad_speculative_decoding.py
gramnarayan a9eb5afc9f
[#9241][feat] AutoDeploy: Support Eagle3 Speculative Decoding (#9869)
Support two model flow with no overlap scheduler or chain drafter. Drafting model is in PyTorch backend.

Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
2025-12-24 23:30:42 -05:00

279 lines
9.7 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from build_and_run_ad import ExperimentConfig, main
from defs.conftest import llm_models_root
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.auto_deploy.llm import LLM
from tensorrt_llm.llmapi import DraftTargetDecodingConfig, EagleDecodingConfig, KvCacheConfig
prompts = [
"What is the capital of France?",
"Please explain the concept of gravity in simple words and a single sentence.",
]
EAGLE_MODEL_SUBPATH = "EAGLE3-LLaMA3.1-Instruct-8B"
LLAMA_BASE_SUBPATH = "llama-3.1-model/Llama-3.1-8B-Instruct"
DRAFT_TARGET_MAX_DRAFT_LEN = 3
EAGLE_MAX_DRAFT_LEN = 3
def get_model_paths():
"""Get model paths using llm_models_root()."""
models_root = llm_models_root()
base_model = os.path.join(models_root, LLAMA_BASE_SUBPATH)
draft_target_model = os.path.join(
models_root,
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
)
eagle_model = os.path.join(models_root, EAGLE_MODEL_SUBPATH)
print(f"Base model path: {base_model}")
print(f"DraftTarget draft model path: {draft_target_model}")
print(f"EAGLE model path: {eagle_model}")
return base_model, draft_target_model, eagle_model
def make_draft_target_config(spec_model_path: str):
return DraftTargetDecodingConfig(
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model_dir=spec_model_path
)
def make_eagle3_config(spec_model_path: str):
return EagleDecodingConfig(
max_draft_len=EAGLE_MAX_DRAFT_LEN,
speculative_model_dir=spec_model_path,
eagle3_one_model=False,
eagle3_layers_to_capture=None,
)
def run_with_autodeploy(model, speculative_config, batch_size):
"""Run AutoDeploy with or without speculative decoding.
Args:
model: Path to the base model
speculative_config: Speculative decoding config (None for baseline mode)
batch_size: Number of prompts to process
Returns:
List of (prompt, output) tuples from prompts_and_outputs
"""
# Select prompts based on batch size
selected_prompts = prompts[:batch_size]
spec_config = speculative_config
# Configure KV cache
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.01,
)
# Configure AutoDeploy LLM arguments
llm_args = {
"model": model,
"skip_loading_weights": False,
"speculative_config": spec_config,
"runtime": "trtllm",
"world_size": 1,
"kv_cache_config": kv_cache_config,
"disable_overlap_scheduler": True,
"max_num_tokens": 64,
}
# Configure experiment with prompts
experiment_config = {
"args": llm_args,
"benchmark": {"enabled": False},
"prompt": {
"batch_size": batch_size,
"queries": selected_prompts,
},
}
# Create ExperimentConfig
cfg = ExperimentConfig(**experiment_config)
# Add sampling parameters (deterministic with temperature=0.0 and fixed seed)
cfg.prompt.sp_kwargs = {
"max_tokens": 50,
"top_k": None,
"temperature": 0.0,
"seed": 42,
}
# Run the experiment
result = main(cfg)
# Extract and return prompts_and_outputs
assert "prompts_and_outputs" in result, "Result should contain 'prompts_and_outputs'"
return result["prompts_and_outputs"]
# Note: This test tests exact equality of outputs between speculative and baseline modes.
# This can fail for larger batch sizes due to nondeterminism with in flight batching.
# TODO: Figure out a robust test for output correctness that can pass for larger batch sizes.
@pytest.mark.parametrize("spec_dec_mode", ["draft_target", "eagle3"])
def test_autodeploy_spec_dec_output(spec_dec_mode):
"""Test AutoDeploy speculative decoding output correctness.
Runs with and without speculative decoding and verifies outputs are identical.
"""
print("\n" + "=" * 80)
print(f"Testing AutoDeploy Speculative Decoding ({spec_dec_mode}) - Output Correctness")
print("=" * 80)
base_model, draft_target_model, eagle_model = get_model_paths()
# Select model and config based on mode
if spec_dec_mode == "draft_target":
spec_model = draft_target_model
spec_config = make_draft_target_config(spec_model)
elif spec_dec_mode == "eagle3": # eagle3
spec_model = eagle_model
spec_config = make_eagle3_config(spec_model)
else:
raise ValueError(f"Unsupported speculative decoding mode: {spec_dec_mode}")
print(f"\nBase Model: {base_model}")
print(f"Speculative Model ({spec_dec_mode}): {spec_model}")
# Run with speculative decoding
print("\n[1/2] Running with speculative decoding enabled...")
spec_outputs = run_with_autodeploy(
model=base_model,
speculative_config=spec_config,
batch_size=1,
)
print(f"Generated {len(spec_outputs)} outputs with speculative decoding")
# Run without speculative decoding (baseline)
print("\n[2/2] Running without speculative decoding (baseline)...")
baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1)
print(f"Generated {len(baseline_outputs)} outputs in baseline mode")
# Verify outputs are identical
print("\nVerifying outputs are identical...")
assert len(spec_outputs) == len(baseline_outputs), (
f"Number of outputs mismatch: spec={len(spec_outputs)}, baseline={len(baseline_outputs)}"
)
for i, ((spec_prompt, spec_output), (baseline_prompt, baseline_output)) in enumerate(
zip(spec_outputs, baseline_outputs, strict=True)
):
print(f"\n[Output {i}]")
print(f" Prompt: {spec_prompt}")
print("================================================")
print(f" Spec Output: {spec_output}")
print("================================================")
print(f" Baseline Output: {baseline_output}")
print("================================================")
assert spec_prompt == baseline_prompt, f"Prompts differ at index {i}"
assert spec_output == baseline_output, (
f"Outputs differ at index {i}:\n\n Spec: {spec_output}\n\n Baseline: {baseline_output}\n\n"
)
print("\n" + "=" * 80)
print("SUCCESS! All outputs are identical between spec-dec and baseline modes")
print("=" * 80)
def test_autodeploy_eagle3_acceptance_rate():
"""Test Eagle3 acceptance rate with AutoDeploy engine.
Runs Eagle3 speculative decoding with streaming and verifies
that the acceptance rate is above a minimum threshold.
"""
print("\n" + "=" * 80)
print("Testing AutoDeploy Eagle3 Acceptance Rate")
print("=" * 80)
base_model, _, eagle_model = get_model_paths()
print(f"\nBase Model: {base_model}")
print(f"Eagle3 Model: {eagle_model}")
max_draft_len = EAGLE_MAX_DRAFT_LEN
# Configure Eagle3 speculative decoding
speculative_config = EagleDecodingConfig(
max_draft_len=max_draft_len,
speculative_model_dir=eagle_model,
eagle3_one_model=False,
eagle3_layers_to_capture=None,
)
# Configure KV cache
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=0.01,
)
# Create AutoDeploy LLM with Eagle3 speculative decoding
# We directly instantiate the LLM class instead of using the main() function
# so that we can stream the outputs to see acceptance rates without needing to
# collect them in the executor.
llm = LLM(
model=base_model,
skip_loading_weights=False,
runtime="trtllm",
world_size=1,
kv_cache_config=kv_cache_config,
speculative_config=speculative_config,
disable_overlap_scheduler=True,
max_num_tokens=64,
)
# Tokenize 2 prompts to test multiple sequential requests
batch_tok_ids = [llm.tokenizer.encode(p) for p in prompts[:2]]
sampling_params = SamplingParams(max_tokens=128, temperature=0, seed=42)
print("\nRunning Eagle3 speculative decoding with streaming...")
# Process each request sequentially and verify acceptance rate
for i in range(len(batch_tok_ids)):
num_tokens = 0
num_drafted = 0
num_accepted = 0
for output in llm.generate_async(batch_tok_ids[i], sampling_params, streaming=True):
new_tokens = output.outputs[0].token_ids
num_drafted += max_draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)
accept_rate = num_accepted / num_drafted
print(f"\nRequest {i + 1} Acceptance Rate Statistics:")
print(f" Total tokens drafted: {num_drafted}")
print(f" Total tokens accepted: {num_accepted}")
print(f" Acceptance rate: {accept_rate:.2%}")
# Verify acceptance rate is above minimum threshold (10%)
min_acceptance_rate = 0.10
assert accept_rate > min_acceptance_rate, (
f"Request {i + 1}: Acceptance rate {accept_rate:.2%} is below minimum threshold {min_acceptance_rate:.0%}"
)
print("\n" + "=" * 80)
print("SUCCESS! All requests passed acceptance rate threshold")
print("=" * 80)