import asyncio import contextlib import copy import json import os import subprocess from typing import Generator, List, Optional, Tuple import aiohttp import pytest import yaml from defs.common import revise_disagg_config_file_with_free_ports from defs.conftest import skip_no_hopper from defs.trt_test_alternative import popen from transformers import AutoTokenizer from tensorrt_llm import logger from tensorrt_llm.serve.openai_client import OpenAIHttpClient from tensorrt_llm.serve.openai_protocol import (CompletionRequest, DisaggregatedParams) from tensorrt_llm.serve.router import (KvCacheAwareRouter, KvCacheAwareServerState, ServerRole, block_key_hasher) def get_ctx_gen_server_urls_from_cfg(config_file: str): with open(config_file, 'r') as file: config = yaml.safe_load(file) ctx_servers = [] gen_servers = [] for server in config["context_servers"]["urls"]: ctx_servers.append("http://" + server) for server in config["generation_servers"]["urls"]: gen_servers.append("http://" + server) return ctx_servers, gen_servers def run_disaggregated_workers( config_file: str, stdout=None, env: Optional[dict] = None, cwd: Optional[str] = None, num_ranks: Optional[int] = None ) -> Tuple[Generator[subprocess.Popen, None, None], List[str], List[str]]: config_file = revise_disagg_config_file_with_free_ports(config_file) ctx_servers, gen_servers = get_ctx_gen_server_urls_from_cfg(config_file) # TODO: auto detect num_ranks assert num_ranks is not None # Start workers workers_cmd = [ 'mpirun', '--allow-run-as-root', '--oversubscribe', '-n', str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c', config_file ] logger.info(f"Running workers with command: {' '.join(workers_cmd)}") workers_proc = popen(workers_cmd, stdout=stdout, stderr=subprocess.STDOUT, env=env, cwd=cwd) return workers_proc, ctx_servers, gen_servers DEFAULT_TIMEOUT_SERVER_START = 900 DEFAULT_TIMEOUT_REQUEST = 180 async def wait_until_all_servers_ready( session: aiohttp.ClientSession, servers: List[str], server_start_timeout_secs: int = 180, ) -> None: async def check_all_servers_ready(): elapsed_time = 0 interval = 3 while elapsed_time < server_start_timeout_secs: _, unready_servers = await OpenAIHttpClient.check_ready_for_servers( session, servers) if len(unready_servers) == 0: return await asyncio.sleep(interval) elapsed_time += interval logger.info( f"[{elapsed_time}] Waiting for servers, {unready_servers}...") try: await asyncio.wait_for(check_all_servers_ready(), timeout=server_start_timeout_secs) except asyncio.TimeoutError: raise TimeoutError( f"Timeout waiting for all servers to be ready in {server_start_timeout_secs} seconds" ) class BasicWorkerTester: def __init__(self, ctx_servers: List[str], gen_servers: List[str], req_timeout_secs: int = DEFAULT_TIMEOUT_REQUEST, server_start_timeout_secs: int = DEFAULT_TIMEOUT_SERVER_START): self.ctx_servers = ctx_servers self.gen_servers = gen_servers self.req_timeout_secs = req_timeout_secs self.server_start_timeout_secs = server_start_timeout_secs async def new_session(self): session = aiohttp.ClientSession( connector=aiohttp.TCPConnector(force_close=True), timeout=aiohttp.ClientTimeout(total=self.req_timeout_secs)) await wait_until_all_servers_ready(session, self.ctx_servers + self.gen_servers, self.server_start_timeout_secs) return session async def send_request(self, session: aiohttp.ClientSession, url: str, request: dict) -> dict: # TODO: streaming support async with session.post(url + "/v1/completions", json=request) as response: content_type = response.headers.get("Content-Type", "") if "text/event-stream" in content_type: raise ValueError( "Received an event-stream although request stream was False" ) response_dict = await response.json() if not response.ok: logger.error(f"Received failed response {response_dict}") response.raise_for_status() return response_dict async def send_disagg_request(self, session: aiohttp.ClientSession, ctx_url: str, gen_url: str, request: dict) -> dict: ctx_request = copy.deepcopy(request) gen_request = copy.deepcopy(request) ctx_request["disaggregated_params"] = {"request_type": "context_only"} ctx_response = await self.send_request(session, ctx_url, ctx_request) assert len(ctx_response["choices"]) == 1 gen_request["disaggregated_params"] = ctx_response["choices"][0][ "disaggregated_params"] gen_request["disaggregated_params"]["request_type"] = "generation_only" gen_response = await self.send_request(session, gen_url, gen_request) return gen_response async def query_kv_cache_events(self, session: aiohttp.ClientSession, url: str): async with session.post(url + "/kv_cache_events") as response: events_raw = await response.json() events = [] for event_raw in events_raw: event = {"id": event_raw["event_id"]} | event_raw["data"] if event["type"] == "stored": for block in event["blocks"]: block["token_id"] = [ token["token_id"] for token in block["tokens"] ] block["token_extra_id"] = [ token["token_extra_id"] for token in block["tokens"] ] # TODO: check by BlockKey::usesExtraIds if not any(block["token_extra_id"]): del block["token_extra_id"] del block["tokens"] events.append(event) return events class ConditionalWorkerTester(BasicWorkerTester): def __init__(self, ctx_servers: List[str], gen_servers: List[str], req_timeout_secs: int = DEFAULT_TIMEOUT_REQUEST, server_start_timeout_secs: int = DEFAULT_TIMEOUT_SERVER_START, model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"): super().__init__(ctx_servers, gen_servers, req_timeout_secs, server_start_timeout_secs) self.model_name = model_name async def multi_round_request(self, session: aiohttp.ClientSession, init_prompt: str, max_rounds: int, threshold: float): request = { "model": self.model_name, "prompt": init_prompt, "max_tokens": 10, "ignore_eos": True, "temperature": 0.0, } prev_prompt_len = 0 curr_prompt_len = 1 for i in range(max_rounds): # conditional disaggregation by kv cache (estimated by prompt length) if prev_prompt_len > curr_prompt_len * threshold: logger.info(f"Sending normal request at iter {i}") response = await self.send_request(session, self.gen_servers[0], request) else: logger.info(f"Sending disaggregated request at iter {i}") response = await self.send_disagg_request( session, self.ctx_servers[0], self.gen_servers[0], request) logger.info( f"Received response {i}: {repr(response['choices'][0]['text'])}" ) prev_prompt_len = response["usage"]["prompt_tokens"] curr_prompt_len = response["usage"]["total_tokens"] request["prompt"] += response["choices"][0]["text"] async def test_multi_round_request(self, init_prompts: List[str], max_rounds: int = 8, threshold: float = 0.75): async with await self.new_session() as session: chat_threads = [ self.multi_round_request(session, prompt, max_rounds, threshold) for prompt in init_prompts ] await asyncio.gather(*chat_threads) class KvCacheEventWorkerTester(BasicWorkerTester): def __init__(self, ctx_servers: List[str], gen_servers: List[str], req_timeout_secs: int = DEFAULT_TIMEOUT_REQUEST, server_start_timeout_secs: int = DEFAULT_TIMEOUT_SERVER_START, model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"): super().__init__(ctx_servers, gen_servers, req_timeout_secs, server_start_timeout_secs) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model_name = model_name self.kv_cache_block_maps: dict[str, KvCacheAwareServerState] = {} self.kv_cache_event_maps: dict[str, list[dict]] = {} for ctx_server in ctx_servers: self.kv_cache_block_maps[ctx_server] = KvCacheAwareServerState( ctx_server) self.kv_cache_event_maps[ctx_server] = [] for gen_server in gen_servers: if gen_server not in self.kv_cache_block_maps: self.kv_cache_block_maps[gen_server] = KvCacheAwareServerState( gen_server) self.kv_cache_event_maps[gen_server] = [] async def send_request(self, session: aiohttp.ClientSession, url: str, request: dict) -> dict: response = await super().send_request(session, url, request) events = await self.query_kv_cache_events(session, url) async with self.kv_cache_block_maps[url]._lock: self.kv_cache_block_maps[url].update_with_events(events) self.kv_cache_event_maps[url].extend(events) return response async def multi_round_request(self, session: aiohttp.ClientSession, init_prompt: str, max_rounds: int, check_match_count: bool = True): request = { "model": self.model_name, "prompt": init_prompt, "max_tokens": 64, "ignore_eos": True, "temperature": 0.0, } tokens_per_block = 32 # TODO: read from config prev_ctx_match_count = 0 prev_gen_match_count = 0 assert len(self.ctx_servers) == 1 and len(self.gen_servers) == 1, \ "This test assumes 1P1D" ctx_server = self.ctx_servers[0] gen_server = self.gen_servers[0] ctx_blocks = self.kv_cache_block_maps[ctx_server] gen_blocks = self.kv_cache_block_maps[gen_server] ctx_events = self.kv_cache_event_maps[ctx_server] gen_events = self.kv_cache_event_maps[gen_server] for i in range(max_rounds): # split tokens into blocks and check block match count by hash tokens = self.tokenizer(request["prompt"])["input_ids"] block_hashes = [] for t in range(0, len(tokens) - 1, tokens_per_block): t_end = min(t + tokens_per_block, len(tokens) - 1) if t_end - t < tokens_per_block: # partial block break block_hashes.append( block_key_hasher(tokens[t:t_end], None if t == 0 else block_hashes[-1])) ctx_match_count = await ctx_blocks.matched_tokens([block_hashes]) gen_match_count = await gen_blocks.matched_tokens([block_hashes]) ctx_evicted = False gen_evicted = False for event in ctx_events: if event["type"] == "removed": ctx_evicted = True break for event in gen_events: if event["type"] == "removed": gen_evicted = True break assert ctx_evicted or ctx_match_count >= prev_ctx_match_count assert gen_evicted or gen_match_count >= prev_gen_match_count ctx_events.clear() gen_events.clear() response = await self.send_disagg_request(session, ctx_server, gen_server, request) logger.info( f"Received response {i}: {repr(response['choices'][0]['text'])}" ) prev_ctx_match_count = ctx_match_count prev_gen_match_count = gen_match_count request["prompt"] += response["choices"][0]["text"] if check_match_count: assert ctx_match_count > 0 assert gen_match_count > 0 assert gen_match_count >= ctx_match_count or gen_evicted return request["prompt"] async def test_multi_round_request(self, init_prompts: List[str], max_rounds: int = 8): async with await self.new_session() as session: chat_threads = [ self.multi_round_request(session, prompt, max_rounds, False) for prompt in init_prompts ] prompts = await asyncio.gather(*chat_threads) # send a request to flush events await self.multi_round_request(session, init_prompts[0], 1, False) await asyncio.gather(*[ self.multi_round_request(session, prompt, 1, True) for prompt in prompts ]) class KvCacheAwareRouterTester(BasicWorkerTester): def __init__(self, ctx_servers: List[str], gen_servers: List[str], req_timeout_secs: int = DEFAULT_TIMEOUT_REQUEST, server_start_timeout_secs: int = DEFAULT_TIMEOUT_SERVER_START, model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", tokens_per_block: int = 32): super().__init__(ctx_servers, gen_servers, req_timeout_secs, server_start_timeout_secs) self.ctx_router = KvCacheAwareRouter(server_role=ServerRole.CONTEXT, servers=ctx_servers, tokens_per_block=tokens_per_block) self.gen_router = KvCacheAwareRouter(server_role=ServerRole.GENERATION, servers=gen_servers, tokens_per_block=tokens_per_block) self.model_name = model_name async def multi_round_request(self, session: aiohttp.ClientSession, init_prompt: str, max_rounds: int = 8, check_server_match: bool = True): request = { "model": self.model_name, "prompt": init_prompt, "max_tokens": 64, "ignore_eos": True, "temperature": 0.0, } ctx_server_prev = None gen_server_prev = None ctx_match = 0 gen_match = 0 for i in range(max_rounds): openai_request = CompletionRequest( model=self.model_name, prompt=request["prompt"], disaggregated_params=DisaggregatedParams( request_type="context_only")) ctx_server, ctx_info = await self.ctx_router.get_next_server( openai_request) prompt_str = request["prompt"] request["prompt"] = ctx_info["token_lists"][0] openai_request.disaggregated_params.request_type = "generation_only" gen_server, _ = await self.gen_router.get_next_server(openai_request ) if check_server_match and ctx_server_prev is not None: ctx_match += int(ctx_server == ctx_server_prev) gen_match += int(gen_server == gen_server_prev) ctx_server_prev = ctx_server gen_server_prev = gen_server response = await self.send_disagg_request(session, ctx_server, gen_server, request) await asyncio.gather( self.ctx_router.finish_request(openai_request, session), self.gen_router.finish_request(openai_request, session)) logger.info( f"Received response {i}: {repr(response['choices'][0]['text'])}" ) request["prompt"] = prompt_str + response["choices"][0]["text"] if check_server_match: assert ctx_match > max_rounds // 2 assert gen_match > max_rounds // 2 return request["prompt"] async def test_multi_round_request(self, init_prompts: List[str], max_rounds: int = 8, warm_up_rounds: int = 4): async with await self.new_session() as session: chat_threads = [ self.multi_round_request(session, prompt, warm_up_rounds, False) for prompt in init_prompts ] prompts = await asyncio.gather(*chat_threads) logger.info("Warm up done") chat_threads = [ self.multi_round_request(session, prompt, max_rounds, True) for prompt in prompts ] await asyncio.gather(*chat_threads) async def test_eviction(self): async with await self.new_session() as session: # send a dummy request for initialization dummy_request = { "model": self.model_name, "prompt": [3] * 2000, "max_tokens": 1, "ignore_eos": True, "temperature": 0.0, } assert len(self.gen_servers) == 1 server = self.gen_servers[0] # only test on this server server_state = self.gen_router._server_state[server] await self.send_request(session, server, dummy_request) # get block pool size from created event events = await self.query_kv_cache_events(session, server) server_state.update_with_events(events) block_pool_size = None for event in events: if event["type"] == "created": block_pool_size = event["num_blocks_per_cache_level"][0] break assert block_pool_size is not None logger.info(f"Block pool size: {block_pool_size}") # the dummy request can be reused openai_request = CompletionRequest(model=self.model_name, prompt=dummy_request["prompt"]) server, info = await self.gen_router.get_next_server(openai_request) first_match = info["matches"][0] logger.info(f"Matched blocks: {first_match}") assert first_match > 0 await self.gen_router.finish_request(openai_request) # flood requests until eviction batch_size = 64 blocks_per_request = 32 requests = [copy.copy(dummy_request) for _ in range(batch_size)] has_evicted = False for i in range(0, block_pool_size // blocks_per_request * 2, batch_size): logger.info(f"Flooding request {i} ~ {i + batch_size - 1}") prompt_len = self.gen_router._tokens_per_block * blocks_per_request - 10 for j in range(batch_size): prompt = [10 + i + j] * prompt_len requests[j]["prompt"] = prompt await asyncio.gather(*[ self.send_request(session, server, request) for request in requests ]) events = await self.query_kv_cache_events(session, server) server_state.update_with_events(events) for event in events: if event["type"] == "removed": has_evicted = True assert has_evicted # the dummy request's reusable length decreases after eviction server, info = await self.gen_router.get_next_server(openai_request) logger.info( f"Matched blocks: {first_match} -> {info['matches'][0]}") assert info["matches"][0] < first_match def prepare_llama_model(llama_model_root: str, llm_venv): src_dst_dict = { llama_model_root: f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", } for src, dst in src_dst_dict.items(): if not os.path.islink(dst): os.makedirs(os.path.dirname(dst), exist_ok=True) os.symlink(src, dst, target_is_directory=True) def load_default_prompts(disaggregated_example_root: str): prompts_file = os.path.join(disaggregated_example_root, 'clients/prompts.json') with open(prompts_file, 'r') as f: return json.load(f) @contextlib.contextmanager def background_workers(llm_venv, config_file: str, num_ranks: int = None): cwd = llm_venv.get_working_directory() os.chdir(cwd) with open(os.path.join(cwd, 'output_workers.log'), 'w+') as log_file: workers_proc, ctx_servers, gen_servers = run_disaggregated_workers( config_file=config_file, stdout=log_file, env=llm_venv._new_env, cwd=cwd, num_ranks=num_ranks) try: with workers_proc as proc: yield ctx_servers, gen_servers except Exception: log_file.seek(0) logger.error("-------- Worker output --------") logger.error(log_file.read()) raise finally: proc.terminate() proc.wait() @pytest.mark.skip(reason="https://nvbugs/5372970") @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) def test_workers_conditional_disaggregation(disaggregated_test_root, disaggregated_example_root, llm_venv, llama_model_root): config_file = os.path.join(disaggregated_test_root, 'test_configs/disagg_config_cache_reuse.yaml') prepare_llama_model(llama_model_root, llm_venv) with background_workers(llm_venv, config_file, 2) as (ctx_servers, gen_servers): tester = ConditionalWorkerTester(ctx_servers, gen_servers) prompts = load_default_prompts(disaggregated_example_root) asyncio.run(tester.test_multi_round_request(prompts)) @pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'], indirect=True) def test_workers_conditional_disaggregation_deepseek_v3_lite_bf16( disaggregated_test_root, disaggregated_example_root, llm_venv, deepseek_v3_model_root): config_file = os.path.join( disaggregated_test_root, 'test_configs/disagg_config_cache_reuse_deepseek_v3.yaml') model_root = f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16" src_dst_dict = { deepseek_v3_model_root: model_root, } for src, dst in src_dst_dict.items(): if not os.path.islink(dst): os.makedirs(os.path.dirname(dst), exist_ok=True) os.symlink(src, dst, target_is_directory=True) with background_workers(llm_venv, config_file, 2) as (ctx_servers, gen_servers): tester = ConditionalWorkerTester(ctx_servers, gen_servers) prompts = load_default_prompts(disaggregated_example_root) asyncio.run(tester.test_multi_round_request(prompts)) @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) def test_workers_kv_cache_events(disaggregated_test_root, disaggregated_example_root, llm_venv, llama_model_root): config_file = os.path.join(disaggregated_test_root, 'test_configs/disagg_config_cache_reuse.yaml') prepare_llama_model(llama_model_root, llm_venv) with background_workers(llm_venv, config_file, 2) as (ctx_servers, gen_servers): tester = KvCacheEventWorkerTester(ctx_servers, gen_servers) prompts = load_default_prompts(disaggregated_example_root) asyncio.run(tester.test_multi_round_request(prompts, 6)) @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) def test_workers_kv_cache_aware_router(disaggregated_test_root, disaggregated_example_root, llm_venv, llama_model_root): config_file = os.path.join( disaggregated_test_root, 'test_configs/disagg_config_cache_aware_balance.yaml') prepare_llama_model(llama_model_root, llm_venv) with background_workers(llm_venv, config_file, 4) as (ctx_servers, gen_servers): tester = KvCacheAwareRouterTester(ctx_servers, gen_servers) prompts = load_default_prompts(disaggregated_example_root) asyncio.run(tester.test_multi_round_request(prompts, 16, 4)) @skip_no_hopper @pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-bf16'], indirect=True) def test_workers_kv_cache_aware_router_deepseek_v3_lite_bf16( disaggregated_test_root, disaggregated_example_root, llm_venv, deepseek_v3_model_root): config_file = os.path.join( disaggregated_test_root, 'test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml') model_root = f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16" src_dst_dict = { deepseek_v3_model_root: model_root, } for src, dst in src_dst_dict.items(): if not os.path.islink(dst): os.makedirs(os.path.dirname(dst), exist_ok=True) os.symlink(src, dst, target_is_directory=True) with background_workers(llm_venv, config_file, 4) as (ctx_servers, gen_servers): tester = KvCacheAwareRouterTester(ctx_servers, gen_servers, model_name="DeepSeek-V3-Lite/bf16", tokens_per_block=64) prompts = load_default_prompts(disaggregated_example_root) asyncio.run(tester.test_multi_round_request(prompts, 8, 4)) @pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], indirect=True) def test_workers_kv_cache_aware_router_eviction(disaggregated_test_root, disaggregated_example_root, llm_venv, llama_model_root): config_file = os.path.join(disaggregated_test_root, 'test_configs/disagg_config_cache_reuse.yaml') prepare_llama_model(llama_model_root, llm_venv) with background_workers(llm_venv, config_file, 2) as (ctx_servers, gen_servers): tester = KvCacheAwareRouterTester(ctx_servers, gen_servers) asyncio.run(tester.test_eviction())