# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re from dataclasses import dataclass from pathlib import Path from typing import Optional, Set import pytest import torch import torch.nn as nn from build_and_run_ad import ExperimentConfig, main from defs.conftest import llm_models_root from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.masking_utils import create_causal_mask from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel from transformers.utils.generic import ModelOutput from tensorrt_llm import SamplingParams from tensorrt_llm._torch.auto_deploy.llm import LLM from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( EagleWrapper, EagleWrapperConfig, ) from tensorrt_llm._torch.auto_deploy.models.eagle import EagleDrafterFactory from tensorrt_llm.llmapi import DraftTargetDecodingConfig, Eagle3DecodingConfig, KvCacheConfig prompts = [ "What is the capital of France?", "Please explain the concept of gravity in simple words and a single sentence.", ] EAGLE_MODEL_SUBPATH = "EAGLE3-LLaMA3.1-Instruct-8B" LLAMA_BASE_SUBPATH = "llama-3.1-model/Llama-3.1-8B-Instruct" DRAFT_TARGET_MAX_DRAFT_LEN = 3 EAGLE_MAX_DRAFT_LEN = 3 def get_model_paths(): """Get model paths using llm_models_root().""" models_root = llm_models_root() base_model = os.path.join(models_root, LLAMA_BASE_SUBPATH) draft_target_model = os.path.join( models_root, "llama-models-v2/TinyLlama-1.1B-Chat-v1.0", ) eagle_model = os.path.join(models_root, EAGLE_MODEL_SUBPATH) print(f"Base model path: {base_model}") print(f"DraftTarget draft model path: {draft_target_model}") print(f"EAGLE model path: {eagle_model}") return base_model, draft_target_model, eagle_model def make_draft_target_config(spec_model_path: str): return DraftTargetDecodingConfig( max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model=spec_model_path ) def make_eagle3_config(spec_model_path: str): return Eagle3DecodingConfig( max_draft_len=EAGLE_MAX_DRAFT_LEN, speculative_model=spec_model_path, eagle3_one_model=False, eagle3_layers_to_capture=None, ) def run_with_autodeploy(model, speculative_config, batch_size): """Run AutoDeploy with or without speculative decoding. Args: model: Path to the base model speculative_config: Speculative decoding config (None for baseline mode) batch_size: Number of prompts to process Returns: List of (prompt, output) tuples from prompts_and_outputs """ # Select prompts based on batch size selected_prompts = prompts[:batch_size] # Configure KV cache kv_cache_config = KvCacheConfig( free_gpu_memory_fraction=0.01, ) # Configure AutoDeploy LLM arguments llm_args = { "model": model, "skip_loading_weights": False, "speculative_config": speculative_config, "runtime": "trtllm", "world_size": 1, "kv_cache_config": kv_cache_config, "disable_overlap_scheduler": True, "max_num_tokens": 64, } # Configure experiment with prompts experiment_config = { "args": llm_args, "benchmark": {"enabled": False}, "prompt": { "batch_size": batch_size, "queries": selected_prompts, }, } # Create ExperimentConfig cfg = ExperimentConfig(**experiment_config) # Add sampling parameters (deterministic with temperature=0.0 and fixed seed) cfg.prompt.sp_kwargs = { "max_tokens": 50, "top_k": None, "temperature": 0.0, "seed": 42, } # Run the experiment result = main(cfg) # Extract and return prompts_and_outputs assert "prompts_and_outputs" in result, "Result should contain 'prompts_and_outputs'" return result["prompts_and_outputs"] # Note: This test tests exact equality of outputs between speculative and baseline modes. # This can fail for larger batch sizes due to nondeterminism with in flight batching. # TODO: Figure out a robust test for output correctness that can pass for larger batch sizes. @pytest.mark.parametrize("spec_dec_mode", ["draft_target", "eagle3"]) def test_autodeploy_spec_dec_output(spec_dec_mode): """Test AutoDeploy speculative decoding output correctness. Runs with and without speculative decoding and verifies outputs are identical. """ print("\n" + "=" * 80) print(f"Testing AutoDeploy Speculative Decoding ({spec_dec_mode}) - Output Correctness") print("=" * 80) base_model, draft_target_model, eagle_model = get_model_paths() # Select model and config based on mode if spec_dec_mode == "draft_target": spec_model = draft_target_model spec_config = make_draft_target_config(spec_model) elif spec_dec_mode == "eagle3": # eagle3 spec_model = eagle_model spec_config = make_eagle3_config(spec_model) else: raise ValueError(f"Unsupported speculative decoding mode: {spec_dec_mode}") print(f"\nBase Model: {base_model}") print(f"Speculative Model ({spec_dec_mode}): {spec_model}") # Run with speculative decoding print("\n[1/2] Running with speculative decoding enabled...") spec_outputs = run_with_autodeploy( model=base_model, speculative_config=spec_config, batch_size=1, ) print(f"Generated {len(spec_outputs)} outputs with speculative decoding") # Run without speculative decoding (baseline) print("\n[2/2] Running without speculative decoding (baseline)...") baseline_outputs = run_with_autodeploy(model=base_model, speculative_config=None, batch_size=1) print(f"Generated {len(baseline_outputs)} outputs in baseline mode") # Verify outputs are identical print("\nVerifying outputs are identical...") assert len(spec_outputs) == len(baseline_outputs), ( f"Number of outputs mismatch: spec={len(spec_outputs)}, baseline={len(baseline_outputs)}" ) for i, ((spec_prompt, spec_output), (baseline_prompt, baseline_output)) in enumerate( zip(spec_outputs, baseline_outputs, strict=True) ): print(f"\n[Output {i}]") print(f" Prompt: {spec_prompt}") print("================================================") print(f" Spec Output: {spec_output}") print("================================================") print(f" Baseline Output: {baseline_output}") print("================================================") assert spec_prompt == baseline_prompt, f"Prompts differ at index {i}" assert spec_output == baseline_output, ( f"Outputs differ at index {i}:\n\n Spec: {spec_output}\n\n Baseline: {baseline_output}\n\n" ) print("\n" + "=" * 80) print("SUCCESS! All outputs are identical between spec-dec and baseline modes") print("=" * 80) def test_autodeploy_eagle3_acceptance_rate(): """Test Eagle3 acceptance rate with AutoDeploy engine. Runs Eagle3 speculative decoding with streaming and verifies that the acceptance rate is above a minimum threshold. """ print("\n" + "=" * 80) print("Testing AutoDeploy Eagle3 Acceptance Rate") print("=" * 80) base_model, _, eagle_model = get_model_paths() print(f"\nBase Model: {base_model}") print(f"Eagle3 Model: {eagle_model}") max_draft_len = EAGLE_MAX_DRAFT_LEN # Configure Eagle3 speculative decoding speculative_config = Eagle3DecodingConfig( max_draft_len=max_draft_len, speculative_model=eagle_model, eagle3_one_model=False, eagle3_layers_to_capture=None, ) # Configure KV cache kv_cache_config = KvCacheConfig( free_gpu_memory_fraction=0.01, ) # Create AutoDeploy LLM with Eagle3 speculative decoding # We directly instantiate the LLM class instead of using the main() function # so that we can stream the outputs to see acceptance rates without needing to # collect them in the executor. llm = LLM( model=base_model, skip_loading_weights=False, runtime="trtllm", world_size=1, kv_cache_config=kv_cache_config, speculative_config=speculative_config, disable_overlap_scheduler=True, max_num_tokens=64, ) # Tokenize 2 prompts to test multiple sequential requests batch_tok_ids = [llm.tokenizer.encode(p) for p in prompts[:2]] sampling_params = SamplingParams(max_tokens=128, temperature=0, seed=42) print("\nRunning Eagle3 speculative decoding with streaming...") # Process each request sequentially and verify acceptance rate for i in range(len(batch_tok_ids)): num_tokens = 0 num_drafted = 0 num_accepted = 0 for output in llm.generate_async(batch_tok_ids[i], sampling_params, streaming=True): new_tokens = output.outputs[0].token_ids num_drafted += max_draft_len num_accepted += len(new_tokens) - num_tokens - 1 num_tokens = len(new_tokens) accept_rate = num_accepted / num_drafted print(f"\nRequest {i + 1} Acceptance Rate Statistics:") print(f" Total tokens drafted: {num_drafted}") print(f" Total tokens accepted: {num_accepted}") print(f" Acceptance rate: {accept_rate:.2%}") # Verify acceptance rate is above minimum threshold (10%) min_acceptance_rate = 0.10 assert accept_rate > min_acceptance_rate, ( f"Request {i + 1}: Acceptance rate {accept_rate:.2%} is below minimum threshold {min_acceptance_rate:.0%}" ) print("\n" + "=" * 80) print("SUCCESS! All requests passed acceptance rate threshold") print("=" * 80) def load_weights(model_path: Path, model: torch.nn.Module): """Load weights from checkpoint while applying the same _checkpoint_conversion_mapping that the factory uses. Returns: tuple of (loaded_keys, missing_keys, unexpected_keys) """ # 1. Load checkpoint keys bin_path = model_path / "pytorch_model.bin" safetensors_path = model_path / "model.safetensors" if safetensors_path.exists(): from safetensors import safe_open with safe_open(safetensors_path, framework="pt") as f: checkpoint_keys_original = list(f.keys()) elif bin_path.exists(): state_dict = torch.load(bin_path, map_location="cpu", weights_only=True) checkpoint_keys_original = list(state_dict.keys()) del state_dict else: raise FileNotFoundError(f"No checkpoint found at {model_path}") # 2. Apply _checkpoint_conversion_mapping (same logic as hf.py _remap_param_names_load_hook) # This is the key part - the factory does this exact same thing in lines 496-512 of hf.py conversion_mapping = getattr(model, "_checkpoint_conversion_mapping", None) checkpoint_keys_remapped = [] for key in checkpoint_keys_original: new_key = key if conversion_mapping: for pattern, replacement in conversion_mapping.items(): new_key = re.sub(pattern, replacement, new_key) checkpoint_keys_remapped.append(new_key) # 3. Get model's expected keys model_keys = set(model.state_dict().keys()) checkpoint_keys = set(checkpoint_keys_remapped) # 4. Calculate differences loaded_keys = checkpoint_keys & model_keys missing_in_checkpoint = model_keys - checkpoint_keys unexpected_in_checkpoint = checkpoint_keys - model_keys return loaded_keys, missing_in_checkpoint, unexpected_in_checkpoint def test_eagle_model_with_weights(): """Test EagleModel forward pass with loaded weights using the EagleDrafterFactory. This test uses EagleDrafterFactory to initialize the model, which directly builds the Eagle drafter model based on the checkpoint's model_type: 1. Factory creates config via AutoConfig.from_pretrained 2. Factory selects Eagle3DrafterForCausalLM based on model_type="llama" 3. Factory creates model via _from_config 4. Factory loads weights via load_or_random_init -> _load_checkpoint This ensures the test validates the exact initialization path used in production. """ print("\n" + "=" * 80) print("Test: EagleModel forward pass with loaded weights (via EagleDrafterFactory)") print("=" * 80) _, _, eagle_model_path = get_model_paths() eagle_path = Path(eagle_model_path) if not eagle_path.exists(): pytest.skip(f"Eagle model not found at {eagle_model_path}") # Check for weights bin_path = eagle_path / "pytorch_model.bin" safetensors_path = eagle_path / "model.safetensors" if not bin_path.exists() and not safetensors_path.exists(): pytest.skip(f"Weights not found at {eagle_model_path}") # 1. Setup Device device = "cuda" if torch.cuda.is_available() else "cpu" # 2. Create factory # EagleDrafterFactory directly builds the correct drafter model based on model_type print("Creating EagleDrafterFactory...") factory = EagleDrafterFactory( model=eagle_model_path, skip_loading_weights=False, # We want to test weight loading ) # 3. Build model using factory # Factory flow: # build_model() -> prefetch_checkpoint() -> _build_model() # _build_model() -> _get_model_config() (gets base LlamaConfig) # _build_model() -> selects Eagle3DrafterForCausalLM for model_type="llama" # _build_model() -> Eagle3DrafterForCausalLM._from_config(config) print("Building model via factory.build_model('meta')...") model = factory.build_model("meta") print(f"Model type: {type(model).__name__}") print(f"Model config type: {type(model.config).__name__}") # 4. Load weights from checkpoint and compare to model's expected keys print("\n--- Weight Loading Analysis ---") loaded_keys, missing_keys, unexpected_keys = load_weights(eagle_path, model) print(f"Total model parameters: {len(loaded_keys) + len(missing_keys)}") print(f"Total checkpoint keys: {len(loaded_keys) + len(unexpected_keys)}") print(f"✅ Weights to be loaded: {len(loaded_keys)}") print(f"⚠️ Missing in checkpoint (will be random init): {len(missing_keys)}") print(f"⚠️ Unexpected in checkpoint (will be ignored): {len(unexpected_keys)}") if unexpected_keys: print("\nUnexpected keys (in checkpoint but model doesn't expect):") for key in sorted(unexpected_keys): if "t2d" in key: print(f" - {key} (expected: not used in Eagle3 for Llama3.1-8B-Instruct)") else: print(f" - {key}") if loaded_keys: print(f"\nLoaded keys ({len(loaded_keys)} total):") for key in sorted(loaded_keys)[:20]: print(f" - {key}") if len(loaded_keys) > 20: print(f" ... and {len(loaded_keys) - 20} more") print("--- End Weight Analysis ---\n") # Verify expected missing and unexpected keys # These are the keys we expect based on Eagle3 architecture: # - embed_tokens: shared from target model (not in Eagle checkpoint) # - t2d: target-to-draft mapping, not used in Eagle3 (uses d2t instead) expected_unexpected_keys = {"model.t2d"} assert len(missing_keys) == 0, ( f"Expect all keys to be loaded.\nKeys that are missing: {missing_keys}\n" ) assert unexpected_keys == expected_unexpected_keys, ( f"Unexpected keys in checkpoint.\n" f"Expected: {expected_unexpected_keys}\n" f"Got: {unexpected_keys}\n" f"Extra unexpected: {unexpected_keys - expected_unexpected_keys}\n" f"Not unexpected (but expected): {expected_unexpected_keys - unexpected_keys}" ) print("✅ Weight loading analysis matches expected missing/unexpected keys!") # 5. Load weights using factory (mimics actual pipeline) # If tensor shapes do not match with how they are used in the forward() function, we will # get an error. print("Loading weights via factory.load_or_random_init()...") factory.load_or_random_init(model, device) print("Weights loaded successfully via factory interface!") model.eval() ############################################################################### # Set up to test the prefill-only version of the EagleWrapper model in test_eagle_wrapper_forward(). # This helps us guarantee that the EagleWrapper model, before it enters AutoDeploy, is working correctly, # The test does not rely on any TRTLLM logic. ############################################################################### class PrefillOnlyEagleResourceManager: """Simple resource manager for Eagle speculative decoding (prefill-only variant). Stores hidden states for use by draft loop in EagleWrapper.forward(). """ def __init__( self, hidden_size: int, num_capture_layers: int, max_batch_size: int, max_seq_len: int, max_draft_len: int, target_dtype: torch.dtype, ): # Buffer for hidden states from target model: [max_tokens, hidden_size * num_capture_layers] # Uses flattened 2D format to match ADHiddenStateManager self.hidden_states = torch.empty( max_batch_size * (max_seq_len + max_draft_len), hidden_size * num_capture_layers, device="cuda", dtype=target_dtype, ) class LlamaModelWithCapture(LlamaModel): """LlamaModel that captures un-normalized hidden states from specified layers. Overwrites the base model's forward method to capture hidden states from specified layers. Base model's forward method is otherwise copied from LlamaModel in HuggingFace. Takes PrefillOnlyEagleResourceManager as an argument to store captured hidden states. """ def __init__( self, config, layers_to_capture: Optional[Set[int]] = None, resource_manager: Optional[PrefillOnlyEagleResourceManager] = None, ): super().__init__(config) # layers_to_capture: set of layer indices (0-indexed) to capture # If None, capture all layers if layers_to_capture is None: self.layers_to_capture = set(range(config.num_hidden_layers)) else: self.layers_to_capture = set(layers_to_capture) self.resource_manager = resource_manager # Validate layer indices for idx in self.layers_to_capture: if idx < 0 or idx >= config.num_hidden_layers: raise ValueError( f"Layer index {idx} out of range. " f"Model has {config.num_hidden_layers} layers (0 to {config.num_hidden_layers - 1})" ) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: # prefill only - no past key values. cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, past_key_values=None, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # Buffer to collect captured hidden states captured_hidden_states = [] for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) # Capture this layer's output if it's in our list if layer_idx in self.layers_to_capture: captured_hidden_states.append(hidden_states) # Apply final normalization for last_hidden_state last_hidden_state = self.norm(hidden_states) # Store captured hidden states in resource manager if available # Resource manager uses 2D flattened format: [max_tokens, hidden_size * num_capture_layers] if self.resource_manager is not None and captured_hidden_states: concatenated = torch.cat(captured_hidden_states, dim=-1) batch_size, seq_len, total_hidden_size = concatenated.shape assert self.resource_manager.hidden_states.shape[-1] == total_hidden_size, ( f"Resource manager buffer last dim {self.resource_manager.hidden_states.shape[-1]} " f"!= concatenated hidden states last dim {total_hidden_size}" ) # Flatten to [batch_size * seq_len, total_hidden_size] for 2D format flattened = concatenated.view(batch_size * seq_len, total_hidden_size) self.resource_manager.hidden_states[: (batch_size * seq_len), :].copy_(flattened) return BaseModelOutputWithPast( last_hidden_state=last_hidden_state, hidden_states=tuple(captured_hidden_states) if captured_hidden_states else None, ) @dataclass class LlamaForCausalLMOutput(ModelOutput): logits: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None class LlamaForCausalLMWithCapture(nn.Module): """Wrapper combining LlamaModelWithCapture with lm_head for EagleWrapper testing. EagleWrapper.forward() expects target_model(input_ids, position_ids) to return logits. This class wraps LlamaModelWithCapture (which captures hidden states to resource manager) and adds the lm_head to produce logits. """ def __init__(self, base_model, capture_model): super().__init__() self.model = capture_model # LlamaModelWithCapture with resource_manager self.lm_head = base_model.lm_head def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs, ): outputs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, **kwargs ) logits = self.lm_head(outputs.last_hidden_state) return LlamaForCausalLMOutput( logits=logits, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, ) def get_input_embeddings(self): return self.model.embed_tokens def get_output_embeddings(self): return self.model.lm_head @classmethod def from_pretrained( cls, model_name: str, resource_manager, capture_layers, dtype=torch.bfloat16, ): """Load a base model and create a LlamaForCausalLMWithCapture with shared weights.""" print(f"Loading {model_name}...") base_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": 0}, ) base_model.eval() # Create LlamaModelWithCapture that shares weights with the base model original_llama_model = base_model.model capture_model = LlamaModelWithCapture.__new__(LlamaModelWithCapture) nn.Module.__init__(capture_model) capture_model.config = original_llama_model.config capture_model.layers_to_capture = capture_layers capture_model.resource_manager = resource_manager # Share all modules (no weight copying) capture_model.embed_tokens = original_llama_model.embed_tokens capture_model.layers = original_llama_model.layers capture_model.norm = original_llama_model.norm capture_model.rotary_emb = original_llama_model.rotary_emb capture_model.gradient_checkpointing = original_llama_model.gradient_checkpointing return cls(base_model, capture_model) def build_eagle_wrapper( base_model_path: str, eagle_model_path: str, resource_manager: PrefillOnlyEagleResourceManager, capture_layers: Set[int], max_seq_len: int, max_draft_len: int, target_dtype: torch.dtype, device: torch.device, ) -> tuple[EagleWrapper, nn.Module]: """Build an EagleWrapper model for testing. This function encapsulates the model building logic using manual model building. Returns: A tuple of (eagle_wrapper, target_model) where: - eagle_wrapper: The EagleWrapper model ready for inference. - target_model: The target model (for verification steps). """ # Build EagleWrapper manually. print("\n" + "-" * 40) print("Building EagleWrapper") print("-" * 40) # Create target model with capture target_model = LlamaForCausalLMWithCapture.from_pretrained( base_model_path, resource_manager, capture_layers, target_dtype ) print("✓ Created target model with capture") # Create draft model using EagleDrafterFactory (mimics production pipeline) # This ensures weights are loaded correctly via the same path as AutoDeploy print("\nCreating draft model via EagleDrafterFactory...") draft_factory = EagleDrafterFactory( model=eagle_model_path, skip_loading_weights=False, ) # Build model on meta device first, then load weights draft_model = draft_factory.build_model("meta") print(f" Model type: {type(draft_model).__name__}") # Load weights via factory print(" Loading weights via factory.load_or_random_init()...") draft_factory.load_or_random_init(draft_model, device) draft_model.eval() # Create EagleWrapper config wrapper_config = EagleWrapperConfig( max_draft_len=max_draft_len, load_embedding_from_target=draft_model.load_embedding_from_target, load_lm_head_from_target=draft_model.load_lm_head_from_target, ) # Build EagleWrapper (this also loads weights from target into draft model where necessary) eagle_wrapper = EagleWrapper( config=wrapper_config, target_model=target_model, draft_model=draft_model, resource_manager=resource_manager, ) eagle_wrapper.eval() print("✓ Built EagleWrapper") return eagle_wrapper, target_model def generate_target_outputs( target_model: nn.Module, input_ids: torch.Tensor, num_iterations: int, ) -> torch.Tensor: """Generate tokens from target model using greedy sampling. Runs target_model.forward() in a loop, taking the last logit from each output, applying greedy sampling with torch.argmax, and appending to input_ids. Args: target_model: Model that returns logits from forward(input_ids, position_ids). input_ids: Initial input token ids of shape [batch_size, seq_len]. num_iterations: Number of tokens to generate. Returns: output_ids: Tensor of shape [batch_size, seq_len + num_iterations] containing the original input_ids plus the generated tokens. """ device = input_ids.device init_seq_len = input_ids.shape[1] print(f"Initial sequence length: {init_seq_len}") current_ids = input_ids.clone() with torch.no_grad(): for _ in range(num_iterations): # Generate position_ids from current sequence length seq_len = current_ids.shape[1] position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) position_ids = position_ids.expand(current_ids.shape[0], -1) # Forward pass logits = target_model(current_ids, position_ids=position_ids).logits # Take the last logit and apply greedy sampling last_logits = logits[:, -1, :] # [batch_size, vocab_size] next_token = torch.argmax(last_logits, dim=-1, keepdim=True) # [batch_size, 1] # Append to current_ids current_ids = torch.cat([current_ids, next_token], dim=1) return current_ids def print_token_analysis( input_ids: torch.Tensor, num_previously_accepted: torch.Tensor, target_output_ids: torch.Tensor, tokenizer, ) -> None: """Print debug analysis of accepted vs speculative tokens for each batch. Args: input_ids: Current input token ids of shape [batch_size, seq_len]. num_previously_accepted: Number of accepted tokens per batch [batch_size]. target_output_ids: Reference output from target model [batch_size, total_seq_len]. tokenizer: Tokenizer for decoding tokens to text. """ batch_size = input_ids.shape[0] print("\n --- Token Analysis (per batch) ---") for i in range(batch_size): prev_accepted_i = num_previously_accepted[i].item() # Accepted tokens (before speculation): input_ids[i, :num_previously_accepted[i]] accepted_tokens = input_ids[i, :prev_accepted_i] # Speculative tokens: input_ids[i, num_previously_accepted[i]:] speculative_tokens = input_ids[i, prev_accepted_i:] # Target model's expected token at this position if prev_accepted_i < target_output_ids.shape[1]: target_token_at_pos = target_output_ids[i, prev_accepted_i] else: target_token_at_pos = None print(f"\n Batch {i}:") print(f" num_previously_accepted: {prev_accepted_i}") print( f" Accepted tokens ({accepted_tokens.shape[0]} tokens): {accepted_tokens.tolist()}" ) accepted_text = tokenizer.decode(accepted_tokens, skip_special_tokens=True) print(f' Accepted text: "{accepted_text}"') print( f" Speculative tokens ({speculative_tokens.shape[0]} tokens): {speculative_tokens.tolist()}" ) if speculative_tokens.shape[0] > 0: spec_text = tokenizer.decode(speculative_tokens, skip_special_tokens=False) print(f' Speculative text: "{spec_text}"') if target_token_at_pos is not None: target_tok_id = target_token_at_pos.item() target_tok_str = tokenizer.decode([target_tok_id]) print( f' Target model\'s next token at pos {prev_accepted_i}: {target_tok_id} ("{target_tok_str}")' ) def manual_sample_and_verify( next_target_inputs: list, num_accepted_tokens: torch.Tensor, target_model: nn.Module, eagle_wrapper: nn.Module, max_draft_len: int, device: torch.device, ) -> list: """Manually verify speculative tokens using sample_and_verify. This is used for batch_size > 1 where truncation prevents speculative tokens from being fed back, so we verify them manually before truncation. Args: next_target_inputs: List of tensors, one per batch element. num_accepted_tokens: Number of tokens accepted so far per batch [batch_size]. target_model: The target model for running forward pass. eagle_wrapper: The EagleWrapper containing sample_and_verify. max_draft_len: Maximum draft length (for capping counts). device: Device to run on. Returns: List of (num_accepted, num_speculative) tuples for each batch element. """ batch_size = len(next_target_inputs) # Due to our truncation trick, all sequences should have the same length seq_lens = [seq.shape[0] for seq in next_target_inputs] assert all(slen == seq_lens[0] for slen in seq_lens), ( f"All sequences should have same length due to truncation, got {seq_lens}" ) verify_seq_len = seq_lens[0] # Stack into batched tensor stacked_inputs = torch.stack(next_target_inputs, dim=0) # [batch_size, seq_len] # Run target model forward to get logits verify_position_ids = ( torch.arange(verify_seq_len, device=device, dtype=torch.long) .unsqueeze(0) .expand(batch_size, -1) ) with torch.no_grad(): verify_target_logits = target_model(stacked_inputs, position_ids=verify_position_ids).logits # new_num_previously_accepted = num_accepted_tokens + 1 # This represents the tokens accepted after target model's output from this iteration new_num_previously_accepted = num_accepted_tokens + 1 # Call sample_and_verify to get acceptance counts _, verify_newly_accepted, _, _ = eagle_wrapper.sample_and_verify( stacked_inputs, verify_target_logits, new_num_previously_accepted ) # Build results list results = [] for i in range(batch_size): num_accepted_i = min(verify_newly_accepted[i].item(), max_draft_len) num_speculative = next_target_inputs[i].shape[0] - new_num_previously_accepted[i].item() results.append((num_accepted_i, num_speculative)) return results def verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted): """Verify the output structure and values from EagleWrapper forward pass. Args: output: The output from EagleWrapper forward pass. tokenizer: The tokenizer for decoding tokens. batch_size: The batch size. num_previously_accepted: Tensor of previously accepted token counts. """ # Verify output structure print("\nOutput verification:") assert output is not None, "Output should not be None" assert hasattr(output, "new_tokens"), "Output should have new_tokens" assert hasattr(output, "new_tokens_lens"), "Output should have new_tokens_lens" print(f" new_tokens: {type(output.new_tokens)} with {len(output.new_tokens)} items") for i, tokens in enumerate(output.new_tokens): new_tokens_text = tokenizer.decode(tokens, skip_special_tokens=True) print(f" batch {i}: shape {tokens.shape}, tokens: {tokens.tolist()}") print(f' batch {i}: decoded: "{new_tokens_text}"') # Compute num_accepted_tokens from new_tokens_lens + num_previously_accepted num_accepted_tokens = num_previously_accepted + output.new_tokens_lens print(f" new_tokens_lens: {output.new_tokens_lens}") print(f" num_accepted_tokens (computed): {num_accepted_tokens}") # Verify new_tokens_lens is within expected bounds assert output.new_tokens_lens.shape == (batch_size,), ( f"new_tokens_lens shape should be ({batch_size},), got {output.new_tokens_lens.shape}" ) @pytest.mark.parametrize("batch_size", [1, 2]) def test_eagle_wrapper_forward(batch_size: int): """Test EagleWrapper forward pass with target and draft models. This test validates the full speculative decoding loop: 1. Target model processes input and captures hidden states 2. Draft model generates speculative tokens 3. EagleWrapper orchestrates verification and drafting For batch size 1, we call EagleWrapper forward in the expected way. Each iteration generates a "golden token" (target output) and draft tokens. We input all of them to the wrapper model, which verifies the draft tokens against the target output. It then outputs the accepted tokens and newly generated draft tokens, along with numbers of accepted tokens, and the process repeats. For batch size > 1, we need to work around the fact that as we run the loop described above, the sequences lengths in the batch will get out of sync. So instead, we do not provide validated draft tokens as input in each iteration - we just input the first accepted token from the previous iteration (which we know was generated by the target model), which keeps the batches in sync. To verify that the output draft tokens are reasonable, we run a manual target model verification step after each iteration. We record how many of the output draft tokens were accepted. In the end, we test that the acceptance ratio of the draft tokens generated by the EagleWrapper is reasonable. Args: batch_size: Number of prompts to process in parallel. """ print("\n" + "=" * 80) print("Test: EagleWrapper forward pass") print("=" * 80) # Set random seeds for reproducibility torch.manual_seed(42) # Get model paths using integration test conventions base_model_path, _, eagle_model_path = get_model_paths() eagle_path = Path(eagle_model_path) if not eagle_path.exists(): pytest.skip("Eagle model not found (model missing)") # Configuration capture_layers = {1, 15, 28} # Layers to capture for Eagle3 num_capture_layers = len(capture_layers) hidden_size = 4096 # Llama 3.1-8B hidden size dtype = torch.bfloat16 device = torch.device("cuda") # Test dimensions max_batch_size = 4 max_seq_len = 1024 max_draft_len = 3 # Tokenize the test prompts tokenizer = AutoTokenizer.from_pretrained(base_model_path) # Llama uses left padding for batch inference tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" if batch_size == 1: input_ids = tokenizer.encode(prompts[0], return_tensors="pt").to(device) else: tokenized = tokenizer( prompts[:batch_size], return_tensors="pt", padding=True, ) input_ids = tokenized.input_ids.to(device) print(f"input_ids: {input_ids}") seq_len = input_ids.shape[1] init_seq_len = seq_len # Store initial sequence length for final comparison print("\nTest configuration:") print(f" target_model: {base_model_path}") print(f" eagle_model: {eagle_path}") print(f" batch_size: {batch_size}, seq_len: {seq_len}") print(f" max_draft_len: {max_draft_len}") print(f" capture_layers: {capture_layers}") print(f" prompts: {prompts[:batch_size]}") print(f" input_ids: {input_ids}") # Create resource manager resource_manager = PrefillOnlyEagleResourceManager( hidden_size=hidden_size, num_capture_layers=num_capture_layers, max_batch_size=max_batch_size, max_seq_len=max_seq_len, max_draft_len=max_draft_len, target_dtype=dtype, ) print("\n✓ Created resource manager") print(f" target_hidden_states shape: {resource_manager.hidden_states.shape}") # Build eagle_wrapper and target_model using the refactored function eagle_wrapper, target_model = build_eagle_wrapper( base_model_path=base_model_path, eagle_model_path=str(eagle_path), resource_manager=resource_manager, capture_layers=capture_layers, max_seq_len=max_seq_len, max_draft_len=max_draft_len, target_dtype=dtype, device=device, ) # Create test inputs (input_ids already created from tokenizer above) position_ids = ( torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) ) # Set previously_accepted_tokens to the length of input_ids (all context tokens are accepted) # Shape should be [batch_size] - a 1D tensor with one value per batch num_previously_accepted = torch.full((batch_size,), seq_len, device=device, dtype=torch.long) print("\nTest inputs:") print(f" input_ids shape: {input_ids.shape}") print(f" input_ids: {input_ids}") print(f" position_ids shape: {position_ids.shape}") print(f" num_previously_accepted: {num_previously_accepted}") # Generate target model outputs with greedy sampling print("\nGenerating target model outputs with greedy sampling (for verification)...") target_output_ids = generate_target_outputs(target_model, input_ids, num_iterations=100) print(f" target_output_ids shape: {target_output_ids.shape}") print(f" target_output_ids: {target_output_ids}") # Decode to text as sanity check generated_text = tokenizer.decode(target_output_ids[0], skip_special_tokens=True) print(f"\n Target model greedy generation decoded text:\n {generated_text}") print("\n✓ EagleWrapper forward pass completed successfully!") print("✓ Output structure verified") print("✓ new_tokens_lens within expected bounds") print("✓ Target model greedy generation completed") print("\n================================================") num_iterations = 70 # Dictionary to track distribution of new_tokens_lens # keys: 0 to max_draft_len # newly_accepted_counts[i]: number of times the number of accepted draft tokens was i newly_accepted_counts = {i: 0 for i in range(max_draft_len + 1)} for iteration in range(num_iterations): print(f"\n{'=' * 40}") print(f"EagleWrapper forward pass - Iteration {iteration + 1}/{num_iterations}") print(f"{'=' * 40}") seq_len = input_ids.shape[1] # Debug: Print speculative tokens, accepted tokens, and target comparison print_token_analysis(input_ids, num_previously_accepted, target_output_ids, tokenizer) kwargs = { "num_previously_accepted": num_previously_accepted, } with torch.no_grad(): output = eagle_wrapper( input_ids=input_ids, position_ids=position_ids, **kwargs, ) verify_eagle_wrapper_output(output, tokenizer, batch_size, num_previously_accepted) # Prepare next_target_inputs # output.new_tokens[i] contains the full draft_input_ids tensor, but the valid prefix # has length num_accepted_tokens[i] + max_draft_len. We slice to get only valid tokens. # We then prepend the first token from the previous iteration's input_ids. # This prepending is only needed for prefill-only mode, since in the cached case, the first token # will always be in the KV cache. # Compute num_accepted_tokens from num_previously_accepted + new_tokens_lens num_accepted_tokens = num_previously_accepted + output.new_tokens_lens valid_prefix_len = num_accepted_tokens + max_draft_len next_target_inputs = [ torch.cat( (input_ids[i, 0].unsqueeze(0), output.new_tokens[i][: valid_prefix_len[i]]), dim=0, ) for i in range(batch_size) ] # Track distribution of newly accepted tokens by reading new_tokens_lens from the output. # For batch size = 1, we are inputting draft tokens to the wrapper model, so new_tokens_lens # gives the number of accepted tokens from drafts in the previous iteration. if batch_size == 1: for val in output.new_tokens_lens.tolist(): newly_accepted_counts[val] += 1 print(f" newly_accepted_counts so far: {newly_accepted_counts}") # For batch_size > 1, we use manual target model verification below instead to check which of the draft tokens # generated in *this* iteration would be accepted by the target model. else: # For batch_size > 1, verify acceptance using sample_and_verify() # before truncation (since truncation prevents speculative tokens from being fed back) verify_results = manual_sample_and_verify( next_target_inputs, num_accepted_tokens, target_model, eagle_wrapper, max_draft_len, device, ) # Update newly_accepted_counts map for i, (num_accepted_i, num_speculative) in enumerate(verify_results): newly_accepted_counts[num_accepted_i] += 1 print( f" [Batch {i}] sample_and_verify: {num_accepted_i}/{num_speculative} speculative accepted" ) # Truncate to keep shapes consistent across batches in each iteration. # We know that the first token that is generated in this iteration is accepted, so it is "safe". # All speculative tokens are truncated regardless of whether they are accepted or not. # This is a hack to prevent the sequence lengths from getting out of sync across batches in each iteration # without needing to change the padding every iteration. truncate_len = input_ids.shape[1] + 1 next_target_inputs = [seq[:truncate_len] for seq in next_target_inputs] next_target_inputs = torch.stack(next_target_inputs, dim=0) print(f" next_target_inputs: {next_target_inputs}") print(f" next_target_inputs.shape: {next_target_inputs.shape}") # Update for next iteration input_ids = next_target_inputs seq_len = input_ids.shape[1] position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) position_ids = position_ids.expand(batch_size, -1) if batch_size > 1: # For multi-batch: increment by 1 (we truncated, so just advance by one token) num_previously_accepted = num_previously_accepted + 1 else: # For single batch: accept the tokens accepted in the previous iteration, plus one # for the output token that was generated by the target. num_previously_accepted = num_accepted_tokens + 1 print(f"\n{'=' * 40}") print(f"Loop completed: {num_iterations} iterations") print("Newly accepted tokens distribution:") for k, v in newly_accepted_counts.items(): print(f" {k}: {v}") # Calculate acceptance ratio # For batch_size == 1: uses new_tokens_lens from eagle wrapper # For batch_size > 1: uses manual verification against target model (since truncation # prevents speculative tokens from being fed back) total_accepted = sum(k * v for k, v in newly_accepted_counts.items()) # First iteration has no tokens to newly accept, subsequent iterations have max_draft_len potential num_iterations_with_drafts = num_iterations - 1 if batch_size == 1 else num_iterations total_potential = max_draft_len * (num_iterations_with_drafts) * batch_size acceptance_ratio = total_accepted / total_potential if total_potential > 0 else 0.0 print(f"\nAcceptance ratio: {total_accepted}/{total_potential} = {acceptance_ratio:.3f}") if batch_size > 1: print(" (batch_size > 1: measured via manual target model verification)") assert acceptance_ratio > 0.1, ( f"Acceptance ratio {acceptance_ratio:.3f} is too low (expected > 0.1)" ) print("\n" + "=" * 80) print("FINAL OUTPUT COMPARISON") print("=" * 80) for i in range(batch_size): print(f"\n{'─' * 40}") print(f"BATCH {i}") print(f"{'─' * 40}") print(f"\n[Target Model Output] ({target_output_ids[i].shape[0]} tokens):") print(f" Tokens: {target_output_ids[i].tolist()}") print(f' Text: "{tokenizer.decode(target_output_ids[i], skip_special_tokens=True)}"') print(f"\n[Eagle Wrapper Output] ({input_ids[i].shape[0]} tokens):") print(f" Tokens: {input_ids[i].tolist()}") print(f' Text: "{tokenizer.decode(input_ids[i], skip_special_tokens=True)}"') print("\n" + "=" * 80) # Verify that the first 10 generated tokens match between target model and eagle wrapper # They seem to diverge after awhile but are semantically the same. # Note that even running the target model in decode vs prefill mode, the outputs seem to diverge similarly, # so this is not worrisome. This test provides a check that they are "similar enough" to each other. num_tokens_to_check = 10 print(f"\nVerifying first {num_tokens_to_check} generated tokens match...") for i in range(batch_size): target_generated = target_output_ids[i, init_seq_len : init_seq_len + num_tokens_to_check] eagle_generated = input_ids[i, init_seq_len : init_seq_len + num_tokens_to_check] print(f" Batch {i}:") print(f" Target: {target_generated.tolist()}") print(f" Eagle: {eagle_generated.tolist()}") assert torch.equal(target_generated, eagle_generated), ( f"Batch {i}: First {num_tokens_to_check} generated tokens do not match!\n" f" Target: {target_generated.tolist()}\n" f" Eagle: {eagle_generated.tolist()}" ) print(f"✓ First {num_tokens_to_check} generated tokens match for all batches!")