[TRTLLM-9065][chore] remove PyTorchConfig completely (#8856)

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-11-07 14:37:03 +08:00 committed by GitHub
parent b26e1617f2
commit 1c6e490894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 90 additions and 320 deletions

View File

@ -34,10 +34,10 @@ from tqdm import tqdm
import tensorrt_llm
from tensorrt_llm import LLM as TORCH_LLM
from tensorrt_llm._tensorrt_engine import LLM as TRT_LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.bindings.executor import DecodingConfig
from tensorrt_llm.llmapi import KvCacheConfig as TRT_KvCacheConfig
from tensorrt_llm.llmapi import RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import MoeConfig
logger = logging.getLogger(__name__)
@ -98,10 +98,8 @@ class TRTLLMEvalBase(TemplateLM):
pytorch_config_params = {
'cuda_graph_config': {} if use_cuda_graph else None,
"print_iter_log": False,
'moe_config': MoeConfig(backend=self.moe_backend)
}
if hasattr(PyTorchConfig, "moe_backend"):
pytorch_config_params["moe_backend"] = self.moe_backend
print(f"Info: moe_backend is set to {self.moe_backend}")
# stop words not currently supported by torch backend
self.use_stop_words = False

View File

@ -175,7 +175,7 @@ class DemoLLM(LLM):
self._executor = DemoGenerationExecutor(
world_size=self.args.world_size,
tokenizer=self.tokenizer,
ad_config=self.args.get_pytorch_backend_config(),
ad_config=self.args,
)
def __del__(self):

View File

@ -403,13 +403,6 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
"""Skip tokenizer initialization in config. We do this in the AutoDeploy LLM class."""
return self
### UTILITY METHODS ############################################################################
# TODO: Remove this after the PyTorch backend is fully migrated to LlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "LlmArgs":
"""Return the LlmArgs (self) object."""
# TODO: can we just pass through self directly??
return type(self)(**self.to_llm_kwargs())
def to_dict(self) -> Dict:
"""Convert model to a dictionary such that cls(**self.to_dict()) == self."""
self_dict = super().to_dict()

View File

@ -326,8 +326,6 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
dist.initialize_or_skip(rank, world_size, port)
# some config
msg = "pytorch_backend_config must be an AD LlmArgs object"
assert isinstance(ad_config, LlmArgs), msg
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size

View File

@ -25,7 +25,6 @@ from tensorrt_llm.mapping import CpType, Mapping
from ..attention_backend import get_sparse_attn_kv_cache_manager
from ..model_config import ModelConfig
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
from .config import PyTorchConfig
from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
@ -73,7 +72,7 @@ class KvCacheCreator:
max_seq_len: int,
max_batch_size: int,
kv_cache_config: KvCacheConfig,
pytorch_backend_config: PyTorchConfig,
llm_args: TorchLlmArgs,
speculative_config: SpeculativeConfig,
sparse_attention_config: SparseAttentionConfig,
profiling_stage_data: Optional[dict],
@ -86,7 +85,7 @@ class KvCacheCreator:
self._max_num_tokens = max_num_tokens
self._max_beam_width = max_beam_width
self._kv_connector_manager = kv_connector_manager
self._pytorch_backend_config = pytorch_backend_config
self._llm_args = llm_args
self._speculative_config = speculative_config
self._sparse_attention_config = sparse_attention_config
self._tokens_per_block = tokens_per_block
@ -248,9 +247,8 @@ class KvCacheCreator:
# estimate_max_kv_cache_tokens submits self._dummy_reqs
num_cache_blocks = 0
num_extra_tokens_per_seq = 1 # account for generated tokens
pytorch_backend_config = self._pytorch_backend_config
spec_cfg = self._speculative_config
if not pytorch_backend_config.disable_overlap_scheduler:
if not self._llm_args.disable_overlap_scheduler:
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
@ -653,7 +651,7 @@ def create_py_executor_instance(
dist,
resources,
mapping,
pytorch_backend_config,
llm_args,
ctx_chunk_config,
model_engine,
start_worker,
@ -680,7 +678,7 @@ def create_py_executor_instance(
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
)
for key, value in pytorch_backend_config.extra_resource_managers.items():
for key, value in llm_args.extra_resource_managers.items():
if key in resources:
raise ValueError(
f"Cannot overwrite existing resource manager {key}.")
@ -805,8 +803,7 @@ def create_py_executor_instance(
drafter=drafter,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=pytorch_backend_config.
disable_overlap_scheduler,
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
max_draft_len=spec_config.max_draft_len
@ -842,13 +839,11 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
)
def instantiate_sampler(engine: PyTorchModelEngine,
pytorch_backend_config: PyTorchConfig, mapping: Mapping,
max_batch_size: int, max_beam_width: int,
max_seq_len: int, mm_encoder_only: bool,
speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig,
kv_cache_config: KvCacheConfig):
def instantiate_sampler(
engine: PyTorchModelEngine, llm_args: TorchLlmArgs, mapping: Mapping,
max_batch_size: int, max_beam_width: int, max_seq_len: int,
mm_encoder_only: bool, speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig, kv_cache_config: KvCacheConfig):
sampler_args = create_torch_sampler_args(
mapping,
max_seq_len=engine.max_seq_len,
@ -858,7 +853,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
max_beam_width=max_beam_width)
if mapping.cp_config.get('cp_type') == CpType.STAR:
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
assert llm_args.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
@ -867,15 +862,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
if mm_encoder_only:
# NOTE: handle model outputs specially for mm encoder executor/engine
return EarlyStopWithMMResult()
if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or (
pytorch_backend_config.sampler_type == SamplerType.auto
if llm_args.sampler_type == SamplerType.TRTLLMSampler or (
llm_args.sampler_type == SamplerType.auto
and decoding_mode.isBeamSearch()):
logger.debug(f"DecodingMode: {decoding_mode.name}")
return TRTLLMSampler(engine.model,
engine.dtype,
mapping,
decoding_mode,
pytorch_backend_config.disable_overlap_scheduler,
llm_args.disable_overlap_scheduler,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
@ -937,7 +932,12 @@ def _try_infer_num_experts(model_config: ModelConfig) -> int:
return num_experts
def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
def _adjust_torch_mem_fraction():
# If true, adjust PyTorch CUDA memory fraction to correspond to the
# total GPU memory minus the statically allocated engine memory.
# If false, set the PyTorch CUDA memory fraction to 1.0.
_limit_torch_cuda_mem_fraction: bool = True
# FIXME: PyTorch only uses the garbage_collection_threshold setting
# if a memory fraction is set, cf.
# https://github.com/pytorch/pytorch/blob/cd995bfb2aac8891465809be3ce29543bd524287/c10/cuda/CUDACachingAllocator.cpp#L1357
@ -966,7 +966,7 @@ def _adjust_torch_mem_fraction(pytorch_backend_config: PyTorchConfig):
# lead PyTorch to release all unused memory before hitting the set fraction. This
# still mitigates OOM, although at a higher performance impact, because it
# effectively resets the allocator cache.
if not pytorch_backend_config._limit_torch_cuda_mem_fraction:
if not _limit_torch_cuda_mem_fraction:
return
mem_reserved = torch.cuda.memory_reserved()
mem_free, mem_total = torch.cuda.mem_get_info()

View File

@ -1,142 +0,0 @@
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader
from ...llmapi.llm_args import LoadFormat, SamplerType
from ..model_config import MoeLoadBalancerConfig
from .resource_manager import BaseResourceManager
@dataclass
class PyTorchConfig:
"""
Extra arguments for the pytorch backend.
"""
# Extra resource managers to use in addition to the KV cache manager.
# Each manager's prepare_resources method is called before the forward pass,
# and update_resources() is called after the pass finishes. free_resources()
# is called when a request finishes.
# The KV cache manager is guaranteed to be invoked after all of these extra
# managers in all stages.
extra_resource_managers: Dict[str, BaseResourceManager] = field(
default_factory=dict)
# If true, use CUDA graphs for decoding. CUDA graphs are only created
# for the batch sizes in cuda_graph_batch_sizes, and are enabled for
# batches that consist of decoding requests *only* (the reason is that
# it's hard to capture a single graph with prefill requests since the
# input shapes are a function of the sequence lengths).
# Note that each CUDA graph can use up to 200 MB of extra memory.
use_cuda_graph: bool = True
cuda_graph_batch_sizes: Optional[List[int]] = None
cuda_graph_max_batch_size: int = 0
# If true, batches are rounded up to the nearest cuda_graph_batch_size.
# This is usually a net win for performance.
cuda_graph_padding_enabled: bool = False
disable_overlap_scheduler: bool = False
# If set, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time.
# If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used.
moe_max_num_tokens: Optional[int] = None
moe_load_balancer: Optional[Union[MoeLoadBalancerConfig, dict, str]] = None
attention_dp_enable_balance: bool = False
attention_dp_time_out_iters: int = 50
attention_dp_batching_wait_iters: int = 10
max_num_tokens: int = 8192
batch_wait_timeout_ms: float = 0
# Iterations to wait before scheduling context even if token budget not reached (0 disables).
batch_wait_timeout_iters: int = 0
# Threshold ratio of max_num_tokens for token accumulation before scheduling context.
# Value range: [0, 1] (0 disables).
batch_wait_max_tokens_ratio: float = 0.0
attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS'
moe_disable_finalize_fusion: bool = False
use_low_precision_moe_combine: bool = False
sampler_type: SamplerType = SamplerType.auto
"""
The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto.
Defaults to auto, which will use TorchSampler unless BeamSearch is requested.
"""
kv_cache_dtype: str = "auto"
mamba_ssm_cache_dtype: str = "auto"
enable_iter_perf_stats: bool = False
# If true, enables per request stats per iteration
# Must also set enable_iter_perf_stats to true to get request stats
enable_iter_req_stats: bool = False
print_iter_log: bool = False
torch_compile_enabled: bool = False
torch_compile_fullgraph: bool = True
torch_compile_inductor_enabled: bool = False
torch_compile_piecewise_cuda_graph: bool = False
torch_compile_piecewise_cuda_graph_num_tokens: Optional[List[int]] = None
# When torch compile is enabled, userbuffers is enabled by default
torch_compile_enable_userbuffers: bool = True
torch_compile_max_num_streams: int = 1
# Enable autotuner only when torch compile is enabled
# TODO: after it can be work stable in warmup stage
enable_autotuner: bool = True
# If true, enable layerwise nvtx marker
enable_layerwise_nvtx_marker: bool = False
# How to load the model weights. By default, detect the weight type
# from the model checkpoint.
load_format: Union[str, LoadFormat] = 'auto'
# If true, enable min-latency mode. Currently only used for Llama4.
enable_min_latency: bool = False
allreduce_strategy: str = "AUTO"
# The iteration interval to create responses under the streaming mode.
# TODO: make this a per-request parameter
stream_interval: int = 1
force_dynamic_quantization: bool = False
# If true, ONLY the vision encoder part of the full model is loaded/executed.
mm_encoder_only: bool = False
# Enable extra setup to support sleep feature.
enable_sleep: bool = False
# If true, adjust PyTorch CUDA memory fraction to correspond to the
# total GPU memory minus the statically allocated engine memory.
# If false, set the PyTorch CUDA memory fraction to 1.0.
_limit_torch_cuda_mem_fraction: bool = True
def _construct_checkpoint_loader(
backend: str, checkpoint_loader: Optional[BaseCheckpointLoader],
checkpoint_format: Optional[str]) -> Optional[BaseCheckpointLoader]:
if backend == "_autodeploy":
return None
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader
from tensorrt_llm._torch.models.modeling_utils import (
get_checkpoint_weight_loader, get_config_loader)
if checkpoint_loader is None:
checkpoint_weight_loader = get_checkpoint_weight_loader(
checkpoint_format)()
config_loader = get_config_loader(checkpoint_format)()
checkpoint_loader = BaseCheckpointLoader.get(
checkpoint_format=checkpoint_format,
weight_loader=checkpoint_weight_loader,
weight_mapper=None,
config_loader=config_loader)
return checkpoint_loader

View File

@ -54,13 +54,12 @@ from ..speculative.utils import SpecDecodingTensor
from ..utils import (get_model_extra_attrs,
set_per_request_piecewise_cuda_graph_flag,
set_torch_compiling, with_model_extra_attrs)
from .config import _construct_checkpoint_loader
from .config_utils import is_mla
from .cuda_graph_runner import CUDAGraphRunner
from .guided_decoder import CapturableGuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import get_draft_token_length
from .model_loader import ModelLoader
from .model_loader import ModelLoader, _construct_checkpoint_loader
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager, ResourceManagerType)
from .sampler import SampleStateTensors

View File

@ -6,6 +6,8 @@ from typing import Callable, Optional, Tuple
import torch
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm.logger import logger
@ -14,13 +16,13 @@ from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
from ...llmapi.llm_args import LoadFormat
from ..model_config import ModelConfig
from ..models import AutoModelForCausalLM
from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
from ..models.modeling_utils import MetaInitMode, timing
from ..modules.fused_moe.moe_load_balancer import (
MoeLoadBalancer, maybe_create_moe_load_balancer)
from .config import LoadFormat
_KV_CACHE_MAP = {
"fp8": QuantAlgo.FP8.value,
@ -63,7 +65,7 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
if not valid_pyt_quant:
raise ValueError(
"Overriding KV cache quantization with an invalid type "
f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" '
f'"llm_args.KvCacheConfig.dtype="{pyt_kv_cache_dtype}" '
f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".')
# If we get to this point we have a valid quantization setting, but if
@ -71,7 +73,7 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig,
if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant:
raise RuntimeError(
"Attempting to override KV cache quantization "
f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype='
f'"{kv_cache_quant}" with llm_args.KvCacheConfig.dtype='
f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a '
"pre-quantized KV cache that doesn't match.")
@ -151,6 +153,31 @@ def get_rank_model_storage(model):
return total_bytes
def _construct_checkpoint_loader(
backend: str, checkpoint_loader: Optional[BaseCheckpointLoader],
checkpoint_format: Optional[str]) -> Optional[BaseCheckpointLoader]:
if backend == "_autodeploy":
return None
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
BaseCheckpointLoader
from tensorrt_llm._torch.models.modeling_utils import (
get_checkpoint_weight_loader, get_config_loader)
if checkpoint_loader is None:
checkpoint_weight_loader = get_checkpoint_weight_loader(
checkpoint_format)()
config_loader = get_config_loader(checkpoint_format)()
checkpoint_loader = BaseCheckpointLoader.get(
checkpoint_format=checkpoint_format,
weight_loader=checkpoint_weight_loader,
weight_mapper=None,
config_loader=config_loader)
return checkpoint_loader
class ModelLoader:
"""
Handles the loading, configuration, and weight initialization of a PyTorch model.

View File

@ -33,7 +33,6 @@ from ..virtual_memory import scope as virtual_memory_scope
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
create_py_executor_instance, instantiate_sampler, is_mla,
validate_feature_combination)
from .config import PyTorchConfig
from .config_utils import is_mla
from .guided_decoder import CapturableGuidedDecoder, GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
@ -225,10 +224,6 @@ def create_py_executor(
lora_config = llm_args.lora_config
kv_connector_config = llm_args.kv_connector_config
pytorch_backend_config = llm_args.get_pytorch_backend_config()
if pytorch_backend_config is None:
pytorch_backend_config = PyTorchConfig()
scheduler_config = llm_args.scheduler_config
# Since peft_cache_config may be subject to change, avoid these changes propagate back
@ -257,23 +252,19 @@ def create_py_executor(
) = llm_args.get_runtime_sizes()
tokens_per_block = kv_cache_config.tokens_per_block
if pytorch_backend_config.attn_backend == "VANILLA":
if llm_args.attn_backend == "VANILLA":
tokens_per_block = max_num_tokens
if pytorch_backend_config.attn_backend in [
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"
]:
if llm_args.attn_backend in ["FLASHINFER", "FLASHINFER_STAR_ATTENTION"]:
# Workaround for flashinfer and star attention
if kv_cache_config.enable_block_reuse:
logger.warning(
f"Disabling block reuse for {pytorch_backend_config.attn_backend} backend"
)
f"Disabling block reuse for {llm_args.attn_backend} backend")
kv_cache_config.enable_block_reuse = False
if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and enable_chunked_context:
if llm_args.attn_backend == "FLASHINFER_STAR_ATTENTION" and enable_chunked_context:
logger.warning(
f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"
)
f"Disabling chunked context for {llm_args.attn_backend} backend")
enable_chunked_context = False
spec_config = llm_args.speculative_config
@ -281,28 +272,22 @@ def create_py_executor(
from tensorrt_llm._torch.speculative import suggest_spec_config
spec_config = suggest_spec_config(max_batch_size)
if not pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
if not llm_args.disable_overlap_scheduler and spec_config is not None:
if not spec_config.spec_dec_mode.support_overlap_scheduler():
logger.warning(
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
)
# TODO(qijun): clean up pytorch_backend_config later
pytorch_backend_config.disable_overlap_scheduler = True
llm_args.disable_overlap_scheduler = True
if mm_encoder_only:
# TODO(qijun): clean up pytorch_backend_config later
pytorch_backend_config.mm_encoder_only = True
llm_args.mm_encoder_only = True
llm_args.disable_overlap_scheduler = True
# Disable overlap scheduler for multimodal encoder-only mode
logger.warning(
"Disabling overlap scheduler for multimodal encoder-only mode. "
"The overlap scheduler is designed for generation models and is not needed "
"when only processing vision encoder inputs.")
pytorch_backend_config.disable_overlap_scheduler = True
llm_args.mm_encoder_only = True
llm_args.disable_overlap_scheduler = True
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
if mpi_disabled():
@ -311,7 +296,7 @@ def create_py_executor(
dist = MPIDist(mapping=mapping)
vm_pools = {}
enable_sleep = pytorch_backend_config.enable_sleep
enable_sleep = llm_args.enable_sleep
cache_transceiver_config = llm_args.cache_transceiver_config
@ -357,19 +342,17 @@ def create_py_executor(
spec_config=spec_config,
)
validate_feature_combination(llm_args, model_engine,
pytorch_backend_config.sampler_type)
validate_feature_combination(llm_args, model_engine, llm_args.sampler_type)
if has_draft_model_engine:
with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_DRAFT,
RestoreMode.PINNED):
draft_spec_config = copy.copy(spec_config)
use_chain_drafter = (
guided_decoding_config is None
and draft_spec_config._allow_chain_drafter
and draft_spec_config._allow_greedy_draft_tokens
and pytorch_backend_config.attn_backend == "TRTLLM")
use_chain_drafter = (guided_decoding_config is None
and draft_spec_config._allow_chain_drafter and
draft_spec_config._allow_greedy_draft_tokens
and llm_args.attn_backend == "TRTLLM")
logger.debug(f"USE CHAIN DRAFTER: {use_chain_drafter}")
if use_chain_drafter:
@ -384,11 +367,8 @@ def create_py_executor(
else:
drafting_loop_wrapper = None
# TODO(qijun): clean up pytorch_backend_config later
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
draft_llm_args = copy.copy(llm_args)
if spec_config.load_format == "dummy":
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
draft_llm_args.load_format = LoadFormat.DUMMY
draft_model_engine = PyTorchModelEngine(
@ -413,7 +393,7 @@ def create_py_executor(
# PyTorchModelEngine modifies these fields, update them
model_engine_max_seq_len = model_engine.max_seq_len
net_max_seq_len = model_engine_max_seq_len
if not pytorch_backend_config.disable_overlap_scheduler:
if not llm_args.disable_overlap_scheduler:
model_engine_max_seq_len = model_engine.max_seq_len + 1
if spec_config is not None:
model_engine_max_seq_len += spec_config.max_total_draft_tokens
@ -514,7 +494,7 @@ def create_py_executor(
with allocation_scope(ExecutorMemoryType.SAMPLER, RestoreMode.PINNED):
sampler = instantiate_sampler(model_engine,
pytorch_backend_config,
llm_args,
mapping,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
@ -592,7 +572,7 @@ def create_py_executor(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_backend_config,
llm_args=llm_args,
speculative_config=spec_config,
profiling_stage_data=profiling_stage_data,
sparse_attention_config=sparse_attention_config,
@ -634,7 +614,7 @@ def create_py_executor(
dist=dist,
resources=resources,
mapping=mapping,
pytorch_backend_config=pytorch_backend_config,
llm_args=llm_args,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
@ -681,7 +661,7 @@ def create_py_executor(
if eng is None:
continue
if eng.attn_metadata is not None:
if pytorch_backend_config.use_cuda_graph:
if llm_args.cuda_graph_config is not None:
eng._release_cuda_graphs()
eng.attn_metadata = None
@ -691,7 +671,7 @@ def create_py_executor(
dist=dist,
resources=resources,
mapping=mapping,
pytorch_backend_config=pytorch_backend_config,
llm_args=llm_args,
ctx_chunk_config=ctx_chunk_config,
model_engine=model_engine,
start_worker=False,
@ -712,7 +692,7 @@ def create_py_executor(
virtual_memory_pools=vm_pools,
)
_adjust_torch_mem_fraction(pytorch_backend_config)
_adjust_torch_mem_fraction()
py_executor.start_worker()
return py_executor

View File

@ -8,7 +8,6 @@ from pydantic import (BaseModel, Field, PositiveFloat, field_validator,
model_validator)
import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.llmapi import (BatchingType, CapacitySchedulerPolicy,
ContextChunkingPolicy, DynamicBatchConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig,
@ -126,7 +125,7 @@ class PerformanceOptions:
return config
def get_pytorch_perf_config(self) -> PyTorchConfig:
def get_pytorch_perf_config(self):
return self.pytorch_config
def get_autodeploy_perf_config(self) -> Dict:

View File

@ -175,7 +175,7 @@ class BaseWorker(GenerationExecutor):
create_autodeploy_executor
create_executor = create_autodeploy_executor
assert isinstance(self.llm_args, ADLlmArgs)
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
args["ad_config"] = self.llm_args
args["tokenizer"] = self._tokenizer
else:
raise ValueError(f"Unsupported backend config: {self._backend}")
@ -184,7 +184,7 @@ class BaseWorker(GenerationExecutor):
self.mapping = self.llm_args.parallel_config.to_mapping()
self.checkpoint_loader = None
if self._backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.config import \
from tensorrt_llm._torch.pyexecutor.model_loader import \
_construct_checkpoint_loader
self.checkpoint_loader = _construct_checkpoint_loader(
self.llm_args.backend, self.llm_args.checkpoint_loader,

View File

@ -8,9 +8,8 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
Set, Tuple, Type, TypeAlias, TypeVar, Union, get_args,
get_origin)
from typing import (Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple,
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
import torch
import yaml
@ -25,9 +24,6 @@ from tensorrt_llm.lora_helper import (LoraConfig,
from .._utils import mpi_rank
if TYPE_CHECKING:
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
# yapf: disable
# isort: off
from ..bindings.executor import (BatchingType as _BatchingType,
@ -2838,79 +2834,6 @@ class TorchLlmArgs(BaseLlmArgs):
executor_config.mm_encoder_only = self.mm_encoder_only
return executor_config
# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "PyTorchConfig":
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
return PyTorchConfig(
extra_resource_managers=self.extra_resource_managers,
use_cuda_graph=bool(self.cuda_graph_config is not None),
cuda_graph_batch_sizes=self.cuda_graph_config.batch_sizes
if self.cuda_graph_config else
CudaGraphConfig.model_fields['batch_sizes'].default,
cuda_graph_max_batch_size=self.cuda_graph_config.max_batch_size
if self.cuda_graph_config else
CudaGraphConfig.model_fields['max_batch_size'].default,
cuda_graph_padding_enabled=self.cuda_graph_config.enable_padding
if self.cuda_graph_config else
CudaGraphConfig.model_fields['enable_padding'].default,
disable_overlap_scheduler=self.disable_overlap_scheduler,
moe_max_num_tokens=self.moe_config.max_num_tokens,
moe_load_balancer=self.moe_config.load_balancer,
attn_backend=self.attn_backend,
moe_backend=self.moe_config.backend,
use_low_precision_moe_combine=self.moe_config.
use_low_precision_moe_combine,
sampler_type=self.sampler_type,
kv_cache_dtype=self.kv_cache_config.dtype,
mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype,
enable_iter_perf_stats=self.enable_iter_perf_stats,
enable_iter_req_stats=self.enable_iter_req_stats,
print_iter_log=self.print_iter_log,
torch_compile_enabled=bool(self.torch_compile_config is not None),
torch_compile_fullgraph=self.torch_compile_config.enable_fullgraph
if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['enable_fullgraph'].default,
torch_compile_inductor_enabled=self.torch_compile_config.
enable_inductor if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['enable_inductor'].default,
torch_compile_piecewise_cuda_graph=self.torch_compile_config.
enable_piecewise_cuda_graph
if self.torch_compile_config is not None else TorchCompileConfig.
model_fields['enable_piecewise_cuda_graph'].default,
torch_compile_piecewise_cuda_graph_num_tokens=self.
torch_compile_config.capture_num_tokens
if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['capture_num_tokens'].default,
torch_compile_enable_userbuffers=self.torch_compile_config.
enable_userbuffers if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['enable_userbuffers'].default,
torch_compile_max_num_streams=self.torch_compile_config.
max_num_streams if self.torch_compile_config is not None else
TorchCompileConfig.model_fields['max_num_streams'].default,
enable_autotuner=self.enable_autotuner,
enable_layerwise_nvtx_marker=self.enable_layerwise_nvtx_marker,
load_format=self.load_format,
enable_min_latency=self.enable_min_latency,
moe_disable_finalize_fusion=self.moe_config.disable_finalize_fusion,
stream_interval=self.stream_interval,
force_dynamic_quantization=self.force_dynamic_quantization,
allreduce_strategy=self.allreduce_strategy,
attention_dp_enable_balance=bool(
self.attention_dp_config is not None
and self.attention_dp_config.enable_balance),
attention_dp_time_out_iters=self.attention_dp_config.timeout_iters
if self.attention_dp_config is not None else
AttentionDpConfig.model_fields['timeout_iters'].default,
attention_dp_batching_wait_iters=self.attention_dp_config.
batching_wait_iters if self.attention_dp_config is not None else
AttentionDpConfig.model_fields['batching_wait_iters'].default,
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
enable_sleep=self.enable_sleep,
)
def update_llm_args_with_extra_dict(
llm_args: Dict,

View File

@ -19,9 +19,9 @@ import torch
from defs.conftest import get_sm_version, is_sm_100f
from tensorrt_llm import LLM
from tensorrt_llm._torch.model_config import MoeLoadBalancerConfig
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,

View File

@ -65,12 +65,6 @@ def test_free_mem_ratio_validation():
InferenceOptimizer(None, get_transform_config(1.1))
def test_get_pytorch_backend_config():
"""Test that get_pytorch_backend_config returns self."""
args = LlmArgs(model="test-model")
assert args.get_pytorch_backend_config() == args
# ================================
# Config Flow Tests
# ================================

View File

@ -20,7 +20,6 @@ from tensorrt_llm._torch.models.checkpoints.hf.llama4_weight_mapper import \
Llama4HfWeightMapper
from tensorrt_llm._torch.models.modeling_llama import \
Llama4ForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
@ -158,8 +157,9 @@ class TestLlama4MinLatency(unittest.TestCase):
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=llama_config,
quant_config=quant_config)
model_config.pytorch_backend_config = PyTorchConfig(
enable_min_latency=enable_min_latency)
model_config.enable_min_latency = enable_min_latency
# TODO: enable llama4 min latency test
model_config.enable_min_latency = False
llama = Llama4ForConditionalGeneration(model_config)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
@ -291,8 +291,9 @@ class TestLlama4MinLatency(unittest.TestCase):
model_config = ModelConfig(pretrained_config=llama_config,
attn_backend=attention_backend)
model_config.pytorch_backend_config = PyTorchConfig(
enable_min_latency=enable_min_latency)
model_config.enable_min_latency = enable_min_latency
# TODO: enable llama4 min latency test
model_config.enable_min_latency = False
llama = Llama4ForConditionalGeneration(model_config)
weight_mapper = Llama4HfWeightMapper()
weight_mapper.init_model_and_config(llama, model_config)