mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
This commit is contained in:
parent
454e7e59e5
commit
d147ad053e
@ -11,8 +11,8 @@ from tqdm import tqdm
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._utils import str_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
|
||||
smooth_gemm,
|
||||
@ -51,7 +51,7 @@ def load_medusa_hf(medusa_path: str,
|
||||
use_weight_only=False,
|
||||
plugin_weight_only_quant_type=None,
|
||||
is_modelopt_ckpt=False):
|
||||
# logger.info("Loading Medusa heads' weights ...")
|
||||
logger.info("Loading Medusa heads' weights ...")
|
||||
|
||||
if is_modelopt_ckpt:
|
||||
from safetensors.torch import load_file
|
||||
|
||||
Loading…
Reference in New Issue
Block a user