TensorRT-LLMs/tests/integration/defs/accuracy/test_disaggregated_serving.py
Ivy Zhang 29cdcdb56a [None][fix] update skip config (#6891)
Signed-off-by: Ivy Zhang <25222398+crazydemo@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
2025-09-01 11:02:31 +08:00

919 lines
34 KiB
Python

# I want to create accuracy tests for disaggregated serving.
# I need to to this by creating a new class that mimics LLM class. Instead of implementing the
# actual methods it will send OAI requests to the disaggregated serving endpoint.
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
import concurrent
import contextlib
import json
import os
import tempfile
import time
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
import openai
import pytest
import requests
import yaml
from tensorrt_llm.executor.result import GenerationResultBase
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_no_hopper, skip_pre_blackwell, skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
LlmapiAccuracyTestHarness, get_accuracy_task)
class Result(GenerationResultBase):
def __init__(self, id: int, sampling_params: SamplingParams,
outputs: List[CompletionOutput]):
super().__init__(id, sampling_params)
self._outputs = outputs
self._streaming = False
@property
def outputs(self) -> List[CompletionOutput]:
return self._outputs
def result(self):
return self
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])
class MyThreadPoolExecutor(ThreadPoolExecutor):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.futures: list[concurrent.futures.Future[RequestOutput]] = []
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
for future in self.futures:
future.result()
return super().__exit__(exc_type, exc_val, exc_tb)
for future in self.futures:
future.cancel()
self.shutdown(wait=True, cancel_futures=True)
return False
@contextlib.contextmanager
def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
ctx_server_config: Dict[str, Any],
gen_server_config: Dict[str, Any],
model_name: str,
tensor_parallel_size: int = 1,
ctx_model: str = None,
gen_model: str = None):
temp_dir = tempfile.TemporaryDirectory()
disaggregated_serving_config_path = os.path.join(
temp_dir.name, "disaggregated_serving_config.yaml")
if tensor_parallel_size > 1:
print(
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
)
with open(disaggregated_serving_config_path, "w") as f:
yaml.dump(disaggregated_server_config, f)
ctx_server_config_path = os.path.join(temp_dir.name,
"ctx_server_config.yaml")
with open(ctx_server_config_path, "w") as f:
yaml.dump(ctx_server_config, f)
gen_server_config_path = os.path.join(temp_dir.name,
"gen_server_config.yaml")
with open(gen_server_config_path, "w") as f:
yaml.dump(gen_server_config, f)
args = LlmArgs.from_kwargs(model=model_name,
tensor_parallel_size=tensor_parallel_size)
trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
ctx_model = ctx_model or model_name
gen_model = gen_model or model_name
ctx_args = [
trtllm_serve_path,
ctx_model,
"--host",
"localhost",
"--backend",
"pytorch",
]
gen_args = [
trtllm_serve_path,
gen_model,
"--host",
"localhost",
"--backend",
"pytorch",
]
gen_tp, gen_pp = gen_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
1)
ctx_tp, ctx_pp = ctx_server_config.get(
"tensor_parallel_size",
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
1)
ctx_total_gpus = ctx_tp * ctx_pp
gen_total_gpus = gen_tp * gen_pp
ctx_urls = disaggregated_server_config["context_servers"]["urls"]
gen_urls = disaggregated_server_config["generation_servers"]["urls"]
ctx_ports = [int(url.split(":")[1]) for url in ctx_urls]
gen_ports = [int(url.split(":")[1]) for url in gen_urls]
ctx_servers = []
current_gpu_offset = 0
for i, port in enumerate(ctx_ports):
env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
gpu_range = range(current_gpu_offset,
current_gpu_offset + ctx_total_gpus)
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range))
current_gpu_offset += ctx_total_gpus
ctx_server_args = ctx_args + [
"--port",
str(port), "--extra_llm_api_options", ctx_server_config_path,
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
ctx_servers.append((env_ctx, ctx_server_args))
gen_servers = []
for i, port in enumerate(gen_ports):
env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
gpu_range = range(current_gpu_offset,
current_gpu_offset + gen_total_gpus)
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, gpu_range))
current_gpu_offset += gen_total_gpus
gen_server_args = gen_args + [
"--port",
str(port), "--extra_llm_api_options", gen_server_config_path,
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
]
if "max_num_tokens" in gen_server_config:
gen_server_args.append(
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
gen_servers.append((env_gen, gen_server_args))
@contextlib.contextmanager
def multi_popen(server_configs):
processes = []
try:
for env, args in server_configs:
proc = popen(args, env=env)
processes.append(proc)
with contextlib.ExitStack() as stack:
opened_processes = [
stack.enter_context(proc) for proc in processes
]
yield opened_processes
except Exception as e:
print(
f"Failed to start disaggregated server processes in multi_popen: {e}"
)
raise
with (MyThreadPoolExecutor(max_workers=16) as
thread_pool, temp_dir, multi_popen(ctx_servers + gen_servers),
popen([
trtllm_serve_path, "disaggregated", "-c",
disaggregated_serving_config_path, "--server_start_timeout",
"3600"
])):
start_time = time.time()
while time.time() - start_time < 3600:
time.sleep(1)
try:
print("Checking health endpoint")
response = requests.get("http://localhost:8000/health")
if response.status_code == 200:
break
except requests.exceptions.ConnectionError:
continue
client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1")
def send_request(prompt: str, sampling_params: SamplingParams,
streaming: bool):
kwargs = {}
if sampling_params is not None:
kwargs.update(max_tokens=sampling_params.max_tokens,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
stop=sampling_params.stop,
seed=sampling_params.seed)
if (guided_decoding_params :=
sampling_params.guided_decoding) is not None:
extra_body = {}
if (schema := guided_decoding_params.json) is not None:
extra_body.update(response_format={
"type": "json",
"schema": json.loads(schema)
})
elif guided_decoding_params.json_object:
extra_body.update(
response_format={"type": "json_object"})
else:
# TODO: Support other guided decoding types
raise ValueError(
f"Unsupported guided decoding params: {guided_decoding_params}."
)
kwargs.update(extra_body=extra_body)
response = client.completions.create(model=model_name,
prompt=prompt,
stream=streaming,
**kwargs)
result = Result(id=0,
sampling_params=sampling_params,
outputs=[
CompletionOutput(text=response.choices[0].text,
index=0)
])
requested_output = RequestOutput._from_generation_result(
result, prompt=prompt)
setattr(requested_output, "result", result.result)
return requested_output
def generate_async(prompt: str,
sampling_params: Optional[SamplingParams] = None,
streaming: bool = False):
future = thread_pool.submit(send_request, prompt, sampling_params,
streaming)
thread_pool.futures.append(future)
return future
tokenizer = load_hf_tokenizer(model_name)
yield DuckLLM(args, tokenizer, generate_async)
def run_parallel_test(model_name: str,
model_path: str,
ctx_pp: int,
ctx_tp: int,
gen_pp: int,
gen_tp: int,
ctx_instances: int,
gen_instances: int,
test_sets: List[LlmapiAccuracyTestHarness],
ctx_model: str = None,
gen_model: str = None):
total_ctx_gpus = ctx_tp * ctx_pp * ctx_instances
total_gen_gpus = gen_tp * gen_pp * gen_instances
if total_ctx_gpus + total_gen_gpus > get_device_count():
pytest.skip(
f"Not enough devices for {ctx_instances} ctx instances (ctx_pp={ctx_pp}*ctx_tp={ctx_tp}) + {gen_instances} gen instances (gen_pp={gen_pp}*gen_tp={gen_tp}), total: {total_ctx_gpus + total_gen_gpus}"
)
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
}
ctx_server_config = {
"pipeline_parallel_size": ctx_pp,
"tensor_parallel_size": ctx_tp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
gen_server_config = {
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
ctx_urls = [f"localhost:{8001 + i * 2}" for i in range(ctx_instances)]
gen_urls = [f"localhost:{8002 + i * 2}" for i in range(gen_instances)]
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": ctx_instances,
"urls": ctx_urls
},
"generation_servers": {
"num_instances": gen_instances,
"urls": gen_urls
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config,
gen_server_config,
model_path,
ctx_model=ctx_model,
gen_model=gen_model) as llm:
for test_set in test_sets:
task = test_set(model_name)
task.evaluate(llm)
@pytest.mark.timeout(3600)
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("disable_overlap_scheduler", [False, True])
def test_auto_dtype(self, disable_overlap_scheduler):
ctx_server_config = {"disable_overlap_scheduler": True}
gen_server_config = {
"disable_overlap_scheduler": disable_overlap_scheduler
}
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(2)
@skip_pre_hopper
def test_ngram(self):
speculative_decoding_config = {
"decoding_type": "NGram",
"max_draft_len": 4,
"max_matching_ngram_size": 4,
"is_keep_all": True,
"is_use_oldest": True,
"is_public_pool": True
}
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": kv_cache_config,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(2)
@skip_pre_hopper
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("eagle3_one_model", [True, False])
def test_eagle3(self, overlap_scheduler, eagle3_one_model):
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 4,
"speculative_model_dir":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": eagle3_one_model
}
ctx_server_config = {
"disable_overlap_scheduler":
True, # BS=1 does not need overlap scheduling
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": True # reuse on context requests
},
"max_num_tokens": 13393 * 2,
"max_batch_size": 1,
"cache_transceiver_config": {
"backend": "DEFAULT"
},
"cuda_graph_config": None,
}
gen_server_config = {
"disable_overlap_scheduler": not overlap_scheduler,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
},
"max_num_tokens": 13393 * 2,
"max_batch_size": 16,
"cache_transceiver_config": {
"backend": "DEFAULT"
},
"cuda_graph_config": None,
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
ctx_server_config = {
"disable_overlap_scheduler": True,
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
gen_server_config = {
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding_with_eagle3(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 3,
"speculative_model_dir":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": False
}
ctx_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.8,
},
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.8,
},
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_tp_pp_symmetric(self, tp, pp, testset):
if tp * pp * 2 > get_device_count():
pytest.skip(f"Not enough devices for tp={tp}*pp={pp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
tp, 1, 1, [get_accuracy_task(testset)])
@parametrize_with_ids("ctx_pp", [2, 4])
@parametrize_with_ids("gen_tp", [1, 2])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
if ctx_pp * gen_tp * 2 > get_device_count():
pytest.skip(
f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test")
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
gen_tp, 1, 1, [get_accuracy_task(testset)])
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
def test_multi_instance(self, testset):
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, 1, 1, 1, 1,
2, 2, [get_accuracy_task(testset)])
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("overlap_scheduler", [False, True])
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {"disable_overlap_scheduler": True}
gen_server_config = {"disable_overlap_scheduler": overlap_scheduler}
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
# Keep this low to avoid warmup OOM in CI
ctx_server_config["max_seq_len"] = 8192
gen_server_config["max_seq_len"] = 8192
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.timeout(3600)
class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite"
MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(60000)
@skip_no_hopper
def test_nixl_backend(self):
ctx_server_config = {
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "NIXL"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "NIXL"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.skip_less_device(8)
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("mtp_nextn",
[0, pytest.param(2, marks=skip_pre_hopper)])
@pytest.mark.skip_less_device(8)
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
ctx_server_config = {"disable_overlap_scheduler": True}
gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler}
ctx_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
gen_server_config["cache_transceiver_config"] = {"backend": "DEFAULT"}
if mtp_nextn > 0:
ctx_server_config["speculative_config"] = {
"decoding_type": "MTP",
"num_nextn_predict_layers": mtp_nextn
}
gen_server_config["speculative_config"] = {
"decoding_type": "MTP",
"num_nextn_predict_layers": mtp_nextn
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.timeout(3600)
class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "google/gemma-3-1b-it"
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-3-1b-it/"
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("overlap_scheduler", [False, True])
def test_auto_dtype(self, overlap_scheduler):
pytest.skip(
"Currently we require full kvcache for variable sliding window. "
"This test only transfers the kvcache inside the sliding window.")
ctx_server_config = {
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
gen_server_config = {
"disable_overlap_scheduler": overlap_scheduler,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
ctx_server_config["kv_cache_config"] = {
# "max_attention_window": [512, 512, 512, 512, 512, 32768],
"enable_block_reuse": True
}
gen_server_config["kv_cache_config"] = {
# "max_attention_window": [512, 512, 512, 512, 512, 32768],
"enable_block_reuse": True
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.timeout(3600)
class TestQwen3_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen3/Qwen3-8B"
MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8"
@pytest.mark.skip_less_device(2)
@skip_no_hopper
def test_nixl_backend(self):
ctx_server_config = {
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "NIXL"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "NIXL"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_hopper
@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("overlap_scheduler", [False, True])
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
gen_server_config = {
"disable_overlap_scheduler": overlap_scheduler,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
def test_chunked_prefill(self):
ctx_server_config = {
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "DEFAULT"
},
"enable_chunked_prefill": True,
"max_num_tokens": 256,
}
gen_server_config = {
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "DEFAULT"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
@skip_pre_blackwell
@pytest.mark.timeout(3600)
class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
FP4_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_nvfp4_hf"
FP8_MODEL = f"{llm_models_root()}/Qwen3/saved_models_Qwen3-30B-A3B_fp8_hf"
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("ctx_pp,gen_tp", [(2, 2)], ids=["ctxpp2gentp2"])
def test_mixed_ctx_gen_model(self, ctx_pp, gen_tp):
ctx_model = self.FP4_MODEL
gen_model = self.FP8_MODEL
return run_parallel_test("Qwen3/Qwen3-30B-A3B",
ctx_model,
ctx_pp=ctx_pp,
ctx_tp=1,
gen_pp=1,
gen_tp=gen_tp,
test_sets=[GSM8K, MMLU],
ctx_model=ctx_model,
gen_model=gen_model)