TensorRT-LLMs/tests/bindings/test_bindings.py
Kaiyu Xie f044eb8d94
Update TensorRT-LLM (#302)
* Update TensorRT-LLM

---------

Co-authored-by: wangruohui <12756472+wangruohui@users.noreply.github.com>
2023-11-07 19:51:58 +08:00

344 lines
13 KiB
Python

import json
import tempfile
from pathlib import Path
import torch
import tensorrt_llm.bindings as _tb
def test_generation_output():
ids = torch.ones(1)
lengths = torch.ones(2)
gen_output = _tb.GenerationOutput(ids, lengths)
assert torch.equal(gen_output.ids, ids)
assert torch.equal(gen_output.lengths, lengths)
assert gen_output.log_probs is None
log_probs = torch.ones(1)
gen_output.log_probs = log_probs
assert gen_output.log_probs == log_probs
assert gen_output.context_logits is None
torch.ones(1)
gen_output.context_logits = log_probs
assert gen_output.context_logits == log_probs
def test_generation_input():
end_id = 42
pad_id = 13
ids = torch.ones(1)
lengths = torch.ones(2)
packed = True
gen_input = _tb.GenerationInput(end_id, pad_id, ids, lengths, packed)
assert gen_input.end_id == end_id
assert gen_input.pad_id == pad_id
assert torch.equal(gen_input.ids, ids)
assert torch.equal(gen_input.lengths, lengths)
assert gen_input.packed == packed
assert gen_input.max_new_tokens is None
max_new_tokens = 100
gen_input.max_new_tokens = max_new_tokens
assert gen_input.max_new_tokens == max_new_tokens
assert gen_input.embedding_bias is None
embedding_bias = torch.ones(3)
gen_input.embedding_bias = embedding_bias
assert torch.equal(gen_input.embedding_bias, embedding_bias)
assert gen_input.prompt_tuning_params.embedding_table is None
assert gen_input.prompt_tuning_params.tasks is None
assert gen_input.prompt_tuning_params.vocab_size is None
embedding_table = torch.ones(3)
tasks = torch.ones(2)
vocab_size = torch.ones(1)
prompt_tuning_params = _tb.PromptTuningParams(
embedding_table=embedding_table, tasks=tasks, vocab_size=vocab_size)
assert len(prompt_tuning_params.prompt_tuning_enabled) == 0
prompt_tuning_enabled = [True, False]
prompt_tuning_params.prompt_tuning_enabled = prompt_tuning_enabled
assert len(prompt_tuning_params.prompt_tuning_enabled) == 2
assert prompt_tuning_params.prompt_tuning_enabled == prompt_tuning_enabled
gen_input.prompt_tuning_params = prompt_tuning_params
assert gen_input.prompt_tuning_params is not None
assert torch.equal(gen_input.prompt_tuning_params.embedding_table,
embedding_table)
assert torch.equal(gen_input.prompt_tuning_params.tasks, tasks)
assert torch.equal(gen_input.prompt_tuning_params.vocab_size, vocab_size)
assert gen_input.prompt_tuning_params.prompt_tuning_enabled == prompt_tuning_enabled
def test_gpt_session_config():
kv_cache_config = _tb.KvCacheConfig()
assert kv_cache_config.max_tokens is None
max_tokens = 13
kv_cache_config.max_tokens = max_tokens
assert kv_cache_config.max_tokens == max_tokens
assert kv_cache_config.free_gpu_memory_fraction is None
free_gpu_memory_fraction = 0.5
kv_cache_config.free_gpu_memory_fraction = free_gpu_memory_fraction
assert kv_cache_config.free_gpu_memory_fraction == free_gpu_memory_fraction
max_batch_size = 1000
max_beam_width = 64
max_sequence_length = 1 << 20
gpt_session_config = _tb.GptSessionConfig(max_batch_size, max_beam_width,
max_sequence_length)
assert gpt_session_config.max_batch_size == max_batch_size
assert gpt_session_config.max_beam_width == max_beam_width
assert gpt_session_config.max_sequence_length == max_sequence_length
assert gpt_session_config.kv_cache_config is not None
assert gpt_session_config.kv_cache_config.max_tokens is None
assert gpt_session_config.kv_cache_config.free_gpu_memory_fraction is None
gpt_session_config.kv_cache_config = kv_cache_config
assert gpt_session_config.kv_cache_config.max_tokens == max_tokens
assert gpt_session_config.kv_cache_config.free_gpu_memory_fraction == free_gpu_memory_fraction
gpt_session_config.kv_cache_config.max_tokens = None
assert gpt_session_config.kv_cache_config.max_tokens is None
gpt_session_config.kv_cache_config.free_gpu_memory_fraction = None
assert gpt_session_config.kv_cache_config.free_gpu_memory_fraction is None
assert not gpt_session_config.decoder_per_request
gpt_session_config.decoder_per_request = True
assert gpt_session_config.decoder_per_request
assert not gpt_session_config.cuda_graph_mode
gpt_session_config.cuda_graph_mode = True
assert gpt_session_config.cuda_graph_mode
assert gpt_session_config.ctx_micro_batch_size is None
ctx_micro_batch_size = 10
gpt_session_config.ctx_micro_batch_size = ctx_micro_batch_size
assert gpt_session_config.ctx_micro_batch_size == ctx_micro_batch_size
assert gpt_session_config.gen_micro_batch_size is None
gen_micro_batch_size = 20
gpt_session_config.gen_micro_batch_size = gen_micro_batch_size
assert gpt_session_config.gen_micro_batch_size == gen_micro_batch_size
def test_quant_mode():
assert _tb.QuantMode.none().value == 0
assert _tb.QuantMode.int4_weights().has_int4_weights
assert _tb.QuantMode.int8_weights().has_int8_weights
assert _tb.QuantMode.activations().has_activations
assert _tb.QuantMode.per_channel_scaling().has_per_channel_scaling
assert _tb.QuantMode.per_token_scaling().has_per_token_scaling
assert _tb.QuantMode.per_group_scaling().has_per_group_scaling
assert _tb.QuantMode.int8_kv_cache().has_int8_kv_cache
assert _tb.QuantMode.fp8_kv_cache().has_fp8_kv_cache
assert _tb.QuantMode.fp8_qdq().has_fp8_qdq
quant_mode = _tb.QuantMode.from_description(True, True, True, True, True,
True, True, True)
assert quant_mode.has_int4_weights
quant_mode -= _tb.QuantMode.int4_weights()
assert not quant_mode.has_int4_weights
quant_mode += _tb.QuantMode.int4_weights()
assert quant_mode.has_int4_weights
assert _tb.QuantMode.none() == _tb.QuantMode.none()
def test_gpt_model_config():
vocab_size = 10000
num_layers = 12
num_heads = 16
hidden_size = 768
data_type = _tb.DataType.FLOAT
gpt_model_config = _tb.GptModelConfig(vocab_size, num_layers, num_heads,
hidden_size, data_type)
assert gpt_model_config.vocab_size == vocab_size
assert gpt_model_config.num_layers() == num_layers
assert gpt_model_config.num_heads == num_heads
assert gpt_model_config.hidden_size == hidden_size
assert gpt_model_config.data_type == data_type
assert gpt_model_config.vocab_size_padded(1) is not None
assert gpt_model_config.size_per_head == hidden_size // num_heads
assert gpt_model_config.num_kv_heads == num_heads
num_kv_heads = 1
gpt_model_config.num_kv_heads = num_kv_heads
assert gpt_model_config.num_kv_heads == num_kv_heads
assert not gpt_model_config.use_gpt_attention_plugin
gpt_model_config.use_gpt_attention_plugin = True
assert gpt_model_config.use_gpt_attention_plugin
assert not gpt_model_config.use_packed_input
gpt_model_config.use_packed_input = True
assert gpt_model_config.use_packed_input
assert not gpt_model_config.use_paged_kv_cache
gpt_model_config.use_paged_kv_cache = True
assert gpt_model_config.use_paged_kv_cache
assert gpt_model_config.tokens_per_block == 64
tokens_per_block = 1024
gpt_model_config.tokens_per_block = tokens_per_block
assert gpt_model_config.tokens_per_block == tokens_per_block
assert gpt_model_config.quant_mode == _tb.QuantMode.none()
gpt_model_config.quant_mode = _tb.QuantMode.int4_weights()
assert gpt_model_config.quant_mode.has_int4_weights
assert gpt_model_config.supports_inflight_batching
assert gpt_model_config.max_batch_size == 0
max_batch_size = 1000
gpt_model_config.max_batch_size = max_batch_size
assert gpt_model_config.max_batch_size == max_batch_size
assert gpt_model_config.max_input_len == 0
max_input_len = 2048
gpt_model_config.max_input_len = max_input_len
assert gpt_model_config.max_input_len == max_input_len
assert gpt_model_config.max_num_tokens is None
max_num_tokens = 10000
gpt_model_config.max_num_tokens = max_num_tokens
assert gpt_model_config.max_num_tokens == max_num_tokens
assert not gpt_model_config.compute_context_logits
gpt_model_config.compute_context_logits = True
assert gpt_model_config.compute_context_logits
assert gpt_model_config.model_variant == _tb.GptModelVariant.GPT
model_variant = _tb.GptModelVariant.GLM
gpt_model_config.model_variant = model_variant
assert gpt_model_config.model_variant == model_variant
assert not gpt_model_config.use_custom_all_reduce
gpt_model_config.use_custom_all_reduce = True
assert gpt_model_config.use_custom_all_reduce
def test_world_config():
tensor_parallelism = 2
pipeline_parallelism = 4
rank = 3
gpus_per_node = 10
world_config = _tb.WorldConfig(tensor_parallelism, pipeline_parallelism,
rank, gpus_per_node)
assert world_config.tensor_parallelism == tensor_parallelism
assert world_config.pipeline_parallelism == pipeline_parallelism
assert world_config.rank == rank
assert world_config.gpus_per_node == gpus_per_node
assert world_config.size == tensor_parallelism * pipeline_parallelism
assert world_config.is_pipeline_parallel
assert world_config.is_tensor_parallel
assert world_config.device == rank % gpus_per_node
assert world_config.pipeline_parallel_rank == rank // tensor_parallelism
assert world_config.tensor_parallel_rank == rank % tensor_parallelism
world_config = _tb.WorldConfig.mpi(gpus_per_node)
assert world_config.tensor_parallelism == 1
assert world_config.pipeline_parallelism == 1
assert world_config.gpus_per_node == gpus_per_node
assert world_config.rank == 0
def test_sampling_config():
beam_width = 12
sampling_config = _tb.SamplingConfig(beam_width)
assert sampling_config.beam_width == 12
def check_empty_then_set(member, value):
assert getattr(sampling_config, member) is None
setattr(sampling_config, member, value)
assert getattr(sampling_config, member) == value
float_array = [1., 2., 3.]
size_t_array = [1, 2, 3]
check_empty_then_set("temperature", float_array)
check_empty_then_set("min_length", size_t_array)
check_empty_then_set("repetition_penalty", float_array)
check_empty_then_set("presence_penalty", float_array)
check_empty_then_set("top_k", size_t_array)
check_empty_then_set("top_p", float_array)
check_empty_then_set("random_seed", size_t_array)
check_empty_then_set("top_p_decay", float_array)
check_empty_then_set("top_p_min", float_array)
check_empty_then_set("top_p_reset_ids", size_t_array)
check_empty_then_set("beam_search_diversity_rate", float_array)
check_empty_then_set("length_penalty", float_array)
def test_gpt_json_config():
model_config = {
"vocab_size": 1000,
"num_layers": 12,
"num_heads": 4,
"hidden_size": 512,
"data_type": _tb.DataType.FLOAT,
}
gpt_model_config = _tb.GptModelConfig(**model_config)
json_config = {
"name": "gpt",
"precision": "float32",
"tensor_parallelism": 1,
"pipeline_parallelism": 1,
"model_config": gpt_model_config
}
gpt_json_config = _tb.GptJsonConfig(**json_config)
def check_properties(the_object, properties, model_config):
for property, value in properties.items():
if isinstance(value, _tb.GptModelConfig):
object_config = getattr(the_object, property)
for subproperty, subvalue in model_config.items():
member = getattr(object_config, subproperty)
if callable(member):
member = member()
assert member == subvalue
else:
assert getattr(the_object, property) == value
check_properties(gpt_json_config, json_config, model_config)
json_dict = {
"builder_config": {
"name": json_config["name"],
"vocab_size": model_config["vocab_size"],
"num_layers": model_config["num_layers"],
"num_heads": model_config["num_heads"],
"hidden_size": model_config["hidden_size"],
"precision": json_config["precision"],
"tensor_parallel": json_config["tensor_parallelism"],
"pipeline_parallel": json_config["pipeline_parallelism"],
},
"plugin_config": {
"paged_kv_cache": False,
"tokens_per_block": 0,
"gpt_attention_plugin": False,
"remove_input_padding": False,
"use_custom_all_reduce": False,
}
}
gpt_json_config = _tb.GptJsonConfig.parse(json.dumps(json_dict))
with tempfile.NamedTemporaryFile("w", delete=False) as fp:
json.dump(json_dict, fp)
fp.close()
gpt_json_config = _tb.GptJsonConfig.parse_file(fp.name)
Path(fp.name).unlink()
rank = 3
gpus_per_node = 10
world_config = _tb.WorldConfig(json_config["tensor_parallelism"],
json_config["pipeline_parallelism"], rank,
gpus_per_node)
assert gpt_json_config.engine_filename(
world_config) == json_config["name"] + "_float32_tp1_rank3.engine"
assert gpt_json_config.engine_filename(
world_config, "llama") == "llama_float32_tp1_rank3.engine"