TensorRT-LLMs/tests/integration/defs/examples/test_jais.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

191 lines
6.4 KiB
Python

import os
import pytest
from defs.common import (convert_weights, generate_summary_cmd, venv_check_call,
venv_mpi_check_call)
from defs.trt_test_alternative import check_call
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.parametrize(
"use_attention_plugin",
[True, False],
ids=["enable_attention_plugin", "disable_attention_plugin"],
)
@pytest.mark.parametrize(
"use_gemm_plugin",
[True, False],
ids=["enable_gemm_plugin", "disable_gemm_plugin"],
)
@pytest.mark.parametrize(
"jais_model_root",
["jais-13b-chat"],
indirect=True,
)
@pytest.mark.parametrize(
"context_fmha_type",
['enabled', 'enabled_with_fp32_acc', 'disabled'],
ids=[
"enable_context_fmha", "enable_context_fmha_with_fp32_acc",
"disable_context_fmha"
],
)
@pytest.mark.parametrize(
"dtype",
["float16", "bfloat16"],
)
def test_llm_jais_single_gpu_summary(jais_example_root, cmodel_dir,
jais_model_root, llm_venv, engine_dir,
use_attention_plugin, use_gemm_plugin,
context_fmha_type, dtype):
context_fmha = (context_fmha_type == 'enabled')
context_fmha_fp32_acc = (context_fmha_type == 'enabled_with_fp32_acc')
max_input_len = max_output_len = 512
max_batch_size = 32
if (not use_attention_plugin) and (context_fmha or context_fmha_fp32_acc):
pytest.skip(
"invalid combination, when attention plugin is disabled, context_fhma cannot be used"
)
model_name = os.path.basename(jais_model_root)
model_dir = convert_weights(llm_venv=llm_venv,
example_root=f"{jais_example_root}/../gpt",
cmodel_dir=cmodel_dir,
model=model_name,
model_path=jais_model_root,
data_type=dtype)
print("Building engines...")
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--max_batch_size={max_batch_size}",
f"--max_input_len={max_input_len}",
f"--max_seq_len={max_output_len + max_input_len}",
f"--output_dir={engine_dir}",
]
if use_attention_plugin:
build_cmd.append(f"--gpt_attention_plugin={dtype}")
if use_gemm_plugin:
build_cmd.append(f"--gemm_plugin={dtype}")
else:
build_cmd.append(f"--gemm_plugin=disable")
if context_fmha:
build_cmd.append("--context_fmha=enable")
else:
build_cmd.append("--context_fmha=disable")
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
print("Run engines...")
summary_cmd = generate_summary_cmd(
jais_example_root,
hf_model_dir=jais_model_root,
engine_dir=engine_dir,
data_type=dtype,
max_input_length=max_input_len,
output_len=max_input_len,
batch_size=max_batch_size,
tensorrt_llm_rouge1_threshold=19,
eval_task="summarize",
)
if context_fmha_fp32_acc:
summary_cmd.append("--enable_context_fmha_fp32_acc")
venv_check_call(llm_venv, summary_cmd)
@pytest.mark.skip_less_device_memory(80000)
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize(
"use_attention_plugin",
[True, False],
ids=["enable_attention_plugin", "disable_attention_plugin"],
)
@pytest.mark.parametrize(
"use_gemm_plugin",
[True, False],
ids=["enable_gemm_plugin", "disable_gemm_plugin"],
)
@pytest.mark.parametrize(
"jais_model_root",
["jais-30b-chat-v3"],
indirect=True,
)
@pytest.mark.parametrize(
"context_fmha_type",
['enabled', 'enabled_with_fp32_acc', 'disabled'],
ids=[
"enable_context_fmha", "enable_context_fmha_with_fp32_acc",
"disable_context_fmha"
],
)
def test_llm_jais_1node_2gpus_summary(jais_example_root, cmodel_dir,
jais_model_root, llm_venv, engine_dir,
use_attention_plugin, use_gemm_plugin,
context_fmha_type):
dtype = "float16"
context_fmha = (context_fmha_type == 'enabled')
context_fmha_fp32_acc = (context_fmha_type == 'enabled_with_fp32_acc')
max_input_len = max_output_len = 512
max_batch_size = 32
if (not use_attention_plugin) and (context_fmha or context_fmha_fp32_acc):
pytest.skip(
"invalid combination, when attention plugin is disabled, context_fhma cannot be used"
)
model_name = os.path.basename(jais_model_root)
model_dir = convert_weights(llm_venv=llm_venv,
example_root=f"{jais_example_root}/../gpt",
cmodel_dir=cmodel_dir,
model=model_name,
model_path=jais_model_root,
data_type=dtype,
gpus=2)
print("Building engines...")
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={model_dir}",
f"--max_batch_size={max_batch_size}",
f"--max_input_len={max_input_len}",
f"--max_seq_len={max_output_len + max_input_len}",
f"--output_dir={engine_dir}",
f"--workers={2}",
]
if use_attention_plugin:
build_cmd.append(f"--gpt_attention_plugin={dtype}")
else:
build_cmd.extend([
"--gpt_attention_plugin=disable",
"--context_fmha=disable",
"--paged_kv_cache=disable",
"--remove_input_padding=disable",
])
if use_gemm_plugin:
build_cmd.append(f"--gemm_plugin={dtype}")
else:
build_cmd.append(f"--gemm_plugin=disable")
if context_fmha:
build_cmd.append("--context_fmha=enable")
else:
build_cmd.append("--context_fmha=disable")
check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env)
print("Run engines...")
summary_cmd = generate_summary_cmd(
jais_example_root,
hf_model_dir=jais_model_root,
engine_dir=engine_dir,
data_type=dtype,
max_input_length=max_input_len,
output_len=max_input_len,
batch_size=max_batch_size,
tensorrt_llm_rouge1_threshold=19,
eval_task="summarize",
)
if context_fmha_fp32_acc:
summary_cmd.append("--enable_context_fmha_fp32_acc")
venv_mpi_check_call(llm_venv, ["mpirun", "-n", "2", "--allow-run-as-root"],
summary_cmd)