[None][fix] Refactoring to avoid circular import when importing torch models (#6720)

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
rakib-hasan 2025-08-11 15:00:42 -07:00 committed by GitHub
parent c9fe07ede6
commit 7ab8112450
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 159 additions and 105 deletions

View File

@ -33,7 +33,7 @@ The PyTorch backend provides LoRA support, allowing you to:
```python
from tensorrt_llm import LLM
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.sampling_params import SamplingParams

View File

@ -5,7 +5,7 @@ from huggingface_hub import snapshot_download
from tensorrt_llm import LLM
from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
def main():

View File

@ -33,6 +33,7 @@ import sys
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
import xgrammar # noqa
import tensorrt_llm._torch.models as torch_models
import tensorrt_llm.functional as functional
import tensorrt_llm.math_utils as math_utils
import tensorrt_llm.models as models
@ -82,6 +83,7 @@ __all__ = [
'default_trtnet',
'precision',
'net_guard',
'torch_models',
'Network',
'Mapping',
'MnnvlMemory',

View File

@ -22,7 +22,7 @@ from ...executor.request import LoRARequest
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
from ...logger import logger
from ...lora_manager import LoraConfig
from ...lora_helper import LoraConfig
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig

View File

@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from tensorrt_llm import logger
import tensorrt_llm.logger as trtllm_logger
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.quantization.utils.fp4_utils import (
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
@ -743,7 +743,7 @@ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
trtllm_logger.logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(

View File

@ -13,9 +13,9 @@ from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
load_torch_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_manager import load_torch_lora
from tensorrt_llm.mapping import Mapping
from ..model_config import ModelConfig

View File

@ -27,7 +27,8 @@ from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
MultimodalRuntimeData)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraModelConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantAlgo
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2

View File

@ -13,7 +13,7 @@ from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization import QuantAlgo

View File

@ -10,7 +10,8 @@ import torch
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.sampling_params import SamplingParams
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range

View File

@ -36,7 +36,7 @@ from .bindings import KVCacheType
from .functional import PositionEmbeddingType
from .graph_rewriting import optimize
from .logger import logger
from .lora_manager import LoraConfig
from .lora_helper import LoraConfig
from .models import PretrainedConfig, PretrainedModel
from .models.modeling_utils import SpeculativeDecodingMode, optimize_model
from .network import Network, net_guard

View File

@ -31,7 +31,8 @@ from tensorrt_llm.auto_parallel.cluster_info import cluster_infos
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.builder import BuildConfig, Engine, build
from tensorrt_llm.logger import logger, severity_map
from tensorrt_llm.lora_manager import LoraConfig, LoraManager
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument

View File

@ -1,6 +1,11 @@
from dataclasses import dataclass
from typing import List, Optional
# isort: off
# needed before trying to import bindings to load tensorrt_libs
import tensorrt as trt # noqa
# isort: on
from tensorrt_llm.bindings import executor as tllme

View File

@ -15,7 +15,7 @@ import torch
from tensorrt_llm.inputs.multimodal import MultimodalParams
from tensorrt_llm.logger import logger, set_level
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from .._utils import mpi_world_size
from ..bindings import executor as tllm

View File

@ -24,7 +24,8 @@ from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
from ..lora_manager import LoraConfig, LoraManager
from ..lora_helper import LoraConfig
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig

View File

@ -12,7 +12,7 @@ from typing import Any, List, Optional
import filelock
import tensorrt_llm
from tensorrt_llm import BuildConfig
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
from tensorrt_llm.logger import logger

View File

@ -19,8 +19,8 @@ from pydantic import PrivateAttr, field_validator, model_validator
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from .._utils import mpi_rank
from ..auto_parallel import AutoParallelConfig, infer_cluster_config

101
tensorrt_llm/lora_helper.py Normal file
View File

@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from ._utils import DictConversion
def get_missing_qkv_modules_from_lora_modules(
lora_target_modules: List[str]) -> List[str]:
"""Get missing QKV modules from LoRA target modules.
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
all disabled at the same time. However, some lora checkpoints (e.g. BART) only contain two of them,
so we use zero tensor to fill the missing ones.
"""
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules
def get_default_trtllm_modules_to_hf_modules():
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}
def use_lora(
model,
lora_config: "LoraConfig",
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
"""Use LoRA with the given model and configuration.
This function is a wrapper that delegates to the appropriate loading function
based on the LoRA checkpoint source.
"""
if lora_config.lora_ckpt_source == "nemo":
from .lora_manager import load_nemo_lora
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
from .lora_manager import load_hf_lora
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: Optional[int] = None
max_cpu_loras: Optional[int] = None
def __post_init__(self):
assert self.lora_ckpt_source in [
"hf", "nemo"
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)
@property
def missing_qkv_modules(self) -> List[str]:
return get_missing_qkv_modules_from_lora_modules(
self.lora_target_modules)

View File

@ -5,7 +5,7 @@ import re
import tarfile
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
@ -16,8 +16,13 @@ import yaml
from tensorrt_llm.bindings import internal as tb_internal
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from .layers.linear import ColumnLinear
from .lora_helper import (
LoraConfig,
get_default_trtllm_modules_to_hf_modules,
get_missing_qkv_modules_from_lora_modules,
)
from .mapping import Mapping
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
@ -232,26 +237,6 @@ def norm_dora_magnitude(
return norm_m
@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int | None = None
max_cpu_loras: int | None = None
def __post_init__(self):
assert self.lora_ckpt_source in ["hf", "nemo"], (
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)
@property
def missing_qkv_modules(self) -> List[str]:
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
@dataclass
class LoraModelConfig:
lora_target_modules: list[str]
@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
lora_config.lora_target_modules = lora_loader.lora_target_modules
def get_default_trtllm_modules_to_hf_modules():
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}
def load_torch_hf_lora(lora_config: LoraConfig):
"""This is a shortned version of load_hf_lora that is used for torch models.
@ -628,19 +596,6 @@ def load_hf_lora(
).to(torch_dtype)
def use_lora(
model,
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
if lora_config.lora_ckpt_source == "nemo":
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
"""Unpack model config and weights from a NeMo .nemo archive file.
@ -762,21 +717,8 @@ class LoraManager(object):
)
@staticmethod
def get_missing_qkv_modules(lora_target_modules):
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
# all disabled at the same time.
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
# to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules
def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
return get_missing_qkv_modules_from_lora_modules(lora_target_modules)
@property
def missing_qkv_modules(self) -> List[str]:

View File

@ -36,9 +36,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
LanguageAdapterConfig, LayerNorm, LoraParams,
PromptTuningEmbedding, RmsNorm)
# yapf: enable
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
from tensorrt_llm.module import Module, ModuleList

View File

@ -28,7 +28,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, LayerNormType,
from ...layers import (Attention, AttentionMaskType, AttentionParams,
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
LoraParams, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -21,7 +21,7 @@ from ...functional import (Tensor, is_gated_activation, non_gated_version, recv,
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, LayerNorm, MoeConfig,
PositionEmbeddingType)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ...quantization import QuantMode

View File

@ -18,7 +18,7 @@ from ..._utils import pad_vocab_size
from ...functional import Tensor, recv, send
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -25,7 +25,7 @@ from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor,
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, FusedGatedMLP, GatedMLP,
PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ...quantization.functional import fused_layernorm

View File

@ -32,9 +32,9 @@ from tensorrt_llm.layers import (MLP, Attention, AttentionMaskParams,
ColumnLinear, Embedding, FusedGatedMLP,
GatedMLP, GroupNorm, KeyValueCacheParams,
LayerNorm, LoraParams, RmsNorm)
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig

View File

@ -20,7 +20,7 @@ from ..._utils import pad_vocab_size
from ...functional import Tensor
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
Embedding, LayerNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType, Tensor
from ...layers import (MLP, MOE, Attention, AttentionMaskType,
BlockSparseAttnParams, ColumnLinear, Embedding,
LayerNorm, MoeConfig, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

View File

@ -26,8 +26,8 @@ from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, RmsNorm, SharedMoE)
from ...layers.moe import MOEWeightWrapper
from ...logger import logger
from ...lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, use_lora)
from ...lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, use_lora)
from ...mapping import Mapping
from ...module import Module
from ...quantization import QuantAlgo

View File

@ -15,7 +15,7 @@
from typing import Optional
from .lora_manager import LoraConfig
from .lora_helper import LoraConfig
from .mapping import Mapping
from .plugin.plugin import PluginConfig

View File

@ -17,7 +17,7 @@ from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
from tensorrt_llm.bindings import executor as tllm
from tensorrt_llm.bindings.internal.batch_manager import \
PeftTaskNotCachedException
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
DataType = tensorrt_llm.bindings.DataType
LoraModule = tensorrt_llm.bindings.LoraModule

View File

@ -42,7 +42,7 @@ from tensorrt_llm.llmapi.llm_utils import (BuildConfig, QuantAlgo, QuantConfig,
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase, TransformersTokenizer,
load_hf_tokenizer)
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.models.automodel import AutoConfig, AutoModelForCausalLM
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,

View File

@ -12,7 +12,7 @@ from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm.executor import GenerationExecutorProxy
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig, SamplingParams
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import PretrainedConfig
from tensorrt_llm.models.llama.model import LLaMAForCausalLM

View File

@ -4,7 +4,7 @@ import pytest
from .test_llm import tinyllama_logits_processor_test_harness
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness
from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness
from .test_llm import _test_llm_capture_request_error

View File

@ -25,7 +25,7 @@ from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
skip_gpu_memory_less_than_80gb,
skip_gpu_memory_less_than_138gb)
from utils.llm_data import llm_models_root
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo