Add llama4 disagg accuracy tests (#4336)

* Add llama4 disagg accuracy tests

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

* Make it async and add GSM8K benchmark

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>

---------

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
Iman Tabrizian 2025-05-19 09:55:08 -04:00 committed by GitHub
parent 001704cc6a
commit c6074c47da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 107 additions and 31 deletions

View File

@ -72,14 +72,15 @@ class Evaluator(ABC):
outputs.append(output)
references.append(reference)
auxiliaries.append(aux)
results = []
for output in tqdm(outputs, desc="Fetching responses"):
output.result()
results.append(output.result())
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
profiler.reset("trtllm exec")
score = self.compute_score(outputs, references, *zip(*auxiliaries))
score = self.compute_score(results, references, *zip(*auxiliaries))
return score
@staticmethod

View File

@ -96,7 +96,7 @@ class LmEvalWrapper(TemplateLM):
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
profiler.start("trtllm exec")
outputs = []
results = []
for request in tqdm(requests,
desc="Submitting requests",
disable=disable_tqdm):
@ -104,12 +104,14 @@ class LmEvalWrapper(TemplateLM):
sampling_params = self._get_sampling_params(gen_kwargs)
output = self.llm.generate_async(prompt,
sampling_params=sampling_params)
outputs.append(output)
results.append(output)
for output in tqdm(outputs,
outputs = []
for output in tqdm(results,
desc="Fetching responses",
disable=disable_tqdm):
output.result()
outputs.append(output.result())
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")

View File

@ -8,6 +8,7 @@ import shutil
import subprocess
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
import openai
@ -20,7 +21,7 @@ from tensorrt_llm.executor.result import GenerationResultBase
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from ..conftest import llm_models_root
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
class Result(GenerationResultBase):
@ -41,10 +42,15 @@ class Result(GenerationResultBase):
class OpenAIServerClient:
def __init__(self, disaggregated_server_config: Dict[str, Any],
def __init__(self,
disaggregated_server_config: Dict[str, Any],
ctx_server_config: Dict[str, Any],
gen_server_config: Dict[str, Any], model_name: str):
gen_server_config: Dict[str, Any],
model_name: str,
tensor_parallel_size: int = 1):
self.thread_pool = ThreadPoolExecutor(max_workers=16)
self.temp_dir = tempfile.mkdtemp()
self.futures = []
self.disaggregated_serving_config_path = os.path.join(
self.temp_dir, "disaggregated_serving_config.yaml")
with open(self.disaggregated_serving_config_path, "w") as f:
@ -58,18 +64,26 @@ class OpenAIServerClient:
with open(gen_server_config_path, "w") as f:
yaml.dump(gen_server_config, f)
with LLM(model_name) as llm:
with LLM(model_name, tensor_parallel_size=tensor_parallel_size) as llm:
self.args = llm.args
cuda_device_idx = 0
cuda_devices = []
for i in range(tensor_parallel_size):
cuda_devices.append(f"{cuda_device_idx}")
cuda_device_idx += 1
trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
common_args = [
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
"pytorch"
]
if tensor_parallel_size > 1:
common_args.append(f"--tp_size={tensor_parallel_size}")
env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_devices)
# Start the context server
self._ctx_server = subprocess.Popen(common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
@ -78,6 +92,11 @@ class OpenAIServerClient:
# Start the generation server
env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
cuda_devices = []
for i in range(tensor_parallel_size):
cuda_devices.append(f"{cuda_device_idx}")
cuda_device_idx += 1
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_devices)
self._gen_server = subprocess.Popen(common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
],
@ -86,7 +105,8 @@ class OpenAIServerClient:
# Start the disaggregated server
self._disaggregated_server = subprocess.Popen([
trtllm_serve_path, "disaggregated", "-c",
self.disaggregated_serving_config_path
self.disaggregated_serving_config_path, "--server_start_timeout",
"3600"
])
self.model_name = model_name
@ -103,10 +123,7 @@ class OpenAIServerClient:
self.client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1")
def generate_async(self,
prompt: str,
sampling_params: Optional[SamplingParams] = None):
# TODO: Make this async
def send_request(self, prompt: str, sampling_params: SamplingParams):
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
@ -127,7 +144,18 @@ class OpenAIServerClient:
setattr(requested_output, "result", result.result)
return requested_output
def __del__(self):
def generate_async(self,
prompt: str,
sampling_params: Optional[SamplingParams] = None):
future = self.thread_pool.submit(self.send_request, prompt,
sampling_params)
self.futures.append(future)
return future
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.temp_dir)
self._ctx_server.terminate()
self._gen_server.terminate()
@ -137,10 +165,14 @@ class OpenAIServerClient:
self._gen_server.wait()
self._disaggregated_server.wait()
for future in self.futures:
future.result()
self.thread_pool.shutdown(wait=True)
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@ -169,8 +201,49 @@ class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
"urls": ["localhost:8002"]
}
}
client = OpenAIServerClient(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH)
task = MMLU(self.MODEL_NAME)
task.evaluate(client)
with OpenAIServerClient(disaggregated_server_config, ctx_server_config,
gen_server_config, self.MODEL_PATH) as client:
task = MMLU(self.MODEL_NAME)
task.evaluate(client)
task = GSM8K(self.MODEL_NAME)
task.evaluate(client)
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
@pytest.mark.parametrize("overlap_scheduler", [False, True])
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {
"pytorch_backend_config": {
"disable_overlap_scheduler": True
}
}
gen_server_config = {
"pytorch_backend_config": {
"disable_overlap_scheduler": overlap_scheduler
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with OpenAIServerClient(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as client:
task = MMLU(self.MODEL_NAME)
task.evaluate(client)
task = GSM8K(self.MODEL_NAME)
task.evaluate(client)

View File

@ -453,8 +453,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput_tp8]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]

View File

@ -37,6 +37,8 @@ l0_dgx_h100:
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
- condition:
ranges:
system_gpu_count:

View File

@ -17,5 +17,7 @@ l0_dgx_h200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
# - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] # OOM
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora

View File

@ -47,8 +47,6 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]

View File

@ -443,8 +443,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:2-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5234058)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5234058)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-RobertaForQuestionAnswering-bert/roberta-base-squad2] SKIP (https://nvbugs/5234058)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False] SKIP (https://nvbugs/5266257)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True] SKIP (https://nvbugs/5266257)
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugspro.nvidia.com/bug/5273945)
disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5279438)