mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
289 lines
10 KiB
Python
289 lines
10 KiB
Python
import pathlib as _pl
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pytest
|
|
import torch
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
from tensorrt_llm import LLM, SamplingParams
|
|
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
|
from tensorrt_llm._torch.model_config import ModelConfig
|
|
from tensorrt_llm._torch.models.checkpoints import HfCheckpointLoader
|
|
from tensorrt_llm._torch.models.checkpoints.base_config_loader import \
|
|
BaseConfigLoader
|
|
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
|
BaseWeightLoader
|
|
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
|
BaseWeightMapper
|
|
from tensorrt_llm._torch.models.modeling_utils import (
|
|
ModelConfig, register_auto_model, register_checkpoint_weight_loader,
|
|
register_config_loader)
|
|
from tensorrt_llm.llmapi import KvCacheConfig
|
|
|
|
|
|
class DummyConfig(PretrainedConfig):
|
|
|
|
def __init__(self):
|
|
self.architectures: list[str] = ["DummyModel"]
|
|
self.torch_dtype: torch.dtype = torch.float16
|
|
self.num_key_value_heads: int = 16
|
|
self.num_attention_heads: int = 16
|
|
self.hidden_size: int = 256
|
|
self.vocab_size: int = 1000
|
|
self.num_hidden_layers: int = 1
|
|
|
|
@property
|
|
def head_dim(self) -> int:
|
|
return self.hidden_size // self.num_attention_heads
|
|
|
|
|
|
@register_auto_model("DummyModel")
|
|
class DummyModel(torch.nn.Module):
|
|
|
|
def __init__(self, model_config: ModelConfig):
|
|
super().__init__()
|
|
self.dtype = model_config.pretrained_config.torch_dtype
|
|
self.model_config = model_config
|
|
self.recorded_position_ids = None
|
|
|
|
def infer_max_seq_len(self):
|
|
return 2048
|
|
|
|
@property
|
|
def config(self):
|
|
return self.model_config.pretrained_config
|
|
|
|
def forward(self,
|
|
*args,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
return_context_logits: bool = False,
|
|
**kwargs) -> torch.Tensor:
|
|
num_batch_tokens = input_ids.size(0)
|
|
|
|
vocab_size = self.config.vocab_size
|
|
hidden_size = self.config.hidden_size
|
|
|
|
last_tokens = torch.cumsum(
|
|
attn_metadata.seq_lens_cuda,
|
|
dim=0,
|
|
dtype=torch.long,
|
|
) - 1
|
|
|
|
# Logits: fixed values for testing
|
|
logits = torch.ones((num_batch_tokens, vocab_size), device='cuda') * 0.1
|
|
|
|
# Logits shape depends on return_context_logits flag
|
|
if not return_context_logits:
|
|
# For context logits, return logits for all positions
|
|
logits = logits[last_tokens]
|
|
|
|
# Context output: values depend on position for testing, one output per input token
|
|
context_output = (
|
|
position_ids.reshape(-1, 1).expand(num_batch_tokens, hidden_size) *
|
|
0.2)
|
|
|
|
# Generation output: values depend on position for testing, one output per sequence
|
|
generation_output = (
|
|
position_ids.reshape(-1, 1).expand(num_batch_tokens, hidden_size) *
|
|
0.3)
|
|
generation_output = generation_output[last_tokens]
|
|
|
|
return {
|
|
"logits": logits,
|
|
"context_output": context_output,
|
|
"generation_output": generation_output
|
|
}
|
|
|
|
def load_weights(self,
|
|
weights: Dict,
|
|
weight_mapper: Optional[BaseWeightMapper] = None,
|
|
skip_modules: List[str] = []):
|
|
pass
|
|
|
|
|
|
@register_checkpoint_weight_loader("DUMMY_FORMAT")
|
|
class DummyWeightLoader(BaseWeightLoader):
|
|
|
|
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
|
|
"""Load weights from your dummy format.
|
|
|
|
Args:
|
|
checkpoint_dir: Directory containing checkpoint files
|
|
**kwargs: Additional loading parameters
|
|
Returns:
|
|
Dictionary mapping parameter names to tensors
|
|
"""
|
|
weights = {}
|
|
|
|
return weights
|
|
|
|
|
|
@register_config_loader("DUMMY_FORMAT")
|
|
class DummyConfigLoader(BaseConfigLoader):
|
|
|
|
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
|
|
"""Load and parse configuration from your dummy format.
|
|
|
|
Args:
|
|
checkpoint_dir: Directory containing configuration files
|
|
**kwargs: Additional loading parameters
|
|
Returns:
|
|
ModelConfig object containing parsed configuration
|
|
"""
|
|
return ModelConfig(pretrained_config=DummyConfig())
|
|
|
|
|
|
@pytest.mark.gpu1
|
|
def test_additional_model_outputs_sampling_params():
|
|
"""Test that additional_model_outputs can be configured in SamplingParams."""
|
|
# Create sampling params with additional outputs
|
|
additional_outputs = ["context_output", "generation_output"]
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=5,
|
|
temperature=0.0, # Deterministic for testing
|
|
additional_model_outputs=additional_outputs)
|
|
|
|
# Verify the sampling params are configured correctly
|
|
assert sampling_params.additional_model_outputs is not None
|
|
assert len(sampling_params.additional_model_outputs) == 2
|
|
assert sampling_params.additional_model_outputs[0] == "context_output"
|
|
assert sampling_params.additional_model_outputs[1] == "generation_output"
|
|
|
|
|
|
@pytest.mark.gpu1
|
|
def test_additional_model_outputs_no_outputs():
|
|
"""Test that no additional outputs are returned when not requested."""
|
|
# Create sampling params without additional outputs
|
|
sampling_params = SamplingParams(
|
|
max_tokens=5,
|
|
temperature=0.0, # Deterministic for testing
|
|
)
|
|
|
|
# Verify that no additional outputs are configured
|
|
assert sampling_params.additional_model_outputs is None
|
|
|
|
|
|
def _test_additional_model_outputs_integration(pp_size: int):
|
|
"""Integration test for additional_model_outputs.
|
|
|
|
This test uses a dummy model to test the additional_model_outputs feature.
|
|
"""
|
|
# Create sampling params with additional outputs
|
|
additional_outputs = ["context_output", "generation_output"]
|
|
|
|
num_generated_tokens = 5
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=num_generated_tokens,
|
|
temperature=0.0, # Deterministic for testing
|
|
end_id=-1, # Use -1 to indicate that the end token is not used
|
|
additional_model_outputs=additional_outputs)
|
|
|
|
# Create LLM with the provided model
|
|
llm = LLM(model=_pl.Path("dummy_path"),
|
|
backend='pytorch',
|
|
pipeline_parallel_size=pp_size,
|
|
max_batch_size=2,
|
|
max_seq_len=128,
|
|
max_num_tokens=5,
|
|
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
|
|
disable_overlap_scheduler=True,
|
|
checkpoint_loader=HfCheckpointLoader(
|
|
weight_loader=DummyWeightLoader(),
|
|
config_loader=DummyConfigLoader()))
|
|
|
|
# Test prompts
|
|
prompts = [[1, 2, 3], [4, 5, 6]]
|
|
num_prompts = len(prompts)
|
|
prompt_lens = [len(prompt) for prompt in prompts]
|
|
|
|
config = DummyConfig()
|
|
max_beam_width = 1
|
|
|
|
with llm:
|
|
# Generate outputs
|
|
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
|
|
|
# Verify outputs
|
|
assert len(outputs) == num_prompts
|
|
for i in range(num_prompts):
|
|
output = outputs[i]
|
|
assert len(output.outputs) == 1
|
|
sequence = output.outputs[0]
|
|
|
|
# Check that additional outputs are present
|
|
assert sequence.additional_context_outputs is not None
|
|
assert sequence.additional_generation_outputs is not None
|
|
|
|
# Check that the requested outputs are present
|
|
assert "context_output" in sequence.additional_context_outputs
|
|
assert "generation_output" in sequence.additional_generation_outputs
|
|
|
|
# Verify tensor shapes are reasonable
|
|
context_output = sequence.additional_context_outputs[
|
|
"context_output"]
|
|
context_generation_output = sequence.additional_generation_outputs[
|
|
"context_output"]
|
|
generation_output = sequence.additional_generation_outputs[
|
|
"generation_output"]
|
|
|
|
# Verify that the outputs are tensors
|
|
assert isinstance(context_output, torch.Tensor)
|
|
assert isinstance(context_generation_output, torch.Tensor)
|
|
assert isinstance(generation_output, torch.Tensor)
|
|
|
|
# Verify context output shape
|
|
assert context_output.dim() == 2
|
|
assert context_output.shape[0] == prompt_lens[i]
|
|
assert context_output.shape[1] == config.hidden_size
|
|
|
|
expected_context_output = torch.arange(
|
|
prompt_lens[i], dtype=torch.float32).unsqueeze(1).expand(
|
|
prompt_lens[i], config.hidden_size) * 0.2
|
|
|
|
assert torch.equal(context_output, expected_context_output)
|
|
|
|
# Verify context generation output shape
|
|
assert context_generation_output.dim() == 3
|
|
assert context_generation_output.shape[0] == num_generated_tokens
|
|
assert context_generation_output.shape[1] == max_beam_width
|
|
assert context_generation_output.shape[2] == config.hidden_size
|
|
|
|
gen_start_idx = prompt_lens[i] - 1
|
|
gen_end_idx = gen_start_idx + num_generated_tokens
|
|
|
|
expected_context_generation_output = torch.arange(
|
|
gen_start_idx, gen_end_idx,
|
|
dtype=torch.float32).unsqueeze(1).expand(
|
|
num_generated_tokens, config.hidden_size) * 0.2
|
|
|
|
assert torch.equal(context_generation_output,
|
|
expected_context_generation_output.unsqueeze(1))
|
|
|
|
# Verify generation output shape
|
|
assert generation_output.dim() == 3
|
|
assert generation_output.shape[0] == num_generated_tokens
|
|
assert generation_output.shape[1] == max_beam_width
|
|
assert generation_output.shape[2] == config.hidden_size
|
|
|
|
expected_generation_output = torch.arange(
|
|
gen_start_idx, gen_end_idx,
|
|
dtype=torch.float32).unsqueeze(1).expand(
|
|
num_generated_tokens, config.hidden_size) * 0.3
|
|
|
|
assert torch.equal(generation_output,
|
|
expected_generation_output.unsqueeze(1))
|
|
|
|
|
|
@pytest.mark.gpu1
|
|
def test_additional_model_outputs_integration():
|
|
_test_additional_model_outputs_integration(1)
|
|
|
|
|
|
@pytest.mark.gpu2
|
|
def test_additional_model_outputs_integration_pp2():
|
|
_test_additional_model_outputs_integration(2)
|