[#2730][fix] Fix circular import bug in medusa/weight.py (#9866)

Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
This commit is contained in:
Kanghwan 2025-12-10 21:51:08 -08:00 committed by GitHub
parent 454e7e59e5
commit d147ad053e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,8 +11,8 @@ from tqdm import tqdm
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.pytorch_utils import Conv1D
from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
smooth_gemm,
@ -51,7 +51,7 @@ def load_medusa_hf(medusa_path: str,
use_weight_only=False,
plugin_weight_only_quant_type=None,
is_modelopt_ckpt=False):
# logger.info("Loading Medusa heads' weights ...")
logger.info("Loading Medusa heads' weights ...")
if is_modelopt_ckpt:
from safetensors.torch import load_file