[None][fix] Refactoring to avoid circular import when importing torch models (#6720)

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
rakib-hasan 2025-08-11 15:00:42 -07:00 committed by GitHub
parent c9fe07ede6
commit 7ab8112450
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 159 additions and 105 deletions

View File

@ -33,7 +33,7 @@ The PyTorch backend provides LoRA support, allowing you to:
```python ```python
from tensorrt_llm import LLM from tensorrt_llm import LLM
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.sampling_params import SamplingParams from tensorrt_llm.sampling_params import SamplingParams

View File

@ -5,7 +5,7 @@ from huggingface_hub import snapshot_download
from tensorrt_llm import LLM from tensorrt_llm import LLM
from tensorrt_llm.executor import LoRARequest from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
def main(): def main():

View File

@ -33,6 +33,7 @@ import sys
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised. # otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
import xgrammar # noqa import xgrammar # noqa
import tensorrt_llm._torch.models as torch_models
import tensorrt_llm.functional as functional import tensorrt_llm.functional as functional
import tensorrt_llm.math_utils as math_utils import tensorrt_llm.math_utils as math_utils
import tensorrt_llm.models as models import tensorrt_llm.models as models
@ -82,6 +83,7 @@ __all__ = [
'default_trtnet', 'default_trtnet',
'precision', 'precision',
'net_guard', 'net_guard',
'torch_models',
'Network', 'Network',
'Mapping', 'Mapping',
'MnnvlMemory', 'MnnvlMemory',

View File

@ -22,7 +22,7 @@ from ...executor.request import LoRARequest
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor) register_input_processor)
from ...logger import logger from ...logger import logger
from ...lora_manager import LoraConfig from ...lora_helper import LoraConfig
from ...sampling_params import SamplingParams from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig from ..model_config import ModelConfig

View File

@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from tensorrt_llm import logger import tensorrt_llm.logger as trtllm_logger
from tensorrt_llm._utils import get_sm_version from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.quantization.utils.fp4_utils import ( from tensorrt_llm.quantization.utils.fp4_utils import (
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices, float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
@ -743,7 +743,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
if int(name.split(".")[0]) not in expert_ids: if int(name.split(".")[0]) not in expert_ids:
continue continue
weight_name = name.replace("weight_scale_inv", "weight") weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}") trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:] weight = weights[weight_name][:]
scale = weights[name][:] scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(

View File

@ -13,9 +13,9 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig, from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, get_default_trtllm_modules_to_hf_modules)
load_torch_lora) from tensorrt_llm.lora_manager import load_torch_lora
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from ..model_config import ModelConfig from ..model_config import ModelConfig

View File

@ -27,7 +27,8 @@ from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
from tensorrt_llm.inputs.multimodal import (MultimodalParams, from tensorrt_llm.inputs.multimodal import (MultimodalParams,
MultimodalRuntimeData) MultimodalRuntimeData)
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraModelConfig
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo from tensorrt_llm.models.modeling_utils import QuantAlgo
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2

View File

@ -13,7 +13,7 @@ from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization import QuantAlgo from tensorrt_llm.quantization import QuantAlgo

View File

@ -10,7 +10,8 @@ import torch
import tensorrt_llm import tensorrt_llm
import tensorrt_llm.bindings import tensorrt_llm.bindings
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.sampling_params import SamplingParams from tensorrt_llm.sampling_params import SamplingParams
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range

View File

@ -36,7 +36,7 @@ from .bindings import KVCacheType
from .functional import PositionEmbeddingType from .functional import PositionEmbeddingType
from .graph_rewriting import optimize from .graph_rewriting import optimize
from .logger import logger from .logger import logger
from .lora_manager import LoraConfig from .lora_helper import LoraConfig
from .models import PretrainedConfig, PretrainedModel from .models import PretrainedConfig, PretrainedModel
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
from .network import Network, net_guard from .network import Network, net_guard

View File

@ -31,7 +31,8 @@ from tensorrt_llm.auto_parallel.cluster_info import cluster_infos
from tensorrt_llm.bindings import KVCacheType from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.builder import BuildConfig, Engine, build from tensorrt_llm.builder import BuildConfig, Engine, build
from tensorrt_llm.logger import logger, severity_map from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.lora_manager import LoraConfig, LoraManager from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument from tensorrt_llm.plugin import PluginConfig, add_plugin_argument

View File

@ -1,6 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import List, Optional
# isort: off
# needed before trying to import bindings to load tensorrt_libs
import tensorrt as trt # noqa
# isort: on
from tensorrt_llm.bindings import executor as tllme from tensorrt_llm.bindings import executor as tllme

View File

@ -15,7 +15,7 @@ import torch
from tensorrt_llm.inputs.multimodal import MultimodalParams from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger, set_level from tensorrt_llm.logger import logger, set_level
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from .._utils import mpi_world_size from .._utils import mpi_world_size
from ..bindings import executor as tllm from ..bindings import executor as tllm

View File

