mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8188][chore] refactor GenerationExecutorWorker with WorkerBase for better code reusing (#7840)
Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com>
This commit is contained in:
parent
e10a027a03
commit
4509d97780
822
tensorrt_llm/executor/base_worker.py
Normal file
822
tensorrt_llm/executor/base_worker.py
Normal file
@ -0,0 +1,822 @@
|
||||
import copy
|
||||
import datetime
|
||||
import enum
|
||||
import json
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank,
|
||||
nvtx_range_debug)
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import ConfigEncoder, Engine, EngineConfig
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..llmapi.tracer import global_tracer
|
||||
from ..llmapi.utils import _SyncQueue, print_colored_debug
|
||||
from ..lora_helper import LoraConfig
|
||||
from ..lora_manager import LoraManager
|
||||
from ..metrics import RequestEventTiming
|
||||
from ..prompt_adapter_manager import PromptAdapterManager
|
||||
from ..runtime import ModelConfig
|
||||
from ..runtime.model_runner import _engine_config_to_model_config
|
||||
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
|
||||
from .executor import GenerationExecutor, IterationResultQueue
|
||||
from .ipc import FusedIpcQueue, IpcQueue
|
||||
from .postproc_worker import (PostprocParams, PostprocWorker,
|
||||
PostprocWorkerConfig)
|
||||
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
|
||||
from .result import (GenerationResult, LogProbsResult, ResponseWrapper,
|
||||
compute_logprobs)
|
||||
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
|
||||
is_llm_response)
|
||||
|
||||
__all__ = [
|
||||
"BaseWorker",
|
||||
]
|
||||
|
||||
|
||||
class BaseWorker(GenerationExecutor):
|
||||
|
||||
class WorkerExit(GeneratorExit):
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Union[Path, Engine],
|
||||
executor_config: Optional[tllm.ExecutorConfig] = None,
|
||||
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
|
||||
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
||||
is_llm_executor: Optional[bool] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
) -> None:
|
||||
postproc_config = postproc_worker_config or PostprocWorkerConfig()
|
||||
super().__init__(
|
||||
num_postprocess_workers=postproc_config.num_postprocess_workers,
|
||||
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
|
||||
is_llm_executor=is_llm_executor,
|
||||
)
|
||||
|
||||
# inputs
|
||||
self._engine = engine
|
||||
self._executor_config = executor_config
|
||||
self._batched_logits_processor = batched_logits_processor
|
||||
self._postproc_worker_config = postproc_worker_config
|
||||
self._is_llm_executor = is_llm_executor
|
||||
self._lora_config = lora_config
|
||||
self._kv_connector_config = kv_connector_config
|
||||
self._hf_model_dir = hf_model_dir
|
||||
self._tokenizer = tokenizer
|
||||
self.llm_args = llm_args
|
||||
|
||||
self.engine = None
|
||||
self.result_queue: Optional[IpcQueue] = None
|
||||
self.postproc_queues: Optional[List[IpcQueue]] = None
|
||||
self.rank = mpi_rank()
|
||||
self.global_rank = global_mpi_rank()
|
||||
# mapping: client_id -> GenerationResult
|
||||
self._results: Dict[int, GenerationResult] = {}
|
||||
# mapping: client_id from Proxy -> request_id returned from runtime backend
|
||||
self._client_id_to_request_id: Dict[int, int] = {}
|
||||
self._await_response_helper = AwaitResponseHelper(weakref.proxy(self))
|
||||
self._is_pytorch_backend = llm_args is not None and llm_args.backend in [
|
||||
"pytorch", "_autodeploy"
|
||||
]
|
||||
|
||||
if not self._is_pytorch_backend and kv_connector_config is not None:
|
||||
raise ValueError(
|
||||
"KV connector config is only supported for PyTorch backend")
|
||||
|
||||
if global_mpi_size() > 1:
|
||||
logger.set_rank(self.global_rank)
|
||||
|
||||
def setup_engine(self):
|
||||
"""
|
||||
Setup the engine for the worker.
|
||||
"""
|
||||
|
||||
if isinstance(self._engine, list):
|
||||
self._engine = self._engine[self.rank]
|
||||
|
||||
def _get_comm_ranks_device_id():
|
||||
device_id = self.global_rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device_id)
|
||||
# Make sure C++ executor would use same devices/ranks as py_executor
|
||||
global_rank = global_mpi_rank()
|
||||
comm_ranks = mpi_comm().allgather(global_rank)
|
||||
device_ids = mpi_comm().allgather(device_id)
|
||||
return comm_ranks, device_ids
|
||||
|
||||
def _create_py_executor():
|
||||
args = {}
|
||||
assert hasattr(
|
||||
self.llm_args, "backend"
|
||||
), "llm_args should be with backend in _create_py_executor"
|
||||
_ = _get_comm_ranks_device_id()
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
|
||||
create_py_executor
|
||||
create_executor = create_py_executor
|
||||
args["llm_args"] = self.llm_args
|
||||
args["checkpoint_dir"] = self._hf_model_dir
|
||||
args["tokenizer"] = self._tokenizer
|
||||
args["lora_config"] = self._lora_config
|
||||
args["kv_connector_config"] = self._kv_connector_config
|
||||
elif self.llm_args.backend == "_autodeploy":
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import \
|
||||
LlmArgs as ADLlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
|
||||
create_autodeploy_executor
|
||||
create_executor = create_autodeploy_executor
|
||||
assert isinstance(self.llm_args, ADLlmArgs)
|
||||
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported backend config: {self.llm_args.backend}")
|
||||
|
||||
# Define additional attributes that can be used later, such as in _deduce_max_tokens
|
||||
self.mapping = self.llm_args.parallel_config.to_mapping()
|
||||
self.checkpoint_loader = None
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.config import \
|
||||
_construct_checkpoint_loader
|
||||
self.checkpoint_loader = _construct_checkpoint_loader(
|
||||
self.llm_args.backend, self.llm_args.checkpoint_loader,
|
||||
self.llm_args.checkpoint_format)
|
||||
|
||||
_executor = create_executor(**args)
|
||||
self.max_seq_len = self.llm_args.max_seq_len
|
||||
if _executor.max_seq_len is not None:
|
||||
# max_seq_len might be updated by model engine as in create_py_executor
|
||||
self.max_seq_len = _executor.max_seq_len
|
||||
return _executor
|
||||
|
||||
def _create_engine(executor_config):
|
||||
engine = self._engine
|
||||
if executor_config is None:
|
||||
executor_config = tllm.ExecutorConfig(1)
|
||||
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
|
||||
processor_batched=self._batched_logits_processor,
|
||||
replicate=False)
|
||||
comm_ranks, device_ids = _get_comm_ranks_device_id()
|
||||
executor_config.parallel_config = tllm.ParallelConfig(
|
||||
participant_ids=comm_ranks, device_ids=device_ids)
|
||||
|
||||
if isinstance(engine, Engine):
|
||||
return tllm.Executor(engine.engine,
|
||||
json.dumps(engine.config.to_dict(),
|
||||
cls=ConfigEncoder),
|
||||
tllm.ModelType.DECODER_ONLY,
|
||||
executor_config=executor_config,
|
||||
managed_weights=engine.managed_weights)
|
||||
|
||||
assert not hasattr(executor_config, "backend")
|
||||
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
|
||||
executor_config)
|
||||
|
||||
self.engine = _create_py_executor(
|
||||
) if self.llm_args is not None else _create_engine(
|
||||
self._executor_config)
|
||||
|
||||
self._lora_manager: Optional[LoraManager] = None
|
||||
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
|
||||
self._runtime_model_config: Optional[ModelConfig] = None
|
||||
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
|
||||
if isinstance(self.engine, Engine):
|
||||
engine_config = self.engine.config
|
||||
else:
|
||||
engine_config = EngineConfig.from_json_file(
|
||||
f"{self._engine}/config.json")
|
||||
self._runtime_model_config = _engine_config_to_model_config(
|
||||
engine_config)
|
||||
if engine_config.build_config.plugin_config.lora_plugin:
|
||||
# TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
|
||||
# (see LoraManager constructor docstring). Getting the peft cache manager from this
|
||||
# point in the TRT flow is currently not supported (it's at the CPP
|
||||
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
|
||||
# optimization is not available in TRT-python flow.
|
||||
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
|
||||
if engine_config.build_config.max_prompt_embedding_table_size > 0:
|
||||
self._prompt_adapter_manager = PromptAdapterManager()
|
||||
|
||||
if self.llm_args and getattr(
|
||||
self.llm_args, "backend",
|
||||
"") == "pytorch" and self._lora_config is not None:
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import \
|
||||
ResourceManagerType
|
||||
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.PEFT_CACHE_MANAGER)
|
||||
self._lora_manager = LoraManager(
|
||||
cpp_peft_cache_manager=peft_cache_manager.impl)
|
||||
lora_model_config = self.engine.model_engine.lora_model_config
|
||||
assert lora_model_config is not None
|
||||
self._lora_model_config = lora_model_config
|
||||
|
||||
def await_responses(self, timeout: Optional[float] = None) -> list:
|
||||
return self.engine.await_responses(timeout=datetime.timedelta(
|
||||
seconds=timeout) if timeout is not None else None)
|
||||
|
||||
def fetch_stats(self) -> list:
|
||||
if isinstance(self.engine, tllm.Executor):
|
||||
iter_stats = self.engine.get_latest_iteration_stats()
|
||||
#TODO: Support req stats with TRT engine
|
||||
# This would require ensuring iter and req stats have same size
|
||||
return [(iter_stat, None) for iter_stat in iter_stats]
|
||||
else:
|
||||
return self.engine.get_latest_iteration_stats()
|
||||
|
||||
def set_result_queue(self, queue):
|
||||
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
|
||||
assert self.postproc_queues is None
|
||||
self.result_queue = queue
|
||||
|
||||
def set_postproc_queues(self, queues: List["IpcQueue"]):
|
||||
""" Set the IPC queues for feeding post-processing processes. """
|
||||
assert self.result_queue is None
|
||||
self.postproc_queues = queues
|
||||
|
||||
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
|
||||
queue: Union[Queue, FusedIpcQueue,
|
||||
IntraProcessQueue]):
|
||||
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
|
||||
it_result_queue.is_initialized = True
|
||||
it_result_queue.queue = queue
|
||||
it_result_queue.aqueue = None
|
||||
|
||||
def return_queue(self, client_id: int):
|
||||
""" If a centralized result queue is registered (used for communication with the proxy)
|
||||
send the message there.
|
||||
Otherwise, push the result directly in the GenerationResult queue.
|
||||
"""
|
||||
if self.result_queue is not None:
|
||||
return self.result_queue
|
||||
return self._results[client_id].queue
|
||||
|
||||
def abort_request(self, client_id: int) -> None:
|
||||
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
|
||||
if self.engine.can_enqueue_requests():
|
||||
request_id = self._client_id_to_request_id.get(client_id, None)
|
||||
if request_id is None:
|
||||
logger.warning(
|
||||
f"Request of client_id {client_id} is finished, cannot abort it."
|
||||
)
|
||||
return
|
||||
self.engine.cancel_request(request_id)
|
||||
|
||||
def _engine_response_callback(self, response: tllm.Response):
|
||||
return response
|
||||
|
||||
def _has_background_error(self) -> bool:
|
||||
return not self._error_queue.empty()
|
||||
|
||||
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
|
||||
bck_error = self._error_queue.get_nowait()
|
||||
assert isinstance(bck_error, Exception)
|
||||
return ErrorResponse(response.client_id, str(bck_error),
|
||||
response.request_id)
|
||||
|
||||
def start(self):
|
||||
raise NotImplementedError(
|
||||
"start method is not implemented in BaseWorker")
|
||||
|
||||
def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
|
||||
"""Returns True if the adapter was loaded by this call, False if it was already loaded"""
|
||||
adapter_id = str(lora_request.adapter_id)
|
||||
newly_loaded_uids = self._lora_manager.load_from_ckpt(
|
||||
[lora_request.path],
|
||||
model_config=self._runtime_model_config if
|
||||
self._runtime_model_config is not None else self._lora_model_config,
|
||||
runtime_mapping=None,
|
||||
uids=[adapter_id],
|
||||
ckpt_source=lora_request.ckpt_source)
|
||||
return adapter_id in newly_loaded_uids
|
||||
|
||||
def _load_prompt_adapter(self,
|
||||
prompt_adapter_request: PromptAdapterRequest):
|
||||
self._prompt_adapter_manager.load_from_ckpt(
|
||||
[prompt_adapter_request.local_path],
|
||||
model_config=self._runtime_model_config,
|
||||
uids=[str(prompt_adapter_request.adapter_id)])
|
||||
|
||||
def _enqueue_request(self, request: GenerationRequest) -> int:
|
||||
assert request.id is not None
|
||||
py_lora_path = None
|
||||
if self._lora_manager is not None and request.lora_request is not None:
|
||||
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
|
||||
request.lora_request.adapter_id)
|
||||
self._load_lora_adapter(request.lora_request)
|
||||
uid = str(request.lora_request.adapter_id)
|
||||
lora_config = tllm.LoraConfig(
|
||||
task_id=request.lora_request.adapter_id,
|
||||
weights=self._lora_manager.cpp_lora_weights[uid]
|
||||
if not adapter_in_cache else None,
|
||||
config=self._lora_manager.cpp_lora_config[uid])
|
||||
py_lora_path = request.lora_request.lora_path
|
||||
else:
|
||||
lora_config = None
|
||||
|
||||
prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
|
||||
prompt_tuning_config = None
|
||||
if request.prompt_adapter_request is not None:
|
||||
self._load_prompt_adapter(request.prompt_adapter_request)
|
||||
uid = str(request.prompt_adapter_request.adapter_id)
|
||||
prompt_tuning_config = tllm.PromptTuningConfig(
|
||||
self._prompt_adapter_manager.uid_to_weights[uid])
|
||||
vocab_size = self._runtime_model_config.vocab_size
|
||||
pa_length = prompt_tuning_config.embedding_table.size(0)
|
||||
prompt_token_ids = list(range(
|
||||
vocab_size, vocab_size + pa_length)) + prompt_token_ids
|
||||
|
||||
# MULTIMODAL
|
||||
# NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field
|
||||
# except `multimodal_input` as it needs to go through the C++ runtime.
|
||||
multimodal_input = None
|
||||
if request.multimodal_params is not None and request.multimodal_params.has_content(
|
||||
):
|
||||
if request.multimodal_params.multimodal_input is not None:
|
||||
multimodal_input = tllm.MultimodalInput(
|
||||
multimodal_hashes=request.multimodal_params.
|
||||
multimodal_input.multimodal_hashes,
|
||||
multimodal_positions=request.multimodal_params.
|
||||
multimodal_input.multimodal_positions,
|
||||
multimodal_lengths=request.multimodal_params.
|
||||
multimodal_input.multimodal_lengths)
|
||||
# NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field
|
||||
request.multimodal_params.multimodal_input = None
|
||||
|
||||
context_phase_params = None
|
||||
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
|
||||
if request.disaggregated_params is not None:
|
||||
assert (
|
||||
not self._is_pytorch_backend
|
||||
or self.engine.kv_cache_transceiver is not None
|
||||
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
|
||||
request_type = request.disaggregated_params.get_request_type()
|
||||
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
|
||||
context_phase_params = request.disaggregated_params.get_context_phase_params(
|
||||
)
|
||||
|
||||
if self._is_pytorch_backend:
|
||||
if not self.llm_args.disable_overlap_scheduler:
|
||||
is_disaggregated = self.engine.kv_cache_transceiver is not None
|
||||
if is_disaggregated and (
|
||||
request_type
|
||||
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
|
||||
raise ValueError(
|
||||
"Context only requests are not supported in pytorch backend when overlap is enabled."
|
||||
)
|
||||
|
||||
assert request.id is not None
|
||||
|
||||
def _deduce_max_tokens(request: GenerationRequest,
|
||||
executor_config: tllm.ExecutorConfig,
|
||||
llm_args: Optional[BaseLlmArgs] = None) -> int:
|
||||
# deduce max_tokens when it's not set by user
|
||||
max_tokens = request.sampling_params.max_tokens
|
||||
query_token_len = len(
|
||||
request.query_token_ids) if request.query_token_ids else 0
|
||||
|
||||
cp_size = 1
|
||||
max_seq_len = None
|
||||
if llm_args is not None:
|
||||
# deduce max_tokens by llm args
|
||||
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
|
||||
if hasattr(self,
|
||||
"mapping") and self.mapping.cp_size is not None:
|
||||
cp_size = self.mapping.cp_size
|
||||
max_seq_len = getattr(self, "max_seq_len", None)
|
||||
else:
|
||||
# deduce max_tokens by executor config
|
||||
if hasattr(executor_config, "mapping"
|
||||
) and executor_config.mapping.cp_size is not None:
|
||||
cp_size = executor_config.mapping.cp_size
|
||||
max_seq_len = getattr(executor_config, "max_seq_len", None)
|
||||
if max_seq_len is None:
|
||||
logger.warning("`default_max_tokens` cannot be deduced")
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
|
||||
)
|
||||
else:
|
||||
# use max_tokens if can't deduce default_max_tokens
|
||||
return max_tokens
|
||||
if executor_config is not None:
|
||||
assert (
|
||||
len(prompt_token_ids) <= executor_config.max_seq_len
|
||||
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
|
||||
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
|
||||
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
|
||||
if default_max_tokens <= 0:
|
||||
logger.warning(
|
||||
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
|
||||
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
|
||||
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
|
||||
)
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` is illegal"
|
||||
)
|
||||
# default_max_tokens is the biggest available value
|
||||
if max_tokens is None:
|
||||
return default_max_tokens
|
||||
elif max_tokens > default_max_tokens:
|
||||
logger.warning(
|
||||
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
|
||||
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
|
||||
)
|
||||
return default_max_tokens
|
||||
return max_tokens
|
||||
|
||||
try:
|
||||
executor_request = tllm.Request(
|
||||
client_id=request.id,
|
||||
input_token_ids=prompt_token_ids,
|
||||
max_tokens=_deduce_max_tokens(
|
||||
request,
|
||||
self._executor_config if not self.llm_args else None,
|
||||
self.llm_args),
|
||||
streaming=request.streaming,
|
||||
sampling_config=request.sampling_params._get_sampling_config(),
|
||||
end_id=-1 if request.sampling_params.ignore_eos else
|
||||
request.sampling_params.end_id,
|
||||
pad_id=request.sampling_params.pad_id,
|
||||
output_config=request.sampling_params._get_output_config(
|
||||
is_pytorch_backend=self._is_pytorch_backend),
|
||||
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
|
||||
return_all_generated_tokens=False,
|
||||
# convert python config into pybind config
|
||||
lookahead_config=PybindMirror.maybe_to_pybind(
|
||||
request.sampling_params.lookahead_config),
|
||||
guided_decoding_params=request.sampling_params.
|
||||
_get_guided_decoding_params(),
|
||||
bad_words=request.sampling_params._get_bad_words(),
|
||||
stop_words=request.sampling_params._get_stop_words(),
|
||||
embedding_bias=request.sampling_params.embedding_bias,
|
||||
lora_config=lora_config,
|
||||
prompt_tuning_config=prompt_tuning_config,
|
||||
multimodal_input=multimodal_input,
|
||||
# NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
|
||||
multimodal_embedding=None,
|
||||
mrope_config=None,
|
||||
logits_post_processor_name=(
|
||||
tllm.Request.BATCHED_POST_PROCESSOR_NAME
|
||||
if request.sampling_params.apply_batched_logits_processor
|
||||
else None),
|
||||
logits_post_processor=None if self._is_pytorch_backend else
|
||||
request.sampling_params.logits_processor,
|
||||
kv_cache_retention_config=request.kv_cache_retention_config,
|
||||
context_phase_params=context_phase_params,
|
||||
type=request_type,
|
||||
cache_salt_id=request.cache_salt_id)
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
|
||||
if self._is_pytorch_backend and request.multimodal_params is not None:
|
||||
if request.multimodal_params.multimodal_data is not None:
|
||||
# NOTE: Deserialize SharedTensor handle to actual tensor
|
||||
request.multimodal_params.to_tensor("multimodal_data")
|
||||
executor_request.py_multimodal_data = request.multimodal_params.multimodal_data
|
||||
|
||||
if self._is_pytorch_backend and request.sampling_params.logits_processor:
|
||||
# For PyTorch backend, we attach logits processors as a dynamic Python attribute
|
||||
# instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
|
||||
lp = request.sampling_params.logits_processor
|
||||
executor_request.py_logits_post_processors = lp if isinstance(
|
||||
lp, list) else [lp]
|
||||
|
||||
executor_request.py_scheduling_params = None
|
||||
if self._is_pytorch_backend and request.scheduling_params is not None:
|
||||
executor_request.py_scheduling_params = request.scheduling_params
|
||||
|
||||
if request.arrival_time is not None:
|
||||
executor_request.py_arrival_time = request.arrival_time
|
||||
|
||||
if request.query_token_ids is not None:
|
||||
# pytorch star attention workflow
|
||||
# a workaround to avoid public interface update
|
||||
req_id = self.engine.enqueue_request(executor_request,
|
||||
request.query_token_ids)
|
||||
else:
|
||||
req_id = self.engine.enqueue_request(executor_request)
|
||||
return req_id
|
||||
except Exception as e:
|
||||
raise RequestError(str(e)) from e
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
|
||||
self.start()
|
||||
|
||||
if self.rank != 0:
|
||||
raise RuntimeError(
|
||||
"Only rank 0 can submit requests.\n"
|
||||
"To fix this, ensure that the llm.generate(...) method is "
|
||||
"guarded with the `if __name__ == '__main__':` block.")
|
||||
|
||||
client_id = request.id if request.id is not None else self._get_next_client_id(
|
||||
)
|
||||
if request.id is None:
|
||||
request.set_id(client_id)
|
||||
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
result = GenerationResult(
|
||||
request,
|
||||
background_error_handler=self._handle_background_error,
|
||||
executor=self,
|
||||
disaggregated_params=request.disaggregated_params,
|
||||
logprob_params=logprob_params)
|
||||
|
||||
self._results[client_id] = result
|
||||
|
||||
request_id = self._enqueue_request(request)
|
||||
# request_id returned from backend is necessary for the abort_request method.
|
||||
self._client_id_to_request_id[client_id] = request_id
|
||||
|
||||
self._handle_background_error()
|
||||
|
||||
return result
|
||||
|
||||
def _pop_result(self, client_id: int):
|
||||
self._results.pop(client_id, None)
|
||||
self._client_id_to_request_id.pop(client_id, None)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
||||
self.shutdown()
|
||||
return exc_type is None or exc_type == BaseWorker.WorkerExit
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
class AwaitResponseHelper:
|
||||
''' Multiple-implementations for await_response for performance. '''
|
||||
|
||||
class HandlerKind(enum.Enum):
|
||||
unknown = 0
|
||||
single_process_worker = 1
|
||||
ipc_batched = 2
|
||||
|
||||
def __init__(self, worker: "BaseWorker"):
|
||||
# TODO: make worker weakref
|
||||
self.worker = worker
|
||||
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
|
||||
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
|
||||
# The error responses when submit request failed will be put here
|
||||
self.temp_error_responses = Queue()
|
||||
|
||||
def responses_handler(self, responses: List[tllm.Response]):
|
||||
HandlerKind = AwaitResponseHelper.HandlerKind
|
||||
|
||||
if self.handler_kind is HandlerKind.unknown:
|
||||
if not (self.worker.result_queue is not None
|
||||
or self.worker.postproc_queues is not None):
|
||||
print_colored_debug(
|
||||
f"creating await_response helper for Worker\n",
|
||||
color="yellow")
|
||||
# When ExecutorBindingWorker is used in the main process
|
||||
# aka the single process mode
|
||||
self.handler_kind = HandlerKind.single_process_worker
|
||||
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
|
||||
# The ExecutorBindingProxy is used
|
||||
print_colored_debug(f"creating await_response helper for IPC\n",
|
||||
color="yellow")
|
||||
self.handler_kind = HandlerKind.ipc_batched
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
match self.handler_kind:
|
||||
case HandlerKind.single_process_worker:
|
||||
return self.handle_for_worker(responses)
|
||||
case HandlerKind.ipc_batched:
|
||||
return self.handle_for_ipc_batched(responses)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, timeout: Optional[float] = None) -> bool:
|
||||
''' This method should be called by a ManagedThread. '''
|
||||
timeout = timeout or 0.1
|
||||
responses = self.worker.engine.await_responses(
|
||||
timeout=datetime.timedelta(seconds=timeout))
|
||||
# filter since The _engine_response_callback may return None
|
||||
responses = list(
|
||||
filter(
|
||||
lambda _: _,
|
||||
[self.worker._engine_response_callback(r) for r in responses]))
|
||||
|
||||
# append the error responses to the temp_error_responses
|
||||
while not self.temp_error_responses.empty():
|
||||
responses.append(self.temp_error_responses.get())
|
||||
|
||||
with nvtx_range_debug(f"await_response-{len(responses)}",
|
||||
color="red",
|
||||
category="Worker"):
|
||||
self.responses_handler(responses)
|
||||
return True
|
||||
|
||||
def handle_for_worker(self, responses: List[tllm.Response]) -> None:
|
||||
''' Return the responses to asyncio.event_loop. '''
|
||||
event_loop = None
|
||||
async_queues = []
|
||||
for response in responses:
|
||||
assert response is not None
|
||||
queue = self.worker.return_queue(response.client_id)
|
||||
|
||||
response = _maybe_wrap_response(self.worker, response,
|
||||
self.worker._is_pytorch_backend)
|
||||
|
||||
# For AsyncQueue.sync_q, we will batch the events to avoid too many
|
||||
# event notifications, thus put without wait here.
|
||||
if isinstance(queue, _SyncQueue):
|
||||
global_tracer().log_instant("worker-rsp.put")
|
||||
queue.put_nowait(response)
|
||||
async_queues.append(queue)
|
||||
# all the loops are identical
|
||||
event_loop = event_loop or queue.loop
|
||||
else:
|
||||
queue.put(response)
|
||||
|
||||
if response.has_error() or response.result.is_final:
|
||||
self.worker._pop_result(response.client_id)
|
||||
|
||||
# Notify the events in bulk for performance.
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(event_loop, async_queues)
|
||||
|
||||
def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
|
||||
''' Perform the IPC in batch explicitly. '''
|
||||
postproc_batches = [
|
||||
[]
|
||||
for _ in range(self.worker.postproc_config.num_postprocess_workers)
|
||||
] if self.enable_postprocprocess_parallel else None
|
||||
rsp_batch = [] if not self.enable_postprocprocess_parallel else None
|
||||
|
||||
for response in responses:
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
pass # send ErrorResponse directly
|
||||
elif self.worker._has_background_error():
|
||||
response = self.worker._create_error_response(response)
|
||||
elif response.has_error():
|
||||
# Convert to ErrorResponse, because tllm.Response cannot be
|
||||
# serialized when it has error.
|
||||
response = ErrorResponse(response.client_id, response.error_msg,
|
||||
response.request_id)
|
||||
else:
|
||||
response = _maybe_wrap_response(self.worker, response,
|
||||
self.worker._is_pytorch_backend)
|
||||
|
||||
_send_rsp(self.worker,
|
||||
response,
|
||||
postproc_batches=postproc_batches,
|
||||
rsp_batch=rsp_batch)
|
||||
|
||||
if postproc_batches:
|
||||
for wid, batch in enumerate(postproc_batches):
|
||||
if batch:
|
||||
self.worker.postproc_queues[wid].put(batch)
|
||||
|
||||
if rsp_batch:
|
||||
self.worker.result_queue.put(rsp_batch)
|
||||
|
||||
|
||||
def _get_params_for_first_rsp(
|
||||
worker,
|
||||
client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
|
||||
res = worker._results.get(client_id, None)
|
||||
assert res is not None
|
||||
if not res._params_transmitted:
|
||||
res._params_transmitted = True
|
||||
return res.sampling_params, res.postproc_params
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_logprobs(worker,
|
||||
response: tllm.Response,
|
||||
is_pytorch_backend=False) -> Optional[LogProbsResult]:
|
||||
"""Compute logprob and prompt logprob and clear out logits if applicable.
|
||||
"""
|
||||
if is_pytorch_backend:
|
||||
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
|
||||
# In the PyTorch backend, logprobs are already computed during runtime if requested.
|
||||
return None
|
||||
|
||||
logprobs_result = None
|
||||
generation_result = worker._results.get(response.client_id, None)
|
||||
|
||||
if not generation_result:
|
||||
return
|
||||
|
||||
logprob_params = getattr(generation_result, "_logprob_params", None)
|
||||
if logprob_params:
|
||||
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
|
||||
logprob_params.logprobs,
|
||||
response.result.context_logits,
|
||||
response.result.generation_logits,
|
||||
response.result.output_token_ids[0])
|
||||
|
||||
if logprob_params.drop_context_logits:
|
||||
response.clear_context_logits()
|
||||
|
||||
if logprob_params.drop_generation_logits:
|
||||
response.clear_generation_logits()
|
||||
|
||||
if response.result.is_final:
|
||||
generation_result.clear_logprob_params()
|
||||
|
||||
return logprobs_result
|
||||
|
||||
|
||||
def _send_rsp(
|
||||
worker,
|
||||
response: Union[tllm.Response, ResponseWrapper, ErrorResponse],
|
||||
postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
|
||||
rsp_batch: Optional[List[tllm.Response]] = None):
|
||||
# if postproc_batches is set, append to batch instead of putting to IpcQueue
|
||||
|
||||
if worker.result_queue is not None:
|
||||
if rsp_batch is not None:
|
||||
rsp_batch.append(response)
|
||||
else:
|
||||
worker.result_queue.put(response)
|
||||
else:
|
||||
sampling_params, postproc_params = _get_params_for_first_rsp(
|
||||
worker, response.client_id)
|
||||
inp = PostprocWorker.Input(
|
||||
response,
|
||||
# sampling_params is necessary for creating fake GenerationResult
|
||||
# instances in the postproc processes. They are for incremental
|
||||
# detokenize. They should be transmitted only once for each
|
||||
# Request.
|
||||
sampling_params=sampling_params,
|
||||
postproc_params=postproc_params,
|
||||
streaming=worker._results.get(response.client_id, None)._streaming)
|
||||
|
||||
pid = response.client_id % worker.postproc_config.num_postprocess_workers
|
||||
|
||||
if not postproc_batches:
|
||||
# Group the responses into buckets for the postprocessing steps.
|
||||
# Bucketing is used instead of random dispatching because the
|
||||
# incremental detokenization during postprocessing relies on the
|
||||
# prior CompletionOutput of a given request.
|
||||
worker.postproc_queues[pid].put(inp)
|
||||
else:
|
||||
postproc_batches[pid].append(inp)
|
||||
|
||||
# Eliminate the finished GenerationRequest instances timely, which may
|
||||
# take considerable memory.
|
||||
if is_llm_response(response):
|
||||
if response.has_error() or response.result.is_final:
|
||||
worker._pop_result(response.client_id)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
worker._pop_result(response.client_id)
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response}")
|
||||
|
||||
|
||||
def _get_metrics_dict(
|
||||
response: tllm.Response) -> dict[RequestEventTiming, float]:
|
||||
req_perf_metrics, metrics_dict = None, {}
|
||||
res = response.result
|
||||
if res:
|
||||
if hasattr(res, '_result'):
|
||||
if result := res.get_result():
|
||||
req_perf_metrics = result.request_perf_metrics
|
||||
else:
|
||||
req_perf_metrics = res.request_perf_metrics
|
||||
if req_perf_metrics and req_perf_metrics.timing_metrics:
|
||||
metrics_dict = {
|
||||
RequestEventTiming.ARRIVAL_TIME:
|
||||
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
|
||||
RequestEventTiming.FIRST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
|
||||
),
|
||||
RequestEventTiming.FIRST_SCHEDULED_TIME:
|
||||
req_perf_metrics.timing_metrics.first_scheduled_time.
|
||||
total_seconds(),
|
||||
RequestEventTiming.LAST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
|
||||
}
|
||||
return metrics_dict
|
||||
|
||||
|
||||
def _maybe_wrap_response(
|
||||
worker,
|
||||
response: tllm.Response,
|
||||
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
|
||||
|
||||
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
|
||||
req_perf_metrics = _get_metrics_dict(response)
|
||||
if logprobs_result or req_perf_metrics:
|
||||
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
|
||||
return response
|
||||
@ -1,6 +1,3 @@
|
||||
import copy
|
||||
import datetime
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -8,48 +5,40 @@ import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
|
||||
mpi_comm, mpi_rank, nvtx_range_debug)
|
||||
from .._utils import KVCacheEventSerializer, mpi_comm, mpi_rank
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import ConfigEncoder, Engine, EngineConfig
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror
|
||||
from ..builder import Engine
|
||||
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
|
||||
from ..llmapi.mpi_session import set_mpi_session_cpp
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
||||
from ..llmapi.tracer import VizTracer, set_global_tracer
|
||||
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
|
||||
clear_sched_affinity, print_colored_debug,
|
||||
print_traceback_on_error)
|
||||
from ..lora_helper import LoraConfig
|
||||
from ..lora_manager import LoraManager
|
||||
from ..metrics import RequestEventTiming
|
||||
from ..prompt_adapter_manager import PromptAdapterManager
|
||||
from ..runtime import ModelConfig
|
||||
from ..runtime.model_runner import _engine_config_to_model_config
|
||||
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
|
||||
from .executor import GenerationExecutor, IterationResultQueue
|
||||
from ..sampling_params import BatchedLogitsProcessor
|
||||
from .base_worker import BaseWorker
|
||||
from .executor import IterationResultQueue
|
||||
from .ipc import FusedIpcQueue, IpcQueue
|
||||
from .postproc_worker import (PostprocParams, PostprocWorker,
|
||||
PostprocWorkerConfig, postproc_worker_main)
|
||||
from .request import (CancellingRequest, GenerationRequest, LoRARequest,
|
||||
PromptAdapterRequest)
|
||||
from .result import (GenerationResult, IterationResult, LogProbsResult,
|
||||
ResponseWrapper, compute_logprobs)
|
||||
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
|
||||
WorkerCommIpcAddrs, has_event_loop, is_llm_response)
|
||||
from .postproc_worker import (PostprocWorker, PostprocWorkerConfig,
|
||||
postproc_worker_main)
|
||||
from .request import CancellingRequest, GenerationRequest
|
||||
from .result import IterationResult
|
||||
from .utils import (ErrorResponse, RequestError, WorkerCommIpcAddrs,
|
||||
has_event_loop)
|
||||
|
||||
__all__ = [
|
||||
"GenerationExecutorWorker",
|
||||
]
|
||||
|
||||
|
||||
class GenerationExecutorWorker(GenerationExecutor):
|
||||
class GenerationExecutorWorker(BaseWorker):
|
||||
|
||||
class WorkerExit(GeneratorExit):
|
||||
pass
|
||||
@ -67,150 +56,20 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[BaseLlmArgs] = None,
|
||||
) -> None:
|
||||
postproc_config = postproc_worker_config or PostprocWorkerConfig()
|
||||
super().__init__(
|
||||
num_postprocess_workers=postproc_config.num_postprocess_workers,
|
||||
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
|
||||
engine=engine,
|
||||
executor_config=executor_config,
|
||||
batched_logits_processor=batched_logits_processor,
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor,
|
||||
lora_config=lora_config,
|
||||
kv_connector_config=kv_connector_config,
|
||||
hf_model_dir=hf_model_dir,
|
||||
tokenizer=tokenizer,
|
||||
llm_args=llm_args,
|
||||
)
|
||||
|
||||
self.engine = None
|
||||
self.result_queue: Optional[IpcQueue] = None
|
||||
self.postproc_queues: Optional[List[IpcQueue]] = None
|
||||
self.rank = mpi_rank()
|
||||
self.global_rank = global_mpi_rank()
|
||||
# mapping: client_id -> GenerationResult
|
||||
self._results: Dict[int, GenerationResult] = {}
|
||||
# mapping: client_id from Proxy -> request_id returned from runtime backend
|
||||
self._client_id_to_request_id: Dict[int, int] = {}
|
||||
self._await_response_helper = AwaitResponseHelper(
|
||||
self) # TODO: make it weakref
|
||||
self._executor_config = executor_config
|
||||
self._is_pytorch_backend = llm_args is not None and llm_args.backend in [
|
||||
"pytorch", "_autodeploy"
|
||||
]
|
||||
self.llm_args = llm_args
|
||||
|
||||
if not self._is_pytorch_backend and kv_connector_config is not None:
|
||||
raise ValueError(
|
||||
"KV connector config is only supported for PyTorch backend")
|
||||
|
||||
if global_mpi_size() > 1:
|
||||
logger.set_rank(self.global_rank)
|
||||
|
||||
if isinstance(engine, list):
|
||||
engine = engine[self.rank]
|
||||
|
||||
def _get_comm_ranks_device_id():
|
||||
device_id = self.global_rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device_id)
|
||||
# Make sure C++ executor would use same devices/ranks as py_executor
|
||||
global_rank = global_mpi_rank()
|
||||
comm_ranks = mpi_comm().allgather(global_rank)
|
||||
device_ids = mpi_comm().allgather(device_id)
|
||||
return comm_ranks, device_ids
|
||||
|
||||
def _create_py_executor():
|
||||
args = {}
|
||||
assert hasattr(
|
||||
self.llm_args, "backend"
|
||||
), "llm_args should be with backend in _create_py_executor"
|
||||
_ = _get_comm_ranks_device_id()
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
|
||||
create_py_executor
|
||||
create_executor = create_py_executor
|
||||
args["llm_args"] = self.llm_args
|
||||
args["checkpoint_dir"] = hf_model_dir
|
||||
args["tokenizer"] = tokenizer
|
||||
args["lora_config"] = lora_config
|
||||
args["kv_connector_config"] = kv_connector_config
|
||||
elif self.llm_args.backend == "_autodeploy":
|
||||
from tensorrt_llm._torch.auto_deploy.llm_args import \
|
||||
LlmArgs as ADLlmArgs
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
|
||||
create_autodeploy_executor
|
||||
create_executor = create_autodeploy_executor
|
||||
assert isinstance(self.llm_args, ADLlmArgs)
|
||||
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported backend config: {self.llm_args.backend}")
|
||||
|
||||
# Define additional attributes that can be used later, such as in _deduce_max_tokens
|
||||
self.mapping = self.llm_args.parallel_config.to_mapping()
|
||||
self.checkpoint_loader = None
|
||||
if self.llm_args.backend == "pytorch":
|
||||
from tensorrt_llm._torch.pyexecutor.config import \
|
||||
_construct_checkpoint_loader
|
||||
self.checkpoint_loader = _construct_checkpoint_loader(
|
||||
self.llm_args.backend, self.llm_args.checkpoint_loader,
|
||||
self.llm_args.checkpoint_format)
|
||||
|
||||
_executor = create_executor(**args)
|
||||
self.max_seq_len = self.llm_args.max_seq_len
|
||||
if _executor.max_seq_len is not None:
|
||||
# max_seq_len might be updated by model engine as in create_py_executor
|
||||
self.max_seq_len = _executor.max_seq_len
|
||||
return _executor
|
||||
|
||||
def _create_engine(executor_config):
|
||||
if executor_config is None:
|
||||
executor_config = tllm.ExecutorConfig(1)
|
||||
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
|
||||
processor_batched=batched_logits_processor, replicate=False)
|
||||
comm_ranks, device_ids = _get_comm_ranks_device_id()
|
||||
executor_config.parallel_config = tllm.ParallelConfig(
|
||||
participant_ids=comm_ranks, device_ids=device_ids)
|
||||
|
||||
if isinstance(engine, Engine):
|
||||
return tllm.Executor(engine.engine,
|
||||
json.dumps(engine.config.to_dict(),
|
||||
cls=ConfigEncoder),
|
||||
tllm.ModelType.DECODER_ONLY,
|
||||
executor_config=executor_config,
|
||||
managed_weights=engine.managed_weights)
|
||||
|
||||
assert not hasattr(executor_config, "backend")
|
||||
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
|
||||
executor_config)
|
||||
|
||||
self.engine = _create_py_executor(
|
||||
) if self.llm_args is not None else _create_engine(executor_config)
|
||||
|
||||
self._lora_manager: Optional[LoraManager] = None
|
||||
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
|
||||
self._runtime_model_config: Optional[ModelConfig] = None
|
||||
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
|
||||
if isinstance(engine, Engine):
|
||||
engine_config = engine.config
|
||||
else:
|
||||
engine_config = EngineConfig.from_json_file(
|
||||
f"{engine}/config.json")
|
||||
self._runtime_model_config = _engine_config_to_model_config(
|
||||
engine_config)
|
||||
if engine_config.build_config.plugin_config.lora_plugin:
|
||||
# TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
|
||||
# (see LoraManager constructor docstring). Getting the peft cache manager from this
|
||||
# point in the TRT flow is currently not supported (it's at the CPP
|
||||
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
|
||||
# optimization is not available in TRT-python flow.
|
||||
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
|
||||
if engine_config.build_config.max_prompt_embedding_table_size > 0:
|
||||
self._prompt_adapter_manager = PromptAdapterManager()
|
||||
|
||||
if self.llm_args and getattr(
|
||||
self.llm_args, "backend",
|
||||
"") == "pytorch" and lora_config is not None:
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import \
|
||||
ResourceManagerType
|
||||
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.PEFT_CACHE_MANAGER)
|
||||
self._lora_manager = LoraManager(
|
||||
cpp_peft_cache_manager=peft_cache_manager.impl)
|
||||
lora_model_config = self.engine.model_engine.lora_model_config
|
||||
assert lora_model_config is not None
|
||||
self._lora_model_config = lora_model_config
|
||||
self.setup_engine()
|
||||
|
||||
self.await_response_thread = ManagedThread(
|
||||
self.await_response_task,
|
||||
@ -227,16 +86,6 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
error_queue=self._error_queue,
|
||||
name="dispatch_kv_cache_events_thread")
|
||||
|
||||
def set_result_queue(self, queue):
|
||||
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
|
||||
assert self.postproc_queues is None
|
||||
self.result_queue = queue
|
||||
|
||||
def set_postproc_queues(self, queues: List["IpcQueue"]):
|
||||
""" Set the IPC queues for feeding post-processing processes. """
|
||||
assert self.result_queue is None
|
||||
self.postproc_queues = queues
|
||||
|
||||
def _create_iteration_result_queue(self,
|
||||
it_result_queue: IterationResultQueue):
|
||||
if not it_result_queue.is_initialized:
|
||||
@ -251,53 +100,13 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
it_result_queue.queue = _queue
|
||||
it_result_queue.aqueue = None
|
||||
|
||||
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
|
||||
queue: Union[Queue, FusedIpcQueue,
|
||||
IntraProcessQueue]):
|
||||
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
|
||||
it_result_queue.is_initialized = True
|
||||
it_result_queue.queue = queue
|
||||
it_result_queue.aqueue = None
|
||||
|
||||
def return_queue(self, client_id: int):
|
||||
""" If a centralized result queue is registered (used for communication with the proxy)
|
||||
send the message there.
|
||||
Otherwise, push the result directly in the GenerationResult queue.
|
||||
"""
|
||||
if self.result_queue is not None:
|
||||
return self.result_queue
|
||||
return self._results[client_id].queue
|
||||
|
||||
def start_thread(self, thread: ManagedThread):
|
||||
if self.engine.can_enqueue_requests() and not thread.is_alive():
|
||||
thread.start()
|
||||
|
||||
def abort_request(self, client_id: int) -> None:
|
||||
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
|
||||
if self.engine.can_enqueue_requests():
|
||||
request_id = self._client_id_to_request_id.get(client_id, None)
|
||||
if request_id is None:
|
||||
logger.warning(
|
||||
f"Request of client_id {client_id} is finished, cannot abort it."
|
||||
)
|
||||
return
|
||||
self.engine.cancel_request(request_id)
|
||||
|
||||
def _engine_response_callback(self, response: tllm.Response):
|
||||
return response
|
||||
|
||||
def await_response_task(self) -> bool:
|
||||
return self._await_response_helper()
|
||||
|
||||
def _has_background_error(self) -> bool:
|
||||
return not self._error_queue.empty()
|
||||
|
||||
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
|
||||
bck_error = self._error_queue.get_nowait()
|
||||
assert isinstance(bck_error, Exception)
|
||||
return ErrorResponse(response.client_id, str(bck_error),
|
||||
response.request_id)
|
||||
|
||||
def _iteration_result_task(self, it_result_queue: IterationResultQueue,
|
||||
engine_get_result_api: Callable,
|
||||
result_singleton: IterationResult,
|
||||
@ -350,16 +159,7 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
# Convert back to JSON string
|
||||
return json.dumps(stats_dict)
|
||||
|
||||
def get_stats():
|
||||
if isinstance(self.engine, tllm.Executor):
|
||||
iter_stats = self.engine.get_latest_iteration_stats()
|
||||
#TODO: Support req stats with TRT engine
|
||||
# This would require ensuring iter and req stats have same size
|
||||
return [(iter_stat, None) for iter_stat in iter_stats]
|
||||
else:
|
||||
return self.engine.get_latest_iteration_stats()
|
||||
|
||||
return self._iteration_result_task(self.stats_queues, get_stats,
|
||||
return self._iteration_result_task(self.stats_queues, self.fetch_stats,
|
||||
self._iter_stats_result,
|
||||
stats_serializer)
|
||||
|
||||
@ -392,264 +192,6 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
if mpi_rank() == 0:
|
||||
self.start_thread(self.dispatch_stats_thread)
|
||||
|
||||
def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
|
||||
"""Returns True if the adapter was loaded by this call, False if it was already loaded"""
|
||||
adapter_id = str(lora_request.adapter_id)
|
||||
newly_loaded_uids = self._lora_manager.load_from_ckpt(
|
||||
[lora_request.path],
|
||||
model_config=self._runtime_model_config if
|
||||
self._runtime_model_config is not None else self._lora_model_config,
|
||||
runtime_mapping=None,
|
||||
uids=[adapter_id],
|
||||
ckpt_source=lora_request.ckpt_source)
|
||||
return adapter_id in newly_loaded_uids
|
||||
|
||||
def _load_prompt_adapter(self,
|
||||
prompt_adapter_request: PromptAdapterRequest):
|
||||
self._prompt_adapter_manager.load_from_ckpt(
|
||||
[prompt_adapter_request.local_path],
|
||||
model_config=self._runtime_model_config,
|
||||
uids=[str(prompt_adapter_request.adapter_id)])
|
||||
|
||||
def _enqueue_request(self, request: GenerationRequest) -> int:
|
||||
assert request.id is not None
|
||||
py_lora_path = None
|
||||
if self._lora_manager is not None and request.lora_request is not None:
|
||||
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
|
||||
request.lora_request.adapter_id)
|
||||
self._load_lora_adapter(request.lora_request)
|
||||
uid = str(request.lora_request.adapter_id)
|
||||
lora_config = tllm.LoraConfig(
|
||||
task_id=request.lora_request.adapter_id,
|
||||
weights=self._lora_manager.cpp_lora_weights[uid]
|
||||
if not adapter_in_cache else None,
|
||||
config=self._lora_manager.cpp_lora_config[uid])
|
||||
py_lora_path = request.lora_request.lora_path
|
||||
else:
|
||||
lora_config = None
|
||||
|
||||
prompt_token_ids = copy.deepcopy(request.prompt_token_ids)
|
||||
prompt_tuning_config = None
|
||||
if request.prompt_adapter_request is not None:
|
||||
self._load_prompt_adapter(request.prompt_adapter_request)
|
||||
uid = str(request.prompt_adapter_request.adapter_id)
|
||||
prompt_tuning_config = tllm.PromptTuningConfig(
|
||||
self._prompt_adapter_manager.uid_to_weights[uid])
|
||||
vocab_size = self._runtime_model_config.vocab_size
|
||||
pa_length = prompt_tuning_config.embedding_table.size(0)
|
||||
prompt_token_ids = list(range(
|
||||
vocab_size, vocab_size + pa_length)) + prompt_token_ids
|
||||
|
||||
# MULTIMODAL
|
||||
# NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field
|
||||
# except `multimodal_input` as it needs to go through the C++ runtime.
|
||||
multimodal_input = None
|
||||
if request.multimodal_params is not None and request.multimodal_params.has_content(
|
||||
):
|
||||
if request.multimodal_params.multimodal_input is not None:
|
||||
multimodal_input = tllm.MultimodalInput(
|
||||
multimodal_hashes=request.multimodal_params.
|
||||
multimodal_input.multimodal_hashes,
|
||||
multimodal_positions=request.multimodal_params.
|
||||
multimodal_input.multimodal_positions,
|
||||
multimodal_lengths=request.multimodal_params.
|
||||
multimodal_input.multimodal_lengths)
|
||||
# NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field
|
||||
request.multimodal_params.multimodal_input = None
|
||||
|
||||
context_phase_params = None
|
||||
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
|
||||
if request.disaggregated_params is not None:
|
||||
assert (
|
||||
not self._is_pytorch_backend
|
||||
or self.engine.kv_cache_transceiver is not None
|
||||
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
|
||||
request_type = request.disaggregated_params.get_request_type()
|
||||
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
|
||||
context_phase_params = request.disaggregated_params.get_context_phase_params(
|
||||
)
|
||||
|
||||
if self._is_pytorch_backend:
|
||||
if not self.llm_args.disable_overlap_scheduler:
|
||||
is_disaggregated = self.engine.kv_cache_transceiver is not None
|
||||
if is_disaggregated and (
|
||||
request_type
|
||||
== tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY):
|
||||
raise ValueError(
|
||||
"Context only requests are not supported in pytorch backend when overlap is enabled."
|
||||
)
|
||||
|
||||
assert request.id is not None
|
||||
|
||||
def _deduce_max_tokens(request: GenerationRequest,
|
||||
executor_config: tllm.ExecutorConfig,
|
||||
llm_args: Optional[BaseLlmArgs] = None) -> int:
|
||||
# deduce max_tokens when it's not set by user
|
||||
max_tokens = request.sampling_params.max_tokens
|
||||
query_token_len = len(
|
||||
request.query_token_ids) if request.query_token_ids else 0
|
||||
|
||||
cp_size = 1
|
||||
max_seq_len = None
|
||||
if llm_args is not None:
|
||||
# deduce max_tokens by llm args
|
||||
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
|
||||
if hasattr(self,
|
||||
"mapping") and self.mapping.cp_size is not None:
|
||||
cp_size = self.mapping.cp_size
|
||||
max_seq_len = getattr(self, "max_seq_len", None)
|
||||
else:
|
||||
# deduce max_tokens by executor config
|
||||
if hasattr(executor_config, "mapping"
|
||||
) and executor_config.mapping.cp_size is not None:
|
||||
cp_size = executor_config.mapping.cp_size
|
||||
max_seq_len = getattr(executor_config, "max_seq_len", None)
|
||||
if max_seq_len is None:
|
||||
logger.warning("`default_max_tokens` cannot be deduced")
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
|
||||
)
|
||||
else:
|
||||
# use max_tokens if can't deduce default_max_tokens
|
||||
return max_tokens
|
||||
if executor_config is not None:
|
||||
assert (
|
||||
len(prompt_token_ids) <= executor_config.max_seq_len
|
||||
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
|
||||
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
|
||||
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
|
||||
if default_max_tokens <= 0:
|
||||
logger.warning(
|
||||
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
|
||||
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
|
||||
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
|
||||
)
|
||||
if max_tokens is None:
|
||||
raise ValueError(
|
||||
"`max_tokens` must be set when `default_max_tokens` is illegal"
|
||||
)
|
||||
# default_max_tokens is the biggest available value
|
||||
if max_tokens is None:
|
||||
return default_max_tokens
|
||||
elif max_tokens > default_max_tokens:
|
||||
logger.warning(
|
||||
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
|
||||
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
|
||||
)
|
||||
return default_max_tokens
|
||||
return max_tokens
|
||||
|
||||
try:
|
||||
executor_request = tllm.Request(
|
||||
client_id=request.id,
|
||||
input_token_ids=prompt_token_ids,
|
||||
max_tokens=_deduce_max_tokens(request, self._executor_config,
|
||||
self.llm_args),
|
||||
streaming=request.streaming,
|
||||
sampling_config=request.sampling_params._get_sampling_config(),
|
||||
end_id=-1 if request.sampling_params.ignore_eos else
|
||||
request.sampling_params.end_id,
|
||||
pad_id=request.sampling_params.pad_id,
|
||||
output_config=request.sampling_params._get_output_config(
|
||||
is_pytorch_backend=self._is_pytorch_backend),
|
||||
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
|
||||
return_all_generated_tokens=False,
|
||||
# convert python config into pybind config
|
||||
lookahead_config=PybindMirror.maybe_to_pybind(
|
||||
request.sampling_params.lookahead_config),
|
||||
guided_decoding_params=request.sampling_params.
|
||||
_get_guided_decoding_params(),
|
||||
bad_words=request.sampling_params._get_bad_words(),
|
||||
stop_words=request.sampling_params._get_stop_words(),
|
||||
embedding_bias=request.sampling_params.embedding_bias,
|
||||
lora_config=lora_config,
|
||||
prompt_tuning_config=prompt_tuning_config,
|
||||
multimodal_input=multimodal_input,
|
||||
# NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
|
||||
multimodal_embedding=None,
|
||||
mrope_config=None,
|
||||
logits_post_processor_name=(
|
||||
tllm.Request.BATCHED_POST_PROCESSOR_NAME
|
||||
if request.sampling_params.apply_batched_logits_processor
|
||||
else None),
|
||||
logits_post_processor=None if self._is_pytorch_backend else
|
||||
request.sampling_params.logits_processor,
|
||||
kv_cache_retention_config=request.kv_cache_retention_config,
|
||||
context_phase_params=context_phase_params,
|
||||
type=request_type,
|
||||
cache_salt_id=request.cache_salt_id)
|
||||
executor_request.py_lora_path = py_lora_path
|
||||
|
||||
if self._is_pytorch_backend and request.multimodal_params is not None:
|
||||
if request.multimodal_params.multimodal_data is not None:
|
||||
# NOTE: Deserialize SharedTensor handle to actual tensor
|
||||
request.multimodal_params.to_tensor("multimodal_data")
|
||||
executor_request.py_multimodal_data = request.multimodal_params.multimodal_data
|
||||
|
||||
if self._is_pytorch_backend and request.sampling_params.logits_processor:
|
||||
# For PyTorch backend, we attach logits processors as a dynamic Python attribute
|
||||
# instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
|
||||
lp = request.sampling_params.logits_processor
|
||||
executor_request.py_logits_post_processors = lp if isinstance(
|
||||
lp, list) else [lp]
|
||||
|
||||
executor_request.py_scheduling_params = None
|
||||
if self._is_pytorch_backend and request.scheduling_params is not None:
|
||||
executor_request.py_scheduling_params = request.scheduling_params
|
||||
|
||||
if request.arrival_time is not None:
|
||||
executor_request.py_arrival_time = request.arrival_time
|
||||
|
||||
if request.query_token_ids is not None:
|
||||
# pytorch star attention workflow
|
||||
# a workaround to avoid public interface update
|
||||
req_id = self.engine.enqueue_request(executor_request,
|
||||
request.query_token_ids)
|
||||
else:
|
||||
req_id = self.engine.enqueue_request(executor_request)
|
||||
return req_id
|
||||
except Exception as e:
|
||||
raise RequestError(str(e)) from e
|
||||
|
||||
def submit(self, request: GenerationRequest) -> GenerationResult:
|
||||
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
|
||||
self.start()
|
||||
|
||||
if self.rank != 0:
|
||||
raise RuntimeError(
|
||||
"Only rank 0 can submit requests.\n"
|
||||
"To fix this, ensure that the llm.generate(...) method is "
|
||||
"guarded with the `if __name__ == '__main__':` block.")
|
||||
|
||||
client_id = request.id if request.id is not None else self._get_next_client_id(
|
||||
)
|
||||
if request.id is None:
|
||||
request.set_id(client_id)
|
||||
|
||||
logprob_params = self._get_logprob_params(request)
|
||||
|
||||
result = GenerationResult(
|
||||
request,
|
||||
background_error_handler=self._handle_background_error,
|
||||
executor=self,
|
||||
disaggregated_params=request.disaggregated_params,
|
||||
logprob_params=logprob_params)
|
||||
|
||||
self._results[client_id] = result
|
||||
|
||||
request_id = self._enqueue_request(request)
|
||||
# request_id returned from backend is necessary for the abort_request method.
|
||||
self._client_id_to_request_id[client_id] = request_id
|
||||
|
||||
self._handle_background_error()
|
||||
|
||||
return result
|
||||
|
||||
def _pop_result(self, client_id: int):
|
||||
self._results.pop(client_id, None)
|
||||
self._client_id_to_request_id.pop(client_id, None)
|
||||
|
||||
def shutdown(self):
|
||||
|
||||
if self.doing_shutdown:
|
||||
@ -705,16 +247,6 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
if isinstance(self.engine, PyExecutor):
|
||||
self.engine.wait_shutdown()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
||||
self.shutdown()
|
||||
return exc_type is None or exc_type == GenerationExecutorWorker.WorkerExit
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
|
||||
@print_traceback_on_error
|
||||
def worker_main(
|
||||
@ -924,265 +456,3 @@ def worker_main(
|
||||
logger.error(traceback.format_exc())
|
||||
# This will be captured by mpi4py and handled by future.done_callback
|
||||
raise e
|
||||
|
||||
|
||||
class AwaitResponseHelper:
|
||||
''' Multiple-implementations for await_response for performance. '''
|
||||
|
||||
class HandlerKind(enum.Enum):
|
||||
unknown = 0
|
||||
single_process_worker = 1
|
||||
ipc_batched = 2
|
||||
|
||||
def __init__(self, worker: "GenerationExecutorWorker"):
|
||||
# TODO: make worker weakref
|
||||
self.worker = worker
|
||||
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
|
||||
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
|
||||
# The error responses when submit request failed will be put here
|
||||
self.temp_error_responses = Queue()
|
||||
|
||||
def responses_handler(self, responses: List[tllm.Response]):
|
||||
HandlerKind = AwaitResponseHelper.HandlerKind
|
||||
|
||||
if self.handler_kind is HandlerKind.unknown:
|
||||
if not (self.worker.result_queue is not None
|
||||
or self.worker.postproc_queues is not None):
|
||||
print_colored_debug(
|
||||
f"creating await_response helper for Worker\n",
|
||||
color="yellow")
|
||||
# When ExecutorBindingWorker is used in the main process
|
||||
# aka the single process mode
|
||||
self.handler_kind = HandlerKind.single_process_worker
|
||||
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
|
||||
# The ExecutorBindingProxy is used
|
||||
print_colored_debug(f"creating await_response helper for IPC\n",
|
||||
color="yellow")
|
||||
self.handler_kind = HandlerKind.ipc_batched
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
match self.handler_kind:
|
||||
case HandlerKind.single_process_worker:
|
||||
return self.handle_for_worker(responses)
|
||||
case HandlerKind.ipc_batched:
|
||||
return self.handle_for_ipc_batched(responses)
|
||||
case _:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self) -> bool:
|
||||
''' This method should be called by a ManagedThread. '''
|
||||
responses = self.worker.engine.await_responses(
|
||||
timeout=datetime.timedelta(milliseconds=100))
|
||||
# filter since The _engine_response_callback may return None
|
||||
responses = list(
|
||||
filter(
|
||||
lambda _: _,
|
||||
[self.worker._engine_response_callback(r) for r in responses]))
|
||||
|
||||
# append the error responses to the temp_error_responses
|
||||
while not self.temp_error_responses.empty():
|
||||
responses.append(self.temp_error_responses.get())
|
||||
|
||||
with nvtx_range_debug(f"await_response-{len(responses)}",
|
||||
color="red",
|
||||
category="Worker"):
|
||||
self.responses_handler(responses)
|
||||
return True
|
||||
|
||||
def handle_for_worker(self, responses: List[tllm.Response]) -> None:
|
||||
''' Return the responses to asyncio.event_loop. '''
|
||||
event_loop = None
|
||||
async_queues = []
|
||||
for response in responses:
|
||||
assert response is not None
|
||||
queue = self.worker.return_queue(response.client_id)
|
||||
|
||||
response = _maybe_wrap_response(self.worker, response,
|
||||
self.worker._is_pytorch_backend)
|
||||
|
||||
# For AsyncQueue.sync_q, we will batch the events to avoid too many
|
||||
# event notifications, thus put without wait here.
|
||||
if isinstance(queue, _SyncQueue):
|
||||
global_tracer().log_instant("worker-rsp.put")
|
||||
queue.put_nowait(response)
|
||||
async_queues.append(queue)
|
||||
# all the loops are identical
|
||||
event_loop = event_loop or queue.loop
|
||||
else:
|
||||
queue.put(response)
|
||||
|
||||
if response.has_error() or response.result.is_final:
|
||||
self.worker._pop_result(response.client_id)
|
||||
|
||||
# Notify the events in bulk for performance.
|
||||
if async_queues:
|
||||
_SyncQueue.notify_many(event_loop, async_queues)
|
||||
|
||||
def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None:
|
||||
''' Perform the IPC in batch explicitly. '''
|
||||
postproc_batches = [
|
||||
[]
|
||||
for _ in range(self.worker.postproc_config.num_postprocess_workers)
|
||||
] if self.enable_postprocprocess_parallel else None
|
||||
rsp_batch = [] if not self.enable_postprocprocess_parallel else None
|
||||
|
||||
for response in responses:
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
pass # send ErrorResponse directly
|
||||
elif self.worker._has_background_error():
|
||||
response = self.worker._create_error_response(response)
|
||||
elif response.has_error():
|
||||
# Convert to ErrorResponse, because tllm.Response cannot be
|
||||
# serialized when it has error.
|
||||
response = ErrorResponse(response.client_id, response.error_msg,
|
||||
response.request_id)
|
||||
else:
|
||||
response = _maybe_wrap_response(self.worker, response,
|
||||
self.worker._is_pytorch_backend)
|
||||
|
||||
_send_rsp(self.worker,
|
||||
response,
|
||||
postproc_batches=postproc_batches,
|
||||
rsp_batch=rsp_batch)
|
||||
|
||||
if postproc_batches:
|
||||
for wid, batch in enumerate(postproc_batches):
|
||||
if batch:
|
||||
self.worker.postproc_queues[wid].put(batch)
|
||||
|
||||
if rsp_batch:
|
||||
self.worker.result_queue.put(rsp_batch)
|
||||
|
||||
|
||||
def _get_params_for_first_rsp(
|
||||
worker,
|
||||
client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]:
|
||||
res = worker._results.get(client_id, None)
|
||||
assert res is not None
|
||||
if not res._params_transmitted:
|
||||
res._params_transmitted = True
|
||||
return res.sampling_params, res.postproc_params
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_logprobs(worker,
|
||||
response: tllm.Response,
|
||||
is_pytorch_backend=False) -> Optional[LogProbsResult]:
|
||||
"""Compute logprob and prompt logprob and clear out logits if applicable.
|
||||
"""
|
||||
if is_pytorch_backend:
|
||||
# _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime.
|
||||
# In the PyTorch backend, logprobs are already computed during runtime if requested.
|
||||
return None
|
||||
|
||||
logprobs_result = None
|
||||
generation_result = worker._results.get(response.client_id, None)
|
||||
|
||||
if not generation_result:
|
||||
return
|
||||
|
||||
logprob_params = getattr(generation_result, "_logprob_params", None)
|
||||
if logprob_params:
|
||||
logprobs_result = compute_logprobs(logprob_params.prompt_logprobs,
|
||||
logprob_params.logprobs,
|
||||
response.result.context_logits,
|
||||
response.result.generation_logits,
|
||||
response.result.output_token_ids[0])
|
||||
|
||||
if logprob_params.drop_context_logits:
|
||||
response.clear_context_logits()
|
||||
|
||||
if logprob_params.drop_generation_logits:
|
||||
response.clear_generation_logits()
|
||||
|
||||
if response.result.is_final:
|
||||
generation_result.clear_logprob_params()
|
||||
|
||||
return logprobs_result
|
||||
|
||||
|
||||
def _send_rsp(
|
||||
worker,
|
||||
response: Union[tllm.Response, ResponseWrapper, ErrorResponse],
|
||||
postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None,
|
||||
rsp_batch: Optional[List[tllm.Response]] = None):
|
||||
# if postproc_batches is set, append to batch instead of putting to IpcQueue
|
||||
|
||||
if worker.result_queue is not None:
|
||||
if rsp_batch is not None:
|
||||
rsp_batch.append(response)
|
||||
else:
|
||||
worker.result_queue.put(response)
|
||||
else:
|
||||
sampling_params, postproc_params = _get_params_for_first_rsp(
|
||||
worker, response.client_id)
|
||||
inp = PostprocWorker.Input(
|
||||
response,
|
||||
# sampling_params is necessary for creating fake GenerationResult
|
||||
# instances in the postproc processes. They are for incremental
|
||||
# detokenize. They should be transmitted only once for each
|
||||
# Request.
|
||||
sampling_params=sampling_params,
|
||||
postproc_params=postproc_params,
|
||||
streaming=worker._results.get(response.client_id, None)._streaming)
|
||||
|
||||
pid = response.client_id % worker.postproc_config.num_postprocess_workers
|
||||
|
||||
if not postproc_batches:
|
||||
# Group the responses into buckets for the postprocessing steps.
|
||||
# Bucketing is used instead of random dispatching because the
|
||||
# incremental detokenization during postprocessing relies on the
|
||||
# prior CompletionOutput of a given request.
|
||||
worker.postproc_queues[pid].put(inp)
|
||||
else:
|
||||
postproc_batches[pid].append(inp)
|
||||
|
||||
# Eliminate the finished GenerationRequest instances timely, which may
|
||||
# take considerable memory.
|
||||
if is_llm_response(response):
|
||||
if response.has_error() or response.result.is_final:
|
||||
worker._pop_result(response.client_id)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
worker._pop_result(response.client_id)
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response}")
|
||||
|
||||
|
||||
def _get_metrics_dict(
|
||||
response: tllm.Response) -> dict[RequestEventTiming, float]:
|
||||
req_perf_metrics, metrics_dict = None, {}
|
||||
res = response.result
|
||||
if res:
|
||||
if hasattr(res, '_result'):
|
||||
if result := res.get_result():
|
||||
req_perf_metrics = result.request_perf_metrics
|
||||
else:
|
||||
req_perf_metrics = res.request_perf_metrics
|
||||
if req_perf_metrics and req_perf_metrics.timing_metrics:
|
||||
metrics_dict = {
|
||||
RequestEventTiming.ARRIVAL_TIME:
|
||||
req_perf_metrics.timing_metrics.arrival_time.total_seconds(),
|
||||
RequestEventTiming.FIRST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.first_token_time.total_seconds(
|
||||
),
|
||||
RequestEventTiming.FIRST_SCHEDULED_TIME:
|
||||
req_perf_metrics.timing_metrics.first_scheduled_time.
|
||||
total_seconds(),
|
||||
RequestEventTiming.LAST_TOKEN_TIME:
|
||||
req_perf_metrics.timing_metrics.last_token_time.total_seconds()
|
||||
}
|
||||
return metrics_dict
|
||||
|
||||
|
||||
def _maybe_wrap_response(
|
||||
worker,
|
||||
response: tllm.Response,
|
||||
is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]:
|
||||
|
||||
logprobs_result = _get_logprobs(worker, response, is_pytorch_backend)
|
||||
req_perf_metrics = _get_metrics_dict(response)
|
||||
if logprobs_result or req_perf_metrics:
|
||||
response = ResponseWrapper(response, logprobs_result, req_perf_metrics)
|
||||
return response
|
||||
|
||||
193
tests/unittest/executor/test_base_worker.py
Normal file
193
tests/unittest/executor/test_base_worker.py
Normal file
@ -0,0 +1,193 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.bindings import executor as tllm
|
||||
from tensorrt_llm.llmapi.mpi_session import MpiPoolSession
|
||||
|
||||
# isort: off
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
|
||||
from utils.llm_data import llm_models_root
|
||||
# isort: on
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
|
||||
from tensorrt_llm.executor.base_worker import BaseWorker
|
||||
from tensorrt_llm.executor.request import GenerationRequest
|
||||
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
model_path = llm_models_root() / default_model_name
|
||||
|
||||
|
||||
class FakeWorker(BaseWorker):
|
||||
|
||||
def __init__(self, engine: str, tp_size: int = 1):
|
||||
llm_args, executor_config = create_fake_executor_config(engine, tp_size)
|
||||
super().__init__(
|
||||
engine=engine,
|
||||
llm_args=llm_args,
|
||||
hf_model_dir=engine,
|
||||
executor_config=executor_config,
|
||||
)
|
||||
# Note: BaseWorker doesn't call setup_engine() automatically,
|
||||
# unlike GenerationExecutorWorker, so we need to call it manually
|
||||
self.setup_engine()
|
||||
self._started = False
|
||||
|
||||
def start(self):
|
||||
"""Override start to mark as started - no background threads needed for test."""
|
||||
if not self._started:
|
||||
self._started = True
|
||||
# For testing, we don't need background threads
|
||||
# The engine's await_responses will handle the mock responses
|
||||
|
||||
def shutdown(self):
|
||||
self._started = False
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown()
|
||||
self.engine = None
|
||||
|
||||
|
||||
class TestWorkerBase:
|
||||
|
||||
def test_create_engine(self):
|
||||
with FakeWorker(engine=model_path) as worker:
|
||||
print(f"Created engine: {worker.engine}")
|
||||
|
||||
def test_submit_request(self):
|
||||
sampling_params = SamplingParams(max_tokens=10)
|
||||
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
|
||||
sampling_params=sampling_params)
|
||||
with FakeWorker(engine=model_path) as worker:
|
||||
print(f"Created engine: {worker.engine}")
|
||||
result = worker.submit(request)
|
||||
|
||||
# For PyTorch backend, the engine handles requests internally
|
||||
# We just need to give it some time to process
|
||||
timeout = 15.0 # 15 seconds timeout
|
||||
start_time = time.time()
|
||||
|
||||
while not result.finished and (time.time() - start_time) < timeout:
|
||||
# Call await_responses with timeout to prevent hanging
|
||||
responses = worker.await_responses(timeout=0.5)
|
||||
time.sleep(0.1)
|
||||
|
||||
if not result.finished:
|
||||
print(f"Request did not complete within {timeout} seconds")
|
||||
else:
|
||||
print(f"Request completed successfully")
|
||||
print(f"Result: {result}")
|
||||
|
||||
def test_fetch_stats(self):
|
||||
request = GenerationRequest(
|
||||
prompt_token_ids=[3, 4, 5],
|
||||
sampling_params=SamplingParams(max_tokens=10))
|
||||
with FakeWorker(engine=model_path) as worker:
|
||||
result = worker.submit(request)
|
||||
|
||||
# Give the engine time to start processing
|
||||
time.sleep(1)
|
||||
|
||||
# Fetch stats while request is processing
|
||||
stats = worker.fetch_stats()
|
||||
print(f"Stats: {stats}")
|
||||
|
||||
# Continue processing until completion or timeout
|
||||
timeout = 10.0
|
||||
start_time = time.time()
|
||||
while not result.finished and (time.time() - start_time) < timeout:
|
||||
worker.await_responses(timeout=0.5)
|
||||
time.sleep(0.1)
|
||||
|
||||
@pytest.mark.parametrize("timeout", [0.1, 0.2, 1])
|
||||
def test_fetch_responses_timeout(self, timeout: float):
|
||||
with FakeWorker(engine=model_path) as worker:
|
||||
# Not submit any request, and let the await_responses timeout.
|
||||
start_time = time.time()
|
||||
results = worker.await_responses(timeout=timeout)
|
||||
elapsed = time.time() - start_time
|
||||
print(f"await_responses latency: {elapsed:.3f} seconds")
|
||||
assert timeout / 2 <= elapsed <= timeout * 2, f"Latency out of expected range: {elapsed}"
|
||||
|
||||
|
||||
def create_fake_executor_config(model_path, tp_size=1):
|
||||
# Use TorchLlmArgs for PyTorch backend tests
|
||||
llm_args = TorchLlmArgs(model=model_path,
|
||||
tensor_parallel_size=tp_size,
|
||||
backend='pytorch')
|
||||
|
||||
executor_config = tllm.ExecutorConfig(1)
|
||||
executor_config.max_batch_size = 1
|
||||
executor_config.model_world_size = tp_size
|
||||
|
||||
update_executor_config(
|
||||
executor_config,
|
||||
pytorch_backend_config=llm_args.get_pytorch_backend_config(),
|
||||
mapping=llm_args.parallel_config.to_mapping(),
|
||||
speculative_config=llm_args.speculative_config,
|
||||
hf_model_dir=model_path,
|
||||
max_input_len=20,
|
||||
max_seq_len=40,
|
||||
checkpoint_format=llm_args.checkpoint_format,
|
||||
checkpoint_loader=llm_args.checkpoint_loader,
|
||||
)
|
||||
|
||||
return llm_args, executor_config
|
||||
|
||||
|
||||
class TestRpcWorkerBaseTP2:
|
||||
|
||||
def setup_method(self):
|
||||
# Use TorchLlmArgs for PyTorch backend with TP2
|
||||
self.llm_args = TorchLlmArgs(model=model_path,
|
||||
tensor_parallel_size=2,
|
||||
backend='pytorch')
|
||||
self.session = self.create_worker_session()
|
||||
|
||||
def create_worker_session(self):
|
||||
session = MpiPoolSession(n_workers=2)
|
||||
return session
|
||||
|
||||
def test_create_executor(self):
|
||||
futures = self.session.submit(
|
||||
TestRpcWorkerBaseTP2.create_executor,
|
||||
engine=model_path,
|
||||
llm_args=self.llm_args,
|
||||
)
|
||||
# Wait for completion
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
self.session.shutdown()
|
||||
|
||||
@staticmethod
|
||||
def create_executor(engine, llm_args):
|
||||
rank = mpi_rank()
|
||||
world_size = mpi_world_size()
|
||||
device_id = rank % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device_id)
|
||||
|
||||
print(f"[Test] Rank {rank}/{world_size} using device {device_id}")
|
||||
|
||||
# Synchronize all workers before creating executor
|
||||
mpi_comm().barrier()
|
||||
|
||||
print(f"[Test] Rank {rank} creating FakeWorker...")
|
||||
executor = FakeWorker(engine=engine, tp_size=2)
|
||||
|
||||
# Note: setup_engine is already called in FakeWorker.__init__
|
||||
print(
|
||||
f"[Test] Rank {rank} FakeWorker created and setup_engine completed successfully"
|
||||
)
|
||||
|
||||
executor.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_worker_base = TestWorkerBase()
|
||||
test_worker_base.test_submit_request()
|
||||
Loading…
Reference in New Issue
Block a user