[None][chore] Remove executor config in create_py_executor (#7599)

Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
Leslie Fang 2025-09-18 14:24:58 +08:00 committed by GitHub
parent b6e916b762
commit 870cfcf9a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 219 additions and 243 deletions

View File

@ -15,10 +15,13 @@ from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
ContextChunkingPolicy,
ExecutorConfig)
GuidedDecodingConfig)
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from tensorrt_llm.llmapi.llm_args import (KvCacheConnectorConfig, LoadFormat,
PybindMirror, TorchLlmArgs)
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase,
_llguidance_tokenizer_info,
_xgrammar_tokenizer_info)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
@ -30,7 +33,7 @@ from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
get_spec_resource_manager)
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla)
from .config import LoadFormat, PyTorchConfig
from .config import PyTorchConfig, _construct_checkpoint_loader
from .config_utils import is_mla
from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
@ -161,58 +164,14 @@ class _ExecutorMemoryMonitor():
))
def _mangle_executor_config(executor_config: ExecutorConfig):
if executor_config.pytorch_backend_config is None:
executor_config.pytorch_backend_config = PyTorchConfig()
pytorch_backend_config = executor_config.pytorch_backend_config
if executor_config.max_num_tokens is None:
executor_config.max_num_tokens = 8192
if pytorch_backend_config.attn_backend in [
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"
]:
# Workaround for flashinfer and star attention
if executor_config.kv_cache_config.enable_block_reuse:
logger.warning(
f"Disabling block reuse for {pytorch_backend_config.attn_backend} backend"
)
executor_config.kv_cache_config.enable_block_reuse = False
if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and executor_config.enable_chunked_context:
logger.warning(
f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"
)
executor_config.enable_chunked_context = False
spec_config = executor_config.speculative_config
if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
if not spec_config.spec_dec_mode.support_overlap_scheduler():
logger.warning(
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
)
executor_config.pytorch_backend_config.disable_overlap_scheduler = True
if executor_config.mm_encoder_only:
from tensorrt_llm.llmapi.llm_args import LoadFormat
pytorch_backend_config.mm_encoder_only = True
pytorch_backend_config.load_format = LoadFormat.VISION_ONLY
# Disable overlap scheduler for multimodal encoder-only mode
logger.warning(
"Disabling overlap scheduler for multimodal encoder-only mode. "
"The overlap scheduler is designed for generation models and is not needed "
"when only processing vision encoder inputs.")
pytorch_backend_config.disable_overlap_scheduler = True
def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
if executor_config.mapping is None:
def _get_mapping(_mapping: Mapping) -> Mapping:
if _mapping is None:
mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
gpus_per_node=tensorrt_llm.default_gpus_per_node(),
rank=tensorrt_llm.mpi_rank())
else:
mapping = copy.deepcopy(executor_config.mapping)
mapping = copy.deepcopy(_mapping)
mapping.rank = tensorrt_llm.mpi_rank()
return mapping
@ -230,6 +189,25 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
sampler.max_seq_len = max_seq_len
def get_guided_decoding_config(guided_decoding_backend: str,
tokenizer: Optional[TokenizerBase] = None):
guided_decoding_config = None
if guided_decoding_backend == 'xgrammar':
assert tokenizer is not None
guided_decoding_config = GuidedDecodingConfig(
backend=GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
**_xgrammar_tokenizer_info(tokenizer))
elif guided_decoding_backend == 'llguidance':
assert tokenizer is not None
guided_decoding_config = GuidedDecodingConfig(
backend=GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE,
**_llguidance_tokenizer_info(tokenizer))
elif guided_decoding_backend is not None:
raise ValueError(
f"Unsupported guided decoding backend {guided_decoding_backend}")
return guided_decoding_config
def create_py_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: str = None,
@ -238,17 +216,96 @@ def create_py_executor(
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
) -> PyExecutor:
executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer)
garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
_mangle_executor_config(executor_config)
pytorch_backend_config = executor_config.pytorch_backend_config
pytorch_backend_config = llm_args.get_pytorch_backend_config()
if pytorch_backend_config is None:
pytorch_backend_config = PyTorchConfig()
mapping = _get_mapping(executor_config)
scheduler_config = PybindMirror.maybe_to_pybind(llm_args.scheduler_config)
peft_cache_config = None
if llm_args.peft_cache_config is not None:
peft_cache_config = PybindMirror.maybe_to_pybind(
llm_args.peft_cache_config)
assert llm_args.kv_cache_config, "Expect llm_args.kv_cache_config is not None"
kv_cache_config = PybindMirror.maybe_to_pybind(llm_args.kv_cache_config)
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
kv_cache_config.enable_block_reuse = False
kv_cache_config.enable_partial_reuse = False
decoding_config = llm_args.decoding_config
guided_decoding_config = get_guided_decoding_config(
llm_args.guided_decoding_backend, tokenizer)
mm_encoder_only = llm_args.mm_encoder_only
enable_chunked_context = llm_args.enable_chunked_prefill
assert llm_args.backend == "pytorch", "_construct_checkpoint_loader expects different parameters for autodeploy"
checkpoint_loader = _construct_checkpoint_loader(llm_args.backend,
llm_args.checkpoint_loader,
llm_args.checkpoint_format)
(
max_beam_width,
max_num_tokens,
max_seq_len,
max_batch_size,
) = llm_args.get_runtime_sizes()
if max_num_tokens is None:
max_num_tokens = 8192
tokens_per_block = llm_args.kv_cache_config.tokens_per_block
if pytorch_backend_config.attn_backend in [
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"
]:
# Workaround for flashinfer and star attention
if kv_cache_config.enable_block_reuse:
logger.warning(
f"Disabling block reuse for {pytorch_backend_config.attn_backend} backend"
)
kv_cache_config.enable_block_reuse = False
if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and enable_chunked_context:
logger.warning(
f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"
)
enable_chunked_context = False
spec_config = llm_args.speculative_config
if spec_config is not None and spec_config.decoding_type == "AUTO":
from tensorrt_llm._torch.speculative import suggest_spec_config
spec_config = suggest_spec_config(max_batch_size)
if not pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
if not spec_config.spec_dec_mode.support_overlap_scheduler():
logger.warning(
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
)
pytorch_backend_config.disable_overlap_scheduler = True
if mm_encoder_only:
pytorch_backend_config.mm_encoder_only = True
pytorch_backend_config.load_format = LoadFormat.VISION_ONLY
# Disable overlap scheduler for multimodal encoder-only mode
logger.warning(
"Disabling overlap scheduler for multimodal encoder-only mode. "
"The overlap scheduler is designed for generation models and is not needed "
"when only processing vision encoder inputs.")
pytorch_backend_config.disable_overlap_scheduler = True
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
dist = MPIDist(mapping=mapping)
cache_transceiver_config = executor_config.cache_transceiver_config
spec_config = executor_config.speculative_config
cache_transceiver_config = None
if llm_args.cache_transceiver_config is not None:
cache_transceiver_config = PybindMirror.maybe_to_pybind(
llm_args.cache_transceiver_config)
has_draft_model_engine = False
has_spec_drafter = False
if spec_config is not None:
@ -257,10 +314,10 @@ def create_py_executor(
# chunk_unit_size may be changed to 64 when using flash mla
attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=executor_config.enable_chunked_context,
cache_reuse=executor_config.kv_cache_config.enable_block_reuse,
chunked_prefill=enable_chunked_context,
cache_reuse=kv_cache_config.enable_block_reuse,
has_speculative_draft_tokens=has_draft_model_engine or has_spec_drafter,
chunk_size=executor_config.max_num_tokens,
chunk_size=max_num_tokens,
)
logger.info("ATTENTION RUNTIME FEATURES: ", attn_runtime_features)
@ -270,16 +327,16 @@ def create_py_executor(
model_engine = PyTorchModelEngine(
model_path=checkpoint_dir,
pytorch_backend_config=pytorch_backend_config,
batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
max_seq_len=executor_config.max_seq_len,
batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
max_seq_len=max_seq_len,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
lora_config=lora_config,
checkpoint_loader=executor_config.checkpoint_loader,
checkpoint_loader=checkpoint_loader,
)
if has_draft_model_engine:
@ -292,7 +349,7 @@ def create_py_executor(
if _get_allow_chain_drafter():
use_chain_drafter = (
executor_config.guided_decoding_config is None
guided_decoding_config is None
and not pytorch_backend_config.enable_mixed_sampler
and pytorch_backend_config.attn_backend == "TRTLLM")
else:
@ -316,17 +373,17 @@ def create_py_executor(
draft_model_engine = PyTorchModelEngine(
model_path=spec_config.speculative_model_dir,
pytorch_backend_config=draft_pytorch_backend_config,
batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
# Note: The draft model engine will infer its own max_seq_len.
# We'll stop drafting when we hit the max.
max_seq_len=executor_config.max_seq_len,
max_seq_len=max_seq_len,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=draft_spec_config,
checkpoint_loader=executor_config.checkpoint_loader,
checkpoint_loader=checkpoint_loader,
is_draft_model=True,
drafting_loop_wrapper=drafting_loop_wrapper,
)
@ -336,71 +393,69 @@ def create_py_executor(
else:
draft_model_engine = None
# PyTorchModelEngine modifies these fields, update them to executor_config
max_seq_len = model_engine.max_seq_len
net_max_seq_len = max_seq_len
# PyTorchModelEngine modifies these fields, update them
model_engine_max_seq_len = model_engine.max_seq_len
net_max_seq_len = model_engine_max_seq_len
if not pytorch_backend_config.disable_overlap_scheduler:
max_seq_len = model_engine.max_seq_len + 1
model_engine_max_seq_len = model_engine.max_seq_len + 1
if spec_config is not None:
max_seq_len += spec_config.max_draft_len
model_engine_max_seq_len += spec_config.max_draft_len
if spec_config is not None:
max_seq_len += get_num_extra_kv_tokens(spec_config)
max_seq_len += spec_config.max_draft_len
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
model_engine_max_seq_len += spec_config.max_draft_len
executor_config.max_seq_len = max_seq_len
executor_config.max_num_tokens = model_engine.max_num_tokens
max_seq_len = model_engine_max_seq_len
max_num_tokens = model_engine.max_num_tokens
config = model_engine.model.model_config.pretrained_config
if is_mla(config):
if model_engine.model.model_config.enable_flash_mla:
executor_config.tokens_per_block = 64
tokens_per_block = 64
logger.info(
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
f"Change tokens_per_block to: {tokens_per_block} for using FlashMLA"
)
sm_version = get_sm_version()
if executor_config.kv_cache_config.enable_block_reuse and sm_version not in [
if kv_cache_config.enable_block_reuse and sm_version not in [
90, 100, 103, 120
]:
logger.warning(
f"KV cache reuse for MLA can only be enabled on SM90/SM100/SM103/SM120, "
f"disable enable_block_reuse for SM{sm_version}")
executor_config.kv_cache_config.enable_block_reuse = False
kv_cache_config.enable_block_reuse = False
kv_cache_quant_algo = model_engine.model.model_config.quant_config.kv_cache_quant_algo
if executor_config.kv_cache_config.enable_block_reuse and not (
if kv_cache_config.enable_block_reuse and not (
kv_cache_quant_algo is None or kv_cache_quant_algo
== QuantAlgo.NO_QUANT or kv_cache_quant_algo == QuantAlgo.FP8):
logger.warning(
f"KV cache reuse for MLA can only be enabled without KV cache quantization or with FP8 quantization, "
f"disable enable_block_reuse for KV cache quant algorithm: {kv_cache_quant_algo}"
)
executor_config.kv_cache_config.enable_block_reuse = False
if executor_config.enable_chunked_context and sm_version not in [
90, 100, 103, 120
]:
kv_cache_config.enable_block_reuse = False
if enable_chunked_context and sm_version not in [90, 100, 103, 120]:
logger.warning(
"Chunked Prefill for MLA can only be enabled on SM90/SM100/SM103/SM120, "
f"disable enable_chunked_context for SM{sm_version}")
executor_config.enable_chunked_context = False
enable_chunked_context = False
model_engine.attn_runtime_features.chunked_prefill = False
if draft_model_engine is not None:
draft_model_engine.attn_runtime_features.chunked_prefill = False
if executor_config.enable_chunked_context:
chunk_unit_size = executor_config.tokens_per_block
max_attention_window = executor_config.kv_cache_config.max_attention_window
if max_attention_window and max_seq_len > min(max_attention_window):
if enable_chunked_context:
chunk_unit_size = tokens_per_block
max_attention_window = kv_cache_config.max_attention_window
if max_attention_window and model_engine_max_seq_len > min(
max_attention_window):
# maxKvStepSizeInFmha = 256
chunk_unit_size = max(256, chunk_unit_size)
logger.info(
f"ChunkUnitSize is set to {chunk_unit_size} as sliding window attention is used."
)
chunking_policy = (
executor_config.scheduler_config.context_chunking_policy
if executor_config.scheduler_config.context_chunking_policy
is not None else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED)
chunking_policy = (scheduler_config.context_chunking_policy if
scheduler_config.context_chunking_policy is not None
else ContextChunkingPolicy.FIRST_COME_FIRST_SERVED)
assert chunk_unit_size is not None, "chunk_unit_size must be set"
ctx_chunk_config = ContextChunkingConfig(chunking_policy,
chunk_unit_size)
@ -410,12 +465,11 @@ def create_py_executor(
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.GUIDED_DECODER):
guided_decoder: Optional[GuidedDecoder] = None
if executor_config.guided_decoding_config is not None:
if guided_decoding_config is not None:
if mapping.is_last_pp_rank():
kwargs = {
"guided_decoding_config":
executor_config.guided_decoding_config,
"max_num_sequences": executor_config.max_batch_size,
"guided_decoding_config": guided_decoding_config,
"max_num_sequences": max_batch_size,
"vocab_size_padded": model_engine.model.vocab_size_padded
}
if spec_config is not None:
@ -440,17 +494,16 @@ def create_py_executor(
)
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
sampler = instantiate_sampler(
model_engine,
pytorch_backend_config,
mapping,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_seq_len=executor_config.max_seq_len,
mm_encoder_only=executor_config.mm_encoder_only,
speculative_config=executor_config.speculative_config,
decoding_config=executor_config.decoding_config,
kv_cache_config=executor_config.kv_cache_config)
sampler = instantiate_sampler(model_engine,
pytorch_backend_config,
mapping,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_seq_len=max_seq_len,
mm_encoder_only=mm_encoder_only,
speculative_config=spec_config,
decoding_config=decoding_config,
kv_cache_config=kv_cache_config)
logger.info(f"Using Sampler: {type(sampler).__name__}")
if kv_connector_config is not None:
@ -461,7 +514,7 @@ def create_py_executor(
raise NotImplementedError(
"CUDA graphs are not supported with KV connector hooks.")
if executor_config.scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT:
if scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT:
raise NotImplementedError(
"KV connector is only supported with guaranteed no evict scheduler policy."
)
@ -510,25 +563,24 @@ def create_py_executor(
mapping=mapping,
net_max_seq_len=net_max_seq_len,
kv_connector_manager=kv_connector_manager,
max_num_tokens=executor_config.max_num_tokens,
max_beam_width=executor_config.max_beam_width,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
kv_cache_config=executor_config.kv_cache_config,
pytorch_backend_config=executor_config.pytorch_backend_config,
speculative_config=executor_config.speculative_config,
max_num_tokens=max_num_tokens,
max_beam_width=max_beam_width,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_backend_config,
speculative_config=spec_config,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_KV_CACHE
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
kv_cache_creator.build_managers(resources, estimating_kv_cache)
# Originally, executor_config.max_seq_len might be changed inside build_managers and used
# below in create_py_executor_instance. Since now, we are changing
# kv_cache_creator._max_seq_len instead, restore executor_config.max_seq_len.
executor_config.max_seq_len = kv_cache_creator._max_seq_len
update_sampler_max_seq_len(executor_config.max_seq_len, sampler)
# Originally, max_seq_len might be mutated inside build_managers as field of executor config.
# Since now, we are changing kv_cache_creator._max_seq_len instead. Restore max_seq_len here.
max_seq_len = kv_cache_creator._max_seq_len
update_sampler_max_seq_len(max_seq_len, sampler)
# Resource managers for speculative decoding
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
@ -565,17 +617,17 @@ def create_py_executor(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
peft_cache_config=executor_config.peft_cache_config,
scheduler_config=executor_config.scheduler_config,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
peft_cache_config=peft_cache_config,
scheduler_config=scheduler_config,
cache_transceiver_config=cache_transceiver_config,
)
# Modify the executor_config.peft_cache_config which might be mutated
# inside create_py_executor_instance
executor_config.peft_cache_config = py_executor.peft_cache_config
# Originally, peft_cache_config might be mutated inside
# create_py_executor_instance. Restore it here.
peft_cache_config = py_executor.peft_cache_config
if estimating_kv_cache:
assert kv_cache_creator is not None
@ -590,13 +642,12 @@ def create_py_executor(
# Before estimating KV cache size, a minimal KV cache has been allocated using
# create_kv_cache_manager above, which caps kv_cache_creator.max_seq_len. Restoring
# the original value before creating the final KV cache.
kv_cache_creator._max_seq_len = max_seq_len
kv_cache_creator._max_seq_len = model_engine_max_seq_len
kv_cache_creator.build_managers(resources, False)
# Originally, executor_config.max_seq_len might be changed again inside build_managers
# Since now, we are changing kv_cache_creator.max_seq_len instead.
# Restore executor_config.max_seq_len which has been used in create_py_executor_instance
executor_config.max_seq_len = kv_cache_creator._max_seq_len
update_sampler_max_seq_len(executor_config.max_seq_len, sampler)
# Originally, max_seq_len might be mutated inside build_managers as field of executor config.
# Since now, we are changing kv_cache_creator._max_seq_len instead. Restore max_seq_len here.
max_seq_len = kv_cache_creator._max_seq_len
update_sampler_max_seq_len(max_seq_len, sampler)
for eng in [model_engine, draft_model_engine]:
if eng is None:
@ -623,16 +674,16 @@ def create_py_executor(
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_num_tokens=executor_config.max_num_tokens,
peft_cache_config=executor_config.peft_cache_config,
scheduler_config=executor_config.scheduler_config,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_num_tokens=max_num_tokens,
peft_cache_config=peft_cache_config,
scheduler_config=scheduler_config,
cache_transceiver_config=cache_transceiver_config,
)
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
_adjust_torch_mem_fraction(pytorch_backend_config)
py_executor.start_worker()
return py_executor

View File

@ -9,7 +9,8 @@ from dataclasses import dataclass, field
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
Set, Type, TypeAlias, TypeVar, Union, get_args, get_origin)
Set, Tuple, Type, TypeAlias, TypeVar, Union, get_args,
get_origin)
import torch
import yaml
@ -57,8 +58,7 @@ from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig,
SpeculativeDecodingMode)
from ..sampling_params import BatchedLogitsProcessor
from .build_cache import BuildCacheConfig
from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info,
_xgrammar_tokenizer_info, tokenizer_factory)
from .tokenizer import TokenizerBase, tokenizer_factory
from .utils import generate_api_docs_as_docstring, get_type_repr
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
@ -1899,84 +1899,14 @@ class BaseLlmArgs(StrictBaseModel):
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size)
def get_executor_config(
self,
_hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
) -> _ExecutorConfig:
executor_config = _ExecutorConfig(
max_beam_width=self.max_beam_width,
scheduler_config=PybindMirror.maybe_to_pybind(
self.scheduler_config),
max_batch_size=self.max_batch_size,
max_num_tokens=self.max_num_tokens,
gather_generation_logits=self.gather_generation_logits,
fail_fast_on_attention_window_too_large=getattr(
self, 'fail_fast_on_attention_window_too_large', False),
def get_runtime_sizes(self, ) -> Tuple[int, int, int, int]:
return (
self.max_beam_width,
self.max_num_tokens,
self.max_seq_len,
self.max_batch_size,
)
if self.kv_cache_config is not None:
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
self.kv_cache_config)
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
executor_config.kv_cache_config.enable_block_reuse = False
executor_config.kv_cache_config.enable_partial_reuse = False
if self.peft_cache_config is not None:
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
self.peft_cache_config)
if self.decoding_config is not None:
executor_config.decoding_config = self.decoding_config
if self.guided_decoding_backend == 'xgrammar':
assert tokenizer is not None
executor_config.guided_decoding_config = _GuidedDecodingConfig(
backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
**_xgrammar_tokenizer_info(tokenizer))
elif self.guided_decoding_backend == 'llguidance':
assert tokenizer is not None
executor_config.guided_decoding_config = _GuidedDecodingConfig(
backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE,
**_llguidance_tokenizer_info(tokenizer))
elif self.guided_decoding_backend is not None:
raise ValueError(
f"Unsupported guided decoding backend {self.guided_decoding_backend}"
)
executor_config.enable_chunked_context = self.enable_chunked_prefill
executor_config.max_beam_width = self.max_beam_width
if self.cache_transceiver_config is not None:
executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
self.cache_transceiver_config)
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
spec_config = self.speculative_config
max_batch_size = executor_config.max_batch_size
if spec_config is not None and spec_config.decoding_type == "AUTO":
from tensorrt_llm._torch.speculative import suggest_spec_config
spec_config = suggest_spec_config(max_batch_size)
if self.kv_cache_config is not None:
executor_config.tokens_per_block = self.kv_cache_config.tokens_per_block
update_executor_config(
executor_config,
backend=self.backend,
pytorch_backend_config=self.get_pytorch_backend_config()
if self.backend in ["pytorch", "_autodeploy"] else None,
mapping=self.parallel_config.to_mapping(),
speculative_config=spec_config,
hf_model_dir=_hf_model_dir,
max_input_len=self.max_input_len,
max_seq_len=self.max_seq_len,
checkpoint_format=None
if self.backend == "_autodeploy" else self.checkpoint_format,
checkpoint_loader=None
if self.backend == "_autodeploy" else self.checkpoint_loader)
return executor_config
class TrtLlmArgs(BaseLlmArgs):
@ -2542,15 +2472,6 @@ class TorchLlmArgs(BaseLlmArgs):
raise ValueError("batch_wait_timeout_ms must be greater than 0")
return self
def get_executor_config(
self,
_hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
) -> _ExecutorConfig:
executor_config = super().get_executor_config(_hf_model_dir, tokenizer)
executor_config.mm_encoder_only = self.mm_encoder_only
return executor_config
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "PyTorchConfig":
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig

View File

@ -438,12 +438,16 @@ class TestTorchLlmArgs:
assert llm.args.max_seq_len == 128
assert llm.args.max_batch_size == 8
executor_config = llm.args.get_executor_config(
llm._hf_model_dir, llm.tokenizer)
assert executor_config.max_beam_width == 1
assert executor_config.max_num_tokens == 256
assert executor_config.max_seq_len == 128
assert executor_config.max_batch_size == 8
(
max_beam_width,
max_num_tokens,
max_seq_len,
max_batch_size,
) = llm.args.get_runtime_sizes()
assert max_beam_width == 1
assert max_num_tokens == 256
assert max_seq_len == 128
assert max_batch_size == 8
def test_dynamic_setattr(self):
with pytest.raises(pydantic_core._pydantic_core.ValidationError):