TensorRT-LLMs/tests/unittest/executor/test_base_worker.py
Yan Chunwei fb51de6c2e
[TRTLLM-8189][chore] enhance GenerationExecutor with RPC (part1) (#5543)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
Signed-off-by: chunweiy <chunweiy@nvidia.com>
Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com>
Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com>
2025-10-05 17:28:20 +08:00

200 lines
6.8 KiB
Python

import os
import sys
import time
import pytest
import torch
from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size
from tensorrt_llm.bindings import executor as tllm
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
# isort: off
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
from utils.llm_data import llm_models_root
from utils.util import skip_single_gpu
# isort: on
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
from tensorrt_llm.executor.base_worker import BaseWorker
from tensorrt_llm.executor.request import GenerationRequest
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm.sampling_params import SamplingParams
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
model_path = llm_models_root() / default_model_name
class FakeWorker(BaseWorker):
def __init__(self, engine: str, tp_size: int = 1):
llm_args, executor_config = create_fake_executor_config(engine, tp_size)
super().__init__(
engine=engine,
llm_args=llm_args,
hf_model_dir=engine,
executor_config=executor_config,
)
# Note: BaseWorker doesn't call setup_engine() automatically,
# unlike GenerationExecutorWorker, so we need to call it manually
self.setup_engine()
self._started = False
def start(self):
"""Override start to mark as started - no background threads needed for test."""
if not self._started:
self._started = True
# For testing, we don't need background threads
# The engine's await_responses will handle the mock responses
def shutdown(self):
self._started = False
if self.engine is not None:
self.engine.shutdown()
self.engine = None
class TestWorkerBase:
def test_create_engine(self):
with FakeWorker(engine=model_path) as worker:
print(f"Created engine: {worker.engine}")
def test_submit_request(self):
sampling_params = SamplingParams(max_tokens=10)
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
sampling_params=sampling_params)
with FakeWorker(engine=model_path) as worker:
print(f"Created engine: {worker.engine}")
result = worker.submit(request)
# For PyTorch backend, the engine handles requests internally
# We just need to give it some time to process
timeout = 15.0 # 15 seconds timeout
start_time = time.time()
while not result.finished and (time.time() - start_time) < timeout:
# Call await_responses with timeout to prevent hanging
responses = worker.await_responses(timeout=0.5)
time.sleep(0.1)
if not result.finished:
print(f"Request did not complete within {timeout} seconds")
else:
print(f"Request completed successfully")
print(f"Result: {result}")
def test_fetch_stats(self):
request = GenerationRequest(
prompt_token_ids=[3, 4, 5],
sampling_params=SamplingParams(max_tokens=10))
with FakeWorker(engine=model_path) as worker:
result = worker.submit(request)
# Give the engine time to start processing
time.sleep(1)
# Fetch stats while request is processing
stats = worker.fetch_stats()
print(f"Stats: {stats}")
# Continue processing until completion or timeout
timeout = 10.0
start_time = time.time()
while not result.finished and (time.time() - start_time) < timeout:
worker.await_responses(timeout=0.5)
time.sleep(0.1)
@pytest.mark.parametrize("timeout", [0.1, 0.2, 1])
def test_fetch_responses_timeout(self, timeout: float):
with FakeWorker(engine=model_path) as worker:
# Not submit any request, and let the await_responses timeout.
start_time = time.time()
results = worker.await_responses(timeout=timeout)
elapsed = time.time() - start_time
print(f"await_responses latency: {elapsed:.3f} seconds")
assert timeout / 2 <= elapsed <= timeout * 2, f"Latency out of expected range: {elapsed}"
def create_fake_executor_config(model_path, tp_size=1):
# Use TorchLlmArgs for PyTorch backend tests
llm_args = TorchLlmArgs(
model=model_path,
tensor_parallel_size=tp_size,
backend='pytorch',
enable_iter_perf_stats=True,
)
executor_config = tllm.ExecutorConfig(1)
executor_config.max_batch_size = 1
executor_config.model_world_size = tp_size
update_executor_config(
executor_config,
pytorch_backend_config=llm_args.get_pytorch_backend_config(),
mapping=llm_args.parallel_config.to_mapping(),
speculative_config=llm_args.speculative_config,
hf_model_dir=model_path,
max_input_len=20,
max_seq_len=40,
checkpoint_format=llm_args.checkpoint_format,
checkpoint_loader=llm_args.checkpoint_loader,
)
return llm_args, executor_config
class TestRpcWorkerBaseTP2:
def setup_method(self):
# Use TorchLlmArgs for PyTorch backend with TP2
self.llm_args = TorchLlmArgs(model=model_path,
tensor_parallel_size=2,
backend='pytorch')
self.session = self.create_worker_session()
def create_worker_session(self):
session = MpiPoolSession(n_workers=2)
return session
@pytest.mark.gpu2
@skip_single_gpu
def test_create_executor(self):
futures = self.session.submit(
TestRpcWorkerBaseTP2.create_executor,
engine=model_path,
llm_args=self.llm_args,
)
# Wait for completion
for future in futures:
future.result()
self.session.shutdown()
@staticmethod
def create_executor(engine, llm_args):
rank = mpi_rank()
world_size = mpi_world_size()
device_id = rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
print(f"[Test] Rank {rank}/{world_size} using device {device_id}")
# Synchronize all workers before creating executor
mpi_comm().barrier()
print(f"[Test] Rank {rank} creating FakeWorker...")
executor = FakeWorker(engine=engine, tp_size=2)
# Note: setup_engine is already called in FakeWorker.__init__
print(
f"[Test] Rank {rank} FakeWorker created and setup_engine completed successfully"
)
executor.shutdown()
if __name__ == "__main__":
test_worker_base = TestWorkerBase()
test_worker_base.test_submit_request()