@ -24,7 +24,8 @@ from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug, clear_sched_affinity, print_colored_debug,
print_traceback_on_error) print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager from ..lora_helper import LoraConfig
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig from ..runtime import ModelConfig

View File

@ -12,7 +12,7 @@ from typing import Any, List, Optional
import filelock import filelock
import tensorrt_llm import tensorrt_llm
from tensorrt_llm import BuildConfig from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
from tensorrt_llm.logger import logger from tensorrt_llm.logger import logger

View File

@ -19,8 +19,8 @@ from pydantic import PrivateAttr, field_validator, model_validator
from strenum import StrEnum from strenum import StrEnum
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from tensorrt_llm.lora_manager import (LoraConfig, from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules) get_default_trtllm_modules_to_hf_modules)
from .._utils import mpi_rank from .._utils import mpi_rank
from ..auto_parallel import AutoParallelConfig, infer_cluster_config from ..auto_parallel import AutoParallelConfig, infer_cluster_config

101
tensorrt_llm/lora_helper.py Normal file
View File

@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from ._utils import DictConversion
def get_missing_qkv_modules_from_lora_modules(
lora_target_modules: List[str]) -> List[str]:
"""Get missing QKV modules from LoRA target modules.
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them,
so we use zero tensor to fill the missing ones.
"""
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules
def get_default_trtllm_modules_to_hf_modules():
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}
def use_lora(
model,
lora_config: "LoraConfig",
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
"""Use LoRA with the given model and configuration.
This function is a wrapper that delegates to the appropriate loading function
based on the LoRA checkpoint source.
"""
if lora_config.lora_ckpt_source == "nemo":
from .lora_manager import load_nemo_lora
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
from .lora_manager import load_hf_lora
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: Optional[int] = None
max_cpu_loras: Optional[int] = None
def __post_init__(self):
assert self.lora_ckpt_source in [
"hf", "nemo"
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)
@property
def missing_qkv_modules(self) -> List[str]:
return get_missing_qkv_modules_from_lora_modules(
self.lora_target_modules)

View File

@ -5,7 +5,7 @@ import re
import tarfile import tarfile
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
@ -16,8 +16,13 @@ import yaml
from tensorrt_llm.bindings import internal as tb_internal from tensorrt_llm.bindings import internal as tb_internal
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from .layers.linear import ColumnLinear from .layers.linear import ColumnLinear
from .lora_helper import (
LoraConfig,
get_default_trtllm_modules_to_hf_modules,
get_missing_qkv_modules_from_lora_modules,
)
from .mapping import Mapping from .mapping import Mapping
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
@ -232,26 +237,6 @@ def norm_dora_magnitude(
return norm_m return norm_m
@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int | None = None
max_cpu_loras: int | None = None
def __post_init__(self):
assert self.lora_ckpt_source in ["hf", "nemo"], (
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)
@property
def missing_qkv_modules(self) -> List[str]:
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
@dataclass @dataclass
class LoraModelConfig: class LoraModelConfig:
lora_target_modules: list[str] lora_target_modules: list[str]
@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
lora_config.lora_target_modules = lora_loader.lora_target_modules lora_config.lora_target_modules = lora_loader.lora_target_modules
def get_default_trtllm_modules_to_hf_modules():
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}
def load_torch_hf_lora(lora_config: LoraConfig): def load_torch_hf_lora(lora_config: LoraConfig):
"""This is a shortned version of load_hf_lora that is used for torch models. """This is a shortned version of load_hf_lora that is used for torch models.
@ -628,19 +596,6 @@ def load_hf_lora(
).to(torch_dtype) ).to(torch_dtype)
def use_lora(
model,
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
if lora_config.lora_ckpt_source == "nemo":
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]: def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
"""Unpack model config and weights from a NeMo .nemo archive file. """Unpack model config and weights from a NeMo .nemo archive file.
@ -762,21 +717,8 @@ class LoraManager(object):
) )
@staticmethod @staticmethod
def get_missing_qkv_modules(lora_target_modules): def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or return get_missing_qkv_modules_from_lora_modules(lora_target_modules)
# all disabled at the same time.
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
# to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules
@property @property
def missing_qkv_modules(self) -> List[str]: def missing_qkv_modules(self) -> List[str]:

View File

@ -36,9 +36,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
LanguageAdapterConfig, LayerNorm, LoraParams, LanguageAdapterConfig, LayerNorm, LoraParams,
PromptTuningEmbedding, RmsNorm) PromptTuningEmbedding, RmsNorm)
# yapf: enable # yapf: enable
from tensorrt_llm.lora_manager import (LoraConfig, from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, get_default_trtllm_modules_to_hf_modules,
use_lora) use_lora)
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
from tensorrt_llm.module import Module, ModuleList from tensorrt_llm.module import Module, ModuleList

View File

