test: Add disaggregated serving accuracy tests (#4036)

Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
Iman Tabrizian 2025-05-05 08:56:59 -07:00 committed by GitHub
parent 5ee38ad92a
commit 85867d76dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 182 additions and 6 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterable, List, Optional, Union from typing import Any, Iterable, List, Optional
import numpy as np import numpy as np
import torch import torch
@ -22,8 +22,7 @@ from tqdm import tqdm
import tensorrt_llm.profiler as profiler import tensorrt_llm.profiler as profiler
from .._torch import LLM as PyTorchLLM from ..llmapi import RequestOutput
from ..llmapi import LLM, RequestOutput
from ..logger import logger from ..logger import logger
from ..sampling_params import SamplingParams from ..sampling_params import SamplingParams
@ -49,8 +48,7 @@ class Evaluator(ABC):
*auxiliaries) -> float: *auxiliaries) -> float:
raise NotImplementedError() raise NotImplementedError()
def do_apply_chat_template(self, llm: Union[LLM, PyTorchLLM], def do_apply_chat_template(self, llm: Any, prompt: str) -> str:
prompt: str) -> str:
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
if self.system_prompt is not None: if self.system_prompt is not None:
messages = [{ messages = [{
@ -62,7 +60,7 @@ class Evaluator(ABC):
add_generation_prompt=True) add_generation_prompt=True)
def evaluate(self, def evaluate(self,
llm: Union[LLM, PyTorchLLM], llm: Any,
sampling_params: Optional[SamplingParams] = None) -> float: sampling_params: Optional[SamplingParams] = None) -> float:
profiler.start("trtllm exec") profiler.start("trtllm exec")
outputs, references, auxiliaries = [], [], [] outputs, references, auxiliaries = [], [], []

View File

@ -0,0 +1,176 @@
# I want to create accuracy tests for disaggregated serving.
# I need to to this by creating a new class that mimics LLM class. Instead of implementing the
# actual methods it will send OAI requests to the disaggregated serving endpoint.
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
import os
import shutil
import subprocess
import tempfile
import time
from typing import Any, Dict, List, Optional
import openai
import pytest
import requests
import yaml
from tensorrt_llm._torch import LLM
from tensorrt_llm.executor.result import GenerationResultBase
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from ..conftest import llm_models_root
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
class Result(GenerationResultBase):
def __init__(self, id: int, sampling_params: SamplingParams,
outputs: List[CompletionOutput]):
super().__init__(id, sampling_params)
self._outputs = outputs
self._streaming = False
@property
def outputs(self) -> List[CompletionOutput]:
return self._outputs
def result(self):
return self
class OpenAIServerClient:
def __init__(self, disaggregated_server_config: Dict[str, Any],
ctx_server_config: Dict[str, Any],
gen_server_config: Dict[str, Any], model_name: str):
self.temp_dir = tempfile.mkdtemp()
self.disaggregated_serving_config_path = os.path.join(
self.temp_dir, "disaggregated_serving_config.yaml")
with open(self.disaggregated_serving_config_path, "w") as f:
yaml.dump(disaggregated_server_config, f)
ctx_server_config_path = os.path.join(self.temp_dir,
"ctx_server_config.yaml")
with open(ctx_server_config_path, "w") as f:
yaml.dump(ctx_server_config, f)
gen_server_config_path = os.path.join(self.temp_dir,
"gen_server_config.yaml")
with open(gen_server_config_path, "w") as f:
yaml.dump(gen_server_config, f)
with LLM(model_name) as llm:
self.args = llm.args
trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
common_args = [
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
"pytorch"
]
env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
# Start the context server
self._ctx_server = subprocess.Popen(common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
],
env=env_ctx)
# Start the generation server
env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
self._gen_server = subprocess.Popen(common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
],
env=env_gen)
# Start the disaggregated server
self._disaggregated_server = subprocess.Popen([
trtllm_serve_path, "disaggregated", "-c",
self.disaggregated_serving_config_path
])
self.model_name = model_name
while True:
time.sleep(1)
try:
print("Checking health endpoint")
response = requests.get("http://localhost:8000/health")
if response.status_code == 200:
break
except requests.exceptions.ConnectionError:
continue
self.client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1")
def generate_async(self,
prompt: str,
sampling_params: Optional[SamplingParams] = None):
# TODO: Make this async
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
stream=False,
**({
"max_tokens": sampling_params.max_tokens,
"temperature": sampling_params.temperature,
"top_p": sampling_params.top_p,
"stop": sampling_params.stop,
"seed": sampling_params.seed
} if sampling_params else {}))
result = Result(
id=0,
sampling_params=sampling_params,
outputs=[CompletionOutput(text=response.choices[0].text, index=0)])
requested_output = RequestOutput._from_generation_result(result,
prompt=prompt)
setattr(requested_output, "result", result.result)
return requested_output
def __del__(self):
shutil.rmtree(self.temp_dir)
self._ctx_server.terminate()
self._gen_server.terminate()
self._disaggregated_server.terminate()
self._ctx_server.wait()
self._gen_server.wait()
self._disaggregated_server.wait()
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.skip_device_not_contain(["H100"])
@pytest.mark.parametrize("overlap_scheduler", [False, True])
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {
"pytorch_backend_config": {
"enable_overlap_scheduler": False
}
}
gen_server_config = {
"pytorch_backend_config": {
"enable_overlap_scheduler": overlap_scheduler
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
client = OpenAIServerClient(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH)
task = MMLU(self.MODEL_NAME)
task.evaluate(client)

View File

@ -44,6 +44,8 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False] - test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True] - test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]