diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 60feeb1859..6e9284fa29 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -3,14 +3,15 @@ from types import SimpleNamespace from typing import Dict, List, Optional, Tuple import torch +from strenum import StrEnum from torch._prims_common import DeviceLikeType from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.llmapi.llm_args import ContextChunkingPolicy from ...._utils import mpi_rank, mpi_world_size -from ....bindings.executor import ContextChunkingPolicy -from ....bindings.internal.batch_manager import CacheType, ContextChunkingConfig +from ....bindings.internal.batch_manager import CacheType from ....mapping import Mapping from ...distributed import MPIDist from ...pyexecutor.model_engine import ModelEngine @@ -376,7 +377,7 @@ def create_autodeploy_executor(ad_config: LlmArgs): if ad_config.enable_chunked_prefill: chunk_unit_size = ad_config.attn_page_size chunking_policy = ContextChunkingPolicy.FIRST_COME_FIRST_SERVED - ctx_chunk_config = ContextChunkingConfig(chunking_policy, chunk_unit_size) + ctx_chunk_config: Tuple[StrEnum, int] = (chunking_policy, chunk_unit_size) else: ctx_chunk_config = None diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e8a62aad5b..4354495ef1 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -13,7 +13,8 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, - SamplerType, SparseAttentionConfig, + SamplerType, SchedulerConfig, + SparseAttentionConfig, SpeculativeConfig, TorchLlmArgs) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, @@ -663,8 +664,8 @@ def create_py_executor_instance( max_batch_size: Optional[int] = None, max_beam_width: Optional[int] = None, max_num_tokens: Optional[int] = None, - peft_cache_config: Optional[trtllm.PeftCacheConfig] = None, - scheduler_config: Optional[trtllm.SchedulerConfig] = None, + peft_cache_config: Optional[PeftCacheConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None, ) -> PyExecutor: kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None) @@ -728,16 +729,14 @@ def create_py_executor_instance( num_lora_modules = model_engine.model.model_config.pretrained_config.num_hidden_layers * \ len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) - peft_cache_config_model = PeftCacheConfig.from_pybind( - peft_cache_config - ) if peft_cache_config is not None else PeftCacheConfig() + peft_cache_config_model = PeftCacheConfig( + ) if peft_cache_config is None else peft_cache_config if lora_config.max_loras is not None: peft_cache_config_model.num_device_module_layer = \ max_lora_rank * num_lora_modules * lora_config.max_loras if lora_config.max_cpu_loras is not None: peft_cache_config_model.num_host_module_layer = \ max_lora_rank * num_lora_modules * lora_config.max_cpu_loras - peft_cache_config = peft_cache_config_model._to_pybind() from tensorrt_llm.bindings import WorldConfig world_config = WorldConfig( @@ -748,7 +747,7 @@ def create_py_executor_instance( gpus_per_node=dist.mapping.gpus_per_node, ) peft_cache_manager = PeftCacheManager( - peft_cache_config=peft_cache_config, + peft_cache_config=peft_cache_config_model, lora_config=lora_config, model_config=model_binding_config, world_config=world_config, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e4d23937ba..6edd7d1024 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -27,11 +27,12 @@ from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled, from tensorrt_llm.bindings.executor import (DisServingRequestStats, FinishReason, InflightBatchingStats, IterationStats, KvCacheStats, - PeftCacheConfig, RequestStage, - RequestStats, SpecDecodingStats, + RequestStage, RequestStats, + SpecDecodingStats, StaticBatchingStats) from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, ReqIdsSet) +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.logger import logger from tensorrt_llm.mapping import CpType from tensorrt_llm.runtime.generation import CUASSERT diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 214ac2014c..bab9f2354e 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -6,18 +6,18 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from dataclasses import dataclass from itertools import chain -from typing import Optional +from typing import Optional, Tuple import torch +from strenum import StrEnum import tensorrt_llm from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm._utils import get_sm_version, mpi_disabled -from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy, - ContextChunkingPolicy, - GuidedDecodingConfig) -from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig -from tensorrt_llm.llmapi.llm_args import LoadFormat, PybindMirror, TorchLlmArgs +from tensorrt_llm.bindings.executor import GuidedDecodingConfig +from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy, + ContextChunkingPolicy, LoadFormat, + PybindMirror, TorchLlmArgs) from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, _llguidance_tokenizer_info, _xgrammar_tokenizer_info) @@ -214,12 +214,11 @@ def create_py_executor( if pytorch_backend_config is None: pytorch_backend_config = PyTorchConfig() - scheduler_config = PybindMirror.maybe_to_pybind(llm_args.scheduler_config) + scheduler_config = llm_args.scheduler_config - peft_cache_config = None - if llm_args.peft_cache_config is not None: - peft_cache_config = PybindMirror.maybe_to_pybind( - llm_args.peft_cache_config) + # Since peft_cache_config may be subject to change, avoid these changes propagate back + # to llm_args.peft_cache_config + peft_cache_config = copy.deepcopy(llm_args.peft_cache_config) assert llm_args.kv_cache_config, "Expect llm_args.kv_cache_config is not None" kv_cache_config = llm_args.kv_cache_config @@ -457,8 +456,8 @@ def create_py_executor( scheduler_config.context_chunking_policy is not None else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED) assert chunk_unit_size is not None, "chunk_unit_size must be set" - ctx_chunk_config = ContextChunkingConfig(chunking_policy, - chunk_unit_size) + ctx_chunk_config: Tuple[StrEnum, + int] = (chunking_policy, chunk_unit_size) else: ctx_chunk_config = None diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index e19a04c17a..7d6d37af78 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -11,7 +11,8 @@ import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE -from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PybindMirror +from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, + PybindMirror) from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig from tensorrt_llm.runtime import ModelConfig as ModelConfigPython @@ -39,7 +40,6 @@ DataType = tensorrt_llm.bindings.DataType KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager RequestList = list[LlmRequest] PeftCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.PeftCacheManager -PeftCacheConfig = tensorrt_llm.bindings.executor.PeftCacheConfig WorldConfig = tensorrt_llm.bindings.WorldConfig TempAttentionWindowInputs = tensorrt_llm.bindings.internal.batch_manager.TempAttentionWindowInputs BlocksPerWindow = Dict[int, Tuple[ @@ -1164,6 +1164,8 @@ class PeftCacheManager(BaseResourceManager): world_config: WorldConfig | None = None): import tensorrt_llm.bindings as _tb + peft_cache_config = peft_cache_config._to_pybind() + peft_cache_manager_config = _tb.PeftCacheManagerConfig( num_host_module_layer=peft_cache_config.num_host_module_layer, num_device_module_layer=peft_cache_config.num_device_module_layer, diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 81a8631c4a..c71c4596ed 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1,9 +1,11 @@ from abc import ABC, abstractmethod from collections import namedtuple -from typing import Optional +from typing import Optional, Tuple + +from strenum import StrEnum -from tensorrt_llm.bindings import executor as tb_executor from tensorrt_llm.bindings import internal as tb_internal +from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy from .llm_request import LlmRequest, LlmRequestState @@ -74,8 +76,8 @@ class BindCapacityScheduler(CapacityScheduler): max_num_requests: int, kv_cache_manager, peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None, - scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor. - CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. + GUARANTEED_NO_EVICT, two_step_lookahead: bool = False, ): super(BindCapacityScheduler, self).__init__() @@ -84,7 +86,7 @@ class BindCapacityScheduler(CapacityScheduler): self.impl = tb_internal.algorithms.CapacityScheduler( max_num_requests=max_num_requests, - capacity_scheduler_policy=scheduler_policy, + capacity_scheduler_policy=scheduler_policy._to_pybind(), has_kv_cache_manager=kv_cache_manager is not None, two_step_lookahead=two_step_lookahead, no_schedule_until_state=LlmRequestState.CONTEXT_INIT, @@ -172,14 +174,19 @@ class BindMicroBatchScheduler(MicroBatchScheduler): self, max_batch_size: int, max_num_tokens: int = None, - ctx_chunk_config: Optional[ - tb_internal.batch_manager.ContextChunkingConfig] = None, + ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, ) -> None: super(BindMicroBatchScheduler, self).__init__() self.max_batch_size = max_batch_size self.max_num_tokens = max_num_tokens + + ctx_chunk_config_cpp = None + if ctx_chunk_config is not None: + ctx_chunk_config_cpp = tb_internal.batch_manager.ContextChunkingConfig( + ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1]) + self.impl = tb_internal.algorithms.MicroBatchScheduler( - ctx_chunk_config, max_num_tokens) + ctx_chunk_config_cpp, max_num_tokens) def schedule( self, active_requests: RequestList, inflight_request_ids: set[int] diff --git a/tests/unittest/_torch/executor/test_resource_manager.py b/tests/unittest/_torch/executor/test_resource_manager.py index f67dfcc60e..1bf7633f80 100644 --- a/tests/unittest/_torch/executor/test_resource_manager.py +++ b/tests/unittest/_torch/executor/test_resource_manager.py @@ -13,14 +13,13 @@ import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, - PeftCacheConfig, PeftCacheManager) from tensorrt_llm.bindings import LayerType from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.bindings.internal.batch_manager import \ PeftTaskNotCachedException -from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PeftCacheConfig from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.mapping import Mapping @@ -234,7 +233,7 @@ class TestResourceManager(unittest.TestCase): num_ensure_workers=mock_config.ensure_thread_count, ) - return peft_cache_config + return PeftCacheConfig.from_pybind(peft_cache_config) def _create_request(self, request_id,