mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
2124 lines
78 KiB
Python
2124 lines
78 KiB
Python
import asyncio
|
|
import datetime
|
|
import gc
|
|
import json
|
|
import os
|
|
import random
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
from typing import List, Optional, Union
|
|
|
|
import datasets
|
|
import pytest
|
|
import torch
|
|
import transformers
|
|
from pydantic import BaseModel
|
|
from utils.util import skip_single_gpu
|
|
|
|
from tensorrt_llm.bindings import executor as tllm
|
|
from tensorrt_llm.executor import (ExecutorBindingsWorker, LoRARequest,
|
|
PromptAdapterRequest, RequestError)
|
|
from tensorrt_llm.llmapi import (LLM, BuildCacheConfig, EagleDecodingConfig,
|
|
GuidedDecodingParams, KvCacheConfig,
|
|
KvCacheRetentionConfig,
|
|
LookaheadDecodingConfig, MedusaDecodingConfig,
|
|
RequestOutput)
|
|
from tensorrt_llm.llmapi._perf_evaluator import perform_faked_oai_postprocess
|
|
from tensorrt_llm.llmapi.llm_args import DynamicBatchConfig, SchedulerConfig
|
|
from tensorrt_llm.llmapi.llm_utils import (BuildConfig, LlmArgs, QuantAlgo,
|
|
QuantConfig, _ParallelConfig)
|
|
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
|
|
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
|
|
from tensorrt_llm.lora_manager import LoraConfig
|
|
from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM
|
|
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
|
|
from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,
|
|
LogitsProcessor, SamplingParams)
|
|
|
|
# isort: off
|
|
from utils.llm_data import llm_models_root
|
|
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper
|
|
# isort: on
|
|
|
|
# The unittests are based on the tiny-llama, which is fast to build and run.
|
|
# There are other tests based on llama-7B model, such as the end-to-end tests in test_e2e.py, and parallel tests in
|
|
# test_llm_multi_gpu.py.
|
|
|
|
pytestmark = pytest.mark.threadleak(enabled=False)
|
|
|
|
|
|
def get_model_path(model_name):
|
|
engine_dir = os.environ.get('LLM_ENGINE_DIR', None)
|
|
if engine_dir:
|
|
return engine_dir
|
|
return str(llm_models_root() / model_name)
|
|
|
|
|
|
def get_reference_count(obj):
|
|
'''
|
|
Get the reference count.
|
|
'''
|
|
return sys.getrefcount(obj) - 1
|
|
|
|
|
|
def check_output(outputs: List[RequestOutput],
|
|
references: Union[List[str], List[List[str]]],
|
|
*,
|
|
similar_threshold: float = 0.8,
|
|
finish_reasons: Optional[List[str]] = None,
|
|
stop_reasons: Optional[List[Union[int, str]]] = None):
|
|
assert len(outputs) == len(references)
|
|
|
|
for i, (output, reference) in enumerate(zip(outputs, references)):
|
|
if isinstance(reference, list):
|
|
# N output
|
|
assert len(output.outputs) == len(reference)
|
|
for j, (out, ref) in enumerate(zip(output.outputs, reference)):
|
|
assert similar(out.text, ref, threshold=similar_threshold)
|
|
if finish_reasons is not None:
|
|
assert out.finish_reason == finish_reasons[i][j]
|
|
if stop_reasons is not None:
|
|
assert out.stop_reason == stop_reasons[i][j]
|
|
else:
|
|
out = output.outputs[0]
|
|
assert similar(out.text, reference, threshold=similar_threshold)
|
|
if finish_reasons is not None:
|
|
assert out.finish_reason == finish_reasons[i]
|
|
if stop_reasons is not None:
|
|
assert out.stop_reason == stop_reasons[i]
|
|
|
|
|
|
def llm_test_harness(model_dir: str,
|
|
inputs: List[str],
|
|
references: List[str],
|
|
*,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
similar_threshold: float = 0.8,
|
|
**llm_kwargs):
|
|
|
|
tp_size = llm_kwargs.get('tensor_parallel_size', 1)
|
|
pp_size = llm_kwargs.get('pipeline_parallel_size', 1)
|
|
world_size = tp_size * pp_size
|
|
if world_size > torch.cuda.device_count():
|
|
pytest.skip(
|
|
f"world_size ({world_size}) is greater than available GPUs ({torch.cuda.device_count()})"
|
|
)
|
|
|
|
tokenizer = llm_kwargs.pop('tokenizer', None)
|
|
if tokenizer is None:
|
|
tokenizer = model_dir
|
|
|
|
llm = LLM(model_dir, tokenizer=tokenizer, **llm_kwargs)
|
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
|
print(outputs)
|
|
check_output(outputs, references, similar_threshold=similar_threshold)
|
|
|
|
assert gc.is_tracked(llm)
|
|
assert len(
|
|
gc.get_referrers(llm)) == 0, f"the references: {gc.get_referrers(llm)}"
|
|
|
|
|
|
def llm_check_output(llm: LLM,
|
|
inputs: List[str],
|
|
references: List[str],
|
|
*,
|
|
sampling_params: Optional[SamplingParams] = None,
|
|
similar_threshold: float = 0.8,
|
|
finish_reasons: Optional[List[str]] = None,
|
|
stop_reasons: Optional[List[Union[int, str]]] = None,
|
|
**gen_kwargs):
|
|
outputs = llm.generate(inputs,
|
|
sampling_params=sampling_params,
|
|
**gen_kwargs)
|
|
print(outputs)
|
|
check_output(outputs,
|
|
references,
|
|
similar_threshold=similar_threshold,
|
|
finish_reasons=finish_reasons,
|
|
stop_reasons=stop_reasons)
|
|
|
|
|
|
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
|
mixtral_model_name = "Mixtral-8x7B-v0.1"
|
|
|
|
llama_model_path = get_model_path(default_model_name)
|
|
llm_engine_dir = os.environ.get('LLM_ENGINE_DIR', './tmp.engine')
|
|
|
|
cnn_dailymail_path = str(llm_models_root() / "datasets" / "cnn_dailymail")
|
|
alpaca_chinese_path = str(llm_models_root() / "datasets" / "silk-road" /
|
|
"alpaca-data-gpt4-chinese")
|
|
|
|
prompts = ["A B C"]
|
|
global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
|
|
|
|
# python api does not seem to support extra tokens needed for prompt tuning + reuse.
|
|
# disable block reuse for those tests.
|
|
# TODO: Add extra tokens to prompt tuning unit tests.
|
|
global_kvcache_config_no_reuse = KvCacheConfig(free_gpu_memory_fraction=0.4,
|
|
enable_block_reuse=False)
|
|
|
|
|
|
@pytest.mark.part0
|
|
@force_ampere
|
|
def test_llm_build_config():
|
|
build_config = BuildConfig()
|
|
# change some building parameters
|
|
build_config.max_batch_size = 129
|
|
build_config.max_beam_width = 4
|
|
build_config.max_num_tokens = 888
|
|
build_config.strongly_typed = True
|
|
build_config.max_seq_len = 333
|
|
|
|
llm = LLM(model=llama_model_path,
|
|
build_config=build_config,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True)
|
|
tmpdir = tempfile.TemporaryDirectory()
|
|
llm.save(tmpdir.name)
|
|
|
|
with open(os.path.join(tmpdir.name, "config.json"), "r") as f:
|
|
# read the build_config and check if the parameters are correctly saved
|
|
engine_config = json.load(f)
|
|
|
|
build_config1 = BuildConfig.from_dict(engine_config["build_config"])
|
|
|
|
# Know issue: this will be converted to None after save engine for single-gpu
|
|
build_config1.plugin_config.nccl_plugin = 'float16'
|
|
assert build_config1.max_batch_size == build_config.max_batch_size
|
|
assert build_config1.max_beam_width == build_config.max_beam_width
|
|
assert build_config1.max_num_tokens == build_config.max_num_tokens
|
|
assert build_config1.strongly_typed == build_config.strongly_typed
|
|
assert build_config1.max_seq_len == build_config.max_seq_len
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_args_invalid_usage():
|
|
runtime_max_batch_size = 3
|
|
runtime_max_num_tokens = 2
|
|
|
|
# Update build_config with warning msg if runtime arguments are passed.
|
|
llm_args = LlmArgs.from_kwargs(model='test-model',
|
|
max_batch_size=runtime_max_batch_size,
|
|
max_num_tokens=runtime_max_num_tokens)
|
|
assert llm_args.build_config.max_batch_size == runtime_max_batch_size
|
|
assert llm_args.build_config.max_num_tokens == runtime_max_num_tokens
|
|
|
|
# Conflict between build_config and runtime_params
|
|
build_config = BuildConfig(max_batch_size=5, max_num_tokens=7)
|
|
llm_args = LlmArgs.from_kwargs(model='test-model',
|
|
build_config=build_config,
|
|
max_batch_size=runtime_max_batch_size,
|
|
max_num_tokens=runtime_max_num_tokens)
|
|
assert llm_args.build_config.max_batch_size == build_config.max_batch_size
|
|
assert llm_args.build_config.max_num_tokens == build_config.max_num_tokens
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_loading_from_hf():
|
|
sampling_params = SamplingParams(max_tokens=8)
|
|
llm_test_harness(llama_model_path,
|
|
prompts, ["D E F G H I J K"],
|
|
sampling_params=sampling_params,
|
|
kv_cache_config=global_kvcache_config)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_llm_loading_from_ckpt():
|
|
tokenizer = TransformersTokenizer.from_pretrained(llama_model_path)
|
|
assert tokenizer is not None
|
|
|
|
ckpt_dir = tempfile.TemporaryDirectory()
|
|
llama = AutoModelForCausalLM.from_hugging_face(llama_model_path)
|
|
llama.save_checkpoint(ckpt_dir.name)
|
|
del llama
|
|
|
|
llm_test_harness(ckpt_dir.name,
|
|
prompts, ["D E F G H I J K"],
|
|
tokenizer=tokenizer,
|
|
kv_cache_config=global_kvcache_config,
|
|
sampling_params=SamplingParams(max_tokens=8))
|
|
|
|
|
|
@pytest.mark.parametrize('model_format', ['hf', 'ckpt'])
|
|
@pytest.mark.part0
|
|
def test_llm_with_dummy_weights(model_format):
|
|
# dummy_dir contains config.json and tokenizer files only
|
|
# the test fails if load_format != 'dummy'
|
|
dummy_dir = tempfile.TemporaryDirectory()
|
|
if model_format == 'hf':
|
|
hf_config = transformers.AutoConfig.from_pretrained(llama_model_path)
|
|
hf_config.save_pretrained(dummy_dir.name)
|
|
else:
|
|
config = AutoConfig.from_hugging_face(llama_model_path, dtype='float16')
|
|
config.to_json_file(os.path.join(dummy_dir.name, 'config.json'))
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(llama_model_path)
|
|
tokenizer.save_pretrained(dummy_dir.name)
|
|
|
|
sampling_params = SamplingParams(max_tokens=8)
|
|
llm_test_harness(dummy_dir.name,
|
|
prompts,
|
|
["A placeholder reference for dummy-weight engine."],
|
|
sampling_params=sampling_params,
|
|
similar_threshold=0.0,
|
|
load_format='dummy',
|
|
kv_cache_config=global_kvcache_config)
|
|
|
|
|
|
class MyTokenizer(TokenizerBase):
|
|
''' A wrapper for the Transformers' tokenizer.
|
|
This is the default tokenizer for LLM. '''
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_dir: str, **kwargs):
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
pretrained_model_dir, **kwargs)
|
|
return MyTokenizer(tokenizer)
|
|
|
|
def __init__(self, tokenizer):
|
|
self.tokenizer = tokenizer
|
|
|
|
@property
|
|
def eos_token_id(self) -> int:
|
|
return self.tokenizer.eos_token_id
|
|
|
|
@property
|
|
def pad_token_id(self) -> int:
|
|
return self.tokenizer.pad_token_id
|
|
|
|
def encode(self, text: str, **kwargs) -> List[int]:
|
|
return self.tokenizer.encode(text, **kwargs)
|
|
|
|
def decode(self, token_ids: List[int], **kwargs) -> str:
|
|
return self.tokenizer.decode(token_ids, **kwargs)
|
|
|
|
def batch_encode_plus(self, texts: List[str], **kwargs) -> dict:
|
|
return self.tokenizer.batch_encode_plus(texts, **kwargs)
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_with_customized_tokenizer():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
# a customized tokenizer is passed to override the default one
|
|
tokenizer=MyTokenizer.from_pretrained(llama_model_path),
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
|
|
for output in llm.generate(prompts):
|
|
print(output)
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_without_tokenizer():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
skip_tokenizer_init=True,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
|
|
sampling_params = SamplingParams(end_id=2, pad_id=2, max_tokens=8)
|
|
|
|
prompts = [[23, 14, 3]]
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
assert not output.outputs[0].text, \
|
|
"The output should be empty since the tokenizer is missing"
|
|
print(output)
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_with_kv_cache_retention_config():
|
|
kv_cache_retention_config = KvCacheRetentionConfig([
|
|
KvCacheRetentionConfig.TokenRangeRetentionConfig(
|
|
0, 2, 30, datetime.timedelta(seconds=30))
|
|
], 80)
|
|
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True)
|
|
|
|
for output in llm.generate(
|
|
prompts, kv_cache_retention_config=kv_cache_retention_config):
|
|
print(output)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'tokenizer_dir, threshold',
|
|
[
|
|
(get_model_path('gpt2'), 0.95), # BPE
|
|
(get_model_path('bert/bert-base-uncased'), 0.95), # WordPiece
|
|
(get_model_path('t5-small'), 0.95), # SentencePiece
|
|
(get_model_path('starcoder2-3b'), 0.95),
|
|
(get_model_path('falcon-7b-instruct'), 0.95),
|
|
(get_model_path('llama-models-v2/llama-v2-7b-hf'), 0.95),
|
|
(get_model_path('codellama/CodeLlama-7b-Instruct-hf'), 0.95),
|
|
(llama_model_path, 0.95),
|
|
(get_model_path(mixtral_model_name), 0.95)
|
|
])
|
|
@pytest.mark.part0
|
|
def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
|
|
random.seed(42)
|
|
|
|
num_samples = 100
|
|
cnn_dailymail = datasets.load_dataset(cnn_dailymail_path,
|
|
name='3.0.0',
|
|
split='train',
|
|
trust_remote_code=True)
|
|
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path,
|
|
split='train',
|
|
trust_remote_code=True)
|
|
dataset = cnn_dailymail['article'][:num_samples // 2] + alpaca_chinese[
|
|
'output_zh'][:num_samples // 2]
|
|
|
|
tokenizer = TransformersTokenizer.from_pretrained(tokenizer_dir,
|
|
legacy=False,
|
|
padding_side='left',
|
|
truncation_side='left',
|
|
trust_remote_code=True,
|
|
use_fast=True)
|
|
|
|
num_perfect = 0
|
|
for text in dataset:
|
|
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
seq_len = len(token_ids)
|
|
prompt_len = random.randint(1, seq_len // 2)
|
|
decoded_text, states = tokenizer.decode_incrementally(
|
|
token_ids[:prompt_len])
|
|
for i in range(prompt_len, len(token_ids)):
|
|
decoded_text, states = tokenizer.decode_incrementally(
|
|
[token_ids[i]], decoded_text, states)
|
|
|
|
if tokenizer_dir.endswith(
|
|
'bert-base-uncased') and tokenizer.clean_up_tokenization_spaces:
|
|
decoded_text = tokenizer.clean_up_tokenization(decoded_text)
|
|
reference = tokenizer.decode(token_ids)
|
|
if decoded_text == reference:
|
|
num_perfect += 1
|
|
else:
|
|
# For non-perfect matching cases, decoded_text should also be very similar to the reference
|
|
assert similar(decoded_text, reference, 0.99)
|
|
print(f"Perfect matching ratio: {num_perfect / num_samples * 100}%")
|
|
assert num_perfect / num_samples >= threshold
|
|
|
|
|
|
# TODO[chunweiy]: Move mixtral test to the e2e test
|
|
def is_memory_enough_for_mixtral():
|
|
if torch.cuda.device_count() < 2:
|
|
return False
|
|
try:
|
|
total_memory = get_total_gpu_memory(0) + get_total_gpu_memory(1)
|
|
if total_memory >= 160 * 1024**3:
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_generate_async():
|
|
_test_llm_generate_async()
|
|
|
|
|
|
def _test_llm_generate_async(model_name=default_model_name,
|
|
tp_size: int = 1,
|
|
use_auto_parallel: bool = False,
|
|
tokenizer=None):
|
|
if "Mixtral" in model_name and use_auto_parallel:
|
|
pytest.skip("Auto parallel is not supported for Mixtral models")
|
|
|
|
tp_size = tp_size if not use_auto_parallel else 1
|
|
world_size = tp_size if use_auto_parallel else None
|
|
|
|
llm = LLM(
|
|
model=get_model_path(model_name),
|
|
tokenizer=tokenizer,
|
|
kv_cache_config=global_kvcache_config,
|
|
tensor_parallel_size=tp_size,
|
|
auto_parallel=use_auto_parallel,
|
|
auto_parallel_world_size=world_size,
|
|
fast_build=True,
|
|
)
|
|
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
|
|
def test_async(streaming: bool):
|
|
|
|
async def task(prompt: str):
|
|
outputs = []
|
|
async for output in llm.generate_async(
|
|
prompt, streaming=streaming,
|
|
sampling_params=sampling_params):
|
|
print('output', output)
|
|
outputs.append(output.outputs[0].text)
|
|
print(' '.join(outputs))
|
|
|
|
async def main():
|
|
tasks = [task(prompt) for prompt in prompts]
|
|
await asyncio.gather(*tasks)
|
|
|
|
asyncio.run(main())
|
|
|
|
def test_wait(streaming: bool):
|
|
for prompt in prompts:
|
|
future = llm.generate_async(prompt,
|
|
streaming=streaming,
|
|
sampling_params=sampling_params)
|
|
for output in future:
|
|
print('wait', output)
|
|
|
|
def test_non_streaming_usage_wait():
|
|
for prompt in prompts:
|
|
output = llm.generate_async(prompt,
|
|
streaming=False,
|
|
sampling_params=sampling_params)
|
|
print(output.outputs[0].text)
|
|
|
|
def test_future(streaming: bool):
|
|
for prompt in prompts:
|
|
future = llm.generate_async(prompt,
|
|
streaming=streaming,
|
|
sampling_params=sampling_params)
|
|
if streaming is True:
|
|
for output in future:
|
|
# Do something else and then wait for the result if needed
|
|
output = output.result(timeout=10)
|
|
print('future', output.outputs[0].text)
|
|
else:
|
|
# Do something else and then wait for the result if needed
|
|
output = future.result(timeout=10)
|
|
print('future', output.outputs[0].text)
|
|
|
|
def test_future_async():
|
|
|
|
async def task(prompt: str):
|
|
future = llm.generate_async(prompt,
|
|
streaming=False,
|
|
sampling_params=sampling_params)
|
|
output = await future.aresult()
|
|
print('future', output.outputs[0].text)
|
|
|
|
async def main():
|
|
tasks = [task(prompt) for prompt in prompts]
|
|
await asyncio.gather(*tasks)
|
|
|
|
asyncio.run(main())
|
|
|
|
test_async(streaming=True)
|
|
test_async(streaming=False)
|
|
test_wait(streaming=True)
|
|
test_wait(streaming=False)
|
|
test_future(streaming=True)
|
|
test_future(streaming=False)
|
|
test_future_async()
|
|
test_non_streaming_usage_wait()
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def llm_for_sampling_params():
|
|
build_config = BuildConfig(max_beam_width=3)
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
build_config=build_config,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
yield llm
|
|
llm.shutdown()
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_user_specify_workspace():
|
|
user_specified_ws_path = '/tmp/specified_workspace'
|
|
shutil.rmtree(user_specified_ws_path, ignore_errors=True)
|
|
os.mkdir(user_specified_ws_path)
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
workspace=user_specified_ws_path,
|
|
fast_build=True)
|
|
pre_built_engine_cfg = llm.args.model / 'config.json'
|
|
assert pre_built_engine_cfg.exists()
|
|
del llm
|
|
gc.collect()
|
|
assert not pre_built_engine_cfg.exists()
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_sampling_params_per_prompt(llm_for_sampling_params: LLM):
|
|
llm = llm_for_sampling_params
|
|
sampling_params_list = [
|
|
SamplingParams(end_id=-1, pad_id=-1) for _ in range(2)
|
|
]
|
|
sampling_params_list[0].max_tokens = 4
|
|
sampling_params_list[1].max_tokens = 8
|
|
|
|
for i, output in enumerate(
|
|
llm.generate(prompts, sampling_params=sampling_params_list)):
|
|
output_len = len(output.outputs[0].token_ids)
|
|
print(f"output_len: {output_len}")
|
|
assert output_len <= sampling_params_list[i].max_tokens
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.parametrize(
|
|
"sampling_params",
|
|
[
|
|
# temperature
|
|
SamplingParams(
|
|
max_tokens=6, temperature=0.5, beam_search_diversity_rate=0.5),
|
|
# topK
|
|
SamplingParams(max_tokens=6, top_k=10, top_p=0.92),
|
|
# topP
|
|
SamplingParams(max_tokens=6, top_p=0.92),
|
|
# penalty
|
|
SamplingParams(max_tokens=8,
|
|
length_penalty=1.0,
|
|
presence_penalty=0.0,
|
|
repetition_penalty=1.0,
|
|
min_tokens=5),
|
|
# early stopping
|
|
SamplingParams(max_tokens=6, early_stopping=5),
|
|
# n-returns
|
|
SamplingParams(max_tokens=6, n=2, top_k=2),
|
|
SamplingParams(max_tokens=6, n=2, top_k=2, best_of=3),
|
|
SamplingParams(max_tokens=6, n=3, use_beam_search=True),
|
|
SamplingParams(max_tokens=6, n=2, best_of=3, use_beam_search=True),
|
|
])
|
|
@pytest.mark.part0
|
|
def test_generate_with_SamplingConfig(llm_for_sampling_params: LLM,
|
|
sampling_params: SamplingParams):
|
|
llm = llm_for_sampling_params
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
print(output)
|
|
assert len(output.outputs) == sampling_params.n
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_seed(llm_for_sampling_params: LLM):
|
|
prompts = ["The capital of France is"] * 10
|
|
# Use a high temperature and large max_tokens to increase the diversity
|
|
sampling_params = [
|
|
SamplingParams(temperature=100, top_k=100, max_tokens=100)
|
|
for _ in range(10)
|
|
]
|
|
# Fix the seed for the first 5 prompts
|
|
for i in range(5):
|
|
sampling_params[i].seed = 515
|
|
|
|
llm = llm_for_sampling_params
|
|
generated_texts = []
|
|
for output in llm.generate(prompts, sampling_params):
|
|
generated_texts.append(output.outputs[0].text)
|
|
for output in llm.generate(prompts, sampling_params):
|
|
generated_texts.append(output.outputs[0].text)
|
|
|
|
assert len(generated_texts) == 20
|
|
assert len(set(generated_texts)) == 11
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_beam_search(llm_for_sampling_params: LLM):
|
|
llm = llm_for_sampling_params
|
|
references = [["D E F G H I", "D E F G I J"]]
|
|
sampling_params = SamplingParams(max_tokens=6, beam_width=2)
|
|
|
|
# Non-streaming mode
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
print(outputs)
|
|
check_output(outputs, references)
|
|
|
|
# Streaming mode
|
|
outputs = [
|
|
llm.generate_async(prompt, sampling_params, streaming=True)
|
|
for prompt in prompts
|
|
]
|
|
outputs = [output.result() for output in outputs]
|
|
print(outputs)
|
|
check_output(outputs, references)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_streaming_llm():
|
|
# TODO[chunweiy]: Test with larger size when the underlying support is ready
|
|
build_config = BuildConfig()
|
|
build_config.plugin_config.streamingllm = True
|
|
build_config.max_batch_size = 8
|
|
build_config.max_seq_len = 512
|
|
kv_cache_config = KvCacheConfig(max_attention_window=[64],
|
|
sink_token_length=4)
|
|
|
|
# Check the plugin config is correctly set
|
|
assert build_config.plugin_config.streamingllm is True
|
|
|
|
sampling_params = SamplingParams(max_tokens=4)
|
|
|
|
llm_test_harness(llama_model_path,
|
|
prompts, ["D E F G"],
|
|
sampling_params=sampling_params,
|
|
build_config=build_config,
|
|
kv_cache_config=kv_cache_config)
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_parallel_config():
|
|
config = _ParallelConfig()
|
|
config.tp_size = 2
|
|
config.pp_size = 2
|
|
assert config.world_size == 4
|
|
config.world_size = 4 # should not raise exception
|
|
|
|
with pytest.raises(ValueError):
|
|
config.world_size = 5
|
|
|
|
|
|
@force_ampere # Save H100 resource
|
|
@pytest.mark.parametrize("gather_context_logits", [True, False])
|
|
@pytest.mark.parametrize("gather_generation_logits", [True, False])
|
|
@pytest.mark.parametrize("return_log_probs", [True]) # prune space
|
|
@pytest.mark.part0
|
|
def test_generate_with_OutputConfig(gather_context_logits: bool,
|
|
gather_generation_logits: bool,
|
|
return_log_probs: bool):
|
|
if not (gather_context_logits or gather_generation_logits): # prune space
|
|
return
|
|
|
|
build_config = BuildConfig()
|
|
build_config.max_batch_size = 128 # reduce buffer sizes, specially for generation logits
|
|
build_config.gather_context_logits = gather_context_logits
|
|
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
build_config=build_config,
|
|
gather_generation_logits=gather_generation_logits,
|
|
fast_build=True,
|
|
)
|
|
sampling_params = SamplingParams(
|
|
max_tokens=8,
|
|
return_context_logits=gather_context_logits,
|
|
return_generation_logits=gather_generation_logits,
|
|
return_log_probs=return_log_probs)
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
if gather_context_logits:
|
|
assert output.context_logits is not None
|
|
assert len(prompts[0].split()) + \
|
|
1 == output.context_logits.shape[0]
|
|
if gather_generation_logits:
|
|
assert output.outputs[0].generation_logits is not None
|
|
assert sampling_params.max_tokens == output.outputs[
|
|
0].generation_logits.shape[0]
|
|
if return_log_probs:
|
|
assert output.outputs[0].logprobs is not None
|
|
|
|
print(output)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_stop_words():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
stop_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1]
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K L M"],
|
|
sampling_params=SamplingParams(end_id=stop_id),
|
|
finish_reasons=['stop'],
|
|
stop_reasons=[None])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H"],
|
|
sampling_params=SamplingParams(max_tokens=5),
|
|
finish_reasons=['length'],
|
|
stop_reasons=[None])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K L M"],
|
|
sampling_params=SamplingParams(stop_token_ids=[stop_id]),
|
|
finish_reasons=['stop'],
|
|
stop_reasons=[stop_id])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K L M N"],
|
|
sampling_params=SamplingParams(
|
|
stop_token_ids=[stop_id],
|
|
include_stop_str_in_output=True),
|
|
finish_reasons=['stop'],
|
|
stop_reasons=[stop_id])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H"],
|
|
sampling_params=SamplingParams(stop="I J"),
|
|
finish_reasons=['stop'],
|
|
stop_reasons=["I J"])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K L M"],
|
|
sampling_params=SamplingParams(stop="I E", max_tokens=10),
|
|
finish_reasons=['length'],
|
|
stop_reasons=[None])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J"],
|
|
sampling_params=SamplingParams(
|
|
stop="I J", include_stop_str_in_output=True),
|
|
finish_reasons=['stop'],
|
|
stop_reasons=["I J"])
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H"],
|
|
sampling_params=SamplingParams(stop=["F E", "I J"],
|
|
stop_token_ids=[stop_id]),
|
|
finish_reasons=['stop'],
|
|
stop_reasons=["I J"])
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_bad_words():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
|
|
bad_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1]
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K L M\n\nI hope this"],
|
|
sampling_params=SamplingParams(max_tokens=15,
|
|
bad_token_ids=[bad_id]))
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I K L M N O P Q R S"],
|
|
sampling_params=SamplingParams(max_tokens=15, bad="I J"))
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I K L M N O P Q R S"],
|
|
sampling_params=SamplingParams(max_tokens=15,
|
|
bad=["F E", "I J"]))
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_sampling_params_misc():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
tokenizer_mode='slow',
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
|
|
fake_end_id = llm.tokenizer.encode("N", add_special_tokens=False)[-1]
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K L M"],
|
|
sampling_params=SamplingParams(max_tokens=15,
|
|
end_id=fake_end_id))
|
|
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I K L M N O P Q R S"],
|
|
sampling_params=SamplingParams(max_tokens=15,
|
|
end_id=fake_end_id,
|
|
ignore_eos=True))
|
|
|
|
llm_check_output(llm,
|
|
prompts, [""],
|
|
sampling_params=SamplingParams(max_tokens=15,
|
|
end_id=fake_end_id,
|
|
detokenize=False))
|
|
|
|
outputs = llm.generate(prompts)
|
|
assert outputs[0].prompt_token_ids == [1, 319, 350, 315]
|
|
|
|
outputs = llm.generate(prompts, SamplingParams(add_special_tokens=False))
|
|
assert outputs[0].prompt_token_ids == [319, 350, 315]
|
|
|
|
outputs = llm.generate(prompts, SamplingParams(truncate_prompt_tokens=2))
|
|
assert outputs[0].prompt_token_ids == [1, 315]
|
|
|
|
# Use embedding bias to force the output tokens to be special tokens
|
|
unk_id = llm.tokenizer.encode('<unk>', add_special_tokens=False)[-1]
|
|
vocab_size_padded = 32000
|
|
embedding_bias = torch.zeros(vocab_size_padded)
|
|
embedding_bias[unk_id] = torch.finfo(torch.float32).max
|
|
|
|
outputs = llm.generate(
|
|
prompts, SamplingParams(max_tokens=5, embedding_bias=embedding_bias))
|
|
assert outputs[0].outputs[0].text == ""
|
|
|
|
outputs = llm.generate(
|
|
prompts,
|
|
SamplingParams(max_tokens=5,
|
|
embedding_bias=embedding_bias,
|
|
skip_special_tokens=False,
|
|
spaces_between_special_tokens=False))
|
|
assert outputs[0].outputs[0].text == "<unk><unk><unk><unk><unk>"
|
|
|
|
outputs = llm.generate(
|
|
prompts,
|
|
SamplingParams(max_tokens=5,
|
|
embedding_bias=embedding_bias,
|
|
skip_special_tokens=False,
|
|
spaces_between_special_tokens=True))
|
|
assert outputs[0].outputs[0].text == "<unk> <unk> <unk> <unk> <unk>"
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_generate_with_embedding_bias():
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(llama_model_path)
|
|
biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1]
|
|
vocab_size_padded = 32000
|
|
embedding_bias = torch.zeros(vocab_size_padded)
|
|
embedding_bias[biased_word_id] = torch.finfo(torch.float32).max
|
|
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
embedding_bias=embedding_bias)
|
|
|
|
llm_test_harness(
|
|
llama_model_path,
|
|
prompts, ["Z Z Z Z Z Z"],
|
|
sampling_params=sampling_params,
|
|
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4))
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_invalid_embedding_bias():
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(llama_model_path)
|
|
biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1]
|
|
vocab_size_padded = 32000
|
|
|
|
# Should raise "Embedding bias data type must be same as model logits type"
|
|
embedding_bias = torch.zeros(vocab_size_padded, dtype=torch.float16)
|
|
embedding_bias[biased_word_id] = torch.finfo(torch.float16).max
|
|
|
|
llm = LLM(llama_model_path, fast_build=True)
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
embedding_bias=embedding_bias)
|
|
|
|
try:
|
|
llm.generate(["A B C"], sampling_params=sampling_params)
|
|
except RequestError:
|
|
return
|
|
|
|
assert (0)
|
|
|
|
|
|
@skip_pre_hopper
|
|
@pytest.mark.part0
|
|
def test_generate_with_embedding_bias_fp8():
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(llama_model_path)
|
|
biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1]
|
|
vocab_size_padded = 32000
|
|
|
|
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
|
|
kv_cache_quant_algo=QuantAlgo.FP8)
|
|
assert quant_config.quant_mode.has_any_quant()
|
|
|
|
llm = LLM(llama_model_path, quant_config=quant_config, fast_build=True)
|
|
|
|
# FP32 embedding bias input (will be converted to FP16)
|
|
embedding_bias = torch.zeros(vocab_size_padded)
|
|
embedding_bias[biased_word_id] = torch.finfo(torch.float32).max
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
embedding_bias=embedding_bias)
|
|
|
|
for output in llm.generate(["A B C"], sampling_params=sampling_params):
|
|
print(output)
|
|
assert output.outputs[0].text == "Z Z Z Z Z Z"
|
|
|
|
# FP16 embedding bias input
|
|
embedding_bias = torch.zeros(vocab_size_padded, dtype=torch.float16)
|
|
embedding_bias[biased_word_id] = torch.finfo(torch.float16).max
|
|
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
embedding_bias=embedding_bias)
|
|
|
|
for output in llm.generate(["A B C"], sampling_params=sampling_params):
|
|
print(output)
|
|
assert output.outputs[0].text == "Z Z Z Z Z Z"
|
|
|
|
|
|
@skip_pre_hopper
|
|
@pytest.mark.part0
|
|
def test_invalid_embedding_bias_fp8():
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(llama_model_path)
|
|
biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1]
|
|
vocab_size_padded = 32000
|
|
|
|
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
|
|
kv_cache_quant_algo=QuantAlgo.FP8)
|
|
assert quant_config.quant_mode.has_any_quant()
|
|
|
|
llm = LLM(llama_model_path, quant_config=quant_config, fast_build=True)
|
|
|
|
# Should raise "Embedding bias tensor needs to be in CPU memory for casting"
|
|
embedding_bias = torch.zeros(vocab_size_padded, device='cuda')
|
|
embedding_bias[biased_word_id] = torch.finfo(torch.float32).max
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
embedding_bias=embedding_bias)
|
|
|
|
try:
|
|
llm.generate(["A B C"], sampling_params=sampling_params)
|
|
except RequestError:
|
|
return
|
|
|
|
assert (0)
|
|
|
|
|
|
class MyLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, biased_word_id):
|
|
self.biased_word_id = biased_word_id
|
|
|
|
def __call__(self, req_id: int, logits: torch.Tensor, ids: List[List[int]],
|
|
stream_ptr: int, client_id: Optional[int]):
|
|
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
|
|
logits[:] = float("-inf")
|
|
logits[..., self.biased_word_id] = 0
|
|
|
|
|
|
def tinyllama_logits_processor_test_harness(**llm_kwargs):
|
|
tokenizer = TransformersTokenizer.from_pretrained(llama_model_path)
|
|
biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1]
|
|
sampling_params = SamplingParams(
|
|
max_tokens=6, logits_processor=MyLogitsProcessor(biased_word_id))
|
|
|
|
llm_test_harness(
|
|
llama_model_path,
|
|
prompts, ["Z Z Z Z Z Z"],
|
|
sampling_params=sampling_params,
|
|
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
|
|
**llm_kwargs)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_tinyllama_logits_processor():
|
|
tinyllama_logits_processor_test_harness()
|
|
|
|
|
|
class MyBatchedLogitsProcessor(BatchedLogitsProcessor):
|
|
|
|
def __init__(self, biased_word_id):
|
|
self.biased_word_id = biased_word_id
|
|
|
|
def __call__(self, req_ids_batch: List[int],
|
|
logits_batch: List[torch.Tensor],
|
|
token_ids_batch: List[List[List[int]]], stream_ptr: int,
|
|
client_ids_batch: List[Optional[int]]):
|
|
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
|
|
for logits in logits_batch:
|
|
logits[:] = float("-inf")
|
|
logits[..., self.biased_word_id] = 0
|
|
|
|
|
|
def tinyllama_logits_processor_batched_test_harness(**llm_kwargs):
|
|
tokenizer = TransformersTokenizer.from_pretrained(llama_model_path)
|
|
biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1]
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
apply_batched_logits_processor=True)
|
|
|
|
llm_test_harness(
|
|
llama_model_path,
|
|
prompts, ["Z Z Z Z Z Z"],
|
|
sampling_params=sampling_params,
|
|
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
|
|
batched_logits_processor=MyBatchedLogitsProcessor(biased_word_id),
|
|
**llm_kwargs)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
def test_tinyllama_logits_processor_batched():
|
|
tinyllama_logits_processor_batched_test_harness()
|
|
|
|
|
|
def tinyllama_guided_decoding_test_harness(**llm_kwargs):
|
|
prompts = [
|
|
"What is 1+1? Answer formatted in a dict in json format: ",
|
|
"What is the year after 2024? Answer: ",
|
|
]
|
|
|
|
class Answer(BaseModel):
|
|
answer: int
|
|
|
|
json_schema = json.dumps(Answer.model_json_schema())
|
|
regex = r"\d+"
|
|
ebnf_grammar = "root ::= [0-9]+"
|
|
|
|
sampling_params = [
|
|
SamplingParams(max_tokens=10),
|
|
SamplingParams(max_tokens=10,
|
|
guided_decoding=GuidedDecodingParams(json_object=True)),
|
|
SamplingParams(max_tokens=10,
|
|
guided_decoding=GuidedDecodingParams(json=json_schema)),
|
|
SamplingParams(max_tokens=10,
|
|
guided_decoding=GuidedDecodingParams(regex=regex)),
|
|
SamplingParams(
|
|
max_tokens=10,
|
|
guided_decoding=GuidedDecodingParams(grammar=ebnf_grammar)),
|
|
]
|
|
|
|
num_prompts, num_sampling_params = len(prompts), len(sampling_params)
|
|
prompts = [p for p in prompts for _ in range(num_sampling_params)]
|
|
sampling_params = [sp for _ in range(num_prompts) for sp in sampling_params]
|
|
references = [
|
|
'\n\n```\n{\n "1":',
|
|
'{"1": "1", "2": "',
|
|
'{"answer": 1}',
|
|
'1',
|
|
'1',
|
|
'2025\n\nQuestion 3:',
|
|
'[2025]',
|
|
'{"answer": 202',
|
|
'2025',
|
|
'2025',
|
|
]
|
|
llm_test_harness(llama_model_path,
|
|
prompts,
|
|
references,
|
|
sampling_params=sampling_params,
|
|
guided_decoding_backend='xgrammar',
|
|
similar_threshold=0.7,
|
|
**llm_kwargs)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.part0
|
|
@pytest.mark.parametrize("backend", ['tensorrt', 'pytorch'])
|
|
def test_tinyllama_guided_decoding(backend: str):
|
|
llm_kwargs = {}
|
|
if backend == 'pytorch':
|
|
llm_kwargs['backend'] = 'pytorch'
|
|
tinyllama_guided_decoding_test_harness(**llm_kwargs)
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_api_medusa():
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
build_config = BuildConfig(
|
|
max_batch_size=1,
|
|
max_seq_len=1024,
|
|
)
|
|
|
|
kv_cache_config = KvCacheConfig(enable_block_reuse=True)
|
|
|
|
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
|
|
max_draft_len=63,
|
|
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
|
|
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
|
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
|
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
|
[0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
|
|
[6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
|
|
[0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
|
|
)
|
|
llm = LLM(model=get_model_path("vicuna-7b-v1.3"),
|
|
build_config=build_config,
|
|
kv_cache_config=kv_cache_config,
|
|
speculative_config=speculative_config,
|
|
fast_build=True)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Print the outputs.
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
@skip_single_gpu
|
|
@pytest.mark.part0
|
|
def test_llm_api_medusa_tp2():
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
build_config = BuildConfig(max_batch_size=1, max_seq_len=1024)
|
|
|
|
kv_cache_config = KvCacheConfig(enable_block_reuse=True)
|
|
|
|
speculative_config = MedusaDecodingConfig(num_medusa_heads=4,
|
|
max_draft_len=63,
|
|
speculative_model=get_model_path("medusa-vicuna-7b-v1.3"),
|
|
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
|
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
|
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
|
[0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
|
|
[6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
|
|
[0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
|
|
)
|
|
llm = LLM(model=get_model_path("vicuna-7b-v1.3"),
|
|
build_config=build_config,
|
|
kv_cache_config=kv_cache_config,
|
|
speculative_config=speculative_config,
|
|
tensor_parallel_size=2,
|
|
fast_build=True)
|
|
|
|
outputs = llm.generate(prompts, sampling_params, tensor_parallel_size=2)
|
|
|
|
# Print the outputs.
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
@pytest.mark.part0
|
|
def test_llm_api_eagle():
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
build_config = BuildConfig(max_batch_size=1, max_seq_len=1024)
|
|
|
|
kv_cache_config = KvCacheConfig(enable_block_reuse=True)
|
|
|
|
speculative_config = EagleDecodingConfig(
|
|
max_draft_len=63,
|
|
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
|
num_eagle_layers=4,
|
|
max_non_leaves_per_layer=10,
|
|
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
|
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
|
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
|
[0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
|
|
[6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
|
|
[0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
|
|
)
|
|
llm = LLM(model=get_model_path("vicuna-7b-v1.3"),
|
|
build_config=build_config,
|
|
kv_cache_config=kv_cache_config,
|
|
speculative_config=speculative_config,
|
|
fast_build=True)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Print the outputs.
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
@skip_single_gpu
|
|
@pytest.mark.part0
|
|
def test_llm_api_eagle_tp2():
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
build_config = BuildConfig(max_batch_size=1, max_seq_len=1024)
|
|
|
|
kv_cache_config = KvCacheConfig(enable_block_reuse=True)
|
|
|
|
speculative_config = EagleDecodingConfig(
|
|
max_draft_len=63,
|
|
speculative_model=get_model_path("EAGLE-Vicuna-7B-v1.3"),
|
|
|
|
num_eagle_layers=4,
|
|
max_non_leaves_per_layer=10,
|
|
eagle_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], \
|
|
[0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], \
|
|
[0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], \
|
|
[0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], \
|
|
[6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], \
|
|
[0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
|
|
)
|
|
llm = LLM(model=get_model_path("vicuna-7b-v1.3"),
|
|
build_config=build_config,
|
|
kv_cache_config=kv_cache_config,
|
|
speculative_config=speculative_config,
|
|
tensor_parallel_size=2,
|
|
fast_build=True)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Print the outputs.
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
def tinyllama_lookahead_decoding_test_harness(**llm_kwargs):
|
|
prompts = [
|
|
"A B C",
|
|
]
|
|
lookahead_config = LookaheadDecodingConfig(max_window_size=3,
|
|
max_ngram_size=3,
|
|
max_verification_set_size=3)
|
|
|
|
build_config = BuildConfig(max_batch_size=8,
|
|
max_num_tokens=128,
|
|
max_input_len=32,
|
|
max_seq_len=64)
|
|
|
|
sampling_params = [
|
|
SamplingParams(max_tokens=8, lookahead_config=lookahead_config),
|
|
]
|
|
|
|
num_prompts, num_sampling_params = len(prompts), len(sampling_params)
|
|
prompts = [p for p in prompts for _ in range(num_sampling_params)]
|
|
sampling_params = [sp for _ in range(num_prompts) for sp in sampling_params]
|
|
references = [
|
|
'D E F G H I J K',
|
|
]
|
|
llm_test_harness(llama_model_path,
|
|
prompts,
|
|
references,
|
|
sampling_params=sampling_params,
|
|
speculative_config=lookahead_config,
|
|
build_config=build_config,
|
|
kv_cache_config=global_kvcache_config,
|
|
**llm_kwargs)
|
|
|
|
|
|
@force_ampere
|
|
def test_tinyllama_lookahead_decoding():
|
|
tinyllama_lookahead_decoding_test_harness()
|
|
|
|
|
|
@force_ampere
|
|
def test_executor_lookahead_decoding_config():
|
|
lookahead_config = LookaheadDecodingConfig(max_window_size=10,
|
|
max_ngram_size=9,
|
|
max_verification_set_size=8)
|
|
sampling_params = SamplingParams(max_tokens=3,
|
|
lookahead_config=lookahead_config)
|
|
|
|
assert sampling_params.lookahead_config.max_window_size == 10
|
|
assert sampling_params.lookahead_config.max_ngram_size == 9
|
|
assert sampling_params.lookahead_config.max_verification_set_size == 8
|
|
|
|
|
|
def llama_v2_13b_lora_test_harness(**llm_kwargs):
|
|
hf_model_dir = get_model_path("llama-models-v2/llama-v2-13b-hf")
|
|
hf_lora_dir = get_model_path("llama-models-v2/chinese-llama-2-lora-13b")
|
|
|
|
# For LoRA checkpoints with finetuned embedding and lm_head, lora_dir must be provided at build time.
|
|
build_config = BuildConfig(lora_config=LoraConfig(lora_dir=[hf_lora_dir]))
|
|
llm = LLM(hf_model_dir,
|
|
tokenizer=hf_lora_dir,
|
|
enable_lora=True,
|
|
max_lora_rank=64,
|
|
build_config=build_config,
|
|
fast_build=True,
|
|
**llm_kwargs)
|
|
|
|
prompts = [
|
|
"今天天气很好,我到公园的时候,",
|
|
"今天天气很好,我到公园的时候,",
|
|
]
|
|
references = [
|
|
"看见好多人们都看书,看书书看书书,看书书看书书书书书书",
|
|
"发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的",
|
|
]
|
|
lora_req = LoRARequest("Chinese", 1, hf_lora_dir)
|
|
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
|
|
outputs = llm.generate(prompts,
|
|
sampling_params,
|
|
lora_request=[None, lora_req])
|
|
for output, ref in zip(outputs, references):
|
|
assert similar(output.outputs[0].text, ref)
|
|
|
|
|
|
def llama_7b_multi_lora_test_harness(**llm_kwargs):
|
|
hf_model_dir = get_model_path("llama-models/llama-7b-hf")
|
|
hf_lora_dir1 = get_model_path("llama-models/luotuo-lora-7b-0.1")
|
|
hf_lora_dir2 = get_model_path("llama-models/Japanese-Alpaca-LoRA-7b-v0")
|
|
|
|
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
|
|
# (1) specify lora_target_modules, or
|
|
# (2) provide a lora_dir to infer the lora_target_modules.
|
|
build_config = BuildConfig(lora_config=LoraConfig(
|
|
lora_target_modules=['attn_q', 'attn_k', 'attn_v']))
|
|
llm = LLM(hf_model_dir,
|
|
enable_lora=True,
|
|
max_lora_rank=8,
|
|
build_config=build_config,
|
|
fast_build=True,
|
|
**llm_kwargs)
|
|
|
|
prompts = [
|
|
"美国的首都在哪里? \n答案:",
|
|
"美国的首都在哪里? \n答案:",
|
|
"美国的首都在哪里? \n答案:",
|
|
"アメリカ合衆国の首都はどこですか? \n答え:",
|
|
"アメリカ合衆国の首都はどこですか? \n答え:",
|
|
"アメリカ合衆国の首都はどこですか? \n答え:",
|
|
]
|
|
references = [
|
|
"沃尔玛\n\n## 新闻\n\n* ",
|
|
"美国的首都是华盛顿。\n\n美国的",
|
|
"纽约\n\n### カンファレンスの",
|
|
"Washington, D.C.\nWashington, D.C. is the capital of the United",
|
|
"华盛顿。\n\n英国の首都是什",
|
|
"ワシントン\nQ1. アメリカ合衆国",
|
|
]
|
|
lora_req1 = LoRARequest("luotuo", 1, hf_lora_dir1)
|
|
lora_req2 = LoRARequest("Japanese", 2, hf_lora_dir2)
|
|
sampling_params = SamplingParams(max_tokens=20)
|
|
outputs = llm.generate(
|
|
prompts,
|
|
sampling_params,
|
|
lora_request=[None, lora_req1, lora_req2, None, lora_req1, lora_req2])
|
|
for output, ref in zip(outputs, references):
|
|
assert similar(output.outputs[0].text, ref)
|
|
|
|
|
|
@skip_gpu_memory_less_than_40gb
|
|
def test_llama_v2_13b_lora():
|
|
llama_v2_13b_lora_test_harness()
|
|
|
|
|
|
@skip_gpu_memory_less_than_40gb
|
|
def test_llama_7b_multi_lora():
|
|
llama_7b_multi_lora_test_harness(max_loras=1, max_cpu_loras=8)
|
|
|
|
|
|
def llama_v2_7b_prompt_adapter_test_harness(**llm_kwargs):
|
|
hf_model_dir = get_model_path("llama-models-v2/llama-v2-7b-hf")
|
|
hf_prompt_adapter_dir = get_model_path("llama-models-v2/llama_tweet_ptune")
|
|
llm = LLM(hf_model_dir,
|
|
enable_prompt_adapter=True,
|
|
max_prompt_adapter_token=8,
|
|
fast_build=True,
|
|
**llm_kwargs)
|
|
|
|
prompts = [
|
|
"Born in north-east France, Soyer trained as a",
|
|
"Born in north-east France, Soyer trained as a",
|
|
"Tweet text: I have complaints! Label: ",
|
|
"Tweet text: I have complaints! Label: ",
|
|
"Tweet text: I have no problems Label: ",
|
|
"Tweet text: I have no problems Label: ",
|
|
]
|
|
references = [
|
|
"painter at the École des Beaux-Arts in Paris. He was a member of the",
|
|
"chef and has worked in the restaurant industry for 15 years.Ћ\nBorn in north",
|
|
"1999.\nTweet text: I have complaints! Label: 19",
|
|
"no complaint",
|
|
"100%\nI have no problems Label: 100%\nI have no",
|
|
"no complaint",
|
|
]
|
|
pa_req = PromptAdapterRequest('tweet', 1, hf_prompt_adapter_dir)
|
|
sampling_params = SamplingParams(max_tokens=20)
|
|
outputs = llm.generate(
|
|
prompts,
|
|
sampling_params,
|
|
prompt_adapter_request=[None, pa_req, None, pa_req, None, pa_req])
|
|
for output, ref in zip(outputs, references):
|
|
assert similar(output.outputs[0].text, ref)
|
|
|
|
|
|
@skip_gpu_memory_less_than_40gb
|
|
def test_llama_v2_7b_prompt_adapter():
|
|
llama_v2_7b_prompt_adapter_test_harness(
|
|
kv_cache_config=global_kvcache_config_no_reuse)
|
|
|
|
|
|
@force_ampere
|
|
def test_generate_block_reuse():
|
|
build_config = BuildConfig()
|
|
build_config.plugin_config._use_paged_context_fmha = True
|
|
build_config.plugin_config._paged_kv_cache = True
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4,
|
|
enable_block_reuse=True),
|
|
build_config=build_config,
|
|
fast_build=True)
|
|
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
|
|
prompts = ["A B C", "A B C D"]
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
|
|
def test_executor_results_cleanup():
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True)
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
for i in range(20):
|
|
llm.generate(prompts, sampling_params=sampling_params)
|
|
|
|
num_remaining_results = len(llm._executor._results)
|
|
print(f"result.size: {num_remaining_results}")
|
|
assert num_remaining_results == 0
|
|
|
|
|
|
@pytest.mark.parametrize("trust_remote_code", [True, False])
|
|
def _test_llm_trust_remote_code(trust_remote_code: bool):
|
|
# OOM when tested with other cases
|
|
# TODO[chunweiy]: Enable this later
|
|
|
|
if trust_remote_code:
|
|
internlm_model_path = get_model_path("internlm-chat-7b")
|
|
llm = LLM(model=internlm_model_path,
|
|
trust_remote_code=trust_remote_code,
|
|
tokenizer=TransformersTokenizer.from_pretrained(
|
|
internlm_model_path, trust_remote_code=trust_remote_code),
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True)
|
|
sampling_params = SamplingParams(max_tokens=6,
|
|
temperature=0.8,
|
|
top_p=0.95)
|
|
prompts = [
|
|
"The future of AI is",
|
|
]
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
print(output)
|
|
else:
|
|
with pytest.raises(ValueError):
|
|
llm = LLM(model="internlm/internlm-chat-7b",
|
|
trust_remote_code=trust_remote_code,
|
|
tokenizer="internlm/internlm-chat-7b",
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True)
|
|
|
|
|
|
def test_llm_build_cache():
|
|
# Activate the build-cache
|
|
cache_config = BuildCacheConfig(max_records=1, max_cache_storage_gb=10)
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
|
|
def first_run():
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
enable_build_cache=cache_config,
|
|
fast_build=True)
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K"],
|
|
sampling_params=sampling_params)
|
|
|
|
def second_run():
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
enable_build_cache=cache_config,
|
|
fast_build=True)
|
|
llm_check_output(llm,
|
|
prompts, ["D E F G H I J K"],
|
|
sampling_params=sampling_params)
|
|
|
|
# the cache should be hit
|
|
assert llm.llm_build_stats.cache_hitted, llm.llm_build_stats.cache_info
|
|
|
|
first_run()
|
|
second_run()
|
|
|
|
|
|
class DummyError(Exception):
|
|
pass
|
|
|
|
|
|
class DummyExecutorMeta(type):
|
|
|
|
def __new__(cls, name, bases, dic, worker_cls):
|
|
new_cls = super().__new__(cls, name, bases, dic)
|
|
|
|
@classmethod
|
|
def create(cls, engine, executor_config, *args, **kwargs):
|
|
return worker_cls(engine=engine, executor_config=executor_config)
|
|
|
|
new_cls.create = create
|
|
return new_cls
|
|
|
|
|
|
def check_llm_return_context_logits(tp_size=1):
|
|
build_config = BuildConfig(gather_context_logits=True)
|
|
|
|
llm = LLM(
|
|
llama_model_path,
|
|
tensor_parallel_size=tp_size,
|
|
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
|
|
build_config=build_config,
|
|
fast_build=True,
|
|
)
|
|
|
|
sampling_params = SamplingParams(max_tokens=8, return_context_logits=True)
|
|
|
|
prompts = ["A B C D E F G H I J K"] * 8
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
assert isinstance(output.context_logits, torch.Tensor)
|
|
print(output)
|
|
|
|
# Check the WAR for returning logits performance
|
|
if tp_size == 1:
|
|
assert isinstance(llm._executor, ExecutorBindingsWorker)
|
|
|
|
|
|
def check_llm_return_generation_logits(tp_size=1):
|
|
|
|
llm = LLM(
|
|
llama_model_path,
|
|
tensor_parallel_size=tp_size,
|
|
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4),
|
|
gather_generation_logits=True,
|
|
fast_build=True,
|
|
)
|
|
|
|
sampling_params = SamplingParams(max_tokens=8,
|
|
return_generation_logits=True)
|
|
|
|
prompts = ["A B C D E F G H I J K"] * 8
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
assert isinstance(output.outputs[0].generation_logits, torch.Tensor)
|
|
print(output)
|
|
|
|
# Check the WAR for returning logits performance
|
|
if tp_size == 1:
|
|
assert isinstance(llm._executor, ExecutorBindingsWorker)
|
|
|
|
|
|
def test_llm_return_context_logits():
|
|
check_llm_return_context_logits(tp_size=1)
|
|
|
|
|
|
def test_llm_return_generation_logits():
|
|
check_llm_return_generation_logits(tp_size=1)
|
|
|
|
|
|
class DummyExecutorWorker3(ExecutorBindingsWorker):
|
|
should_raise_error = True
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.counter = 0
|
|
self.failed_requests = set()
|
|
|
|
def _engine_response_callback(self, response: tllm.Response):
|
|
if response.client_id in self.failed_requests:
|
|
return response
|
|
# Making the first response failed, and the subsequent responses successful
|
|
if DummyExecutorWorker3.should_raise_error:
|
|
DummyExecutorWorker3.should_raise_error = False
|
|
print(f"Raise error for {response.client_id}")
|
|
self.failed_requests.add(response.client_id)
|
|
return tllm.Response(
|
|
request_id=0, # dummy value
|
|
client_id=response.client_id,
|
|
error_msg="Test error")
|
|
else:
|
|
return response
|
|
|
|
|
|
DummyExecutor3 = DummyExecutorMeta("DummyExecutor3", (), {},
|
|
worker_cls=DummyExecutorWorker3)
|
|
|
|
|
|
@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5063025")
|
|
def test_llm_handling_per_requeust_error():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
executor_cls=DummyExecutor3,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
# The dummy executor will delay the responses
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
|
|
def batch_task():
|
|
DummyExecutorWorker3.should_raise_error = True
|
|
with pytest.raises(RequestError):
|
|
for output in llm.generate(prompts,
|
|
sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
batch_task()
|
|
|
|
|
|
@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5063025")
|
|
def test_llm_handling_per_requeust_error_async():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
executor_cls=DummyExecutor3,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
)
|
|
# The dummy executor will delay the responses
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
|
|
# test in streaming mode
|
|
async def task():
|
|
# 10 requests, each request will get error, while the whole LLM instance is still alive
|
|
with pytest.raises(RequestError):
|
|
DummyExecutorWorker3.should_raise_error = True
|
|
async for output in llm.generate_async(
|
|
prompts[0], streaming=True,
|
|
sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
DummyExecutorWorker3.should_raise_error = False
|
|
async for output in llm.generate_async(prompts[0],
|
|
streaming=True,
|
|
sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
asyncio.run(task())
|
|
|
|
|
|
def validate_stats(pytorch_backend, results, max_tokens):
|
|
assert results
|
|
assert len(results) == max_tokens if pytorch_backend else max_tokens + 1
|
|
for iter, result in enumerate(results):
|
|
ifbStats = result["inflightBatchingStats"]
|
|
expected_num_scheduled = 1 if (iter < max_tokens) else 0
|
|
assert ifbStats["numScheduledRequests"] == expected_num_scheduled
|
|
if iter == 0:
|
|
assert ifbStats["numContextRequests"] == 1
|
|
assert ifbStats["numGenRequests"] == 0
|
|
assert result["numActiveRequests"] == 1
|
|
elif iter == max_tokens:
|
|
assert ifbStats["numContextRequests"] == 0
|
|
assert ifbStats["numGenRequests"] == 0
|
|
assert result["numActiveRequests"] == 0
|
|
else:
|
|
assert ifbStats["numContextRequests"] == 0
|
|
assert ifbStats["numGenRequests"] == 1
|
|
assert result["numActiveRequests"] == 1
|
|
|
|
#TODO: For some reason, w/o pytorch backend, numCompleted is always 0
|
|
# need to revisit this
|
|
expected_num_completed = 1 if iter == len(results) - 1 else 0
|
|
assert result["numCompletedRequests"] == expected_num_completed
|
|
|
|
|
|
def llm_get_stats_test_harness(tp_size: int = 1,
|
|
return_context_logits: bool = False,
|
|
pytorch_backend: bool = False,
|
|
use_overlap: bool = False):
|
|
llm_args_extra = {}
|
|
sampling_args_extra = {}
|
|
if return_context_logits:
|
|
llm_args_extra["build_config"] = BuildConfig(gather_context_logits=True)
|
|
sampling_args_extra["return_context_logits"] = True
|
|
|
|
if pytorch_backend:
|
|
print("Use PyTorch path...")
|
|
from tensorrt_llm._torch import LLM as LLM_torch
|
|
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
|
llm_args_extra["pytorch_backend_config"] = PyTorchConfig(
|
|
enable_iter_perf_stats=True, enable_overlap_scheduler=use_overlap)
|
|
LLM_CLASS = LLM_torch
|
|
else:
|
|
LLM_CLASS = LLM
|
|
|
|
llm = LLM_CLASS(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
tensor_parallel_size=tp_size,
|
|
fast_build=True,
|
|
**llm_args_extra)
|
|
|
|
max_tokens = 5
|
|
sampling_params = SamplingParams(max_tokens=max_tokens,
|
|
**sampling_args_extra)
|
|
|
|
for output in llm.generate(prompts, sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
results = llm.get_stats(2)
|
|
|
|
validate_stats(pytorch_backend, results, max_tokens)
|
|
|
|
assert not llm.get_stats(2)
|
|
|
|
# test that IterationResult()._done is properly set
|
|
_ = llm.generate(prompts, sampling_params=sampling_params)
|
|
assert llm.get_stats(2)
|
|
|
|
|
|
@pytest.mark.parametrize("return_context_logits, pytorch_backend, use_overlap",
|
|
[
|
|
(True, False, False),
|
|
(False, False, False),
|
|
(False, True, False),
|
|
(False, True, True),
|
|
])
|
|
def test_llm_get_stats(return_context_logits, pytorch_backend, use_overlap):
|
|
llm_get_stats_test_harness(tp_size=1,
|
|
return_context_logits=return_context_logits,
|
|
pytorch_backend=pytorch_backend,
|
|
use_overlap=use_overlap)
|
|
|
|
|
|
def llm_get_stats_async_test_harness(tp_size: int = 1,
|
|
return_context_logits: bool = False,
|
|
pytorch_backend: bool = False,
|
|
use_overlap: bool = False):
|
|
llm_args_extra = {}
|
|
sampling_args_extra = {}
|
|
if return_context_logits:
|
|
llm_args_extra["build_config"] = BuildConfig(gather_context_logits=True)
|
|
sampling_args_extra["return_context_logits"] = True
|
|
|
|
if pytorch_backend:
|
|
print("Use PyTorch path...")
|
|
from tensorrt_llm._torch import LLM as LLM_torch
|
|
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
|
llm_args_extra["pytorch_backend_config"] = PyTorchConfig(
|
|
enable_iter_perf_stats=True, enable_overlap_scheduler=use_overlap)
|
|
LLM_CLASS = LLM_torch
|
|
else:
|
|
LLM_CLASS = LLM
|
|
|
|
llm = LLM_CLASS(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
tensor_parallel_size=tp_size,
|
|
fast_build=True,
|
|
**llm_args_extra)
|
|
|
|
max_tokens = 6
|
|
sampling_params = SamplingParams(max_tokens=max_tokens,
|
|
**sampling_args_extra)
|
|
|
|
async def task0():
|
|
async for output in llm.generate_async(prompts[0],
|
|
streaming=True,
|
|
sampling_params=sampling_params):
|
|
print(output)
|
|
|
|
async def task1():
|
|
results = []
|
|
await asyncio.sleep(
|
|
3) # ensure there's stats to collect for the assertion
|
|
async for stats in llm.get_stats_async(timeout=2):
|
|
print(stats)
|
|
results.append(stats)
|
|
|
|
assert results
|
|
|
|
async def main():
|
|
for i in range(2): # test recurrent usage
|
|
await asyncio.gather(task0(), task1())
|
|
|
|
asyncio.run(main())
|
|
|
|
|
|
@pytest.mark.parametrize("return_context_logits, pytorch_backend, use_overlap",
|
|
[
|
|
(True, False, False),
|
|
(False, False, False),
|
|
(False, True, False),
|
|
(False, True, True),
|
|
])
|
|
def test_llm_get_stats_async(return_context_logits, pytorch_backend,
|
|
use_overlap):
|
|
llm_get_stats_async_test_harness(
|
|
tp_size=1,
|
|
return_context_logits=return_context_logits,
|
|
pytorch_backend=pytorch_backend,
|
|
use_overlap=use_overlap)
|
|
|
|
|
|
def test_llm_chunked_prefill():
|
|
sampling_params = SamplingParams(max_tokens=8)
|
|
build_config = BuildConfig()
|
|
build_config.plugin_config.use_paged_context_fmha = True
|
|
build_config.max_num_tokens = 64
|
|
new_tokens = 8
|
|
build_config.max_seq_len = build_config.max_num_tokens + new_tokens
|
|
|
|
def fail_path():
|
|
sampling_params = SamplingParams(max_tokens=8)
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
build_config=build_config,
|
|
enable_chunked_prefill=False,
|
|
fast_build=True)
|
|
|
|
with pytest.raises(ValueError):
|
|
output = llm.generate_async(
|
|
"A " * build_config.max_num_tokens,
|
|
sampling_params=sampling_params,
|
|
).result()
|
|
|
|
def success_path():
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
build_config=build_config,
|
|
enable_chunked_prefill=True,
|
|
fast_build=True,
|
|
)
|
|
|
|
output = llm.generate_async(
|
|
"A " * build_config.max_num_tokens,
|
|
sampling_params=sampling_params,
|
|
).result()
|
|
|
|
fail_path()
|
|
success_path()
|
|
|
|
|
|
def _test_llm_capture_request_error(tp_size: int = 1):
|
|
build_config = BuildConfig()
|
|
build_config.max_num_tokens = 64
|
|
|
|
llm = LLM(
|
|
model=llama_model_path,
|
|
build_config=build_config,
|
|
fast_build=True,
|
|
)
|
|
|
|
prompt = 'A ' * 65 # the minimum max_num_tokens is 64
|
|
|
|
with pytest.raises(RequestError):
|
|
llm.generate(prompt)
|
|
|
|
|
|
def test_llm_capture_request_error():
|
|
_test_llm_capture_request_error(tp_size=1)
|
|
|
|
|
|
def test_llm_api_jupyter_scenario():
|
|
|
|
with LLM(
|
|
model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True,
|
|
) as llm:
|
|
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
|
|
async def task():
|
|
return llm.generate(["A", "B", "C", "D"], sampling_params)
|
|
|
|
output = asyncio.run(task())
|
|
for token in output:
|
|
print(token)
|
|
|
|
|
|
def test_llm_dynamic_batch_config():
|
|
scheduler_config = SchedulerConfig(dynamic_batch_config=DynamicBatchConfig(
|
|
enable_batch_size_tuning=True,
|
|
enable_max_num_tokens_tuning=True,
|
|
dynamic_batch_moving_average_window=128))
|
|
llm_test_harness(llama_model_path,
|
|
prompts, ["D E F G H I J K"],
|
|
sampling_params=SamplingParams(max_tokens=9),
|
|
scheduler_config=scheduler_config)
|
|
|
|
|
|
def run_llm_with_postprocess_parallel(tp_size: int = 1):
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
|
|
postproc_settings = dict(_num_postprocess_workers=2,
|
|
_postprocess_tokenizer_dir=llama_model_path)
|
|
|
|
llm_test_harness(llama_model_path,
|
|
prompts, ["D E F G H I J K"],
|
|
sampling_params=sampling_params,
|
|
kv_cache_config=global_kvcache_config,
|
|
tensor_parallel_size=tp_size,
|
|
**postproc_settings)
|
|
|
|
|
|
def test_llm_with_postprocess_parallel():
|
|
run_llm_with_postprocess_parallel(tp_size=1)
|
|
|
|
|
|
def run_llm_with_postprocess_parallel_and_result_handler(
|
|
streaming, backend, tp_size: int = 1):
|
|
sampling_params = SamplingParams(max_tokens=6)
|
|
from tensorrt_llm.executor.postproc_worker import (PostprocArgs,
|
|
PostprocParams)
|
|
post_proc_args = PostprocArgs(tokenizer=llama_model_path)
|
|
post_proc_params = PostprocParams(
|
|
post_processor=perform_faked_oai_postprocess,
|
|
postproc_args=post_proc_args)
|
|
llm = LLM(model=llama_model_path,
|
|
backend=backend,
|
|
kv_cache_config=global_kvcache_config,
|
|
tensor_parallel_size=tp_size,
|
|
_num_postprocess_workers=2,
|
|
_postprocess_tokenizer_dir=llama_model_path,
|
|
fast_build=True)
|
|
golden_result = "DEFGHI"
|
|
for i, output in enumerate(
|
|
llm.generate_async(prompts[0],
|
|
sampling_params=sampling_params,
|
|
_postproc_params=post_proc_params,
|
|
streaming=streaming)):
|
|
if i < len(golden_result) - 1:
|
|
assert golden_result[i] in output.outputs[0]._postprocess_result[-1]
|
|
else:
|
|
assert golden_result[i] in output.outputs[0]._postprocess_result[
|
|
-2] # EOS
|
|
|
|
|
|
@pytest.mark.parametrize("streaming", [True, False])
|
|
@pytest.mark.parametrize("backend", [None, "pytorch"])
|
|
def test_llm_with_postprocess_parallel_and_result_handler(streaming, backend):
|
|
run_llm_with_postprocess_parallel_and_result_handler(streaming,
|
|
backend,
|
|
tp_size=1)
|
|
|
|
|
|
def run_llm_abort_request(llm: LLM, sampling_params: SamplingParams):
|
|
# to make sure LLM run slower for canceling the request to be actually performed
|
|
sampling_params.max_tokens = 100
|
|
sampling_params.end_id = -1 # let it run for a while
|
|
|
|
async def task():
|
|
result = llm.generate_async(prompts[0],
|
|
sampling_params=sampling_params,
|
|
streaming=True)
|
|
print(f"to abort")
|
|
result.abort()
|
|
|
|
print(f"waiting for the result")
|
|
# Before it actually abort, we should see some outputs
|
|
outputs = []
|
|
async for output in result:
|
|
print(f"get output: {output}")
|
|
outputs.append(output)
|
|
print(f"get {len(outputs)} remaining outputs")
|
|
print(f"outputs: {outputs}")
|
|
print(f"finish_reason: {outputs[-1].outputs[0].finish_reason}")
|
|
assert 1 <= len(
|
|
outputs) < 1000 # It should be aborted before the completion
|
|
# NOTE: known issue: only the last output is finished and got the finish_reason
|
|
assert outputs[-1].outputs[-1].finish_reason == "cancelled"
|
|
|
|
asyncio.run(task())
|
|
|
|
|
|
sampling_params_for_aborting_request = [
|
|
SamplingParams(),
|
|
# n-returns
|
|
SamplingParams(n=2, top_k=2),
|
|
SamplingParams(n=2, top_k=2, best_of=3),
|
|
SamplingParams(n=3, use_beam_search=True),
|
|
SamplingParams(n=2, best_of=3, use_beam_search=True),
|
|
]
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.parametrize("sampling_params",
|
|
sampling_params_for_aborting_request)
|
|
def test_llm_abort_request(llm_for_sampling_params,
|
|
sampling_params: SamplingParams):
|
|
run_llm_abort_request(llm=llm_for_sampling_params,
|
|
sampling_params=sampling_params)
|
|
|
|
|
|
@force_ampere
|
|
@pytest.mark.parametrize(
|
|
"sampling_params",
|
|
[
|
|
SamplingParams() # pytorch only supports n=1
|
|
])
|
|
def test_llm_abort_request_pytorch(sampling_params):
|
|
from tensorrt_llm._torch import LLM as LLM_torch
|
|
llm = LLM_torch(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config)
|
|
run_llm_abort_request(llm=llm, sampling_params=sampling_params)
|
|
|
|
|
|
def test_llm_reward_model_pytorch():
|
|
rm_model_path = get_model_path("Qwen2.5-Math-PRM-7B")
|
|
tokenizer = TransformersTokenizer.from_pretrained(rm_model_path)
|
|
tokenized_input = tokenizer(prompts, return_tensors="pt")["input_ids"]
|
|
|
|
from tensorrt_llm._torch import LLM as LLM_torch
|
|
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
|
llm = LLM_torch(
|
|
model=rm_model_path,
|
|
pytorch_backend_config=PyTorchConfig(attn_backend="VANILLA"))
|
|
|
|
sampling_params = SamplingParams(return_context_logits=True)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
scores = outputs[0].context_logits
|
|
|
|
print(scores)
|
|
|
|
assert scores.shape == (tokenized_input.shape[1], 2)
|
|
assert not outputs[0].outputs[0].text
|
|
|
|
|
|
def test_llm_sampling_params_n_lt_max_batch_size():
|
|
sampling_params = SamplingParams(n=2, best_of=1)
|
|
build_config = BuildConfig(max_batch_size=1, max_seq_len=1024)
|
|
llm = LLM(model=llama_model_path,
|
|
kv_cache_config=global_kvcache_config,
|
|
build_config=build_config,
|
|
fast_build=True)
|
|
|
|
with pytest.raises(ValueError):
|
|
llm.generate_async(prompts[0], sampling_params=sampling_params)
|
|
|
|
|
|
def test_llm_api_draft_target():
|
|
sampling_params = SamplingParams(max_tokens=4)
|
|
|
|
build_config = BuildConfig(
|
|
speculative_decoding_mode=SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL,
|
|
max_draft_len=4,
|
|
max_batch_size=2,
|
|
max_beam_width=1,
|
|
max_seq_len=128,
|
|
max_num_tokens=64)
|
|
|
|
llm = LLM(llama_model_path,
|
|
build_config=build_config,
|
|
kv_cache_config=global_kvcache_config,
|
|
fast_build=True)
|
|
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_llm_with_postprocess_parallel_and_result_handler(True, "pytorch")
|