# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Adapted from examples/quantization/hf_ptq.py """ import copy import json import os import random import sys import time from importlib.metadata import version import numpy as np import torch from accelerate.hooks import remove_hook_from_module from datasets import load_dataset from modelopt.torch.utils import print_rank_0 from safetensors.torch import load_file, save_file from torch.utils.data import DataLoader from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoTokenizer) from .._utils import release_gc, str_dtype_to_torch from ..logger import logger from ..mapping import Mapping from .image_processing import MllamaImageProcessor from .mode import QuantAlgo EMPTY_CFG = { "quant_cfg": { "*weight_quantizer": { "enable": False, }, "*input_quantizer": { "enable": False }, "*lm_head*": { "enable": False }, "*output_layer*": { "enable": False }, "default": { "enable": False }, }, "algorithm": "max", } KV_CACHE_CFG = { "*.query_key_value.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.Wqkv.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.W_pack.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.c_attn.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.k_proj.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.v_proj.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.k.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, "*.v.output_quantizer": { "num_bits": 8, "axis": None, "enable": True }, } KV_QUANT_CFG_CHOICES = { "fp8": "FP8_KV_CFG", "nvfp4": "NVFP4_KV_CFG", } def quant_cfg_choices(): import modelopt.torch.quantization as mtq QUANT_CFG_CHOICES = { "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, "fp8": mtq.FP8_DEFAULT_CFG, "fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG, "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "int8_wo": EMPTY_CFG, "int4_wo": EMPTY_CFG, "full_prec": EMPTY_CFG, } if hasattr(mtq, "NVFP4_DEFAULT_CFG"): QUANT_CFG_CHOICES["nvfp4"] = mtq.NVFP4_DEFAULT_CFG return QUANT_CFG_CHOICES def model_type_is_enc_dec(model_type): return model_type in ["t5", "bart"] MODEL_NAME_PATTERN_MAP = { "GPT2": "gpt2", "Xverse": "llama", "MllamaForConditionalGeneration": "mllama", "Llama": "llama", "MllamaForCausalLM": "mllama", "Mistral": "llama", "GPTJ": "gptj", "FalconForCausalLM": "falcon", "RWForCausalLM": "falcon", "baichuan": "baichuan", "MPT": "mpt", "Bloom": "bloom", "ChatGLM": "chatglm", "QWen": "qwen", "Qwen2VLForConditionalGeneration": "qwen2_vl", "RecurrentGemma": "recurrentgemma", "Gemma3": "gemma3", "Gemma2": "gemma2", "Gemma": "gemma", "MixtralForCausalLM": "llama", "NemotronForCausalLM": "nemotron", "GPTBigCodeForCausalLM": "gpt_bigcode", "ArcticForCausalLM": "llama", "PhiMoEForCausalLM": "phi3", "Phi3SmallForCausalLM": "phi3small", "Phi3ForCausalLM": "phi3", "Phi3VForCausalLM": "phi3", "Starcoder2ForCausalLM": "gptnext", "GPTBigCodeForCausalLM": "gptnext", "GLM": "glm", "Exaone": "exaone", "DeciLMForCausalLM": "deci", "DeepseekForCausalLM": "deepseek", "GraniteForCausalLM": "granite", "GraniteMoeForCausalLM": "granitemoe", "T5": "t5", "Bart": "bart" } MULTIMODAL_DATASETS = ['scienceqa', 'science_qa'] class _CustomDataset(torch.utils.data.Dataset): def __init__(self, encodings): self.encodings = encodings def __getitem__(self, idx): item = { key: val[idx].clone().detach().requires_grad_(False) for key, val in self.encodings.items() } return item def __len__(self): return len(self.encodings["input_ids"]) class EncDecModelWrapper(torch.nn.Module): def __init__(self, hf_model=None): super().__init__() self.hf_model = hf_model self.model_type = get_model_type(hf_model) def forward(self, **kwargs): self.hf_model.generate(**kwargs) def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.hf_model, name) def get_tokenizer(ckpt_path, max_seq_length=2048, model_type=None): logger.info(f"Initializing tokenizer from {ckpt_path}") tokenizer = AutoTokenizer.from_pretrained( ckpt_path, model_max_length=max_seq_length, padding_side="left", trust_remote_code=True, ) if tokenizer.pad_token is None: if model_type and model_type == "qwen": # qwen use token id 151643 as pad and eos tokens tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) elif model_type and model_type == "qwen2_vl": # qwen use token id 151643 as pad and 151643 and 151645 as eos tokens tokenizer.eos_token = [ tokenizer.convert_ids_to_tokens(151643), tokenizer.convert_ids_to_tokens(151645) ] tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) else: tokenizer.pad_token = tokenizer.eos_token assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!" return tokenizer def get_processor(ckpt_path, max_seq_length=2048, model_type=None, device=None): logger.info(f"Initializing tokenizer from {ckpt_path}") processor = AutoProcessor.from_pretrained( ckpt_path, model_max_length=max_seq_length, padding_side="left", trust_remote_code=True, ) if processor.tokenizer.pad_token is None: if model_type and model_type == "qwen": # qwen use token id 151643 as pad and eos tokens processor.tokenizer.eos_token = processor.tokenizer.convert_ids_to_tokens( 151643) processor.tokenizer.pad_token = processor.tokenizer.convert_ids_to_tokens( 151643) else: processor.tokenizer.pad_token = processor.tokenizer.eos_token assert processor.tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!" if model_type == 'mllama': processor = MllamaImageProcessor(processor, device) return processor def _get_vila_model(model_dir): sys.path.append(model_dir + "/../VILA") from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa from transformers import AutoModel model = AutoModel.from_pretrained( model_dir, device_map='auto', trust_remote_code=True, ) return model.llm def get_hf_config(ckpt_path): if "mpt" in ckpt_path: # MPT-7B cannot get initialized from AutoConfig from transformers import MptConfig return MptConfig.from_pretrained(ckpt_path) else: return AutoConfig.from_pretrained(ckpt_path, trust_remote_code=True) def _get_llava_qwen_model(model_dir, dtype, device): if "hf" in model_dir: from transformers import LlavaOnevisionForConditionalGeneration model = LlavaOnevisionForConditionalGeneration.from_pretrained( model_dir, dtype=dtype, device_map=device) model = model.language_model else: from llava.model.builder import load_pretrained_model _, model, _, _ = load_pretrained_model(model_dir, None, 'llava_qwen', torch_dtype=dtype, device_map=device) return model def get_model(ckpt_path: str, dtype: str = 'bfloat16', device: str = 'cuda', device_map: str = "auto"): logger.info(f"Initializing model from {ckpt_path}") # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo hf_config = get_hf_config(ckpt_path) torch_dtype = str_dtype_to_torch(dtype) model_cls = AutoModelForCausalLM if hf_config.model_type == "llava": from transformers import LlavaForConditionalGeneration model_cls = LlavaForConditionalGeneration elif hf_config.model_type == "mpt": from transformers import MptForCausalLM model_cls = MptForCausalLM elif hf_config.model_type == 'mllama': from transformers import MllamaForConditionalGeneration model_cls = MllamaForConditionalGeneration elif hf_config.model_type == 'qwen2_vl': from transformers import Qwen2VLForConditionalGeneration model_cls = Qwen2VLForConditionalGeneration if "vila" in ckpt_path: model = _get_vila_model(ckpt_path) elif "llava-onevision-qwen2" in ckpt_path: model = _get_llava_qwen_model(ckpt_path, dtype, device) elif hf_config.model_type == "glm": from transformers import AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_path, device_map="cuda", dtype=torch_dtype, trust_remote_code=True) elif model_type_is_enc_dec(hf_config.model_type): from transformers import AutoModelForSeq2SeqLM model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_path, device_map=device, dtype=torch_dtype, trust_remote_code=True) model = EncDecModelWrapper(hf_model=model) else: model = model_cls.from_pretrained( ckpt_path, device_map=device_map if device != "cpu" else "cpu", dtype="auto", trust_remote_code=True) if hf_config.model_type in ["llava", "internvl_chat"]: model = model.language_model elif hf_config.model_type == "qwen2_vl": #WAR for Qwen2-VL because its lm_head is outside of LLM lm_head = model.lm_head model = model.model model.lm_head = lm_head model.eval() model_dtype = next(model.parameters()).dtype if torch_dtype != model_dtype: logger.info( f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, " f"but the data type of the HuggingFace model is {model_dtype}.") return model def get_model_type(model): if type(model).__name__ == "EncDecModelWrapper": return model.model_type if type(model).__name__ in MODEL_NAME_PATTERN_MAP: return MODEL_NAME_PATTERN_MAP[type(model).__name__] for k, v in MODEL_NAME_PATTERN_MAP.items(): if k.lower() in type(model).__name__.lower(): return v return None def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail", tokenizer=None, batch_size=1, calib_size=512, block_size=512, device=None, include_labels=False): logger.info("Loading calibration dataset") if dataset_name_or_dir == "pileval": dataset = load_dataset( "json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train", trust_remote_code=True) dataset = dataset["text"][:calib_size] elif "scienceqa" in dataset_name_or_dir.lower( ) or "science_qa" in dataset_name_or_dir.lower(): if os.path.isdir(dataset_name_or_dir): dataset = load_dataset(dataset_name_or_dir, split="train", trust_remote_code=True) else: dataset = load_dataset("derek-thomas/ScienceQA", split="train", trust_remote_code=True) dataset = dataset.select(range(calib_size)) elif "cnn_dailymail" in dataset_name_or_dir: dataset = load_dataset( dataset_name_or_dir, name="3.0.0", split="train", trust_remote_code=True, ) dataset = dataset["article"][:calib_size] elif os.path.isdir(dataset_name_or_dir): logger.info( f"Recognized local dataset repo {dataset_name_or_dir} for calibration; " "assuming the calibration data are in the train split and text column." ) dataset = load_dataset(dataset_name_or_dir, split="train", trust_remote_code=True) dataset = dataset["text"][:calib_size] else: raise NotImplementedError( f"Unsupported dataset name or local repo directory: {dataset_name_or_dir}." ) is_multimodal = False for dataset_name in MULTIMODAL_DATASETS: if dataset_name in dataset_name_or_dir: is_multimodal = True if is_multimodal: # Apply the preprocessing function to the dataset processed_dataset = dataset.map(tokenizer.preprocess_function, batched=False, remove_columns=dataset.column_names) # Create DataLoader with the custom collate function calib_dataloader = DataLoader(processed_dataset, batch_size=batch_size, shuffle=False, collate_fn=tokenizer.collate_function) else: batch_encoded = tokenizer.batch_encode_plus(dataset, return_tensors="pt", padding=True, truncation=True, max_length=block_size) if device: batch_encoded = batch_encoded.to(device) if include_labels: # Labels are needed when backward is called in the model. # The labels should be a shifted version of the input_ids. # However, we should not shift the input_ids here since the labels are shifted by # Huggingface models during loss calculation as shown here - # https://github.com/huggingface/transformers/blob/7f79a97399bb52aad8460e1da2f36577d5dccfed/src/transformers/models/llama/modeling_llama.py#L1093-L1095 batch_encoded["labels"] = torch.where( batch_encoded["attention_mask"] > 0.5, batch_encoded["input_ids"], -100) batch_encoded = _CustomDataset(batch_encoded) else: # For backward compatibility, if labels are not needed, we only return input_ids. batch_encoded = _CustomDataset( {"input_ids": batch_encoded["input_ids"]}) calib_dataloader = DataLoader(batch_encoded, batch_size=batch_size, shuffle=False) return calib_dataloader def quantize_model(model, quant_cfg, calib_dataloader, batch_size, qformat, auto_quantize_bits): import modelopt.torch.quantization as mtq # NOTE: for ModelOpt v0.19 release # calibrate_loop = dataset_utils.create_forward_loop( # calib_dataloader, dataloader=calib_dataloader) def calibrate_loop(): if calib_dataloader is None: return with torch.no_grad(): low_mem_mode = False for idx, data in enumerate(calib_dataloader): logger.debug(f"Calibrating batch {idx}") batch_size = data[list(data.keys())[0]].shape[0] if batch_size == 1: model(**data) elif not low_mem_mode: # Try running the forward once. # If output memory, we try running inference with split input tensors try: model(**data) except torch.OutOfMemoryError: print( "Warning: torch.OutOfMemoryError detected, try reducing the batch size..." ) low_mem_mode = True if low_mem_mode: split_data_1 = { key: data[key][:batch_size // 2, ...] for key in data } model(**split_data_1) split_data_2 = { key: data[key][batch_size // 2:, ...] for key in data } model(**split_data_2) QUANT_CFG_CHOICES = { "int8": "INT8_DEFAULT_CFG", "int8_sq": "INT8_SMOOTHQUANT_CFG", "fp8": "FP8_DEFAULT_CFG", "fp8_pc_pt": "FP8_PER_CHANNEL_PER_TOKEN_CFG", "int4_awq": "INT4_AWQ_CFG", "w4a8_awq": "W4A8_AWQ_BETA_CFG", } logger.info("Starting quantization...") start_time = time.time() if auto_quantize_bits: logger.info("Starting mixed precision quantization...") from packaging import version as v opt_kwargs = {} modelopt_version = version('nvidia-modelopt') if v.parse(modelopt_version) > v.parse("0.21"): opt_kwargs['disabled_layers'] = ["*lm_head*"] model, search_history = mtq.auto_quantize( model, data_loader=calib_dataloader, loss_func=lambda output, batch: output.loss, constraints={"effective_bits": auto_quantize_bits}, forward_step=lambda model, batch: model(**batch), quantization_formats=[ QUANT_CFG_CHOICES[item] for item in qformat.split(",") ] + [None], num_calib_steps=len(calib_dataloader), num_score_steps=min( len(calib_dataloader), 128 // batch_size ), # Limit the number of score steps to avoid long calibration time verbose=True, **opt_kwargs) mtq.print_quant_summary(model) # We need to explicitly calibrate for kv cache quantization enable_kv_cache_quantization = "int8" not in qformat if enable_kv_cache_quantization: mtq.set_quantizer_by_cfg( model, quant_cfg={ "*output_quantizer": { "num_bits": (4, 3), "axis": None, "enable": True } }, ) # Lets calibrate only the output quantizer this time. Let's disable all other quantizers. with mtq.set_quantizer_by_cfg_context(model, { "*": { "enable": False }, "*output_quantizer": { "enable": True } }): mtq.calibrate(model, algorithm="max", forward_loop=calibrate_loop) else: mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) end_time = time.time() logger.info( "Quantization done. Total time used: {:.2f} s.".format(end_time - start_time)) return model def combine_medusa_weight(tp_size, pp_size, base_model_output_dir, num_medusa_heads, num_medusa_layers, max_draft_len, medusa_hidden_act, medusa_model_dir, quant_medusa_head): with open(f"{medusa_model_dir}/config.json", "r") as fp: medusa_config = json.load(fp) num_medusa_heads_from_config = medusa_config.get('medusa_num_heads', num_medusa_heads) num_medusa_layers = medusa_config.get('medusa_num_layers', num_medusa_layers) if num_medusa_heads is None: num_medusa_heads = num_medusa_heads_from_config assert max_draft_len > 0, "should have max_draft_len > 0" world_size = tp_size * pp_size # Process for each rank for rank in range(world_size): mapping = Mapping(world_size=world_size, rank=rank, tp_size=tp_size, pp_size=pp_size) # 1. Load medusa weight for each rank from tensorrt_llm.models.medusa.weight import load_medusa_hf medusa_weights = load_medusa_hf(medusa_path=medusa_model_dir, num_medusa_heads=num_medusa_heads, num_medusa_layers=num_medusa_layers, mapping=mapping, dtype="float16") # 2. Load base model safetensors (after quant) base_model_weights = load_file( f"{base_model_output_dir}/rank{rank}.safetensors") # 3. Combine and save weight base_model_weights.update(medusa_weights) save_file(base_model_weights, f"{base_model_output_dir}/rank{rank}.safetensors") # 4. Add medusa config into config.json with open(f"{base_model_output_dir}/config.json", 'r') as f: base_model_config = json.load(f) f.close() with open(f"{base_model_output_dir}/config.json", 'w') as f: base_model_config['architecture'] = "MedusaForCausalLM" base_model_config['quantization']['exclude_modules'] = [ 'lm_head', '*router', '*vocab_embedding', '*position_embedding', '*block_embedding', ] if not quant_medusa_head: base_model_config['quantization']['exclude_modules'].append( '*medusa_heads*') base_model_config['max_draft_len'] = max_draft_len base_model_config['num_medusa_heads'] = num_medusa_heads base_model_config['num_medusa_layers'] = num_medusa_layers json.dump(base_model_config, f, indent=4) torch.cuda.empty_cache() logger.info("Combine medusa heads' weight, done.") def quantize_and_export(*, model_dir, device, calib_dataset, dtype, qformat, kv_cache_dtype, calib_size, batch_size, calib_max_seq_length, awq_block_size, output_dir, tp_size, pp_size, cp_size, seed, tokenizer_max_seq_length, num_medusa_heads=None, num_medusa_layers=None, max_draft_len=None, medusa_hidden_act=None, medusa_model_dir=None, quant_medusa_head=None, auto_quantize_bits=None, device_map="auto", quantize_lm_head=False): ''' Load model from the model_dir, call Modelopt to quantize the model, and then export the quantized model as TRT-LLM checkpoint ''' try: import modelopt # noqa except ImportError as e: logger.error( "Failed to import modelopt, pls check the Modelopt installation. Currently it is known to be unsupported on Windows OS" ) raise e import modelopt.torch.quantization as mtq from modelopt.torch.export import export_tensorrt_llm_checkpoint from tensorrt_llm.models.convert_utils import infer_dtype if not torch.cuda.is_available(): raise EnvironmentError("GPU is required for inference.") random.seed(seed) np.random.seed(seed) # Check that only one quantization format is provided for non auto_quant case if not auto_quantize_bits: assert (len(qformat.split(",")) == 1 ), "Quantization supports only one quantization format." hf_config = get_hf_config(model_dir) dtype = infer_dtype(dtype, getattr(hf_config, 'torch_dtype', None)) model = get_model(model_dir, dtype, device=device, device_map=device_map) model_type = get_model_type(model) is_enc_dec = model_type_is_enc_dec(model_type) if "vila" in model_dir: tokenizer = get_tokenizer(model_dir + "/llm", max_seq_length=tokenizer_max_seq_length, model_type=model_type) elif model_type == "mllama": tokenizer = get_processor(model_dir, max_seq_length=tokenizer_max_seq_length, model_type=model_type, device=device) else: tokenizer = get_tokenizer(model_dir, max_seq_length=tokenizer_max_seq_length, model_type=model_type) if qformat in ["full_prec", "int8_wo", "int4_wo" ] and kv_cache_dtype is None: logger.info(f"No quantization applied, export {dtype} model") else: if "awq" in qformat: if calib_size > 32: logger.info( f"AWQ calibration could take longer with calib_size = {calib_size}, Using" " calib_size=32 instead") calib_size = 32 logger.info( "\nAWQ calibration could take longer than other calibration methods. Please" " increase the batch size to speed up the calibration process. Batch size can be" " set by adding the argument --batch_size to the command line.\n" ) quant_cfg = None if not auto_quantize_bits: if qformat in quant_cfg_choices(): quant_cfg = quant_cfg_choices()[qformat] else: raise ValueError(f"Unsupported quantization format: {qformat}") if "awq" in qformat: quant_cfg = copy.deepcopy(quant_cfg_choices()[qformat]) weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] if awq_block_size: weight_quantizer["block_sizes"][-1] = awq_block_size # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models if "w4a8_awq" == qformat and model_type in ["gemma", "mpt"]: quant_cfg["algorithm"] = { "method": "awq_lite", "alpha_step": 1 } if kv_cache_dtype is not None: if kv_cache_dtype == "fp8": kv_cache_quant_cfg = getattr( mtq, KV_QUANT_CFG_CHOICES[kv_cache_dtype])["quant_cfg"] quant_cfg["quant_cfg"].update(kv_cache_quant_cfg) else: quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} if qformat == 'fp8' and quantize_lm_head: print_rank_0("Quantizing lm_head layer") del quant_cfg["quant_cfg"]["*lm_head*"] calib_dataloader = get_calib_dataloader( dataset_name_or_dir=calib_dataset, tokenizer=tokenizer, batch_size=batch_size, calib_size=calib_size, block_size=calib_max_seq_length, device=model.device, include_labels=auto_quantize_bits is not None, ) model = quantize_model(model, quant_cfg, calib_dataloader, batch_size, qformat, auto_quantize_bits) with torch.inference_mode(): if model_type is None: logger.info( f"Unknown model type {type(model).__name__}. Continue exporting..." ) model_type = f"unknown:{type(model).__name__}" architecture = type(model).__name__ export_path = output_dir start_time = time.time() # Move meta tensor back to device before exporting. remove_hook_from_module(model, recurse=True) QUANT_ALGO = { "int8": "INT8", "int8_sq": "W8A8_SQ_PER_CHANNEL", "fp8": "FP8", "int4_awq": "W4A16_AWQ", "w4a8_awq": "W4A8_AWQ", } if model_type == 'mllama': model = model.language_model export_tensorrt_llm_checkpoint( model.hf_model if is_enc_dec else model, model_type, getattr(torch, dtype), export_dir=export_path, inference_tensor_parallel=tp_size, inference_pipeline_parallel=pp_size, ) export_paths = [] tensorrt_llm_configs = [] if not is_enc_dec: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_configs.append(tensorrt_llm_config) export_paths.append(export_path) else: for component in ["encoder", "decoder"]: with open(f"{export_path}/{component}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_configs.append(tensorrt_llm_config) export_paths.append(f"{export_path}/{component}") for export_path, tensorrt_llm_config in zip(export_paths, tensorrt_llm_configs): tensorrt_llm_config["model_type"] = model_type if not is_enc_dec: tensorrt_llm_config["architecture"] = architecture # Workaround for wo quantization if qformat in ["int8_wo", "int4_wo", "full_prec"]: if qformat == "int8_wo": tensorrt_llm_config["quantization"][ "quant_algo"] = QuantAlgo.W8A16 elif qformat == "int4_wo": tensorrt_llm_config["quantization"][ "quant_algo"] = QuantAlgo.W4A16 else: tensorrt_llm_config["quantization"]["quant_algo"] = None # HF uses rope_scaling while tensorrt_llm uses rotary_scaling if hasattr(model.config, "rope_scaling" ) and "rotary_scaling" not in tensorrt_llm_config: tensorrt_llm_config["rotary_scaling"] = getattr( model.config, "rope_scaling") with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Workaround for Modelopt 0.9.x fp8_kv_cache knob issue if qformat in ['fp8', 'nvfp4'] and kv_cache_dtype is None: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_config["quantization"][ "kv_cache_quant_algo"] = None with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Workaround for qwen version if model_type == 'qwen' or model_type == 'qwen2_vl': with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) qwen_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) try: from transformers import LlavaOnevisionConfig if isinstance(qwen_config, LlavaOnevisionConfig): qwen_config = qwen_config.text_config except: pass tensorrt_llm_config["qwen_type"] = qwen_config.model_type if qwen_config.model_type == "qwen2": tensorrt_llm_config[ "norm_epsilon"] = qwen_config.rms_norm_eps tensorrt_llm_config["rotary_base"] = qwen_config.rope_theta tensorrt_llm_config[ "intermediate_size"] = qwen_config.intermediate_size with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Set rotary parameters correctly for chatglm. if model_type == 'chatglm': rotary_base = 10000.0 rotary_embedding_scaling = None chatglm_config = AutoConfig.from_pretrained( model_dir, trust_remote_code=True) chatglm_version = tensorrt_llm_config['chatglm_version'] rope_ratio = tensorrt_llm_config.get('rope_ratio', 1.0) if chatglm_version == 'chatglm2': if rope_ratio > 1: rotary_embedding_scaling = { 'type': 'linear', 'factor': rope_ratio } elif chatglm_version == 'chatglm3': rotary_base *= rope_ratio with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_config['rotary_base'] = rotary_base tensorrt_llm_config['rotary_scaling'] = rotary_embedding_scaling tensorrt_llm_config['rotary_pct'] = 0.5 with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # context parallel if cp_size > 1: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_config["mapping"]["cp_size"] = cp_size tensorrt_llm_config["mapping"]["attn_tp_size"] = -1 tensorrt_llm_config["mapping"]["attn_cp_size"] = -1 tensorrt_llm_config["mapping"]["world_size"] *= cp_size with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) if model_type == 'gptnext': with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) if tensorrt_llm_config['max_position_embeddings'] is None: tensorrt_llm_config['max_position_embeddings'] = getattr( model.config, "n_positions", None) with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) # Workaround for combining medusa head # TODO: move these integration into modelopt to avoid redundant reading and writing if medusa_model_dir is not None: combine_medusa_weight(tp_size, pp_size, export_path, num_medusa_heads, num_medusa_layers, max_draft_len, medusa_hidden_act, medusa_model_dir, quant_medusa_head) # Workaround for mllama if model_type == 'mllama': from tensorrt_llm.models.mllama.config import MLLaMAConfig config = MLLaMAConfig.from_hugging_face( model_dir, dtype=dtype, ) for key, value in config.to_dict().items(): if key not in tensorrt_llm_config: tensorrt_llm_config[key] = value with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) end_time = time.time() logger.info( "Quantized model exported to {} \nTotal time used {:.2f} s.".format( export_path, end_time - start_time)) # Need to delete the model and release memory explicitly; # otherwise torch may retain its GPU memory until a delayed GC running, # which reduces the available GPU memory for subsequent stages. del model release_gc() def unwrap_model(model, module_instances=None): # Reference: https://github.com/NVIDIA/Megatron-LM/blob/core_r0.8.0/megatron/training/utils.py from megatron.core import DistributedDataParallel as DDP from megatron.core.transformer.module import Float16Module if module_instances is None: module_instances = (DDP, Float16Module) return_list = True if not isinstance(model, list): model = [model] return_list = False unwrapped_model = [] for model_module in model: while isinstance(model_module, module_instances): model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] return unwrapped_model def get_nemo_calib_dataloader(dataset_name_or_dir="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): if dataset_name_or_dir == "pileval": dataset = load_dataset( "json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train", trust_remote_code=True) text_column = "text" elif "wikitext" in dataset_name_or_dir: dataset = load_dataset(dataset_name_or_dir, "wikitext-103-v1", split="train", trust_remote_code=True) text_column = "text" elif "cnn_dailymail" in dataset_name_or_dir: dataset = load_dataset(dataset_name_or_dir, name="3.0.0", split="train", trust_remote_code=True) text_column = "article" elif os.path.isdir(dataset_name_or_dir): logger.info( f"Recognized local dataset repo {dataset_name_or_dir} for calibration; " "assuming the calibration data are in the train split and text column." ) dataset = load_dataset(dataset_name_or_dir, split="train", trust_remote_code=True) text_column = "text" else: raise NotImplementedError( f"Unsupported dataset name or local repo directory: {dataset_name_or_dir}." ) calib_size = max(min(len(dataset), calib_size), batch_size) for i in range(calib_size // batch_size): batch = dataset[i * batch_size:(i + 1) * batch_size][text_column] for j in range(len(batch)): batch[j] = batch[j][:max_sequence_length] yield batch def quantize_nemo_and_export(*, nemo_ckpt_path, decoder_type, calib_dataset, calib_tp_size, calib_pp_size, dtype, qformat, kv_cache_dtype, calib_size, batch_size, calib_max_seq_length, awq_block_size, output_dir, tp_size, pp_size, cp_size, seed): try: import modelopt # noqa except ImportError as e: logger.error( "Failed to import modelopt, pls check the modelopt installation. Currently it is known to be unsupported on Windows OS" ) raise e import modelopt.torch.quantization as mtq from megatron.core import parallel_state from megatron.core.transformer.module import Float16Module from modelopt.torch.export import export_tensorrt_llm_checkpoint from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import \ MegatronGPTModel from nemo.collections.nlp.modules.common.text_generation_strategy import \ GPTModelTextGenerationStrategy from nemo.collections.nlp.parts.nlp_overrides import ( NLPDDPStrategy, NLPSaveRestoreConnector) from nemo.utils.model_utils import load_config, save_artifacts from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer if not torch.cuda.is_available(): raise EnvironmentError("GPU is required for the inference.") random.seed(seed) np.random.seed(seed) model_cfg = load_config(nemo_ckpt_path) # dtype is used for non-quantized layers supported_dtype = ["auto", "float16", "bfloat16"] assert dtype in supported_dtype, f"{dtype} not supported. Supported dtypes are {supported_dtype}" if dtype == 'auto': dtype = model_cfg.get('precision', None) if dtype is None: dtype = 'float16' elif 'bf16' in dtype or 'bfloat16' in dtype: dtype = 'bfloat16' else: dtype = 'float16' logger.info(f"Specified dtype 'auto'; inferred dtype {dtype!r}.") torch_dtype = getattr(torch, dtype) with open_dict(model_cfg): model_cfg.activations_checkpoint_method = None model_cfg.activations_checkpoint_granularity = None model_cfg.tensor_model_parallel_size = calib_tp_size model_cfg.pipeline_model_parallel_size = calib_pp_size model_cfg.sequence_parallel = False # Only custom modelopt spec is supported for PTQ: this custom spec is largely based on local Megatron-LM # layer definitions to avoid Transformer Engine implementations that are currently not supported. model_cfg.name = "modelopt" # trainer required for restoring model parallel models trainer_config = { 'devices': calib_tp_size * calib_pp_size, 'num_nodes': 1, 'accelerator': 'gpu', 'logger': False, 'precision': model_cfg.precision, 'enable_checkpointing': False, } trainer = Trainer(strategy=NLPDDPStrategy(), **trainer_config) connector = NLPSaveRestoreConnector() model = MegatronGPTModel.restore_from( restore_path=nemo_ckpt_path, trainer=trainer, override_config_path=model_cfg, save_restore_connector=connector, ) model.freeze() print_rank_0(model) # Have to turn off activations_checkpoint_method for inference try: model.model.module.language_model.encoder.activations_checkpoint_method = None except AttributeError: pass # Check whether the DDP is initialized if parallel_state.is_unitialized(): def dummy(): return if model.trainer.strategy.launcher is not None: model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() inference_config = { 'greedy': False, 'top_k': 0, 'top_p': 0.9, 'temperature': 1.0, 'add_BOS': True, 'tokens_to_generate': 30, 'all_probs': False, 'repetition_penalty': 1.2, 'min_tokens_to_generate': 0, 'compute_logprob': False, 'batch_size': batch_size, 'max_context_length': calib_max_seq_length, 'strategy': GPTModelTextGenerationStrategy(model), } model.set_inference_config(inference_config) if qformat in ["full_prec", "int8_wo", "int4_wo" ] and kv_cache_dtype is None: print_rank_0(f"No quantization applied, export {dtype} model") else: if "awq" in qformat: if calib_size > 32: print_rank_0( "AWQ calibration could take longer with calib_size =" f" {calib_size}, Using calib_size=32 instead") calib_size = 32 print_rank_0( "\nAWQ calibration could take longer than other calibration methods. Please" " increase the batch size to speed up the calibration process. Batch size can be" " set by adding the argument inference.batch_size= to the command" " line.\n") dataloader = get_nemo_calib_dataloader( dataset_name_or_dir=calib_dataset, batch_size=batch_size, calib_size=calib_size, max_sequence_length=calib_max_seq_length, ) # =================== Start Quantization ==================== if qformat in quant_cfg_choices(): quant_cfg = quant_cfg_choices()[qformat] else: raise ValueError(f"Unsupported quantization format: {qformat}") if "awq" in qformat: quant_cfg = copy.deepcopy(quant_cfg_choices()[qformat]) weight_quantizer = quant_cfg["quant_cfg"][ "*weight_quantizer"] # type: ignore if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] weight_quantizer["block_sizes"][-1] = awq_block_size if kv_cache_dtype is not None: if kv_cache_dtype == "fp8": for value in KV_CACHE_CFG.values(): value.update({"num_bits": (4, 3)}) # type: ignore quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore print_rank_0(quant_cfg) # Always turn on FP8 kv cache to save memory footprint. # For int8_sq, we use int8 kv cache. # TODO: Investigate why enabling FP8 kv cache will cause accuracy regressions for nemotron. # quant_cfg["quant_cfg"]["*output_quantizer"] = { # type: ignore[index] # "num_bits": 8 if args.qformat == "int8_sq" else (4, 3), # "axis": None, # "enable": args.decoder_type != "gptnext", # } dataloader = [data for data in dataloader] def forward_loop(model): for i, batch in enumerate(dataloader): print_rank_0(f"Calibrating batch {i}") model.predict_step(batch, i) start_time = time.time() model = mtq.quantize(model, quant_cfg, forward_loop) # type: ignore[arg-type] end_time = time.time() tot_time = end_time - start_time tput = calib_size / tot_time print_rank_0( f"Quantization done. Total time used {tot_time}s. Throughput {tput} samples/s" ) # =================== End Quantization ====================== if decoder_type == "gptnext": # We found squared_relu may have an under-calibration problem. # Clamp the scaling_factor with a min threshold to avoid under-calibration. maxbound = 0 if qformat == "fp8": maxbound = 448 elif qformat == "int8_sq": maxbound = 127 model = mtq.postprocess_amax( model, "*input_quantizer", lambda amax: torch.clamp(amax, min=0.01 * maxbound)) if torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) if model_cfg.megatron_amp_O2: model.model = unwrap_model(model.model, Float16Module) start_time = time.time() export_tensorrt_llm_checkpoint( model, decoder_type, torch_dtype, export_dir=output_dir, inference_tensor_parallel=tp_size, inference_pipeline_parallel=pp_size, ) # context parallel if cp_size > 1: with open(f"{export_path}/config.json", "r") as f: tensorrt_llm_config = json.load(f) tensorrt_llm_config["mapping"]["cp_size"] = cp_size tensorrt_llm_config["mapping"]["world_size"] *= cp_size with open(f"{export_path}/config.json", "w") as f: json.dump(tensorrt_llm_config, f, indent=4) end_time = time.time() print_rank_0( f"Model config exported to: {output_dir}. Total time used {end_time - start_time}s" ) if torch.distributed.get_rank() == 0: save_artifacts(model, output_dir, use_abspath=True) # Need to delete the model and release memory explicitly; # otherwise torch may retain its GPU memory until a delayed GC running, # which reduces the available GPU memory for subsequent stages. del model release_gc()