mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
This commit is contained in:
parent
454e7e59e5
commit
d147ad053e
@ -11,8 +11,8 @@ from tqdm import tqdm
|
|||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from tensorrt_llm import logger
|
|
||||||
from tensorrt_llm._utils import str_dtype_to_torch
|
from tensorrt_llm._utils import str_dtype_to_torch
|
||||||
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.mapping import Mapping
|
from tensorrt_llm.mapping import Mapping
|
||||||
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
|
from tensorrt_llm.models.convert_utils import (dup_kv_weight, generate_int8,
|
||||||
smooth_gemm,
|
smooth_gemm,
|
||||||
@ -51,7 +51,7 @@ def load_medusa_hf(medusa_path: str,
|
|||||||
use_weight_only=False,
|
use_weight_only=False,
|
||||||
plugin_weight_only_quant_type=None,
|
plugin_weight_only_quant_type=None,
|
||||||
is_modelopt_ckpt=False):
|
is_modelopt_ckpt=False):
|
||||||
# logger.info("Loading Medusa heads' weights ...")
|
logger.info("Loading Medusa heads' weights ...")
|
||||||
|
|
||||||
if is_modelopt_ckpt:
|
if is_modelopt_ckpt:
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user