mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
test: Add disaggregated serving accuracy tests (#4036)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
This commit is contained in:
parent
5ee38ad92a
commit
85867d76dd
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Iterable, List, Optional, Union
|
from typing import Any, Iterable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -22,8 +22,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import tensorrt_llm.profiler as profiler
|
import tensorrt_llm.profiler as profiler
|
||||||
|
|
||||||
from .._torch import LLM as PyTorchLLM
|
from ..llmapi import RequestOutput
|
||||||
from ..llmapi import LLM, RequestOutput
|
|
||||||
from ..logger import logger
|
from ..logger import logger
|
||||||
from ..sampling_params import SamplingParams
|
from ..sampling_params import SamplingParams
|
||||||
|
|
||||||
@ -49,8 +48,7 @@ class Evaluator(ABC):
|
|||||||
*auxiliaries) -> float:
|
*auxiliaries) -> float:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def do_apply_chat_template(self, llm: Union[LLM, PyTorchLLM],
|
def do_apply_chat_template(self, llm: Any, prompt: str) -> str:
|
||||||
prompt: str) -> str:
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
if self.system_prompt is not None:
|
if self.system_prompt is not None:
|
||||||
messages = [{
|
messages = [{
|
||||||
@ -62,7 +60,7 @@ class Evaluator(ABC):
|
|||||||
add_generation_prompt=True)
|
add_generation_prompt=True)
|
||||||
|
|
||||||
def evaluate(self,
|
def evaluate(self,
|
||||||
llm: Union[LLM, PyTorchLLM],
|
llm: Any,
|
||||||
sampling_params: Optional[SamplingParams] = None) -> float:
|
sampling_params: Optional[SamplingParams] = None) -> float:
|
||||||
profiler.start("trtllm exec")
|
profiler.start("trtllm exec")
|
||||||
outputs, references, auxiliaries = [], [], []
|
outputs, references, auxiliaries = [], [], []
|
||||||
|
|||||||
176
tests/integration/defs/accuracy/test_disaggregated_serving.py
Normal file
176
tests/integration/defs/accuracy/test_disaggregated_serving.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
# I want to create accuracy tests for disaggregated serving.
|
||||||
|
# I need to to this by creating a new class that mimics LLM class. Instead of implementing the
|
||||||
|
# actual methods it will send OAI requests to the disaggregated serving endpoint.
|
||||||
|
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from tensorrt_llm._torch import LLM
|
||||||
|
from tensorrt_llm.executor.result import GenerationResultBase
|
||||||
|
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
|
||||||
|
|
||||||
|
from ..conftest import llm_models_root
|
||||||
|
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
|
||||||
|
|
||||||
|
|
||||||
|
class Result(GenerationResultBase):
|
||||||
|
|
||||||
|
def __init__(self, id: int, sampling_params: SamplingParams,
|
||||||
|
outputs: List[CompletionOutput]):
|
||||||
|
super().__init__(id, sampling_params)
|
||||||
|
self._outputs = outputs
|
||||||
|
self._streaming = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> List[CompletionOutput]:
|
||||||
|
return self._outputs
|
||||||
|
|
||||||
|
def result(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServerClient:
|
||||||
|
|
||||||
|
def __init__(self, disaggregated_server_config: Dict[str, Any],
|
||||||
|
ctx_server_config: Dict[str, Any],
|
||||||
|
gen_server_config: Dict[str, Any], model_name: str):
|
||||||
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
|
self.disaggregated_serving_config_path = os.path.join(
|
||||||
|
self.temp_dir, "disaggregated_serving_config.yaml")
|
||||||
|
with open(self.disaggregated_serving_config_path, "w") as f:
|
||||||
|
yaml.dump(disaggregated_server_config, f)
|
||||||
|
ctx_server_config_path = os.path.join(self.temp_dir,
|
||||||
|
"ctx_server_config.yaml")
|
||||||
|
with open(ctx_server_config_path, "w") as f:
|
||||||
|
yaml.dump(ctx_server_config, f)
|
||||||
|
gen_server_config_path = os.path.join(self.temp_dir,
|
||||||
|
"gen_server_config.yaml")
|
||||||
|
with open(gen_server_config_path, "w") as f:
|
||||||
|
yaml.dump(gen_server_config, f)
|
||||||
|
|
||||||
|
with LLM(model_name) as llm:
|
||||||
|
self.args = llm.args
|
||||||
|
|
||||||
|
trtllm_serve_path = "trtllm-serve"
|
||||||
|
# Common arguments for both servers
|
||||||
|
common_args = [
|
||||||
|
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
|
||||||
|
"pytorch"
|
||||||
|
]
|
||||||
|
env_ctx = os.environ.copy()
|
||||||
|
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||||
|
|
||||||
|
# Start the context server
|
||||||
|
self._ctx_server = subprocess.Popen(common_args + [
|
||||||
|
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
|
||||||
|
],
|
||||||
|
env=env_ctx)
|
||||||
|
# Start the generation server
|
||||||
|
env_gen = os.environ.copy()
|
||||||
|
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
|
||||||
|
self._gen_server = subprocess.Popen(common_args + [
|
||||||
|
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
|
||||||
|
],
|
||||||
|
env=env_gen)
|
||||||
|
|
||||||
|
# Start the disaggregated server
|
||||||
|
self._disaggregated_server = subprocess.Popen([
|
||||||
|
trtllm_serve_path, "disaggregated", "-c",
|
||||||
|
self.disaggregated_serving_config_path
|
||||||
|
])
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
try:
|
||||||
|
print("Checking health endpoint")
|
||||||
|
response = requests.get("http://localhost:8000/health")
|
||||||
|
if response.status_code == 200:
|
||||||
|
break
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.client = openai.OpenAI(api_key="1234567890",
|
||||||
|
base_url=f"http://localhost:8000/v1")
|
||||||
|
|
||||||
|
def generate_async(self,
|
||||||
|
prompt: str,
|
||||||
|
sampling_params: Optional[SamplingParams] = None):
|
||||||
|
# TODO: Make this async
|
||||||
|
response = self.client.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=False,
|
||||||
|
**({
|
||||||
|
"max_tokens": sampling_params.max_tokens,
|
||||||
|
"temperature": sampling_params.temperature,
|
||||||
|
"top_p": sampling_params.top_p,
|
||||||
|
"stop": sampling_params.stop,
|
||||||
|
"seed": sampling_params.seed
|
||||||
|
} if sampling_params else {}))
|
||||||
|
result = Result(
|
||||||
|
id=0,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
outputs=[CompletionOutput(text=response.choices[0].text, index=0)])
|
||||||
|
requested_output = RequestOutput._from_generation_result(result,
|
||||||
|
prompt=prompt)
|
||||||
|
setattr(requested_output, "result", result.result)
|
||||||
|
return requested_output
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
shutil.rmtree(self.temp_dir)
|
||||||
|
self._ctx_server.terminate()
|
||||||
|
self._gen_server.terminate()
|
||||||
|
self._disaggregated_server.terminate()
|
||||||
|
|
||||||
|
self._ctx_server.wait()
|
||||||
|
self._gen_server.wait()
|
||||||
|
self._disaggregated_server.wait()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||||
|
MODEL_NAME = "meta-llama/Llama-3.1-8B"
|
||||||
|
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
|
||||||
|
|
||||||
|
@pytest.mark.skip_less_device_memory(32000)
|
||||||
|
@pytest.mark.skip_device_not_contain(["H100"])
|
||||||
|
@pytest.mark.parametrize("overlap_scheduler", [False, True])
|
||||||
|
def test_auto_dtype(self, overlap_scheduler):
|
||||||
|
ctx_server_config = {
|
||||||
|
"pytorch_backend_config": {
|
||||||
|
"enable_overlap_scheduler": False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gen_server_config = {
|
||||||
|
"pytorch_backend_config": {
|
||||||
|
"enable_overlap_scheduler": overlap_scheduler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
disaggregated_server_config = {
|
||||||
|
"hostname": "localhost",
|
||||||
|
"port": 8000,
|
||||||
|
"backend": "pytorch",
|
||||||
|
"context_servers": {
|
||||||
|
"num_instances": 1,
|
||||||
|
"urls": ["localhost:8001"]
|
||||||
|
},
|
||||||
|
"generation_servers": {
|
||||||
|
"num_instances": 1,
|
||||||
|
"urls": ["localhost:8002"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client = OpenAIServerClient(disaggregated_server_config,
|
||||||
|
ctx_server_config, gen_server_config,
|
||||||
|
self.MODEL_PATH)
|
||||||
|
task = MMLU(self.MODEL_NAME)
|
||||||
|
task.evaluate(client)
|
||||||
@ -44,6 +44,8 @@ l0_h100:
|
|||||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler]
|
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
|
||||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
|
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
|
||||||
|
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
|
||||||
|
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
|
||||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
|
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
|
||||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
|
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user