TensorRT-LLMs/tests/unittest/executor/test_rpc_worker.py
Yan Chunwei b86256eb54
[TRTLLM-9144][fix] enhance RPC robustness (#8711)
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Co-authored-by: Erin Ho <14718778+hchings@users.noreply.github.com>
2025-12-02 21:37:59 +08:00

87 lines
3.0 KiB
Python

import asyncio
import os
import sys
import time
from tensorrt_llm.executor.request import GenerationRequest
from tensorrt_llm.executor.rpc_worker import RpcWorker
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, TorchLlmArgs
from tensorrt_llm.sampling_params import SamplingParams
# isort: off
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
from utils.llm_data import llm_models_root
# isort: on
model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
assert model_path.exists()
class TestRpcWorkerTP1:
def setup_method(self):
self.llm_args = TorchLlmArgs(
model=model_path,
tensor_parallel_size=1,
backend='pytorch',
enable_iter_perf_stats=True,
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.5, ),
)
# Create RpcWorker instance
self.worker = RpcWorker(
engine=model_path,
llm_args=self.llm_args,
hf_model_dir=model_path,
)
# Initialize the engine
self.worker.setup_engine()
def teardown_method(self):
# Clean up the worker
self.worker.shutdown()
def test_fetch_responses_async(self):
"""Test that fetch_responses_async can be called and returns a list."""
# Submit a request first
sampling_params = SamplingParams(max_tokens=10)
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
sampling_params=sampling_params)
self.worker.submit(request)
# Sleep a bit to let the request start processing
time.sleep(0.5)
# Fetch responses with a timeout to prevent hanging
responses = asyncio.run(self.worker.fetch_responses_async(timeout=1.0))
assert isinstance(responses, list)
def test_fetch_stats_async(self):
"""Test that fetch_stats_async can be called and returns a list."""
# Submit a request first to generate some stats
sampling_params = SamplingParams(max_tokens=10)
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
sampling_params=sampling_params)
self.worker.submit(request)
# Sleep a bit to let the request start processing
time.sleep(0.5)
# Fetch stats
stats = asyncio.run(self.worker.fetch_stats_async())
assert isinstance(stats, list)
def test_fetch_kv_cache_events_async(self):
"""Test that fetch_kv_cache_events_async can be called and returns a list."""
# Submit a request first to generate some kv cache events
sampling_params = SamplingParams(max_tokens=10)
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
sampling_params=sampling_params)
self.worker.submit(request)
# Sleep a bit to let the request start processing
time.sleep(0.5)
# Fetch kv cache events
events = asyncio.run(self.worker.fetch_kv_cache_events_async())
assert isinstance(events, list)