import os import pytest from defs.common import (convert_weights, generate_summary_cmd, venv_check_call, venv_mpi_check_call) from defs.trt_test_alternative import check_call @pytest.mark.skip_less_device_memory(80000) @pytest.mark.parametrize( "use_attention_plugin", [True, False], ids=["enable_attention_plugin", "disable_attention_plugin"], ) @pytest.mark.parametrize( "use_gemm_plugin", [True, False], ids=["enable_gemm_plugin", "disable_gemm_plugin"], ) @pytest.mark.parametrize( "jais_model_root", ["jais-13b-chat"], indirect=True, ) @pytest.mark.parametrize( "context_fmha_type", ['enabled', 'enabled_with_fp32_acc', 'disabled'], ids=[ "enable_context_fmha", "enable_context_fmha_with_fp32_acc", "disable_context_fmha" ], ) @pytest.mark.parametrize( "dtype", ["float16", "bfloat16"], ) def test_llm_jais_single_gpu_summary(jais_example_root, cmodel_dir, jais_model_root, llm_venv, engine_dir, use_attention_plugin, use_gemm_plugin, context_fmha_type, dtype): context_fmha = (context_fmha_type == 'enabled') context_fmha_fp32_acc = (context_fmha_type == 'enabled_with_fp32_acc') max_input_len = max_output_len = 512 max_batch_size = 32 if (not use_attention_plugin) and (context_fmha or context_fmha_fp32_acc): pytest.skip( "invalid combination, when attention plugin is disabled, context_fhma cannot be used" ) model_name = os.path.basename(jais_model_root) model_dir = convert_weights(llm_venv=llm_venv, example_root=f"{jais_example_root}/../gpt", cmodel_dir=cmodel_dir, model=model_name, model_path=jais_model_root, data_type=dtype) print("Building engines...") build_cmd = [ "trtllm-build", f"--checkpoint_dir={model_dir}", f"--max_batch_size={max_batch_size}", f"--max_input_len={max_input_len}", f"--max_seq_len={max_output_len + max_input_len}", f"--output_dir={engine_dir}", ] if use_attention_plugin: build_cmd.append(f"--gpt_attention_plugin={dtype}") if use_gemm_plugin: build_cmd.append(f"--gemm_plugin={dtype}") else: build_cmd.append(f"--gemm_plugin=disable") if context_fmha: build_cmd.append("--context_fmha=enable") else: build_cmd.append("--context_fmha=disable") check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env) print("Run engines...") summary_cmd = generate_summary_cmd( jais_example_root, hf_model_dir=jais_model_root, engine_dir=engine_dir, data_type=dtype, max_input_length=max_input_len, output_len=max_input_len, batch_size=max_batch_size, tensorrt_llm_rouge1_threshold=19, eval_task="summarize", ) if context_fmha_fp32_acc: summary_cmd.append("--enable_context_fmha_fp32_acc") venv_check_call(llm_venv, summary_cmd) @pytest.mark.skip_less_device_memory(80000) @pytest.mark.skip_less_device(2) @pytest.mark.parametrize( "use_attention_plugin", [True, False], ids=["enable_attention_plugin", "disable_attention_plugin"], ) @pytest.mark.parametrize( "use_gemm_plugin", [True, False], ids=["enable_gemm_plugin", "disable_gemm_plugin"], ) @pytest.mark.parametrize( "jais_model_root", ["jais-30b-chat-v3"], indirect=True, ) @pytest.mark.parametrize( "context_fmha_type", ['enabled', 'enabled_with_fp32_acc', 'disabled'], ids=[ "enable_context_fmha", "enable_context_fmha_with_fp32_acc", "disable_context_fmha" ], ) def test_llm_jais_1node_2gpus_summary(jais_example_root, cmodel_dir, jais_model_root, llm_venv, engine_dir, use_attention_plugin, use_gemm_plugin, context_fmha_type): dtype = "float16" context_fmha = (context_fmha_type == 'enabled') context_fmha_fp32_acc = (context_fmha_type == 'enabled_with_fp32_acc') max_input_len = max_output_len = 512 max_batch_size = 32 if (not use_attention_plugin) and (context_fmha or context_fmha_fp32_acc): pytest.skip( "invalid combination, when attention plugin is disabled, context_fhma cannot be used" ) model_name = os.path.basename(jais_model_root) model_dir = convert_weights(llm_venv=llm_venv, example_root=f"{jais_example_root}/../gpt", cmodel_dir=cmodel_dir, model=model_name, model_path=jais_model_root, data_type=dtype, gpus=2) print("Building engines...") build_cmd = [ "trtllm-build", f"--checkpoint_dir={model_dir}", f"--max_batch_size={max_batch_size}", f"--max_input_len={max_input_len}", f"--max_seq_len={max_output_len + max_input_len}", f"--output_dir={engine_dir}", f"--workers={2}", ] if use_attention_plugin: build_cmd.append(f"--gpt_attention_plugin={dtype}") else: build_cmd.extend([ "--gpt_attention_plugin=disable", "--context_fmha=disable", "--paged_kv_cache=disable", "--remove_input_padding=disable", ]) if use_gemm_plugin: build_cmd.append(f"--gemm_plugin={dtype}") else: build_cmd.append(f"--gemm_plugin=disable") if context_fmha: build_cmd.append("--context_fmha=enable") else: build_cmd.append("--context_fmha=disable") check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env) print("Run engines...") summary_cmd = generate_summary_cmd( jais_example_root, hf_model_dir=jais_model_root, engine_dir=engine_dir, data_type=dtype, max_input_length=max_input_len, output_len=max_input_len, batch_size=max_batch_size, tensorrt_llm_rouge1_threshold=19, eval_task="summarize", ) if context_fmha_fp32_acc: summary_cmd.append("--enable_context_fmha_fp32_acc") venv_mpi_check_call(llm_venv, ["mpirun", "-n", "2", "--allow-run-as-root"], summary_cmd)