@ -28,7 +28,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType,
from ...layers import (Attention, AttentionMaskType, AttentionParams, from ...layers import (Attention, AttentionMaskType, AttentionParams,
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams, ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
LoraParams, PositionEmbeddingType, RmsNorm) LoraParams, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -21,7 +21,7 @@ from ...functional import (Tensor, is_gated_activation, non_gated_version, recv,
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear, from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, LayerNorm, MoeConfig, Embedding, GatedMLP, LayerNorm, MoeConfig,
PositionEmbeddingType) PositionEmbeddingType)
from ...lora_manager import LoraConfig, use_lora from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ...quantization import QuantMode from ...quantization import QuantMode

View File

@ -18,7 +18,7 @@ from ..._utils import pad_vocab_size
from ...functional import Tensor, recv, send from ...functional import Tensor, recv, send
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm) Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -25,7 +25,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor,
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear, from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, FusedGatedMLP, GatedMLP, Embedding, FusedGatedMLP, GatedMLP,
PositionEmbeddingType, RmsNorm) PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ...quantization.functional import fused_layernorm from ...quantization.functional import fused_layernorm

View File

@ -32,9 +32,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
ColumnLinear, Embedding, FusedGatedMLP, ColumnLinear, Embedding, FusedGatedMLP,
GatedMLP, GroupNorm, KeyValueCacheParams, GatedMLP, GroupNorm, KeyValueCacheParams,
LayerNorm, LoraParams, RmsNorm) LayerNorm, LoraParams, RmsNorm)
from tensorrt_llm.lora_manager import (LoraConfig, from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, get_default_trtllm_modules_to_hf_modules,
use_lora) use_lora)
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig

View File

@ -20,7 +20,7 @@ from ..._utils import pad_vocab_size
from ...functional import Tensor from ...functional import Tensor
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear, from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
Embedding, LayerNorm) Embedding, LayerNorm)
from ...lora_manager import LoraConfig, use_lora from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType, Tensor
from ...layers import (MLP, MOE, Attention, AttentionMaskType, from ...layers import (MLP, MOE, Attention, AttentionMaskType,
BlockSparseAttnParams, ColumnLinear, Embedding, BlockSparseAttnParams, ColumnLinear, Embedding,
LayerNorm, MoeConfig, RmsNorm) LayerNorm, MoeConfig, RmsNorm)
from ...lora_manager import LoraConfig, use_lora from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM, from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -26,8 +26,8 @@ from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, RmsNorm, SharedMoE) Embedding, GatedMLP, RmsNorm, SharedMoE)
from ...layers.moe import MOEWeightWrapper from ...layers.moe import MOEWeightWrapper
from ...logger import logger from ...logger import logger
from ...lora_manager import (LoraConfig, from ...lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, use_lora) get_default_trtllm_modules_to_hf_modules, use_lora)
from ...mapping import Mapping from ...mapping import Mapping
from ...module import Module from ...module import Module
from ...quantization import QuantAlgo from ...quantization import QuantAlgo

View File

@ -15,7 +15,7 @@
from typing import Optional from typing import Optional
from .lora_manager import LoraConfig from .lora_helper import LoraConfig
from .mapping import Mapping from .mapping import Mapping
from .plugin.plugin import PluginConfig from .plugin.plugin import PluginConfig

View File

@ -17,7 +17,7 @@ from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.bindings import executor as tllm
from tensorrt_llm.bindings.internal.batch_manager import \ from tensorrt_llm.bindings.internal.batch_manager import \
PeftTaskNotCachedException PeftTaskNotCachedException
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
DataType = tensorrt_llm.bindings.DataType DataType = tensorrt_llm.bindings.DataType
LoraModule = tensorrt_llm.bindings.LoraModule LoraModule = tensorrt_llm.bindings.LoraModule

View File

@ -42,7 +42,7 @@ from tensorrt_llm.llmapi.llm_utils import (BuildConfig, QuantAlgo, QuantConfig,
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer, from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer,
load_hf_tokenizer) load_hf_tokenizer)
from tensorrt_llm.llmapi.utils import get_total_gpu_memory from tensorrt_llm.llmapi.utils import get_total_gpu_memory
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
from tensorrt_llm.sampling_params import (BatchedLogitsProcessor, from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,

View File

@ -12,7 +12,7 @@ from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.executor import GenerationExecutorProxy from tensorrt_llm.executor import GenerationExecutorProxy
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import PretrainedConfig from tensorrt_llm.models import PretrainedConfig
from tensorrt_llm.models.llama.model import LLaMAForCausalLM from tensorrt_llm.models.llama.model import LLaMAForCausalLM

View File

@ -4,7 +4,7 @@ import pytest
from .test_llm import tinyllama_logits_processor_test_harness from .test_llm import tinyllama_logits_processor_test_harness
from tensorrt_llm import LLM from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
from .test_llm import _test_llm_capture_request_error from .test_llm import _test_llm_capture_request_error

View File

@ -25,7 +25,7 @@ from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_80gb,
skip_gpu_memory_less_than_138gb) skip_gpu_memory_less_than_138gb)
from utils.llm_data import llm_models_root from utils.llm_data import llm_models_root
from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.models.modeling_utils import QuantConfig from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo from tensorrt_llm.quantization.mode import QuantAlgo