mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-8483][chore] Refine scheduler_config and peft_cache_config in create_py_executor (#8451)
Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
parent
bac9e8c2ad
commit
50d4e5bc06
@ -3,14 +3,15 @@ from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from strenum import StrEnum
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.llmapi.llm_args import ContextChunkingPolicy
|
||||
|
||||
from ...._utils import mpi_rank, mpi_world_size
|
||||
from ....bindings.executor import ContextChunkingPolicy
|
||||
from ....bindings.internal.batch_manager import CacheType, ContextChunkingConfig
|
||||
from ....bindings.internal.batch_manager import CacheType
|
||||
from ....mapping import Mapping
|
||||
from ...distributed import MPIDist
|
||||
from ...pyexecutor.model_engine import ModelEngine
|
||||
@ -376,7 +377,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
|
||||
if ad_config.enable_chunked_prefill:
|
||||
chunk_unit_size = ad_config.attn_page_size
|
||||
chunking_policy = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED
|
||||
ctx_chunk_config = ContextChunkingConfig(chunking_policy, chunk_unit_size)
|
||||
ctx_chunk_config: Tuple[StrEnum, int] = (chunking_policy, chunk_unit_size)
|
||||
else:
|
||||
ctx_chunk_config = None
|
||||
|
||||
|
||||
@ -13,7 +13,8 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
|
||||
from tensorrt_llm.bindings.executor import DecodingMode
|
||||
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
|
||||
MTPDecodingConfig, PeftCacheConfig,
|
||||
SamplerType, SparseAttentionConfig,
|
||||
SamplerType, SchedulerConfig,
|
||||
SparseAttentionConfig,
|
||||
SpeculativeConfig, TorchLlmArgs)
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_helper import (LoraConfig,
|
||||
@ -663,8 +664,8 @@ def create_py_executor_instance(
|
||||
max_batch_size: Optional[int] = None,
|
||||
max_beam_width: Optional[int] = None,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
peft_cache_config: Optional[trtllm.PeftCacheConfig] = None,
|
||||
scheduler_config: Optional[trtllm.SchedulerConfig] = None,
|
||||
peft_cache_config: Optional[PeftCacheConfig] = None,
|
||||
scheduler_config: Optional[SchedulerConfig] = None,
|
||||
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
|
||||
) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
@ -728,16 +729,14 @@ def create_py_executor_instance(
|
||||
num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \
|
||||
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
|
||||
|
||||
peft_cache_config_model = PeftCacheConfig.from_pybind(
|
||||
peft_cache_config
|
||||
) if peft_cache_config is not None else PeftCacheConfig()
|
||||
peft_cache_config_model = PeftCacheConfig(
|
||||
) if peft_cache_config is None else peft_cache_config
|
||||
if lora_config.max_loras is not None:
|
||||
peft_cache_config_model.num_device_module_layer = \
|
||||
max_lora_rank * num_lora_modules * lora_config.max_loras
|
||||
if lora_config.max_cpu_loras is not None:
|
||||
peft_cache_config_model.num_host_module_layer = \
|
||||
max_lora_rank * num_lora_modules * lora_config.max_cpu_loras
|
||||
peft_cache_config = peft_cache_config_model._to_pybind()
|
||||
|
||||
from tensorrt_llm.bindings import WorldConfig
|
||||
world_config = WorldConfig(
|
||||
@ -748,7 +747,7 @@ def create_py_executor_instance(
|
||||
gpus_per_node=dist.mapping.gpus_per_node,
|
||||
)
|
||||
peft_cache_manager = PeftCacheManager(
|
||||
peft_cache_config=peft_cache_config,
|
||||
peft_cache_config=peft_cache_config_model,
|
||||
lora_config=lora_config,
|
||||
model_config=model_binding_config,
|
||||
world_config=world_config,
|
||||
|
||||
@ -27,11 +27,12 @@ from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
|
||||
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
PeftCacheConfig, RequestStage,
|
||||
RequestStats, SpecDecodingStats,
|
||||
RequestStage, RequestStats,
|
||||
SpecDecodingStats,
|
||||
StaticBatchingStats)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
|
||||
ReqIdsSet)
|
||||
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import CpType
|
||||
from tensorrt_llm.runtime.generation import CUASSERT
|
||||
|
||||
@ -6,18 +6,18 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from strenum import StrEnum
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_disabled
|
||||
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy,
|
||||
GuidedDecodingConfig)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
||||
from tensorrt_llm.llmapi.llm_args import LoadFormat, PybindMirror, TorchLlmArgs
|
||||
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
|
||||
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy, LoadFormat,
|
||||
PybindMirror, TorchLlmArgs)
|
||||
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase,
|
||||
_llguidance_tokenizer_info,
|
||||
_xgrammar_tokenizer_info)
|
||||
@ -214,12 +214,11 @@ def create_py_executor(
|
||||
if pytorch_backend_config is None:
|
||||
pytorch_backend_config = PyTorchConfig()
|
||||
|
||||
scheduler_config = PybindMirror.maybe_to_pybind(llm_args.scheduler_config)
|
||||
scheduler_config = llm_args.scheduler_config
|
||||
|
||||
peft_cache_config = None
|
||||
if llm_args.peft_cache_config is not None:
|
||||
peft_cache_config = PybindMirror.maybe_to_pybind(
|
||||
llm_args.peft_cache_config)
|
||||
# Since peft_cache_config may be subject to change, avoid these changes propagate back
|
||||
# to llm_args.peft_cache_config
|
||||
peft_cache_config = copy.deepcopy(llm_args.peft_cache_config)
|
||||
|
||||
assert llm_args.kv_cache_config, "Expect llm_args.kv_cache_config is not None"
|
||||
kv_cache_config = llm_args.kv_cache_config
|
||||
@ -457,8 +456,8 @@ def create_py_executor(
|
||||
scheduler_config.context_chunking_policy is not None
|
||||
else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED)
|
||||
assert chunk_unit_size is not None, "chunk_unit_size must be set"
|
||||
ctx_chunk_config = ContextChunkingConfig(chunking_policy,
|
||||
chunk_unit_size)
|
||||
ctx_chunk_config: Tuple[StrEnum,
|
||||
int] = (chunking_policy, chunk_unit_size)
|
||||
else:
|
||||
ctx_chunk_config = None
|
||||
|
||||
|
||||
@ -11,7 +11,8 @@ import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._utils import mpi_disabled
|
||||
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PybindMirror
|
||||
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
|
||||
PybindMirror)
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
|
||||
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
|
||||
@ -39,7 +40,6 @@ DataType = tensorrt_llm.bindings.DataType
|
||||
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
|
||||
RequestList = list[LlmRequest]
|
||||
PeftCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.PeftCacheManager
|
||||
PeftCacheConfig = tensorrt_llm.bindings.executor.PeftCacheConfig
|
||||
WorldConfig = tensorrt_llm.bindings.WorldConfig
|
||||
TempAttentionWindowInputs = tensorrt_llm.bindings.internal.batch_manager.TempAttentionWindowInputs
|
||||
BlocksPerWindow = Dict[int, Tuple[
|
||||
@ -1164,6 +1164,8 @@ class PeftCacheManager(BaseResourceManager):
|
||||
world_config: WorldConfig | None = None):
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
peft_cache_config = peft_cache_config._to_pybind()
|
||||
|
||||
peft_cache_manager_config = _tb.PeftCacheManagerConfig(
|
||||
num_host_module_layer=peft_cache_config.num_host_module_layer,
|
||||
num_device_module_layer=peft_cache_config.num_device_module_layer,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from strenum import StrEnum
|
||||
|
||||
from tensorrt_llm.bindings import executor as tb_executor
|
||||
from tensorrt_llm.bindings import internal as tb_internal
|
||||
from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy
|
||||
|
||||
from .llm_request import LlmRequest, LlmRequestState
|
||||
|
||||
@ -74,8 +76,8 @@ class BindCapacityScheduler(CapacityScheduler):
|
||||
max_num_requests: int,
|
||||
kv_cache_manager,
|
||||
peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None,
|
||||
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
|
||||
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
|
||||
scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.
|
||||
GUARANTEED_NO_EVICT,
|
||||
two_step_lookahead: bool = False,
|
||||
):
|
||||
super(BindCapacityScheduler, self).__init__()
|
||||
@ -84,7 +86,7 @@ class BindCapacityScheduler(CapacityScheduler):
|
||||
|
||||
self.impl = tb_internal.algorithms.CapacityScheduler(
|
||||
max_num_requests=max_num_requests,
|
||||
capacity_scheduler_policy=scheduler_policy,
|
||||
capacity_scheduler_policy=scheduler_policy._to_pybind(),
|
||||
has_kv_cache_manager=kv_cache_manager is not None,
|
||||
two_step_lookahead=two_step_lookahead,
|
||||
no_schedule_until_state=LlmRequestState.CONTEXT_INIT,
|
||||
@ -172,14 +174,19 @@ class BindMicroBatchScheduler(MicroBatchScheduler):
|
||||
self,
|
||||
max_batch_size: int,
|
||||
max_num_tokens: int = None,
|
||||
ctx_chunk_config: Optional[
|
||||
tb_internal.batch_manager.ContextChunkingConfig] = None,
|
||||
ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None,
|
||||
) -> None:
|
||||
super(BindMicroBatchScheduler, self).__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
|
||||
ctx_chunk_config_cpp = None
|
||||
if ctx_chunk_config is not None:
|
||||
ctx_chunk_config_cpp = tb_internal.batch_manager.ContextChunkingConfig(
|
||||
ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1])
|
||||
|
||||
self.impl = tb_internal.algorithms.MicroBatchScheduler(
|
||||
ctx_chunk_config, max_num_tokens)
|
||||
ctx_chunk_config_cpp, max_num_tokens)
|
||||
|
||||
def schedule(
|
||||
self, active_requests: RequestList, inflight_request_ids: set[int]
|
||||
|
||||
@ -13,14 +13,13 @@ import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
|
||||
PeftCacheConfig,
|
||||
PeftCacheManager)
|
||||
from tensorrt_llm.bindings import LayerType
|
||||
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
||||
from tensorrt_llm.bindings import executor as tllm
|
||||
from tensorrt_llm.bindings.internal.batch_manager import \
|
||||
PeftTaskNotCachedException
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -234,7 +233,7 @@ class TestResourceManager(unittest.TestCase):
|
||||
num_ensure_workers=mock_config.ensure_thread_count,
|
||||
)
|
||||
|
||||
return peft_cache_config
|
||||
return PeftCacheConfig.from_pybind(peft_cache_config)
|
||||
|
||||
def _create_request(self,
|
||||
request_id,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user