mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
1020 lines
44 KiB
Python
1020 lines
44 KiB
Python
import copy
|
|
import datetime
|
|
import enum
|
|
import json
|
|
import os
|
|
import time
|
|
import traceback
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from pathlib import Path
|
|
from queue import Queue
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
import tensorrt_llm.executor.serialization as serialization
|
|
from tensorrt_llm.logger import logger
|
|
|
|
from .._utils import (KVCacheEventSerializer, global_mpi_rank, mpi_comm,
|
|
mpi_rank, nvtx_range_debug)
|
|
from ..bindings import executor as tllm
|
|
from ..builder import ConfigEncoder, Engine, EngineConfig
|
|
from ..llmapi.llm_args import PybindMirror
|
|
from ..llmapi.mpi_session import set_mpi_session_cpp
|
|
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
|
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
|
|
clear_sched_affinity, print_colored_debug,
|
|
print_traceback_on_error)
|
|
from ..lora_manager import LoraConfig, LoraManager
|
|
from ..prompt_adapter_manager import PromptAdapterManager
|
|
from ..runtime import ModelConfig
|
|
from ..runtime.model_runner import _engine_config_to_model_config
|
|
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
|
|
from .executor import GenerationExecutor, IterationResultQueue
|
|
from .ipc import FusedIpcQueue, IpcQueue
|
|
from .postproc_worker import (PostprocParams, PostprocWorker,
|
|
PostprocWorkerConfig, postproc_worker_main)
|
|
from .request import (CancellingRequest, GenerationRequest, LoRARequest,
|
|
PromptAdapterRequest)
|
|
from .result import (GenerationResult, IterationResult, LogProbsResult,
|
|
ResponseWrapper, compute_logprobs)
|
|
from .utils import (PERIODICAL_RESP_IN_AWAIT, ErrorResponse, IntraProcessQueue,
|
|
RequestError, WorkerCommIpcAddrs, has_event_loop,
|
|
is_llm_response)
|
|
|
|
__all__ = [
|
|
"GenerationExecutorWorker",
|
|
]
|
|
|
|
|
|
class GenerationExecutorWorker(GenerationExecutor):
|
|
|
|
class WorkerExit(GeneratorExit):
|
|
pass
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Union[Path, Engine],
|
|
executor_config: Optional[tllm.ExecutorConfig] = None,
|
|
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
|
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
|
is_llm_executor: Optional[bool] = None,
|
|
lora_config: Optional[LoraConfig] = None,
|
|
) -> None:
|
|
postproc_config = postproc_worker_config or PostprocWorkerConfig()
|
|
super().__init__(
|
|
num_postprocess_workers=postproc_config.num_postprocess_workers,
|
|
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
|
|
is_llm_executor=is_llm_executor,
|
|
)
|
|
|
|
self.engine = None
|
|
self.result_queue: Optional[IpcQueue] = None
|
|
self.postproc_queues: Optional[List[IpcQueue]] = None
|
|
self.rank = mpi_rank()
|
|
self.global_rank = global_mpi_rank()
|
|
# mapping: client_id -> GenerationResult
|
|
self._results: Dict[int, GenerationResult] = {}
|
|
# mapping: client_id from Proxy -> request_id returned from runtime backend
|
|
self._client_id_to_request_id: Dict[int, int] = {}
|
|
self._await_response_helper = AwaitResponseHelper(
|
|
self) # TODO: make it weakref
|
|
self._executor_config = executor_config
|
|
self._is_pytorch_backend = getattr(self._executor_config, "backend",
|
|
None) == "pytorch"
|
|
|
|
if isinstance(engine, list):
|
|
engine = engine[self.rank]
|
|
|
|
if executor_config is None:
|
|
executor_config = tllm.ExecutorConfig(1)
|
|
|
|
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
|
|
processor_batched=batched_logits_processor, replicate=False)
|
|
|
|
def _create_engine():
|
|
device_id = self.global_rank % torch.cuda.device_count()
|
|
torch.cuda.set_device(device_id)
|
|
|
|
# Make sure C++ executor would use same devices/ranks as py_executor
|
|
global_rank = global_mpi_rank()
|
|
comm_ranks = mpi_comm().allgather(global_rank)
|
|
device_ids = mpi_comm().allgather(device_id)
|
|
executor_config.parallel_config = tllm.ParallelConfig(
|
|
participant_ids=comm_ranks, device_ids=device_ids)
|
|
|
|
if isinstance(engine, Engine):
|
|
return tllm.Executor(engine.engine,
|
|
json.dumps(engine.config.to_dict(),
|
|
cls=ConfigEncoder),
|
|
tllm.ModelType.DECODER_ONLY,
|
|
executor_config=executor_config,
|
|
managed_weights=engine.managed_weights)
|
|
|
|
if not hasattr(executor_config, "backend"):
|
|
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
|
|
executor_config)
|
|
args = {
|
|
"executor_config": executor_config,
|
|
"checkpoint_dir": executor_config.hf_model_dir,
|
|
"engine_dir": executor_config.trt_engine_dir,
|
|
}
|
|
if executor_config.backend == "pytorch":
|
|
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
|
|
create_py_executor
|
|
create_executor = create_py_executor
|
|
args["lora_config"] = lora_config
|
|
elif executor_config.backend == "autodeploy":
|
|
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
|
|
create_autodeploy_executor
|
|
create_executor = create_autodeploy_executor
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported backend config: {executor_config.backend}")
|
|
|
|
return create_executor(**args)
|
|
|
|
self.engine = _create_engine()
|
|
|
|
self._lora_manager: Optional[LoraManager] = None
|
|
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
|
|
self._runtime_model_config: Optional[ModelConfig] = None
|
|
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
|
|
if isinstance(engine, Engine):
|
|
engine_config = engine.config
|
|
else:
|
|
engine_config = EngineConfig.from_json_file(
|
|
f"{engine}/config.json")
|
|
self._runtime_model_config = _engine_config_to_model_config(
|
|
engine_config)
|
|
if engine_config.build_config.plugin_config.lora_plugin:
|
|
self._lora_manager = LoraManager()
|
|
if engine_config.build_config.max_prompt_embedding_table_size > 0:
|
|
self._prompt_adapter_manager = PromptAdapterManager()
|
|
|
|
if getattr(executor_config, "backend",
|
|
"") == "pytorch" and lora_config is not None:
|
|
self._lora_manager = LoraManager()
|
|
lora_model_config = self.engine.model_engine.lora_model_config
|
|
assert lora_model_config is not None
|
|
self._lora_model_config = lora_model_config
|
|
|
|
self.await_response_thread = ManagedThread(
|
|
self.await_response_task,
|
|
error_queue=self._error_queue,
|
|
name="await_response_thread")
|
|
|
|
self.dispatch_stats_thread = ManagedThread(
|
|
self.dispatch_stats_task,
|
|
error_queue=self._error_queue,
|
|
name="dispatch_stats_thread")
|
|
|
|
self.dispatch_kv_cache_events_thread = ManagedThread(
|
|
self.dispatch_kv_cache_events_task,
|
|
error_queue=self._error_queue,
|
|
name="dispatch_kv_cache_events_thread")
|
|
|
|
def set_result_queue(self, queue):
|
|
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
|
|
assert self.postproc_queues is None
|
|
self.result_queue = queue
|
|
|
|
def set_postproc_queues(self, queues: List["IpcQueue"]):
|
|
""" Set the IPC queues for feeding post-processing processes. """
|
|
assert self.result_queue is None
|
|
self.postproc_queues = queues
|
|
|
|
def _create_iteration_result_queue(self,
|
|
it_result_queue: IterationResultQueue):
|
|
if not it_result_queue.is_initialized:
|
|
# not yet initialized
|
|
it_result_queue.is_initialized = True
|
|
if has_event_loop():
|
|
_queue = AsyncQueue()
|
|
it_result_queue.queue = _queue.sync_q
|
|
it_result_queue.aqueue = _queue
|
|
else:
|
|
_queue = Queue()
|
|
it_result_queue.queue = _queue
|
|
it_result_queue.aqueue = None
|
|
|
|
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
|
|
queue: Union[Queue, FusedIpcQueue,
|
|
IntraProcessQueue]):
|
|
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
|
|
it_result_queue.is_initialized = True
|
|
it_result_queue.queue = queue
|
|
it_result_queue.aqueue = None
|
|
|
|
def return_queue(self, client_id: int):
|
|
""" If a centralized result queue is registered (used for communication with the proxy)
|
|
send the message there.
|
|
Otherwise, push the result directly in the GenerationResult queue.
|
|
"""
|
|
if self.result_queue is not None:
|
|
return self.result_queue
|
|
return self._results[client_id].queue
|
|
|
|
def start_thread(self, thread: ManagedThread):
|
|
if self.engine.can_enqueue_requests() and not thread.is_alive():
|
|
thread.start()
|
|
|
|
def abort_request(self, client_id: int) -> None:
|
|
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
|
|
if self.engine.can_enqueue_requests():
|
|
request_id = self._client_id_to_request_id.get(client_id, None)
|
|
if request_id is None:
|
|
logger.warning(
|
|
f"Request of client_id {client_id} is finished, cannot abort it."
|
|
)
|
|
return
|
|
self.engine.cancel_request(request_id)
|
|
|
|
def _engine_response_callback(self, response: tllm.Response):
|
|
return response
|
|
|
|
def await_response_task(self) -> bool:
|
|
return self._await_response_helper()
|
|
|
|
def _has_background_error(self) -> bool:
|
|
return not self._error_queue.empty()
|
|
|
|
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
|
|
bck_error = self._error_queue.get_nowait()
|
|
assert isinstance(bck_error, Exception)
|
|
return ErrorResponse(response.client_id, str(bck_error),
|
|
response.request_id)
|
|
|
|
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
|
|
engine_get_result_api: Callable,
|
|
result_singleton: IterationResult,
|
|
result_serializer: Callable):
|
|
time.sleep(0.2)
|
|
async_queues = []
|
|
queue = result_singleton.queue if self._is_llm_executor and result_singleton else it_result_queue.queue
|
|
try:
|
|
for results in engine_get_result_api():
|
|
res = result_serializer(results)
|
|
if self._is_llm_executor and result_singleton:
|
|
# In this case, there's no ExecutorBindingProxy.
|
|
# Worker needs to take care of putting to result queue.
|
|
while queue.full():
|
|
queue.get()
|
|
if isinstance(queue, _SyncQueue):
|
|
queue.put_nowait(res)
|
|
async_queues.append(queue)
|
|
else:
|
|
queue.put(res)
|
|
else:
|
|
# Send to ExecutorBindingProxy via IPC
|
|
queue.put(res)
|
|
|
|
if async_queues:
|
|
_SyncQueue.notify_many(queue.loop, async_queues)
|
|
except AsyncQueue.EventLoopShutdownError:
|
|
# This happens in the last results loop while the generate workflow is stopped.
|
|
logger.debug("worker.py: EventLoopShutdownError")
|
|
except Exception as e:
|
|
logger.error(f"worker.py: Error in _iteration_result_task: {e}")
|
|
raise e
|
|
|
|
return True # success
|
|
|
|
def dispatch_stats_task(self) -> bool:
|
|
|
|
# Define a Callable to join iteration and request stats
|
|
def stats_serializer(
|
|
stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
|
|
iteration_stats, req_stats = stats
|
|
stats_dict = json.loads(iteration_stats.to_json_str())
|
|
|
|
if req_stats is not None and len(req_stats) > 0:
|
|
stats_dict["requestStats"] = []
|
|
for req_stat in req_stats:
|
|
stats_dict["requestStats"].append(
|
|
json.loads(req_stat.to_json_str()))
|
|
|
|
# Convert back to JSON string
|
|
return json.dumps(stats_dict)
|
|
|
|
def get_stats():
|
|
if isinstance(self.engine, tllm.Executor):
|
|
iter_stats = self.engine.get_latest_iteration_stats()
|
|
#TODO: Support req stats with TRT engine
|
|
# This would require ensuring iter and req stats have same size
|
|
return [(iter_stat, None) for iter_stat in iter_stats]
|
|
else:
|
|
return self.engine.get_latest_iteration_stats()
|
|
|
|
return self._iteration_result_task(self.stats_queues, get_stats,
|
|
self._iter_stats_result,
|
|
stats_serializer)
|
|
|
|
def dispatch_kv_cache_events_task(self) -> bool:
|
|
if isinstance(self.engine, tllm.Executor):
|
|
# Check if the engine has a kv cache event manager
|
|
# If not, return an empty list for the events which will cause the thread to exit early.
|
|
event_manager = self.engine.get_kv_cache_event_manager()
|
|
if event_manager is None:
|
|
events_api = lambda: [None]
|
|
else:
|
|
events_api = event_manager.get_latest_events
|
|
return self._iteration_result_task(
|
|
self.kv_events_queues, events_api, self._iter_kv_events_result,
|
|
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
|
|
else:
|
|
return self._iteration_result_task(
|
|
self.kv_events_queues, self.engine.get_latest_kv_cache_events,
|
|
self._iter_kv_events_result,
|
|
lambda x: json.dumps(KVCacheEventSerializer.serialize(x)))
|
|
|
|
def start(self):
|
|
# create iteration result queues
|
|
self._create_iteration_result_queue(self.stats_queues)
|
|
self._create_iteration_result_queue(self.kv_events_queues)
|
|
|
|
# start threads
|
|
self.start_thread(self.await_response_thread)
|
|
self.start_thread(self.dispatch_kv_cache_events_thread)
|
|
if mpi_rank() == 0:
|
|
self.start_thread(self.dispatch_stats_thread)
|
|
|
|
def _load_lora_adapter(self, lora_request: LoRARequest):
|
|
self._lora_manager.load_from_ckpt(
|
|
[lora_request.path],
|
|
model_config=self._runtime_model_config if
|
|
self._runtime_model_config is not None else self._lora_model_config,
|
|
runtime_mapping=None,
|
|
uids=[str(lora_request.adapter_id)])
|
|
|
|
def _load_prompt_adapter(self,
|
|
prompt_adapter_request: PromptAdapterRequest):
|
|
self._prompt_adapter_manager.load_from_ckpt(
|
|
[prompt_adapter_request.local_path],
|
|
model_config=self._runtime_model_config,
|
|
uids=[str(prompt_adapter_request.adapter_id)])
|
|
|
|
def _enqueue_request(self, request: GenerationRequest) -> int:
|
|
assert request.id is not None
|
|
if self._lora_manager is not None and request.lora_request is not None:
|
|
self._load_lora_adapter(request.lora_request)
|
|
uid = str(request.lora_request.adapter_id)
|
|
lora_config = tllm.LoraConfig(
|
|
task_id=request.lora_request.adapter_id,
|
|
weights=self._lora_manager.cpp_lora_weights[uid],
|
|
config=self._lora_manager.cpp_lora_config[uid])
|
|
else:
|
|
lora_config = None
|
|
|
|
prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
|
|
prompt_tuning_config = None
|
|
multimodal_embedding = None
|
|
mrope_config = None
|
|
if request.multimodal_embedding is not None:
|
|
multimodal_embedding = request.multimodal_embedding
|
|
if request.prompt_adapter_request is not None:
|
|
self._load_prompt_adapter(request.prompt_adapter_request)
|
|
uid = str(request.prompt_adapter_request.adapter_id)
|
|
prompt_tuning_config = tllm.PromptTuningConfig(
|
|
self._prompt_adapter_manager.uid_to_weights[uid])
|
|
vocab_size = self._runtime_model_config.vocab_size
|
|
pa_length = prompt_tuning_config.embedding_table.size(0)
|
|
prompt_token_ids = list(range(
|
|
vocab_size, vocab_size + pa_length)) + prompt_token_ids
|
|
|
|
if request.mrope_config is not None:
|
|
mrope_config = tllm.MropeConfig(**request.mrope_config)
|
|
|
|
context_phase_params = None
|
|
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
|
|
if request.disaggregated_params is not None:
|
|
request_type = request.disaggregated_params.get_request_type()
|
|
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
|
|
context_phase_params = request.disaggregated_params.get_context_phase_params(
|
|
)
|
|
|
|
is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler
|
|
if is_overlap_enabled:
|
|
is_disaggregated = self.engine.kv_cache_transceiver is not None
|
|
if is_disaggregated and (
|
|
request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
|
|
raise ValueError(
|
|
"Context only requests are not supported in pytorch backend when overlap is enabled."
|
|
)
|
|
|
|
assert request.id is not None
|
|
|
|
def _deduce_max_tokens(request: GenerationRequest,
|
|
executor_config: tllm.ExecutorConfig) -> int:
|
|
if request.sampling_params.max_tokens:
|
|
return request.sampling_params.max_tokens
|
|
# deduce max_tokens when it's not set by user
|
|
query_token_len = len(
|
|
request.query_token_ids) if request.query_token_ids else 0
|
|
cp_size = 1 if (not hasattr(executor_config, "mapping")
|
|
or executor_config.mapping.cp_size
|
|
is None) else executor_config.mapping.cp_size
|
|
if not hasattr(executor_config, "max_seq_len"):
|
|
raise RuntimeError(
|
|
"max_tokens for sampling is not set and cannot be deduced")
|
|
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
|
|
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
|
|
if default_max_tokens < 0:
|
|
raise ValueError(
|
|
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
|
|
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
|
|
f"is larger than max_seq_len {executor_config.max_seq_len}")
|
|
return default_max_tokens
|
|
|
|
try:
|
|
executor_request = tllm.Request(
|
|
client_id=request.id,
|
|
input_token_ids=prompt_token_ids,
|
|
max_tokens=_deduce_max_tokens(request, self._executor_config),
|
|
streaming=request.streaming,
|
|
sampling_config=request.sampling_params._get_sampling_config(),
|
|
end_id=-1 if request.sampling_params.ignore_eos else
|
|
request.sampling_params.end_id,
|
|
pad_id=request.sampling_params.pad_id,
|
|
output_config=request.sampling_params._get_output_config(
|
|
is_pytorch_backend=self._is_pytorch_backend),
|
|
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
|
|
return_all_generated_tokens=False,
|
|
# convert python config into pybind config
|
|
lookahead_config=PybindMirror.maybe_to_pybind(
|
|
request.sampling_params.lookahead_config),
|
|
guided_decoding_params=request.sampling_params.
|
|
_get_guided_decoding_params(),
|
|
bad_words=request.sampling_params._get_bad_words(),
|
|
stop_words=request.sampling_params._get_stop_words(),
|
|
embedding_bias=request.sampling_params.embedding_bias,
|
|
lora_config=lora_config,
|
|
prompt_tuning_config=prompt_tuning_config,
|
|
multimodal_embedding=multimodal_embedding,
|
|
mrope_config=mrope_config,
|
|
logits_post_processor_name=(
|
|
tllm.Request.BATCHED_POST_PROCESSOR_NAME
|
|
if request.sampling_params.apply_batched_logits_processor
|
|
else None),
|
|
logits_post_processor=None if self._is_pytorch_backend else
|
|
request.sampling_params.logits_processor,
|
|
kv_cache_retention_config=request.kv_cache_retention_config,
|
|
context_phase_params=context_phase_params,
|
|
type=request_type)
|
|
|
|
if self._is_pytorch_backend and request.sampling_params.logits_processor:
|
|
# For PyTorch backend, we attach logits processors as a dynamic Python attribute
|
|
# instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
|
|
lp = request.sampling_params.logits_processor
|
|
executor_request.py_logits_post_processors = lp if isinstance(
|
|
lp, list) else [lp]
|
|
|
|
if request.query_token_ids is not None:
|
|
# pytorch star attention workflow
|
|
# a workaround to avoid public interface update
|
|
req_id = self.engine.enqueue_request(executor_request,
|
|
request.query_token_ids)
|
|
else:
|
|
req_id = self.engine.enqueue_request(executor_request)
|
|
return req_id
|
|
except Exception as e:
|
|
raise RequestError(str(e)) from e
|
|
|
|
def submit(self, request: GenerationRequest) -> GenerationResult:
|
|
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
|
|
self.start()
|
|
|
|
if self.rank != 0:
|
|
raise RuntimeError(
|
|
"Only rank 0 can submit requests.\n"
|
|
"To fix this, ensure that the llm.generate(...) method is "
|
|
"guarded with the `if __name__ == '__main__':` block.")
|
|
|
|
client_id = request.id if request.id is not None else self._get_next_client_id(
|
|
)
|
|
if request.id is None:
|
|
request.set_id(client_id)
|
|
|
|
logprob_params = self._get_logprob_params(request)
|
|
|
|
result = GenerationResult(
|
|
request,
|
|
background_error_handler=self._handle_background_error,
|
|
executor=self,
|
|
disaggregated_params=request.disaggregated_params,
|
|
logprob_params=logprob_params)
|
|
|
|
self._results[client_id] = result
|
|
|
|
request_id = self._enqueue_request(request)
|
|
# request_id returned from backend is necessary for the abort_request method.
|
|
self._client_id_to_request_id[client_id] = request_id
|
|
|
|
self._handle_background_error()
|
|
|
|
return result
|
|
|
|
def _pop_result(self, client_id: int):
|
|
self._results.pop(client_id, None)
|
|
self._client_id_to_request_id.pop(client_id, None)
|
|
|
|
def shutdown(self):
|
|
|
|
if self.doing_shutdown:
|
|
return
|
|
else:
|
|
self.doing_shutdown = True
|
|
|
|
print_colored_debug(f'Worker {mpi_rank()} shutdown...\n', "yellow")
|
|
|
|
if self.engine is not None:
|
|
if self.engine.can_enqueue_requests():
|
|
|
|
if self.await_response_thread.is_alive():
|
|
self.await_response_thread.stop()
|
|
self.await_response_thread.join()
|
|
if self.dispatch_stats_thread.is_alive():
|
|
self.dispatch_stats_thread.stop()
|
|
self.dispatch_stats_thread.join()
|
|
if self.dispatch_kv_cache_events_thread.is_alive():
|
|
self.dispatch_kv_cache_events_thread.stop()
|
|
self.dispatch_kv_cache_events_thread.join()
|
|
|
|
self.engine.shutdown()
|
|
self.engine = None
|
|
|
|
# Check if there are any errors from the threads before shutdown.
|
|
self._handle_background_error()
|
|
|
|
print_colored_debug(f"Worker {mpi_rank()} shutdown done.\n", "yellow")
|
|
|
|
def block_subordinates(self):
|
|
if self.rank != 0:
|
|
if isinstance(self.engine, tllm.Executor):
|
|
self.shutdown()
|
|
raise self.WorkerExit(
|
|
"block_subordinates() should be used in a `with GenerationExecutorWorker() as ...:` block"
|
|
)
|
|
from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor
|
|
if isinstance(self.engine, PyExecutor):
|
|
self.engine.wait_shutdown()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
|
self.shutdown()
|
|
return exc_type is None or exc_type == GenerationExecutorWorker.WorkerExit
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|
|
|
|
|
|
@print_traceback_on_error
|
|
def worker_main(
|
|
engine: Path | Engine,
|
|
worker_queues: WorkerCommIpcAddrs,
|
|
log_level: str,
|
|
executor_config: Optional[tllm.ExecutorConfig] = None,
|
|
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
|
worker_cls: type = GenerationExecutorWorker,
|
|
tracer_init_kwargs: Optional[dict] = None,
|
|
_torch_model_class_mapping: Optional[dict] = None,
|
|
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
|
ready_signal: Optional[str] = None,
|
|
is_llm_executor: Optional[
|
|
bool] = True, # whether it's the main executor instance
|
|
lora_config: Optional[LoraConfig] = None,
|
|
BASE_ZMQ_CLASSES: Dict = serialization.BASE_ZMQ_CLASSES,
|
|
) -> None:
|
|
# The base classes for ZMQ serialization. Passed through from the parent process to ensure
|
|
# that children processes include any classes added at runtime (such as those from `register_approved_ipc_class`).
|
|
serialization.BASE_ZMQ_CLASSES = BASE_ZMQ_CLASSES
|
|
mpi_comm().barrier()
|
|
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
|
|
"green")
|
|
|
|
pid = os.getpid()
|
|
cpus = os.sched_getaffinity(pid)
|
|
if cpus:
|
|
logger.warning(
|
|
f"Found worker process {pid} was bound to {cpus}, this may harm"
|
|
"performance.", )
|
|
logger.warning(f"Will clear the cpu affinity")
|
|
clear_sched_affinity(pid)
|
|
|
|
result_queue: Optional[IpcQueue] = None
|
|
result_queues: Optional[List[IpcQueue]] = None
|
|
|
|
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig()
|
|
|
|
is_leader: bool = mpi_rank() == 0
|
|
if tracer_init_kwargs is not None and is_leader:
|
|
tracer = VizTracer(**tracer_init_kwargs)
|
|
tracer.register_exit()
|
|
tracer.start()
|
|
set_global_tracer(tracer)
|
|
|
|
if _torch_model_class_mapping is not None:
|
|
from tensorrt_llm._torch.models.modeling_auto import MODEL_CLASS_MAPPING
|
|
MODEL_CLASS_MAPPING.update(**_torch_model_class_mapping)
|
|
|
|
set_mpi_session_cpp(mpi_comm())
|
|
|
|
if is_leader:
|
|
# Only set the log level for the leader process, the other processes will
|
|
# inherit the log level from "TLLM_LOG_LEVEL" environment variable
|
|
logger.set_level(log_level)
|
|
request_queue = IpcQueue(worker_queues.request_queue_addr,
|
|
is_server=False,
|
|
name="worker_request_queue")
|
|
request_error_queue = IpcQueue(worker_queues.request_error_queue_addr,
|
|
is_server=False,
|
|
name="worker_request_error_queue")
|
|
mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr,
|
|
is_server=False,
|
|
fuse_message=True,
|
|
name="worker_stats_queue")
|
|
kv_cache_events_queue = FusedIpcQueue(
|
|
worker_queues.kv_cache_events_queue_addr,
|
|
is_server=False,
|
|
fuse_message=False,
|
|
name="worker_kv_cache_events_queue")
|
|
|
|
if postproc_worker_config.enabled:
|
|
# IPC queues for sending inputs to the postprocess parallel
|
|
# processes, each one is a PAIR zmq socket
|
|
result_queues = [
|
|
FusedIpcQueue(is_server=True,
|
|
fuse_message=PERIODICAL_RESP_IN_AWAIT,
|
|
name=f"postprocess_{i}_feedin_queue")
|
|
for i in range(postproc_worker_config.num_postprocess_workers)
|
|
]
|
|
else:
|
|
# IPC queue for sending results back to the proxy, and let the
|
|
# Proxy process to handle the postprocess
|
|
result_queue = FusedIpcQueue(worker_queues.result_queue_addr,
|
|
is_server=False,
|
|
fuse_message=PERIODICAL_RESP_IN_AWAIT,
|
|
name="worker_result_queue")
|
|
|
|
def notify_proxy_threads_to_quit():
|
|
# Signal the dispatcher thread in the proxy to quit
|
|
if result_queue is not None:
|
|
result_queue.put(None)
|
|
else:
|
|
assert result_queues is not None
|
|
for q in result_queues:
|
|
q.put(None)
|
|
# Signal the stats thread in the proxy to quit
|
|
mp_stats_queue.put(None)
|
|
kv_cache_events_queue.put(None)
|
|
|
|
postprocess_worker_futures = []
|
|
if is_leader and postproc_worker_config.enabled:
|
|
print_colored_debug(f"initiate postprocess workers...", "yellow")
|
|
|
|
proxy_result_queue: tuple[
|
|
str, Optional[bytes]] = worker_queues.result_queue_addr
|
|
|
|
assert result_queues is not None
|
|
assert postproc_worker_config.postprocess_tokenizer_dir is not None
|
|
postproc_worker_pool = ProcessPoolExecutor(
|
|
max_workers=postproc_worker_config.num_postprocess_workers)
|
|
assert isinstance(proxy_result_queue, tuple)
|
|
for i in range(postproc_worker_config.num_postprocess_workers):
|
|
fut = postproc_worker_pool.submit(
|
|
postproc_worker_main, result_queues[i].address,
|
|
proxy_result_queue,
|
|
postproc_worker_config.postprocess_tokenizer_dir,
|
|
PostprocWorker.default_record_creator,
|
|
serialization.BASE_ZMQ_CLASSES)
|
|
postprocess_worker_futures.append(fut)
|
|
|
|
# Error handling in the Worker/MPI process
|
|
# 1. During Executor initialization, the errors will be captured and
|
|
# send back via request_error_queue.
|
|
# 2. During execution, the errors will be captured by ManagedThreads
|
|
# a) For per-request error, the error will be send back via
|
|
# result_queue, and eventually raised in handle_response() in
|
|
# the main thread.
|
|
# b) For system error, the error will be raised in the MPI process
|
|
# and handled by future.done_callback, that will propagate the
|
|
# error to the error_queue in the main thread.
|
|
|
|
mpi_comm().barrier()
|
|
print_colored_debug(f"Worker {mpi_rank()} ready to setup backend...\n",
|
|
"green")
|
|
|
|
try:
|
|
worker: GenerationExecutorWorker = worker_cls(
|
|
engine,
|
|
executor_config,
|
|
batched_logits_processor,
|
|
postproc_worker_config=postproc_worker_config,
|
|
is_llm_executor=is_llm_executor,
|
|
lora_config=lora_config)
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
|
|
logger.error(traceback.format_exc())
|
|
print_colored_debug(f"error: {traceback.format_exc()}", "red")
|
|
if is_leader:
|
|
request_error_queue.put(e)
|
|
return
|
|
|
|
with worker:
|
|
try:
|
|
worker.block_subordinates()
|
|
|
|
if is_leader:
|
|
if postproc_worker_config.enabled:
|
|
worker.set_postproc_queues(result_queues)
|
|
else:
|
|
worker.set_result_queue(result_queue)
|
|
|
|
# initialize the iteration result queues
|
|
worker._set_iteration_result_queue(worker.stats_queues,
|
|
mp_stats_queue)
|
|
worker._set_iteration_result_queue(worker.kv_events_queues,
|
|
kv_cache_events_queue)
|
|
request_error_queue.put(ready_signal)
|
|
while (req := request_queue.get()) is not None:
|
|
if isinstance(req, CancellingRequest):
|
|
worker.abort_request(req.id)
|
|
elif isinstance(req, GenerationRequest):
|
|
try:
|
|
worker.submit(req)
|
|
request_error_queue.put(None) # None means success
|
|
except RequestError as e:
|
|
request_error_queue.put(e)
|
|
else:
|
|
raise ValueError(f"Unknown request type: {type(req)}")
|
|
|
|
notify_proxy_threads_to_quit()
|
|
|
|
except GenerationExecutorWorker.WorkerExit as e:
|
|
# This will capture by the with-statement and exit normally.
|
|
raise e
|
|
|
|
except Exception as e: # other critical errors
|
|
if is_leader:
|
|
notify_proxy_threads_to_quit()
|
|
err = Exception(f"Failed during generation: {e}")
|
|
logger.error(traceback.format_exc())
|
|
if is_leader:
|
|
request_error_queue.put(err)
|
|
|
|
|
|
class AwaitResponseHelper:
|
|
''' Multiple-implementations for await_response for performance. '''
|
|
|
|
class HandlerKind(enum.Enum):
|
|
unknown = 0
|
|
single_process_worker = 1
|
|
ipc_periodically = 2
|
|
ipc_batched = 3
|
|
|
|
def __init__(self, worker: "GenerationExecutorWorker"):
|
|
# TODO: make worker weakref
|
|
self.worker = worker
|
|
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
|
|
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
|
|
|
|
def responses_handler(self, responses: List[tllm.Response]):
|
|
HandlerKind = AwaitResponseHelper.HandlerKind
|
|
|
|
if self.handler_kind is HandlerKind.unknown:
|
|
if not (self.worker.result_queue is not None
|
|
or self.worker.postproc_queues is not None):
|
|
print_colored_debug(
|
|
f"creating await_response helper for Worker\n",
|
|
color="yellow")
|
|
# When ExecutorBindingWorker is used in the main process
|
|
# aka the single process mode
|
|
self.handler_kind = HandlerKind.single_process_worker
|
|
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
|
|
# The ExecutorBindingProxy is used
|
|
print_colored_debug(f"creating await_response helper for IPC\n",
|
|
color="yellow")
|
|
if PERIODICAL_RESP_IN_AWAIT:
|
|
self.handler_kind = HandlerKind.ipc_periodically
|
|
else:
|
|
self.handler_kind = HandlerKind.ipc_batched
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
match self.handler_kind:
|
|
case HandlerKind.single_process_worker:
|
|
return self.handle_for_worker(responses)
|
|
case HandlerKind.ipc_batched:
|
|
return self.handle_for_ipc_batched(responses)
|
|
case HandlerKind.ipc_periodically:
|
|
return self.handle_for_ipc_periodically(responses)
|
|
case _:
|
|
raise NotImplementedError
|
|
|
|
def __call__(self) -> bool:
|
|
''' This method should be called by a ManagedThread. '''
|
|
responses = self.worker.engine.await_responses(
|
|
timeout=datetime.timedelta(milliseconds=100))
|
|
# filter since The _engine_response_callback may return None
|
|
responses = list(
|
|
filter(
|
|
lambda _: _,
|
|
[self.worker._engine_response_callback(r) for r in responses]))
|
|
|
|
with nvtx_range_debug(f"await_response-{len(responses)}",
|
|
color="red",
|
|
category="Worker"):
|
|
self.responses_handler(responses)
|
|
return True
|
|
|
|
def handle_for_worker(self, responses: List[tllm.Response]) -> None:
|
|
''' Return the responses to asyncio.event_loop. '''
|
|
event_loop = None
|
|
async_queues = []
|
|
for response in responses:
|
|
assert response is not None
|
|
queue = self.worker.return_queue(response.client_id)
|
|
|
|
logprobs_result = _get_logprobs(self.worker, response,
|
|
self.worker._is_pytorch_backend)
|
|
if logprobs_result:
|
|
response = ResponseWrapper(response, logprobs_result)
|
|
|
|
# For AsyncQueue.sync_q, we will batch the events to avoid too many
|
|
# event notifications, thus put without wait here.
|
|
if isinstance(queue, _SyncQueue):
|
|
global_tracer().log_instant("worker-rsp.put")
|
|
queue.put_nowait(response)
|
|
async_queues.append(queue)
|
|
# all the loops are identical
|
|
event_loop = event_loop or queue.loop
|
|
else:
|
|
queue.put(response)
|
|
|
|
if response.has_error() or response.result.is_final:
|
|
self.worker._pop_result(response.client_id)
|
|
|
|
# Notify the events in bulk for performance.
|
|
if async_queues:
|
|
_SyncQueue.notify_many(event_loop, async_queues)
|
|
|
|
def handle_for_ipc_periodically(self,
|
|
responses: List[tllm.Response]) -> None:
|
|
''' Return the responses to Proxy via IPC. This will put Rsp to a Queue
|
|
in a FusedIpcQueue, and a background thread will batch them and invoke
|
|
IPC periodically. '''
|
|
|
|
with nvtx_range_debug(f"handle_for_ipc_periodically-{len(responses)}",
|
|
color="red",
|
|
category="Worker"):
|
|
|
|
for response in responses:
|
|
|
|
if self.worker._has_background_error():
|
|
response = self.worker._create_error_response(response)
|
|
elif response.has_error():
|
|
response = ErrorResponse(response.client_id,
|
|
response.error_msg,
|
|
response.request_id)
|
|
else:
|
|
logprobs_result = _get_logprobs(
|
|
self.worker, response, self.worker._is_pytorch_backend)
|
|
if logprobs_result:
|
|
response = ResponseWrapper(response, logprobs_result)
|
|
|
|
# TODO: To verify the performance of using ZMQ instead of SharedMemory
|
|
# to send the logits tensor back to the Proxy process.
|
|
_send_rsp(self.worker, response)
|
|
|
|
def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
|
|
''' Perform the IPC in batch explicitly. '''
|
|
postproc_batches = [
|
|
[]
|
|
for _ in range(self.worker.postproc_config.num_postprocess_workers)
|
|
] if self.enable_postprocprocess_parallel else None
|
|
rsp_batch = [] if not self.enable_postprocprocess_parallel else None
|
|
|
|
for response in responses:
|
|
|
|
if self.worker._has_background_error():
|
|
response = self.worker._create_error_response(response)
|
|
elif response.has_error():
|
|
# Convert to ErrorResponse, because tllm.Response cannot be
|
|
# serialized when it has error.
|
|
response = ErrorResponse(response.client_id, response.error_msg,
|
|
response.request_id)
|
|
else:
|
|
logprobs_result = _get_logprobs(self.worker, response,
|
|
self.worker._is_pytorch_backend)
|
|
if logprobs_result:
|
|
response = ResponseWrapper(response, logprobs_result)
|
|
|
|
_send_rsp(self.worker,
|
|
response,
|
|
postproc_batches=postproc_batches,
|
|
rsp_batch=rsp_batch)
|
|
|
|
if postproc_batches:
|
|
for wid, batch in enumerate(postproc_batches):
|
|
if batch:
|
|
self.worker.postproc_queues[wid].put(batch)
|
|
|
|
if rsp_batch:
|
|
self.worker.result_queue.put(rsp_batch)
|
|
|
|
|
|
def _get_params_for_first_rsp(
|
|
worker,
|
|
client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
|
|
res = worker._results.get(client_id, None)
|
|
assert res is not None
|
|
if not res._params_transmitted:
|
|
res._params_transmitted = True
|
|
return res.sampling_params, res.postproc_params
|
|
return None, None
|
|
|
|
|
|
def _get_logprobs(worker,
|
|
response: tllm.Response,
|
|
is_pytorch_backend=False) -> Optional[LogProbsResult]:
|
|
"""Compute logprob and prompt logprob and clear out logits if applicable.
|
|
"""
|
|
if is_pytorch_backend:
|
|
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
|
|
# In the PyTorch backend, logprobs are already computed during runtime if requested.
|
|
return None
|
|
|
|
logprobs_result = None
|
|
generation_result = worker._results.get(response.client_id, None)
|
|
|
|
if not generation_result:
|
|
return
|
|
|
|
logprob_params = getattr(generation_result, "_logprob_params", None)
|
|
if logprob_params:
|
|
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
|
|
logprob_params.logprobs,
|
|
response.result.context_logits,
|
|
response.result.generation_logits,
|
|
response.result.output_token_ids[0])
|
|
|
|
if logprob_params.drop_context_logits:
|
|
response.clear_context_logits()
|
|
|
|
if logprob_params.drop_generation_logits:
|
|
response.clear_generation_logits()
|
|
|
|
if response.result.is_final:
|
|
generation_result.clear_logprob_params()
|
|
|
|
return logprobs_result
|
|
|
|
|
|
def _send_rsp(
|
|
worker,
|
|
response: Union[tllm.Response, ResponseWrapper, ErrorResponse],
|
|
postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
|
|
rsp_batch: Optional[List[tllm.Response]] = None):
|
|
# if postproc_batches is set, append to batch instead of putting to IpcQueue
|
|
|
|
if worker.result_queue is not None:
|
|
if rsp_batch is not None:
|
|
rsp_batch.append(response)
|
|
else:
|
|
worker.result_queue.put(response)
|
|
else:
|
|
sampling_params, postproc_params = _get_params_for_first_rsp(
|
|
worker, response.client_id)
|
|
inp = PostprocWorker.Input(
|
|
response,
|
|
# sampling_params is necessary for creating fake GenerationResult
|
|
# instances in the postproc processes. They are for incremental
|
|
# detokenize. They should be transmitted only once for each
|
|
# Request.
|
|
sampling_params=sampling_params,
|
|
postproc_params=postproc_params,
|
|
streaming=worker._results.get(response.client_id, None)._streaming)
|
|
|
|
pid = response.client_id % worker.postproc_config.num_postprocess_workers
|
|
|
|
if not postproc_batches:
|
|
# Group the responses into buckets for the postprocessing steps.
|
|
# Bucketing is used instead of random dispatching because the
|
|
# incremental detokenization during postprocessing relies on the
|
|
# prior CompletionOutput of a given request.
|
|
worker.postproc_queues[pid].put(inp)
|
|
else:
|
|
postproc_batches[pid].append(inp)
|
|
|
|
# Eliminate the finished GenerationRequest instances timely, which may
|
|
# take considerable memory.
|
|
if is_llm_response(response):
|
|
if response.has_error() or response.result.is_final:
|
|
worker._pop_result(response.client_id)
|
|
elif isinstance(response, ErrorResponse):
|
|
worker._pop_result(response.client_id)
|
|
else:
|
|
raise ValueError(f"Unknown response type: {response}")
|