diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 9b37f8c7b2..5aa7f7ce70 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -23,6 +23,11 @@ def add_llm_args(parser): type=str, nargs="+", help="A single or a list of text prompts.") + parser.add_argument('--checkpoint_format', + type=str, + default=None, + choices=["HF", "mistral"], + help="Model checkpoint format.") # Build config parser.add_argument("--max_seq_len", type=int, @@ -237,6 +242,7 @@ def setup_llm(args, **kwargs): llm = LLM( model=args.model_dir, backend='pytorch', + checkpoint_format=args.checkpoint_format, disable_overlap_scheduler=args.disable_overlap_scheduler, kv_cache_config=kv_cache_config, attn_backend=args.attention_backend, diff --git a/examples/models/core/mistral_large_3/README.md b/examples/models/core/mistral_large_3/README.md new file mode 100644 index 0000000000..dfd3fd0c28 --- /dev/null +++ b/examples/models/core/mistral_large_3/README.md @@ -0,0 +1,53 @@ +# Mistral Large V3 + +* Setup the model path + +```bash +export mistral_large_3_model_path= +``` + +## LLM-only run + +* Run the Mistral Large V3 by `quickstart_advanced.py` + +```bash +mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py \ + --model_dir ${mistral_large_3_model_path} \ + --tp_size 4 \ + --moe_ep_size 4 \ + --max_tokens 100 \ + --checkpoint_format mistral \ + --moe_backend TRTLLM +``` + +* Launch the trtllm-serve and send a request + +```bash +echo " +backend: pytorch +tensor_parallel_size: 4 +moe_expert_parallel_size: 4 +enable_attention_dp: false +kv_cache_config: + enable_block_reuse: true +checkpoint_format: mistral +" > serve.yml +mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \ + ${mistral_large_3_model_path} \ + --host localhost --port 8001 --backend pytorch \ + --extra_llm_api_options serve.yml \ + --tokenizer ${mistral_large_3_model_path} \ + 2>&1 | tee serve_debug.log & + +curl http://localhost:8001/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "${mistral_large_3_model_path}", + "prompt": "The capital of France is", + "max_tokens": 16, + "top_k": 16 + }' + +# The result would be like +{"id":"cmpl-7e342c1d722d4226a1bf3ed35d762c35","object":"text_completion","created":1764061351,"model":"${mistral_large_3_model_path}","choices":[{"index":0,"text":"The capital of France is **Paris**.\n\nParis is the largest city in France and","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":7,"total_tokens":23,"completion_tokens":16,"prompt_tokens_details":{"cached_tokens":1}},"prompt_token_ids":null} +``` diff --git a/requirements.txt b/requirements.txt index e123aafcde..8f740a9ede 100644 --- a/requirements.txt +++ b/requirements.txt @@ -75,3 +75,4 @@ numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing partial_json_parser apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf +mistral-common==1.8.6 diff --git a/tensorrt_llm/_torch/models/checkpoints/__init__.py b/tensorrt_llm/_torch/models/checkpoints/__init__.py index 6a7426eb5b..590a4c7ea9 100644 --- a/tensorrt_llm/_torch/models/checkpoints/__init__.py +++ b/tensorrt_llm/_torch/models/checkpoints/__init__.py @@ -12,11 +12,30 @@ from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper from .hf.weight_loader import HfWeightLoader from .hf.weight_mapper import HfWeightMapper +from .mistral.checkpoint_loader import (MistralCheckpointLoader, + MistralLarge3CheckpointLoader) +from .mistral.config_loader import MistralConfigLoader +from .mistral.weight_mapper import (MistralLarge3WeightMapper, + MistralWeightMapper) __all__ = [ - "HfConfigLoader", "HfWeightLoader", "HfWeightMapper", - "BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper", - "Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper", - "Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper", - "Qwen3NextHfWeightMapper", "LlavaNextHfWeightMapper" + "HfConfigLoader", + "HfWeightLoader", + "HfWeightMapper", + "MistralConfigLoader", + "MistralWeightMapper", + "MistralCheckpointLoader", + "BaseCheckpointLoader", + "HfCheckpointLoader", + "NemotronHHfWeightMapper", + "Gemma3HfWeightMapper", + "MixtralHfWeightMapper", + "Llama4HfWeightMapper", + "Qwen2MoeHfWeightMapper", + "Qwen3MoeHfWeightMapper", + "Qwen2VLHfWeightMapper", + "Qwen3NextHfWeightMapper", + "LlavaNextHfWeightMapper", + "MistralLarge3CheckpointLoader", + "MistralLarge3WeightMapper", ] diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py index 7c24f19ae7..3b1c3af172 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py @@ -19,6 +19,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping +@register_checkpoint_weight_loader("mistral") @register_checkpoint_weight_loader("HF") class HfWeightLoader(BaseWeightLoader): """ diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/__init__.py b/tensorrt_llm/_torch/models/checkpoints/mistral/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py new file mode 100644 index 0000000000..433bde665b --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/checkpoint_loader.py @@ -0,0 +1,75 @@ +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper +from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader +from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import MistralConfigLoader +from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader + + +@register_checkpoint_loader("mistral") +class MistralCheckpointLoader(HfCheckpointLoader): + def __init__( + self, + *, + weight_loader: BaseWeightLoader | None = None, + weight_mapper: BaseWeightMapper | None = None, + config_loader: BaseConfigLoader | None = None, + ): + super().__init__( + weight_loader=weight_loader, weight_mapper=weight_mapper, config_loader=config_loader + ) + self._checkpoint_format = "mistral" + self.mm_module_mapping = { + "vision_encoder": "vision_tower", + "pre_mm_projector_norm": "multi_modal_projector.norm", + "vision_language_adapter": "multi_modal_projector", + "patch_merger": "multi_modal_projector.patch_merger", + } + + def preprocess_weights(self, weights: dict) -> dict: + """ + Aggregate weights by module + """ + hf_weights = {} + + for key, value in weights.items(): + modules = key.split(".") + + if modules[0] not in self.mm_module_mapping.keys(): + hf_weights["language_model." + key] = value + + else: + modules[0] = self.mm_module_mapping[modules[0]] + hf_weights[".".join(modules)] = value + + return hf_weights + + def inverse_nvfp4_global_scales(self, weights): + for key in weights.keys(): + if "global_scale" in key: + weights[key] = 1.0 / weights[key] + + def load_weights(self, checkpoint_dir: str, **kwargs): + weights = super().weight_loader.load_weights(checkpoint_dir, **kwargs) + weights = self.preprocess_weights(weights) + # The definition of global_scale is different in Mistral, need to inverse the scale + self.inverse_nvfp4_global_scales(weights) + return weights + + def get_default_config_loader(self) -> MistralConfigLoader: + return MistralConfigLoader() + + +@register_checkpoint_loader("mistral_large_3") +class MistralLarge3CheckpointLoader(MistralCheckpointLoader): + def __init__( + self, + *, + weight_loader: BaseWeightLoader | None = None, + weight_mapper: BaseWeightMapper | None = None, + config_loader: BaseConfigLoader | None = None, + ): + super().__init__( + weight_loader=weight_loader, weight_mapper=weight_mapper, config_loader=config_loader + ) + self._checkpoint_format = "mistral_large_3" diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py new file mode 100644 index 0000000000..95e93fdc05 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py @@ -0,0 +1,314 @@ +import json +from pathlib import Path +from typing import Any + +from transformers import PretrainedConfig, WhisperConfig + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader +from tensorrt_llm._torch.models.modeling_utils import register_config_loader +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization.mode import QuantAlgo + +################### +# vllm code here +# https://github.com/vllm-project/vllm/blob/48a5fff66e78985a634abac0d8d7f271da744000/vllm/transformers_utils/configs/mistral.py +################### + + +def adapt_config_dict( + config_dict: dict[str, Any], + defaults: dict[str, Any] = {}, +) -> PretrainedConfig: + config_dict = _remap_general_mistral_args(config_dict) + + if bool(config_dict.get("quantization")): + config_dict = _remap_mistral_quantization_args(config_dict) + + is_moe = bool(config_dict.get("moe")) + is_mistral_large_3 = is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 + if config_dict.get("model_type") == "mamba": + config_dict["architectures"] = ["Mamba2ForCausalLM"] + elif is_moe and is_mistral_large_3: + config_dict = _remap_moe_args(config_dict) + config_dict["model_type"] = "deepseek_v3" + config_dict["architectures"] = ["MistralLarge3ForCausalLM"] + + assert "llama_4_scaling" in config_dict, "MistralLarge3 expect llama4 scaling config." + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [key in config_dict["llama_4_scaling"] for key in llama_4_scaling_config_keys] + ), f"llama_4_scaling config should define the keys: {','.join(llama_4_scaling_config_keys)}" + elif is_moe: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if bool(config_dict.get("yarn")): + config_dict = _remap_mistral_yarn_args(config_dict) + + if bool(config_dict.get("llama_4_scaling")): + llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"] + assert all( + [key in config_dict["llama_4_scaling"] for key in llama_4_scaling_config_keys] + ), f"llama_4_scaling config should define the keys: {','.join(llama_4_scaling_config_keys)}" + + is_vision = (config_dict.get("multimodal") or {}).get("vision_encoder_args") or config_dict.get( + "vision_encoder" + ) + is_audio = bool( + ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get("encoder_args") + ) + + assert not (is_vision and is_audio), "Vision and audio are mutually exclusive" + + if is_vision: + config_dict = _remap_mistral_vision_args(config_dict) + if is_audio: + config_dict = _remap_mistral_audio_args(config_dict) + + for k, v in defaults.items(): + config_dict.setdefault(k, v) + + config = PretrainedConfig.from_dict(config_dict) + + return config + + +def _remap_mistral_vision_args(config: dict) -> dict: + if config.get("multimodal"): + vision_config = config.pop("multimodal") + else: + vision_config = config.pop("vision_encoder") + + quant_config = config.get("quantization_config") + config = { + "model_type": "pixtral", + "architectures": ["PixtralForConditionalGeneration"], + "text_config": PretrainedConfig.from_dict(config), + "vision_config": PretrainedConfig.from_dict(vision_config), + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_mistral_yarn_args(config: dict) -> dict: + yarn_config_map = { + "factor": "factor", + "original_max_position_embeddings": "original_max_position_embeddings", + "beta": "beta_fast", + "alpha": "beta_slow", + "apply_scale": "apply_yarn_scaling", + } + yarn_config = config.get("yarn") or {} + config["rope_parameters"] = { + "rope_type": "yarn", + "mscale_all_dim": 1, + } + + if rope_theta := config.pop("rope_theta", None): + config["rope_parameters"]["rope_theta"] = rope_theta + + for old_name, new_name in yarn_config_map.items(): + if old_name in yarn_config: + config["rope_parameters"][new_name] = yarn_config.pop(old_name) + + assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" + + return config + + +def _remap_general_mistral_args(config: dict) -> dict: + # Mistral key -> HF key + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + # HF key -> (Mistral key, default value) + top_level_mapping_with_default = { + "model_type": ("model_type", "transformer"), + "hidden_act": ("activation", "silu"), + "tie_word_embeddings": ("tied_embeddings", False), + "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)), + "max_position_embeddings": ("max_position_embeddings", 128_000), + } + + for key, new_key in config_mapping.items(): + if key in config: + config[new_key] = config.pop(key) + + for new_key, (key, default_value) in top_level_mapping_with_default.items(): + config[new_key] = config.pop(key, default_value) + + return config + + +def _remap_mistral_quantization_args(config: dict) -> dict: + if config.get("quantization"): + quantization = config.pop("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + qscheme_act = quantization.get("qscheme_act") + assert qscheme_act in ("NO_SCALES", "TENSOR", None), ( + "Only NO_SCALES and TENSOR (default) are supported for qscheme_act" + ) + is_dynamic = qscheme_act == "NO_SCALES" + config["quantization_config"] = { + "quant_method": "fp8", + "activation_scheme": "dynamic" if is_dynamic else "static", + } + else: + raise ValueError(f"Found unknown quantization='{quantization}' in config") + + return config + + +def _remap_mistral_audio_args(config: dict) -> dict: + whisper_args = config["multimodal"].pop("whisper_model_args") + encoder_args = whisper_args["encoder_args"] + downsample_args = whisper_args["downsample_args"] + + quant_config = config.get("quantization_config") + config = { + "model_type": "whixtral", + "architectures": ["VoxtralForConditionalGeneration"], + "text_config": PretrainedConfig.from_dict(config), + "audio_config": WhisperConfig( + num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], + window_size=encoder_args["audio_encoding_args"]["window_size"], + sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], + hop_length=encoder_args["audio_encoding_args"]["hop_length"], + downsample_factor=downsample_args["downsample_factor"], + d_model=encoder_args["dim"], + encoder_layers=encoder_args["n_layers"], + encoder_ffn_dim=encoder_args["hidden_dim"], + encoder_attention_heads=encoder_args["n_heads"], + vocab_size=encoder_args["vocab_size"], + max_source_positions=encoder_args["max_source_positions"], + is_encoder_decoder=False, # Override WhisperConfig default + ), + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_moe_args(config: dict) -> dict: + moe_config_map = { + "route_every_n": "moe_layer_freq", + "first_k_dense_replace": "first_k_dense_replace", + "num_experts_per_tok": "num_experts_per_tok", + "num_experts": "n_routed_experts", + "expert_hidden_dim": "moe_intermediate_size", + "routed_scale": "routed_scaling_factor", + "num_shared_experts": "n_shared_experts", + "num_expert_groups": "n_group", + "num_expert_groups_per_tok": "topk_group", + } + moe_config = config.get("moe", {}) + for old_name, new_name in moe_config_map.items(): + if old_name in moe_config: + value = moe_config.pop(old_name) + config[new_name] = value + + config["topk_method"] = None + config["norm_topk_prob"] = True + config["scoring_func"] = "softmax" + + return config + + +###################### +# End of vllm code +###################### + + +@register_config_loader("mistral") +@register_config_loader("mistral_large_3") +class MistralConfigLoader(BaseConfigLoader): + def _load_mistral_config_dict(self, checkpoint_dir: str, config_file_name: str) -> dict | None: + file_path = Path(checkpoint_dir) / Path(config_file_name) + + if file_path.exists() and file_path.is_file(): + with open(file_path) as file: + return json.load(file) + return None + + # Adaptation of + # https://github.com/vllm-project/vllm/blob/48a5fff66e78985a634abac0d8d7f271da744000/vllm/transformers_utils/config.py#L175 + def _parse_mistral_config(self, checkpoint_dir: str): + config_file_name = "params.json" + + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = self._load_mistral_config_dict(checkpoint_dir, config_file_name) + if config_dict is None: + raise ValueError( + f"Failed to load '{config_file_name}' config from '{checkpoint_dir}'. " + f"Only local checkpoints are supported for mistral format." + ) + assert isinstance(config_dict, dict) + + if (max_position_embeddings := config_dict.get("max_position_embeddings")) is None: + max_position_embeddings = 128_000 + config_dict["max_position_embeddings"] = max_position_embeddings + + pretrained_config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if (sliding_window := getattr(pretrained_config, "sliding_window", None)) and isinstance( + sliding_window, list + ): + pattern_repeats = pretrained_config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + pretrained_config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + pretrained_config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, pretrained_config + + def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig: + # Re-write from ModelConfig.from_pretrained + + config_dict, pretrained_config = self._parse_mistral_config(checkpoint_dir) + + # Some checkpoints lack torch_dtype, populate with dtype + pretrained_config.torch_dtype = getattr(pretrained_config, "dtype", None) + quant_config = QuantConfig() + layer_quant_config = None + + hf_quant_config = pretrained_config.quantization_config + if hf_quant_config.get("quant_method") == "compressed-tensors": + if "NVFP4" in hf_quant_config.get("config_groups"): + quant_config.quant_algo = QuantAlgo.NVFP4 + quant_config.group_size = 16 + ignore_list = hf_quant_config.get("ignore", []) + quant_config.exclude_modules = [] + if "re:.*attn.*" in ignore_list: + quant_config.exclude_modules.append("model.layers.*.self_attn.*") + if "re:vision_encoder.*" in ignore_list: + quant_config.exclude_modules.append("vision_encoder*") + if "re:vision_language_adapter.*" in ignore_list: + quant_config.exclude_modules.append("vision_language_adapter*") + + elif "FP8_BLOCK" in hf_quant_config.get("config_groups"): + quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + quant_config.group_size = 128 + quant_config.exclude_modules = ["*q_a_proj*", "*kv_a_proj_with_mqa*"] + + kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter + model_config = ModelConfig( + pretrained_config=pretrained_config, + quant_config=quant_config, + quant_config_dict=layer_quant_config, + **kwargs, + ) + model_config._frozen = True + return model_config diff --git a/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py new file mode 100644 index 0000000000..28362f1f90 --- /dev/null +++ b/tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py @@ -0,0 +1,131 @@ +from torch import nn + +from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper +from tensorrt_llm._torch.models.modeling_utils import register_mapper + + +@register_mapper("mistral", "MistralForCausalLM") +@register_mapper("mistral", "PixtralForConditionalGeneration") +class MistralWeightMapper(HfWeightMapper): + def __init__(self): + super().__init__() + + self._callbacks.append(self._permute_qk) + + self.pixtral_mapping = { + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "w_in": "linear_1", + "w_out": "linear_2", + } + + self.mistral_llm_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "qscale_act": "input_scale", + "qscale_weight": "weight_scale_inv", + "kv_fake_quantizer.qscale_act": "kv_scale", + "q_fake_quantizer.qscale_act": "attn.q_scale", + "k_fake_quantizer.qscale_act": "k_scale", + "v_fake_quantizer.qscale_act": "v_scale", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm", + # For Eagle3 + "language_model.eagle_linear": "model.fc", + "language_model.layers": "layers", + "language_model.norm": "norm", + } + self.mistral_llm_mapping.update(self.pixtral_mapping) + + # Adapted from: + # https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657 + def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dict: + renamed_weights = {} + + for key in list(weights.keys()): + new_key = key + modules = key.split(".") + num_modules = len(modules) + for i in range(num_modules): + item = modules[i] + next_item = modules[i + 1] if i < num_modules - 1 else None + + combined_item = f"{item}.{next_item}" if next_item is not None else None + + if combined_item in params_map: + new_key = new_key.replace(combined_item, params_map[combined_item]) + elif item in params_map: + new_key = new_key.replace(item, params_map[item]) + + renamed_weights[new_key] = weights[key] + + return renamed_weights + + def _permute_qk(self, module: nn.Module, new_name: str, weights: dict): + # Adapted from: + # https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657 + + processed_weights = {} + config = self.config.pretrained_config + + def permute(w, n_heads: int, attn_out: int): + attn_in = config.head_dim * n_heads + + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) + + # rotary embeds should be sliced + # If using quantized model in mistral format, + # quantization scales (qscale_weight) also need to be sliced + + if new_name in ["k_proj", "q_proj"]: + n_heads = ( + config.num_key_value_heads if new_name == "k_proj" else config.num_attention_heads + ) + + processed_weights["weight"] = permute(weights["weight"], n_heads, config.hidden_size) + + if "qscale_weight" in weights and weights["qscale_weight"].numel() > 1: + processed_weights["qscale_weight"] = permute(weights["qscale_weight"], n_heads, 1) + + return processed_weights + + return weights + + +@register_mapper("mistral_large_3") +@register_mapper("mistral_large_3", "PixtralForConditionalGeneration") +@register_mapper("mistral_large_3", "MistralLarge3ForCausalLM") +class MistralLarge3WeightMapper(MistralWeightMapper): + def __init__(self): + super().__init__() + + self.mistral_llm_mapping.update( + { + "wkv_a_with_mqa": "kv_a_proj_with_mqa", + "wkv_b": "kv_b_proj", + "wq_a": "q_a_proj", + "q_a_norm": "q_a_layernorm", + "wq_b": "q_b_proj", + "kv_a_norm": "kv_a_layernorm", + "k_fake_quantizer.qscale_act": "mla_attn.mla_attn.k_scale", + "q_fake_quantizer.qscale_act": "mla_attn.mla_attn.q_scale", + "v_fake_quantizer.qscale_act": "mla_attn.mla_attn.v_scale", + "gate": "mlp.gate", + "shared_experts": "mlp.shared_experts", + "experts": "mlp.experts", + "router_biases": "mlp.gate.e_score_correction_bias", + } + ) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 40fbaa983d..8df4eae706 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -746,17 +746,19 @@ class Deepseekv3MoE(nn.Module): config = model_config.pretrained_config self.top_k = top_k self.use_dp = model_config.mapping.enable_attention_dp - self.gate = DeepseekV3Gate( - hidden_size, - num_experts, - top_k=top_k, - n_group=config.n_group, - topk_group=config.topk_group, - routed_scaling_factor=config.routed_scaling_factor, - dtype=dtype, - fuse_routing_kernel=True, - apply_routing=False, - moe_backend=model_config.moe_backend) + gate_cls = DeepseekV3Gate + if hasattr(model_config.pretrained_config, "gate_cls"): + gate_cls = model_config.pretrained_config.gate_cls + self.gate = gate_cls(hidden_size, + num_experts, + top_k=top_k, + n_group=config.n_group, + topk_group=config.topk_group, + routed_scaling_factor=config.routed_scaling_factor, + dtype=dtype, + fuse_routing_kernel=True, + apply_routing=False, + moe_backend=model_config.moe_backend) self.experts = create_moe( num_experts=num_experts, routing_method=self.gate.routing_method, diff --git a/tensorrt_llm/_torch/models/modeling_mistral.py b/tensorrt_llm/_torch/models/modeling_mistral.py index 9ade4dee22..2667d20d55 100644 --- a/tensorrt_llm/_torch/models/modeling_mistral.py +++ b/tensorrt_llm/_torch/models/modeling_mistral.py @@ -1,6 +1,7 @@ +import copy import dataclasses import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch import torchvision @@ -14,11 +15,16 @@ from tensorrt_llm._torch.attention_backend.interface import ( PositionalEmbeddingParams, RopeParams) from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models import modeling_pixtral +from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import \ + MistralWeightMapper +from tensorrt_llm._torch.models.modeling_mistral_large3 import ( + Mistral3Gate, MistralLarge3ForCausalLM) from tensorrt_llm._torch.models.modeling_multimodal_utils import ( find_input_mm_embeds, fuse_input_embeds, get_multimodal_embeddings) from tensorrt_llm._torch.models.modeling_utils import (DecoderModel, DecoderModelForCausalLM, _load_weights_impl, + filter_weights, register_auto_model) from tensorrt_llm._torch.modules.attention import Attention from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer @@ -52,7 +58,7 @@ class MistralAttention(Attention): def __init__( self, model_config: ModelConfig[MistralConfig], - layer_idx: Optional[int] = None, + layer_idx: int | None = None, ): config = model_config.pretrained_config super().__init__( @@ -111,8 +117,8 @@ class MistralDecoderLayer(DecoderLayer): position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor] = None, - spec_metadata: Optional[SpecMetadata] = None, + residual: torch.Tensor | None = None, + spec_metadata: SpecMetadata | None = None, **kwargs, ) -> torch.Tensor: if residual is None: @@ -169,11 +175,11 @@ class MistralModel(DecoderModel): def forward( self, attn_metadata: AttentionMetadata, - input_ids: Optional[torch.IntTensor] = None, - position_ids: Optional[torch.IntTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[SpecMetadata] = None, - lora_params: Optional[Any] = None, + input_ids: torch.IntTensor | None = None, + position_ids: torch.IntTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + spec_metadata: SpecMetadata | None = None, + lora_params: Any | None = None, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( @@ -222,7 +228,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor, self, model_path: str, config: PretrainedConfig, - tokenizer: Optional[AutoTokenizer], + tokenizer: AutoTokenizer | None, trust_remote_code: bool = False, **kwargs, ): @@ -264,9 +270,11 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor, @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams - ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + ) -> Tuple[List[int], ExtraProcessedInputs | None]: images = inputs.get("multi_modal_data", {}).get("image") - do_rescale = self.processor.image_processor.do_rescale + mm_processor_kwargs = inputs.get("mm_processor_kwargs", {}) + do_rescale = getattr(self.processor.image_processor, "do_rescale", + False) if images is not None and isinstance(images[0], torch.Tensor): # The default multimodal input loader will normalize images to [0, 1] when the requested # format is "pt" (pytorch tensors), but not for "pil" (PIL images). @@ -276,6 +284,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor, text=inputs["prompt"], images=images, do_rescale=do_rescale, + **mm_processor_kwargs, ) input_ids = processed.pop("input_ids").tolist()[0] # Remaining in `processed`: @@ -331,6 +340,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor, @register_auto_model("Mistral3ForConditionalGeneration") +@register_auto_model("PixtralForConditionalGeneration") @register_input_processor( Mistral3InputProcessor, model_type="mistral3", @@ -365,34 +375,48 @@ class Mistral3VLM(PreTrainedModel): config = model_config.pretrained_config super().__init__(config) - self.model_config = model_config - - llm_model_config = self._get_sub_model_config(model_config, - "text_config") - # This is necessary for the auto weight mapper to figure out what it needs. - llm_model_config.pretrained_config.architectures = config.architectures - self.llm = MistralForCausalLM(llm_model_config) - - self._device = "cuda" - # NOTE: current `modelopt` does not support quantizing the vision portion. - vision_model_config = self._get_sub_model_config(model_config, - "vision_config", - quant_config=None) - self._vision_tower = modeling_pixtral.PixtralVisionModel( - vision_model_config) - self._multi_modal_projector = Mistral3MultiModalProjector(model_config) - vision_feature_layer = config.vision_feature_layer + vision_feature_layer = getattr(config, "vision_feature_layer", -1) if vision_feature_layer != -1: raise ValueError( f"Using intermediate layers ({vision_feature_layer}) in the `PixtralVisionModel` " f"is not supported. Please use `vision_feature_layer=-1`.") + self._device = "cuda" self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16) - - self._image_token_ids = torch.tensor([config.image_token_index], + image_token_index = getattr( + config, "image_token_index", None) or getattr( + config.vision_config, "image_token_id", None) + self._image_token_ids = torch.tensor([image_token_index], dtype=torch.int32, device=self._device) + + model_config_cp = copy.deepcopy(model_config) + + llm_model_config = self._get_sub_model_config(model_config_cp, + "text_config") + self.model_config = model_config_cp + llm_class = MistralForCausalLM + if llm_model_config.pretrained_config.architectures[ + 0] == "MistralLarge3ForCausalLM": + llm_class = MistralLarge3ForCausalLM + + llm_model_config.pretrained_config.gate_cls = Mistral3Gate + self.llm = llm_class(llm_model_config) + self.model_config.extra_attrs.update(llm_model_config.extra_attrs) + + # NOTE: current `modelopt` does not support quantizing the vision portion. + # NOTE: attn_backend: Pixtral head size not always divisible by 128 + vision_model_config = self._get_sub_model_config(model_config_cp, + "vision_config", + attn_backend="VANILLA", + quant_config=None) + + self._vision_tower = modeling_pixtral.PixtralVisionModel( + vision_model_config) + self._multi_modal_projector = Mistral3MultiModalProjector( + model_config).eval().to(self._device) self._post_config() + self.is_loaded = True # This is necessary because the executor looks at # `model.model_config.pretrained_config.vocab_size`. @@ -400,18 +424,39 @@ class Mistral3VLM(PreTrainedModel): self.config = self.llm.config self.model_config.pretrained_config = self.llm.config - def load_weights(self, weights: Dict, *args, **kwargs): - llm_weights = _filter_weights(weights, "language_model.") - self.llm.load_weights(llm_weights, *args, **kwargs) + def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs): + vit_params_map = None + if weight_mapper: + if isinstance(weight_mapper, MistralWeightMapper): + vit_params_map = weight_mapper.pixtral_mapping - vit_weights = _filter_weights(weights, "vision_tower.") - self._vision_tower.load_weights(vit_weights, *args, **kwargs) + llm_weights = filter_weights(weights=weights, prefix="language_model") + logger.debug(f"Loading weights for {type(self.llm)}") + self.llm.load_weights(llm_weights) + logger.debug(f"Successfully loaded weights for {type(self.llm)}") - mm_projector_weights = _filter_weights(weights, - "multi_modal_projector.") - # `_load_weights_impl` assumes `config.hidden_size` exists, which is not the case for the - # top-level `Mistral3Config`. + vit_weights = filter_weights(weights=weights, prefix="vision_tower") + logger.debug(f"Loading weights for {type(self._vision_tower)}") + + if vit_params_map is not None: + vit_weights = weight_mapper.rename_by_params_map( + weights=vit_weights, params_map=vit_params_map) + + self._vision_tower.load_weights(vit_weights) + logger.debug( + f"Successfully loaded weights for {type(self._vision_tower)}") + + logger.debug(f"Loading weights for {type(self._multi_modal_projector)}") + mm_projector_weights = filter_weights(weights=weights, + prefix="multi_modal_projector") + + if vit_params_map is not None: + mm_projector_weights = weight_mapper.rename_by_params_map( + weights=mm_projector_weights, params_map=vit_params_map) self._multi_modal_projector.load_state_dict(mm_projector_weights) + logger.debug( + f"Successfully loaded weights for {type(self._multi_modal_projector)}" + ) def infer_max_seq_len(self) -> int: return self.llm.infer_max_seq_len() @@ -420,9 +465,10 @@ class Mistral3VLM(PreTrainedModel): def forward( self, attn_metadata: AttentionMetadata, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, return_context_logits: bool = False, + spec_metadata: SpecMetadata | None = None, **kwargs, ) -> torch.Tensor: """Forward method.""" @@ -455,6 +501,7 @@ class Mistral3VLM(PreTrainedModel): position_ids=position_ids, inputs_embeds=inputs_embeds, return_context_logits=return_context_logits, + spec_metadata=spec_metadata, ) @staticmethod @@ -465,16 +512,41 @@ class Mistral3VLM(PreTrainedModel): ) -> ModelConfig: # Extract the subconfig from the `transformers` config and shove it into our own # `ModelConfig` class. + assert name in [ + "text_config", "vision_config" + ], f"Expected subconfig name to be either 'text_config' or 'vision_config'. Got {name} instead." + pretrained_config = getattr(model_config.pretrained_config, name) + sub_model_config: ModelConfig[MistralConfig] = dataclasses.replace( model_config, pretrained_config=getattr(model_config.pretrained_config, name), **changes, ) + if name == "text_config": + sub_model_config._frozen = False + sub_model_config.skip_create_weights_in_init = True + if not hasattr( + sub_model_config.pretrained_config, "architectures" + ) or sub_model_config.pretrained_config.architectures is None: + sub_model_config.pretrained_config.architectures = model_config.pretrained_config.architectures + sub_model_config._frozen = True + # Make sure some fields that are not explicitly included in the sub config, but present # in the top-level config, are replicated. if (hasattr(sub_model_config.pretrained_config, "torch_dtype") and sub_model_config.pretrained_config.torch_dtype is None): - sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype + sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype or torch.bfloat16 + + if name == "vision_config": + pretrained_config = sub_model_config.pretrained_config + defaults = { + "head_dim": pretrained_config.hidden_size // + pretrained_config.num_attention_heads, + "hidden_act": "silu", + } + for attr, default in defaults.items(): + if not hasattr(pretrained_config, attr): + setattr(pretrained_config, attr, default) return sub_model_config @@ -572,6 +644,12 @@ class Mistral3VLM(PreTrainedModel): def mm_token_ids(self): return self._image_token_ids + def load_draft_weights( + self, + weights: Dict, + weight_mapper: MistralWeightMapper | None = None) -> None: + self.llm.load_draft_weights(weights, weight_mapper=weight_mapper) + # Original implementation: # https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66 @@ -586,13 +664,15 @@ class Mistral3PatchMerger(torch.nn.Module): self.config = config hidden_size = config.vision_config.hidden_size - self._spatial_merge_size = config.spatial_merge_size + self._spatial_merge_size = getattr( + config, "spatial_merge_size", None) or getattr( + config.vision_config, "spatial_merge_size") self._patch_size = config.vision_config.patch_size self.merging_layer = Linear( in_features=hidden_size * self._spatial_merge_size**2, out_features=hidden_size, bias=False, - dtype=config.torch_dtype, + dtype=config.torch_dtype or model_config.torch_dtype, mapping=model_config.mapping, ) @@ -640,7 +720,7 @@ class Mistral3MultiModalProjector(torch.nn.Module): self.model_config = model_config self.config = config - dtype = config.torch_dtype + dtype = config.torch_dtype or model_config.torch_dtype self.norm = RMSNorm( hidden_size=config.vision_config.hidden_size, # NOTE: the original implementation actually does not look at the config for this value. @@ -650,21 +730,21 @@ class Mistral3MultiModalProjector(torch.nn.Module): ) self.patch_merger = Mistral3PatchMerger(model_config) # We have hidden_size * the number of vision feature layers - num_feature_layers = 1 if isinstance(config.vision_feature_layer, - int) else len( - config.vision_feature_layer) + vision_feature_layer = getattr(config, "vision_feature_layer", -1) + num_feature_layers = 1 if isinstance(vision_feature_layer, + int) else len(vision_feature_layer) self.linear_1 = Linear( in_features=config.vision_config.hidden_size * num_feature_layers, out_features=config.text_config.hidden_size, - bias=config.multimodal_projector_bias, + bias=getattr(config, "multimodal_projector_bias", None), dtype=dtype, mapping=model_config.mapping, ) - self.act = ACT2FN[config.projector_hidden_act] + self.act = ACT2FN[getattr(config, "projector_hidden_act", "gelu")] self.linear_2 = Linear( in_features=config.text_config.hidden_size, out_features=config.text_config.hidden_size, - bias=config.multimodal_projector_bias, + bias=getattr(config, "multimodal_projector_bias", None), dtype=dtype, mapping=model_config.mapping, ) diff --git a/tensorrt_llm/_torch/models/modeling_mistral_large3.py b/tensorrt_llm/_torch/models/modeling_mistral_large3.py new file mode 100644 index 0000000000..c88cebdf05 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_mistral_large3.py @@ -0,0 +1,70 @@ +from typing import Dict, List + +import torch +from torch import nn + +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralLarge3WeightMapper +from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3ForCausalLM +from tensorrt_llm._torch.models.modeling_utils import register_auto_model +from tensorrt_llm._torch.modules.fused_moe import RenormalizeNaiveMoeRoutingMethod +from tensorrt_llm.quantization.mode import QuantAlgo + + +class Mistral3Gate(nn.Module): + def __init__( + self, + hidden_size: int, + num_experts: int, + top_k: int, + dtype: torch.dtype | None = None, + **kwargs, + ): + super().__init__() + self.weight = nn.Parameter( + torch.empty((num_experts, hidden_size), dtype=dtype), requires_grad=False + ) + self.top_k = top_k + self.dtype = dtype + self.routing_method = RenormalizeNaiveMoeRoutingMethod(top_k=self.top_k) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits: torch.Tensor = torch.ops.trtllm.cublas_mm( + hidden_states, self.weight.t(), bias=None, out_dtype=self.dtype + ) + return logits + + def load_weights(self, weights: List[Dict]): + assert len(weights) == 1 + + self.weight.copy_(weights[0]["weight"][:]) + + +@register_auto_model("MistralLarge3ForCausalLM") +class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM): + def __init__(self, model_config: ModelConfig): + super().__init__(model_config) + self.weight_mapper = MistralLarge3WeightMapper() + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + def load_weights(self, weights: Dict): + assert self.model_config is not None, "self.model_config is required" + params_map = self.weight_mapper.mistral_llm_mapping.copy() + quantization_weights_map: Dict[str, str] = {} + if self.model_config.quant_config.quant_algo == QuantAlgo.NVFP4: + quantization_weights_map = { + "weight_packed": "weight", + "input_global_scale": "input_scale", + "weight_global_scale": "weight_scale_2", + } + elif self.model_config.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: + quantization_weights_map = { + "weight_scale": "weight_scale_inv", + } + if quantization_weights_map: + params_map.update(quantization_weights_map) + weights = self.weight_mapper.rename_by_params_map(weights=weights, params_map=params_map) + + super().load_weights(weights) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index 6013d51fa2..e4fa9da6e6 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -38,14 +38,22 @@ _CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDic def load_pretrained_config(model_name_or_path: str, trust_remote_code: bool = False, + checkpoint_format: str = None, **kwargs) -> transformers.PretrainedConfig: config_dict, _ = transformers.PretrainedConfig.get_config_dict( model_name_or_path, **kwargs) model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] model_config = config_class.from_pretrained(model_name_or_path, **kwargs) + elif checkpoint_format in ("mistral", "mistral_large_3"): + from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \ + MistralConfigLoader + model_config = getattr( + MistralConfigLoader().load(model_name_or_path).pretrained_config, + "text_config") else: model_config = transformers.AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=trust_remote_code) diff --git a/tensorrt_llm/inputs/registry.py b/tensorrt_llm/inputs/registry.py index 7737600e6f..54902a5ba3 100644 --- a/tensorrt_llm/inputs/registry.py +++ b/tensorrt_llm/inputs/registry.py @@ -600,6 +600,12 @@ def create_input_processor( logger.debug( f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back." ) + elif checkpoint_format in ("mistral", "mistral_large_3"): + logger.debug(f"Detected checkpoint_format={checkpoint_format}.") + from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \ + MistralConfigLoader + model_config = MistralConfigLoader().load(model_path_or_dir) + config = model_config.pretrained_config else: logger.debug( f"checkpoint_format={checkpoint_format}; skipping HF config load.") diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 70285d0aea..c9699bb91f 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -112,7 +112,8 @@ class OpenAIServer: from tensorrt_llm._torch.pyexecutor.config_utils import \ load_pretrained_config self.model_config = load_pretrained_config(hf_tokenizer_path, - trust_remote_code=trust_remote_code) + trust_remote_code=trust_remote_code, + checkpoint_format=getattr(self.llm.args, "checkpoint_format", None)) except Exception: logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path) self.model_config = None diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index c62ff5a0d8..33f7dddc6b 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -281,3 +281,5 @@ bigcode/starcoder2-7b: - accuracy: 26.5 bigcode/starcoder2-15b: - accuracy: 54.5 +mistral/Mistral-Large-3-675B: + - accuracy: 90.83 diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index dd404ba8f7..f728919abe 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -340,3 +340,5 @@ mistralai/Mistral-Nemo-12b-Base: - accuracy: 69.66 - quant_algo: FP8 accuracy: 69.66 +mistral/Mistral-Large-3-675B: + - accuracy: 87.54 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index f4bb84ae63..3b667b15c9 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -4828,3 +4828,107 @@ class TestLlama3_1_8B_Instruct_RocketKV(LlmapiAccuracyTestHarness): task.evaluate(llm, sampling_params=sampling_params, extra_evaluator_kwargs=extra_evaluator_kwargs) + + +class TestMistralLarge3_675B(LlmapiAccuracyTestHarness): + MODEL_NAME = "mistral/Mistral-Large-3-675B" + + @skip_pre_blackwell + @pytest.mark.skip_less_mpi_world_size(4) + @pytest.mark.skip_less_device_memory(183000) + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend,eagle3", + [ + (4, 1, 4, False, True, True, "TRTLLM", False), + ], + ids=[ + "latency_moe_trtllm", + ], + ) + def test_nvfp4_4gpus(self, tp_size, pp_size, ep_size, attention_dp, + cuda_graph, overlap_scheduler, moe_backend, eagle3): + + if moe_backend == "TRTLLM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE TRTLLM backend does not support SM version 120 or 121") + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4, + enable_block_reuse=not eagle3) + spec_config = None + if eagle3: + spec_config = EagleDecodingConfig( + max_draft_len=2, + speculative_model_dir= + f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/", + eagle3_one_model=True) + with LLM( + f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-NVFP4/", + checkpoint_format="mistral", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + **pytorch_config, + enable_attention_dp=attention_dp, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @skip_pre_blackwell + @pytest.mark.skip_less_mpi_world_size(8) + @pytest.mark.skip_less_device_memory(183000) + @pytest.mark.parametrize( + "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend,eagle3", + [ + (8, 1, 8, False, True, True, "DEEPGEMM", False), + ], + ids=[ + "latency_moe_deepgemm", + ], + ) + def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph, + overlap_scheduler, moe_backend, eagle3): + + if moe_backend == "DEEPGEMM" and (get_sm_version() == 120 + or get_sm_version() == 121): + pytest.skip( + "MOE DEEPGEMM backend does not support SM version 120 or 121") + + pytorch_config = dict( + disable_overlap_scheduler=not overlap_scheduler, + cuda_graph_config=CudaGraphConfig() if cuda_graph else None, + moe_config=MoeConfig(backend=moe_backend)) + + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4, + enable_block_reuse=not eagle3) + spec_config = None + if eagle3: + spec_config = EagleDecodingConfig( + max_draft_len=2, + speculative_model_dir= + f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/", + eagle3_one_model=True) + with LLM( + f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512/", + checkpoint_format="mistral", + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + moe_expert_parallel_size=ep_size, + **pytorch_config, + enable_attention_dp=attention_dp, + kv_cache_config=kv_cache_config, + speculative_config=spec_config) as llm: + + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index ccd23bdf08..f54045dd16 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -109,6 +109,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (180) + - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (90) - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 447e989f54..b53a64c61b 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -91,3 +91,4 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-dp4-trtllm-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8] - accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90) + - accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (90)