mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Support Mistral Large3 LLM part (#9820)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
This commit is contained in:
parent
98d72c7648
commit
e49c70f6df
@ -23,6 +23,11 @@ def add_llm_args(parser):
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="A single or a list of text prompts.")
|
||||
parser.add_argument('--checkpoint_format',
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["HF", "mistral"],
|
||||
help="Model checkpoint format.")
|
||||
# Build config
|
||||
parser.add_argument("--max_seq_len",
|
||||
type=int,
|
||||
@ -237,6 +242,7 @@ def setup_llm(args, **kwargs):
|
||||
llm = LLM(
|
||||
model=args.model_dir,
|
||||
backend='pytorch',
|
||||
checkpoint_format=args.checkpoint_format,
|
||||
disable_overlap_scheduler=args.disable_overlap_scheduler,
|
||||
kv_cache_config=kv_cache_config,
|
||||
attn_backend=args.attention_backend,
|
||||
|
||||
53
examples/models/core/mistral_large_3/README.md
Normal file
53
examples/models/core/mistral_large_3/README.md
Normal file
@ -0,0 +1,53 @@
|
||||
# Mistral Large V3
|
||||
|
||||
* Setup the model path
|
||||
|
||||
```bash
|
||||
export mistral_large_3_model_path=<mistral_large_3_model_path>
|
||||
```
|
||||
|
||||
## LLM-only run
|
||||
|
||||
* Run the Mistral Large V3 by `quickstart_advanced.py`
|
||||
|
||||
```bash
|
||||
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py \
|
||||
--model_dir ${mistral_large_3_model_path} \
|
||||
--tp_size 4 \
|
||||
--moe_ep_size 4 \
|
||||
--max_tokens 100 \
|
||||
--checkpoint_format mistral \
|
||||
--moe_backend TRTLLM
|
||||
```
|
||||
|
||||
* Launch the trtllm-serve and send a request
|
||||
|
||||
```bash
|
||||
echo "
|
||||
backend: pytorch
|
||||
tensor_parallel_size: 4
|
||||
moe_expert_parallel_size: 4
|
||||
enable_attention_dp: false
|
||||
kv_cache_config:
|
||||
enable_block_reuse: true
|
||||
checkpoint_format: mistral
|
||||
" > serve.yml
|
||||
mpirun -n 1 --allow-run-as-root --oversubscribe python3 -m tensorrt_llm.commands.serve serve \
|
||||
${mistral_large_3_model_path} \
|
||||
--host localhost --port 8001 --backend pytorch \
|
||||
--extra_llm_api_options serve.yml \
|
||||
--tokenizer ${mistral_large_3_model_path} \
|
||||
2>&1 | tee serve_debug.log &
|
||||
|
||||
curl http://localhost:8001/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "${mistral_large_3_model_path}",
|
||||
"prompt": "The capital of France is",
|
||||
"max_tokens": 16,
|
||||
"top_k": 16
|
||||
}'
|
||||
|
||||
# The result would be like
|
||||
{"id":"cmpl-7e342c1d722d4226a1bf3ed35d762c35","object":"text_completion","created":1764061351,"model":"${mistral_large_3_model_path}","choices":[{"index":0,"text":"The capital of France is **Paris**.\n\nParis is the largest city in France and","token_ids":null,"logprobs":null,"context_logits":null,"finish_reason":"length","stop_reason":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":7,"total_tokens":23,"completion_tokens":16,"prompt_tokens_details":{"cached_tokens":1}},"prompt_token_ids":null}
|
||||
```
|
||||
@ -75,3 +75,4 @@ numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
|
||||
partial_json_parser
|
||||
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
|
||||
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
|
||||
mistral-common==1.8.6
|
||||
|
||||
@ -12,11 +12,30 @@ from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
|
||||
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
|
||||
from .hf.weight_loader import HfWeightLoader
|
||||
from .hf.weight_mapper import HfWeightMapper
|
||||
from .mistral.checkpoint_loader import (MistralCheckpointLoader,
|
||||
MistralLarge3CheckpointLoader)
|
||||
from .mistral.config_loader import MistralConfigLoader
|
||||
from .mistral.weight_mapper import (MistralLarge3WeightMapper,
|
||||
MistralWeightMapper)
|
||||
|
||||
__all__ = [
|
||||
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper",
|
||||
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
|
||||
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
|
||||
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
|
||||
"Qwen3NextHfWeightMapper", "LlavaNextHfWeightMapper"
|
||||
"HfConfigLoader",
|
||||
"HfWeightLoader",
|
||||
"HfWeightMapper",
|
||||
"MistralConfigLoader",
|
||||
"MistralWeightMapper",
|
||||
"MistralCheckpointLoader",
|
||||
"BaseCheckpointLoader",
|
||||
"HfCheckpointLoader",
|
||||
"NemotronHHfWeightMapper",
|
||||
"Gemma3HfWeightMapper",
|
||||
"MixtralHfWeightMapper",
|
||||
"Llama4HfWeightMapper",
|
||||
"Qwen2MoeHfWeightMapper",
|
||||
"Qwen3MoeHfWeightMapper",
|
||||
"Qwen2VLHfWeightMapper",
|
||||
"Qwen3NextHfWeightMapper",
|
||||
"LlavaNextHfWeightMapper",
|
||||
"MistralLarge3CheckpointLoader",
|
||||
"MistralLarge3WeightMapper",
|
||||
]
|
||||
|
||||
@ -19,6 +19,7 @@ from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
@register_checkpoint_weight_loader("mistral")
|
||||
@register_checkpoint_weight_loader("HF")
|
||||
class HfWeightLoader(BaseWeightLoader):
|
||||
"""
|
||||
|
||||
@ -0,0 +1,75 @@
|
||||
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import BaseWeightLoader
|
||||
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import BaseWeightMapper
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.checkpoint_loader import HfCheckpointLoader
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import MistralConfigLoader
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_checkpoint_loader
|
||||
|
||||
|
||||
@register_checkpoint_loader("mistral")
|
||||
class MistralCheckpointLoader(HfCheckpointLoader):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
weight_loader: BaseWeightLoader | None = None,
|
||||
weight_mapper: BaseWeightMapper | None = None,
|
||||
config_loader: BaseConfigLoader | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
weight_loader=weight_loader, weight_mapper=weight_mapper, config_loader=config_loader
|
||||
)
|
||||
self._checkpoint_format = "mistral"
|
||||
self.mm_module_mapping = {
|
||||
"vision_encoder": "vision_tower",
|
||||
"pre_mm_projector_norm": "multi_modal_projector.norm",
|
||||
"vision_language_adapter": "multi_modal_projector",
|
||||
"patch_merger": "multi_modal_projector.patch_merger",
|
||||
}
|
||||
|
||||
def preprocess_weights(self, weights: dict) -> dict:
|
||||
"""
|
||||
Aggregate weights by module
|
||||
"""
|
||||
hf_weights = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
modules = key.split(".")
|
||||
|
||||
if modules[0] not in self.mm_module_mapping.keys():
|
||||
hf_weights["language_model." + key] = value
|
||||
|
||||
else:
|
||||
modules[0] = self.mm_module_mapping[modules[0]]
|
||||
hf_weights[".".join(modules)] = value
|
||||
|
||||
return hf_weights
|
||||
|
||||
def inverse_nvfp4_global_scales(self, weights):
|
||||
for key in weights.keys():
|
||||
if "global_scale" in key:
|
||||
weights[key] = 1.0 / weights[key]
|
||||
|
||||
def load_weights(self, checkpoint_dir: str, **kwargs):
|
||||
weights = super().weight_loader.load_weights(checkpoint_dir, **kwargs)
|
||||
weights = self.preprocess_weights(weights)
|
||||
# The definition of global_scale is different in Mistral, need to inverse the scale
|
||||
self.inverse_nvfp4_global_scales(weights)
|
||||
return weights
|
||||
|
||||
def get_default_config_loader(self) -> MistralConfigLoader:
|
||||
return MistralConfigLoader()
|
||||
|
||||
|
||||
@register_checkpoint_loader("mistral_large_3")
|
||||
class MistralLarge3CheckpointLoader(MistralCheckpointLoader):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
weight_loader: BaseWeightLoader | None = None,
|
||||
weight_mapper: BaseWeightMapper | None = None,
|
||||
config_loader: BaseConfigLoader | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
weight_loader=weight_loader, weight_mapper=weight_mapper, config_loader=config_loader
|
||||
)
|
||||
self._checkpoint_format = "mistral_large_3"
|
||||
314
tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py
Normal file
314
tensorrt_llm/_torch/models/checkpoints/mistral/config_loader.py
Normal file
@ -0,0 +1,314 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import PretrainedConfig, WhisperConfig
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.checkpoints.base_config_loader import BaseConfigLoader
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_config_loader
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
###################
|
||||
# vllm code here
|
||||
# https://github.com/vllm-project/vllm/blob/48a5fff66e78985a634abac0d8d7f271da744000/vllm/transformers_utils/configs/mistral.py
|
||||
###################
|
||||
|
||||
|
||||
def adapt_config_dict(
|
||||
config_dict: dict[str, Any],
|
||||
defaults: dict[str, Any] = {},
|
||||
) -> PretrainedConfig:
|
||||
config_dict = _remap_general_mistral_args(config_dict)
|
||||
|
||||
if bool(config_dict.get("quantization")):
|
||||
config_dict = _remap_mistral_quantization_args(config_dict)
|
||||
|
||||
is_moe = bool(config_dict.get("moe"))
|
||||
is_mistral_large_3 = is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
|
||||
if config_dict.get("model_type") == "mamba":
|
||||
config_dict["architectures"] = ["Mamba2ForCausalLM"]
|
||||
elif is_moe and is_mistral_large_3:
|
||||
config_dict = _remap_moe_args(config_dict)
|
||||
config_dict["model_type"] = "deepseek_v3"
|
||||
config_dict["architectures"] = ["MistralLarge3ForCausalLM"]
|
||||
|
||||
assert "llama_4_scaling" in config_dict, "MistralLarge3 expect llama4 scaling config."
|
||||
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
||||
assert all(
|
||||
[key in config_dict["llama_4_scaling"] for key in llama_4_scaling_config_keys]
|
||||
), f"llama_4_scaling config should define the keys: {','.join(llama_4_scaling_config_keys)}"
|
||||
elif is_moe:
|
||||
config_dict["architectures"] = ["MixtralForCausalLM"]
|
||||
else:
|
||||
config_dict["architectures"] = ["MistralForCausalLM"]
|
||||
|
||||
if bool(config_dict.get("yarn")):
|
||||
config_dict = _remap_mistral_yarn_args(config_dict)
|
||||
|
||||
if bool(config_dict.get("llama_4_scaling")):
|
||||
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
|
||||
assert all(
|
||||
[key in config_dict["llama_4_scaling"] for key in llama_4_scaling_config_keys]
|
||||
), f"llama_4_scaling config should define the keys: {','.join(llama_4_scaling_config_keys)}"
|
||||
|
||||
is_vision = (config_dict.get("multimodal") or {}).get("vision_encoder_args") or config_dict.get(
|
||||
"vision_encoder"
|
||||
)
|
||||
is_audio = bool(
|
||||
((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get("encoder_args")
|
||||
)
|
||||
|
||||
assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"
|
||||
|
||||
if is_vision:
|
||||
config_dict = _remap_mistral_vision_args(config_dict)
|
||||
if is_audio:
|
||||
config_dict = _remap_mistral_audio_args(config_dict)
|
||||
|
||||
for k, v in defaults.items():
|
||||
config_dict.setdefault(k, v)
|
||||
|
||||
config = PretrainedConfig.from_dict(config_dict)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_vision_args(config: dict) -> dict:
|
||||
if config.get("multimodal"):
|
||||
vision_config = config.pop("multimodal")
|
||||
else:
|
||||
vision_config = config.pop("vision_encoder")
|
||||
|
||||
quant_config = config.get("quantization_config")
|
||||
config = {
|
||||
"model_type": "pixtral",
|
||||
"architectures": ["PixtralForConditionalGeneration"],
|
||||
"text_config": PretrainedConfig.from_dict(config),
|
||||
"vision_config": PretrainedConfig.from_dict(vision_config),
|
||||
}
|
||||
if quant_config:
|
||||
config["quantization_config"] = quant_config
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_yarn_args(config: dict) -> dict:
|
||||
yarn_config_map = {
|
||||
"factor": "factor",
|
||||
"original_max_position_embeddings": "original_max_position_embeddings",
|
||||
"beta": "beta_fast",
|
||||
"alpha": "beta_slow",
|
||||
"apply_scale": "apply_yarn_scaling",
|
||||
}
|
||||
yarn_config = config.get("yarn") or {}
|
||||
config["rope_parameters"] = {
|
||||
"rope_type": "yarn",
|
||||
"mscale_all_dim": 1,
|
||||
}
|
||||
|
||||
if rope_theta := config.pop("rope_theta", None):
|
||||
config["rope_parameters"]["rope_theta"] = rope_theta
|
||||
|
||||
for old_name, new_name in yarn_config_map.items():
|
||||
if old_name in yarn_config:
|
||||
config["rope_parameters"][new_name] = yarn_config.pop(old_name)
|
||||
|
||||
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_general_mistral_args(config: dict) -> dict:
|
||||
# Mistral key -> HF key
|
||||
config_mapping = {
|
||||
"dim": "hidden_size",
|
||||
"norm_eps": "rms_norm_eps",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"n_heads": "num_attention_heads",
|
||||
"hidden_dim": "intermediate_size",
|
||||
}
|
||||
# HF key -> (Mistral key, default value)
|
||||
top_level_mapping_with_default = {
|
||||
"model_type": ("model_type", "transformer"),
|
||||
"hidden_act": ("activation", "silu"),
|
||||
"tie_word_embeddings": ("tied_embeddings", False),
|
||||
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
|
||||
"max_position_embeddings": ("max_position_embeddings", 128_000),
|
||||
}
|
||||
|
||||
for key, new_key in config_mapping.items():
|
||||
if key in config:
|
||||
config[new_key] = config.pop(key)
|
||||
|
||||
for new_key, (key, default_value) in top_level_mapping_with_default.items():
|
||||
config[new_key] = config.pop(key, default_value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_quantization_args(config: dict) -> dict:
|
||||
if config.get("quantization"):
|
||||
quantization = config.pop("quantization", {})
|
||||
if quantization.get("qformat_weight") == "fp8_e4m3":
|
||||
qscheme_act = quantization.get("qscheme_act")
|
||||
assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
|
||||
"Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
|
||||
)
|
||||
is_dynamic = qscheme_act == "NO_SCALES"
|
||||
config["quantization_config"] = {
|
||||
"quant_method": "fp8",
|
||||
"activation_scheme": "dynamic" if is_dynamic else "static",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Found unknown quantization='{quantization}' in config")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
whisper_args = config["multimodal"].pop("whisper_model_args")
|
||||
encoder_args = whisper_args["encoder_args"]
|
||||
downsample_args = whisper_args["downsample_args"]
|
||||
|
||||
quant_config = config.get("quantization_config")
|
||||
config = {
|
||||
"model_type": "whixtral",
|
||||
"architectures": ["VoxtralForConditionalGeneration"],
|
||||
"text_config": PretrainedConfig.from_dict(config),
|
||||
"audio_config": WhisperConfig(
|
||||
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
||||
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
||||
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
||||
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
||||
downsample_factor=downsample_args["downsample_factor"],
|
||||
d_model=encoder_args["dim"],
|
||||
encoder_layers=encoder_args["n_layers"],
|
||||
encoder_ffn_dim=encoder_args["hidden_dim"],
|
||||
encoder_attention_heads=encoder_args["n_heads"],
|
||||
vocab_size=encoder_args["vocab_size"],
|
||||
max_source_positions=encoder_args["max_source_positions"],
|
||||
is_encoder_decoder=False, # Override WhisperConfig default
|
||||
),
|
||||
}
|
||||
if quant_config:
|
||||
config["quantization_config"] = quant_config
|
||||
return config
|
||||
|
||||
|
||||
def _remap_moe_args(config: dict) -> dict:
|
||||
moe_config_map = {
|
||||
"route_every_n": "moe_layer_freq",
|
||||
"first_k_dense_replace": "first_k_dense_replace",
|
||||
"num_experts_per_tok": "num_experts_per_tok",
|
||||
"num_experts": "n_routed_experts",
|
||||
"expert_hidden_dim": "moe_intermediate_size",
|
||||
"routed_scale": "routed_scaling_factor",
|
||||
"num_shared_experts": "n_shared_experts",
|
||||
"num_expert_groups": "n_group",
|
||||
"num_expert_groups_per_tok": "topk_group",
|
||||
}
|
||||
moe_config = config.get("moe", {})
|
||||
for old_name, new_name in moe_config_map.items():
|
||||
if old_name in moe_config:
|
||||
value = moe_config.pop(old_name)
|
||||
config[new_name] = value
|
||||
|
||||
config["topk_method"] = None
|
||||
config["norm_topk_prob"] = True
|
||||
config["scoring_func"] = "softmax"
|
||||
|
||||
return config
|
||||
|
||||
|
||||
######################
|
||||
# End of vllm code
|
||||
######################
|
||||
|
||||
|
||||
@register_config_loader("mistral")
|
||||
@register_config_loader("mistral_large_3")
|
||||
class MistralConfigLoader(BaseConfigLoader):
|
||||
def _load_mistral_config_dict(self, checkpoint_dir: str, config_file_name: str) -> dict | None:
|
||||
file_path = Path(checkpoint_dir) / Path(config_file_name)
|
||||
|
||||
if file_path.exists() and file_path.is_file():
|
||||
with open(file_path) as file:
|
||||
return json.load(file)
|
||||
return None
|
||||
|
||||
# Adaptation of
|
||||
# https://github.com/vllm-project/vllm/blob/48a5fff66e78985a634abac0d8d7f271da744000/vllm/transformers_utils/config.py#L175
|
||||
def _parse_mistral_config(self, checkpoint_dir: str):
|
||||
config_file_name = "params.json"
|
||||
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
config_dict = self._load_mistral_config_dict(checkpoint_dir, config_file_name)
|
||||
if config_dict is None:
|
||||
raise ValueError(
|
||||
f"Failed to load '{config_file_name}' config from '{checkpoint_dir}'. "
|
||||
f"Only local checkpoints are supported for mistral format."
|
||||
)
|
||||
assert isinstance(config_dict, dict)
|
||||
|
||||
if (max_position_embeddings := config_dict.get("max_position_embeddings")) is None:
|
||||
max_position_embeddings = 128_000
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
pretrained_config = adapt_config_dict(config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if (sliding_window := getattr(pretrained_config, "sliding_window", None)) and isinstance(
|
||||
sliding_window, list
|
||||
):
|
||||
pattern_repeats = pretrained_config.num_hidden_layers // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
pretrained_config.layer_types = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
pretrained_config.sliding_window = next(filter(None, sliding_window), None)
|
||||
|
||||
return config_dict, pretrained_config
|
||||
|
||||
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
|
||||
# Re-write from ModelConfig.from_pretrained
|
||||
|
||||
config_dict, pretrained_config = self._parse_mistral_config(checkpoint_dir)
|
||||
|
||||
# Some checkpoints lack torch_dtype, populate with dtype
|
||||
pretrained_config.torch_dtype = getattr(pretrained_config, "dtype", None)
|
||||
quant_config = QuantConfig()
|
||||
layer_quant_config = None
|
||||
|
||||
hf_quant_config = pretrained_config.quantization_config
|
||||
if hf_quant_config.get("quant_method") == "compressed-tensors":
|
||||
if "NVFP4" in hf_quant_config.get("config_groups"):
|
||||
quant_config.quant_algo = QuantAlgo.NVFP4
|
||||
quant_config.group_size = 16
|
||||
ignore_list = hf_quant_config.get("ignore", [])
|
||||
quant_config.exclude_modules = []
|
||||
if "re:.*attn.*" in ignore_list:
|
||||
quant_config.exclude_modules.append("model.layers.*.self_attn.*")
|
||||
if "re:vision_encoder.*" in ignore_list:
|
||||
quant_config.exclude_modules.append("vision_encoder*")
|
||||
if "re:vision_language_adapter.*" in ignore_list:
|
||||
quant_config.exclude_modules.append("vision_language_adapter*")
|
||||
|
||||
elif "FP8_BLOCK" in hf_quant_config.get("config_groups"):
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
quant_config.group_size = 128
|
||||
quant_config.exclude_modules = ["*q_a_proj*", "*kv_a_proj_with_mqa*"]
|
||||
|
||||
kwargs.pop("trust_remote_code", None) # ModelConfig does not have this input parameter
|
||||
model_config = ModelConfig(
|
||||
pretrained_config=pretrained_config,
|
||||
quant_config=quant_config,
|
||||
quant_config_dict=layer_quant_config,
|
||||
**kwargs,
|
||||
)
|
||||
model_config._frozen = True
|
||||
return model_config
|
||||
131
tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py
Normal file
131
tensorrt_llm/_torch/models/checkpoints/mistral/weight_mapper.py
Normal file
@ -0,0 +1,131 @@
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_mapper
|
||||
|
||||
|
||||
@register_mapper("mistral", "MistralForCausalLM")
|
||||
@register_mapper("mistral", "PixtralForConditionalGeneration")
|
||||
class MistralWeightMapper(HfWeightMapper):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self._callbacks.append(self._permute_qk)
|
||||
|
||||
self.pixtral_mapping = {
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"w_in": "linear_1",
|
||||
"w_out": "linear_2",
|
||||
}
|
||||
|
||||
self.mistral_llm_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale_inv",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"q_fake_quantizer.qscale_act": "attn.q_scale",
|
||||
"k_fake_quantizer.qscale_act": "k_scale",
|
||||
"v_fake_quantizer.qscale_act": "v_scale",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm",
|
||||
# For Eagle3
|
||||
"language_model.eagle_linear": "model.fc",
|
||||
"language_model.layers": "layers",
|
||||
"language_model.norm": "norm",
|
||||
}
|
||||
self.mistral_llm_mapping.update(self.pixtral_mapping)
|
||||
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657
|
||||
def rename_by_params_map(self, params_map: dict[str, str], weights: dict) -> dict:
|
||||
renamed_weights = {}
|
||||
|
||||
for key in list(weights.keys()):
|
||||
new_key = key
|
||||
modules = key.split(".")
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = f"{item}.{next_item}" if next_item is not None else None
|
||||
|
||||
if combined_item in params_map:
|
||||
new_key = new_key.replace(combined_item, params_map[combined_item])
|
||||
elif item in params_map:
|
||||
new_key = new_key.replace(item, params_map[item])
|
||||
|
||||
renamed_weights[new_key] = weights[key]
|
||||
|
||||
return renamed_weights
|
||||
|
||||
def _permute_qk(self, module: nn.Module, new_name: str, weights: dict):
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/883b42896a9ed9791750d721fad26005b7569eba/vllm/model_executor/models/llama.py#L657
|
||||
|
||||
processed_weights = {}
|
||||
config = self.config.pretrained_config
|
||||
|
||||
def permute(w, n_heads: int, attn_out: int):
|
||||
attn_in = config.head_dim * n_heads
|
||||
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
|
||||
if new_name in ["k_proj", "q_proj"]:
|
||||
n_heads = (
|
||||
config.num_key_value_heads if new_name == "k_proj" else config.num_attention_heads
|
||||
)
|
||||
|
||||
processed_weights["weight"] = permute(weights["weight"], n_heads, config.hidden_size)
|
||||
|
||||
if "qscale_weight" in weights and weights["qscale_weight"].numel() > 1:
|
||||
processed_weights["qscale_weight"] = permute(weights["qscale_weight"], n_heads, 1)
|
||||
|
||||
return processed_weights
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
@register_mapper("mistral_large_3")
|
||||
@register_mapper("mistral_large_3", "PixtralForConditionalGeneration")
|
||||
@register_mapper("mistral_large_3", "MistralLarge3ForCausalLM")
|
||||
class MistralLarge3WeightMapper(MistralWeightMapper):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.mistral_llm_mapping.update(
|
||||
{
|
||||
"wkv_a_with_mqa": "kv_a_proj_with_mqa",
|
||||
"wkv_b": "kv_b_proj",
|
||||
"wq_a": "q_a_proj",
|
||||
"q_a_norm": "q_a_layernorm",
|
||||
"wq_b": "q_b_proj",
|
||||
"kv_a_norm": "kv_a_layernorm",
|
||||
"k_fake_quantizer.qscale_act": "mla_attn.mla_attn.k_scale",
|
||||
"q_fake_quantizer.qscale_act": "mla_attn.mla_attn.q_scale",
|
||||
"v_fake_quantizer.qscale_act": "mla_attn.mla_attn.v_scale",
|
||||
"gate": "mlp.gate",
|
||||
"shared_experts": "mlp.shared_experts",
|
||||
"experts": "mlp.experts",
|
||||
"router_biases": "mlp.gate.e_score_correction_bias",
|
||||
}
|
||||
)
|
||||
@ -746,17 +746,19 @@ class Deepseekv3MoE(nn.Module):
|
||||
config = model_config.pretrained_config
|
||||
self.top_k = top_k
|
||||
self.use_dp = model_config.mapping.enable_attention_dp
|
||||
self.gate = DeepseekV3Gate(
|
||||
hidden_size,
|
||||
num_experts,
|
||||
top_k=top_k,
|
||||
n_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
dtype=dtype,
|
||||
fuse_routing_kernel=True,
|
||||
apply_routing=False,
|
||||
moe_backend=model_config.moe_backend)
|
||||
gate_cls = DeepseekV3Gate
|
||||
if hasattr(model_config.pretrained_config, "gate_cls"):
|
||||
gate_cls = model_config.pretrained_config.gate_cls
|
||||
self.gate = gate_cls(hidden_size,
|
||||
num_experts,
|
||||
top_k=top_k,
|
||||
n_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
routed_scaling_factor=config.routed_scaling_factor,
|
||||
dtype=dtype,
|
||||
fuse_routing_kernel=True,
|
||||
apply_routing=False,
|
||||
moe_backend=model_config.moe_backend)
|
||||
self.experts = create_moe(
|
||||
num_experts=num_experts,
|
||||
routing_method=self.gate.routing_method,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
@ -14,11 +15,16 @@ from tensorrt_llm._torch.attention_backend.interface import (
|
||||
PositionalEmbeddingParams, RopeParams)
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models import modeling_pixtral
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import \
|
||||
MistralWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
|
||||
Mistral3Gate, MistralLarge3ForCausalLM)
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
|
||||
find_input_mm_embeds, fuse_input_embeds, get_multimodal_embeddings)
|
||||
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
_load_weights_impl,
|
||||
filter_weights,
|
||||
register_auto_model)
|
||||
from tensorrt_llm._torch.modules.attention import Attention
|
||||
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
||||
@ -52,7 +58,7 @@ class MistralAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[MistralConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
layer_idx: int | None = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(
|
||||
@ -111,8 +117,8 @@ class MistralDecoderLayer(DecoderLayer):
|
||||
position_ids: torch.IntTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
spec_metadata: SpecMetadata | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
@ -169,11 +175,11 @@ class MistralModel(DecoderModel):
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.IntTensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
lora_params: Optional[Any] = None,
|
||||
input_ids: torch.IntTensor | None = None,
|
||||
position_ids: torch.IntTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
spec_metadata: SpecMetadata | None = None,
|
||||
lora_params: Any | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
@ -222,7 +228,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
self,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: Optional[AutoTokenizer],
|
||||
tokenizer: AutoTokenizer | None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@ -264,9 +270,11 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
) -> Tuple[List[int], ExtraProcessedInputs | None]:
|
||||
images = inputs.get("multi_modal_data", {}).get("image")
|
||||
do_rescale = self.processor.image_processor.do_rescale
|
||||
mm_processor_kwargs = inputs.get("mm_processor_kwargs", {})
|
||||
do_rescale = getattr(self.processor.image_processor, "do_rescale",
|
||||
False)
|
||||
if images is not None and isinstance(images[0], torch.Tensor):
|
||||
# The default multimodal input loader will normalize images to [0, 1] when the requested
|
||||
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
|
||||
@ -276,6 +284,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
text=inputs["prompt"],
|
||||
images=images,
|
||||
do_rescale=do_rescale,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
input_ids = processed.pop("input_ids").tolist()[0]
|
||||
# Remaining in `processed`:
|
||||
@ -331,6 +340,7 @@ class Mistral3InputProcessor(BaseMultimodalInputProcessor,
|
||||
|
||||
|
||||
@register_auto_model("Mistral3ForConditionalGeneration")
|
||||
@register_auto_model("PixtralForConditionalGeneration")
|
||||
@register_input_processor(
|
||||
Mistral3InputProcessor,
|
||||
model_type="mistral3",
|
||||
@ -365,34 +375,48 @@ class Mistral3VLM(PreTrainedModel):
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(config)
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
llm_model_config = self._get_sub_model_config(model_config,
|
||||
"text_config")
|
||||
# This is necessary for the auto weight mapper to figure out what it needs.
|
||||
llm_model_config.pretrained_config.architectures = config.architectures
|
||||
self.llm = MistralForCausalLM(llm_model_config)
|
||||
|
||||
self._device = "cuda"
|
||||
# NOTE: current `modelopt` does not support quantizing the vision portion.
|
||||
vision_model_config = self._get_sub_model_config(model_config,
|
||||
"vision_config",
|
||||
quant_config=None)
|
||||
self._vision_tower = modeling_pixtral.PixtralVisionModel(
|
||||
vision_model_config)
|
||||
self._multi_modal_projector = Mistral3MultiModalProjector(model_config)
|
||||
vision_feature_layer = config.vision_feature_layer
|
||||
vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||
if vision_feature_layer != -1:
|
||||
raise ValueError(
|
||||
f"Using intermediate layers ({vision_feature_layer}) in the `PixtralVisionModel` "
|
||||
f"is not supported. Please use `vision_feature_layer=-1`.")
|
||||
|
||||
self._device = "cuda"
|
||||
self.model_dtype = getattr(config, "torch_dtype", torch.bfloat16)
|
||||
|
||||
self._image_token_ids = torch.tensor([config.image_token_index],
|
||||
image_token_index = getattr(
|
||||
config, "image_token_index", None) or getattr(
|
||||
config.vision_config, "image_token_id", None)
|
||||
self._image_token_ids = torch.tensor([image_token_index],
|
||||
dtype=torch.int32,
|
||||
device=self._device)
|
||||
|
||||
model_config_cp = copy.deepcopy(model_config)
|
||||
|
||||
llm_model_config = self._get_sub_model_config(model_config_cp,
|
||||
"text_config")
|
||||
self.model_config = model_config_cp
|
||||
llm_class = MistralForCausalLM
|
||||
if llm_model_config.pretrained_config.architectures[
|
||||
0] == "MistralLarge3ForCausalLM":
|
||||
llm_class = MistralLarge3ForCausalLM
|
||||
|
||||
llm_model_config.pretrained_config.gate_cls = Mistral3Gate
|
||||
self.llm = llm_class(llm_model_config)
|
||||
self.model_config.extra_attrs.update(llm_model_config.extra_attrs)
|
||||
|
||||
# NOTE: current `modelopt` does not support quantizing the vision portion.
|
||||
# NOTE: attn_backend: Pixtral head size not always divisible by 128
|
||||
vision_model_config = self._get_sub_model_config(model_config_cp,
|
||||
"vision_config",
|
||||
attn_backend="VANILLA",
|
||||
quant_config=None)
|
||||
|
||||
self._vision_tower = modeling_pixtral.PixtralVisionModel(
|
||||
vision_model_config)
|
||||
self._multi_modal_projector = Mistral3MultiModalProjector(
|
||||
model_config).eval().to(self._device)
|
||||
self._post_config()
|
||||
self.is_loaded = True
|
||||
|
||||
# This is necessary because the executor looks at
|
||||
# `model.model_config.pretrained_config.vocab_size`.
|
||||
@ -400,18 +424,39 @@ class Mistral3VLM(PreTrainedModel):
|
||||
self.config = self.llm.config
|
||||
self.model_config.pretrained_config = self.llm.config
|
||||
|
||||
def load_weights(self, weights: Dict, *args, **kwargs):
|
||||
llm_weights = _filter_weights(weights, "language_model.")
|
||||
self.llm.load_weights(llm_weights, *args, **kwargs)
|
||||
def load_weights(self, weights: Dict, weight_mapper=None, *args, **kwargs):
|
||||
vit_params_map = None
|
||||
if weight_mapper:
|
||||
if isinstance(weight_mapper, MistralWeightMapper):
|
||||
vit_params_map = weight_mapper.pixtral_mapping
|
||||
|
||||
vit_weights = _filter_weights(weights, "vision_tower.")
|
||||
self._vision_tower.load_weights(vit_weights, *args, **kwargs)
|
||||
llm_weights = filter_weights(weights=weights, prefix="language_model")
|
||||
logger.debug(f"Loading weights for {type(self.llm)}")
|
||||
self.llm.load_weights(llm_weights)
|
||||
logger.debug(f"Successfully loaded weights for {type(self.llm)}")
|
||||
|
||||
mm_projector_weights = _filter_weights(weights,
|
||||
"multi_modal_projector.")
|
||||
# `_load_weights_impl` assumes `config.hidden_size` exists, which is not the case for the
|
||||
# top-level `Mistral3Config`.
|
||||
vit_weights = filter_weights(weights=weights, prefix="vision_tower")
|
||||
logger.debug(f"Loading weights for {type(self._vision_tower)}")
|
||||
|
||||
if vit_params_map is not None:
|
||||
vit_weights = weight_mapper.rename_by_params_map(
|
||||
weights=vit_weights, params_map=vit_params_map)
|
||||
|
||||
self._vision_tower.load_weights(vit_weights)
|
||||
logger.debug(
|
||||
f"Successfully loaded weights for {type(self._vision_tower)}")
|
||||
|
||||
logger.debug(f"Loading weights for {type(self._multi_modal_projector)}")
|
||||
mm_projector_weights = filter_weights(weights=weights,
|
||||
prefix="multi_modal_projector")
|
||||
|
||||
if vit_params_map is not None:
|
||||
mm_projector_weights = weight_mapper.rename_by_params_map(
|
||||
weights=mm_projector_weights, params_map=vit_params_map)
|
||||
self._multi_modal_projector.load_state_dict(mm_projector_weights)
|
||||
logger.debug(
|
||||
f"Successfully loaded weights for {type(self._multi_modal_projector)}"
|
||||
)
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
return self.llm.infer_max_seq_len()
|
||||
@ -420,9 +465,10 @@ class Mistral3VLM(PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: SpecMetadata | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Forward method."""
|
||||
@ -455,6 +501,7 @@ class Mistral3VLM(PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
return_context_logits=return_context_logits,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -465,16 +512,41 @@ class Mistral3VLM(PreTrainedModel):
|
||||
) -> ModelConfig:
|
||||
# Extract the subconfig from the `transformers` config and shove it into our own
|
||||
# `ModelConfig` class.
|
||||
assert name in [
|
||||
"text_config", "vision_config"
|
||||
], f"Expected subconfig name to be either 'text_config' or 'vision_config'. Got {name} instead."
|
||||
pretrained_config = getattr(model_config.pretrained_config, name)
|
||||
|
||||
sub_model_config: ModelConfig[MistralConfig] = dataclasses.replace(
|
||||
model_config,
|
||||
pretrained_config=getattr(model_config.pretrained_config, name),
|
||||
**changes,
|
||||
)
|
||||
if name == "text_config":
|
||||
sub_model_config._frozen = False
|
||||
sub_model_config.skip_create_weights_in_init = True
|
||||
if not hasattr(
|
||||
sub_model_config.pretrained_config, "architectures"
|
||||
) or sub_model_config.pretrained_config.architectures is None:
|
||||
sub_model_config.pretrained_config.architectures = model_config.pretrained_config.architectures
|
||||
sub_model_config._frozen = True
|
||||
|
||||
# Make sure some fields that are not explicitly included in the sub config, but present
|
||||
# in the top-level config, are replicated.
|
||||
if (hasattr(sub_model_config.pretrained_config, "torch_dtype")
|
||||
and sub_model_config.pretrained_config.torch_dtype is None):
|
||||
sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype
|
||||
sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype or torch.bfloat16
|
||||
|
||||
if name == "vision_config":
|
||||
pretrained_config = sub_model_config.pretrained_config
|
||||
defaults = {
|
||||
"head_dim": pretrained_config.hidden_size //
|
||||
pretrained_config.num_attention_heads,
|
||||
"hidden_act": "silu",
|
||||
}
|
||||
for attr, default in defaults.items():
|
||||
if not hasattr(pretrained_config, attr):
|
||||
setattr(pretrained_config, attr, default)
|
||||
|
||||
return sub_model_config
|
||||
|
||||
@ -572,6 +644,12 @@ class Mistral3VLM(PreTrainedModel):
|
||||
def mm_token_ids(self):
|
||||
return self._image_token_ids
|
||||
|
||||
def load_draft_weights(
|
||||
self,
|
||||
weights: Dict,
|
||||
weight_mapper: MistralWeightMapper | None = None) -> None:
|
||||
self.llm.load_draft_weights(weights, weight_mapper=weight_mapper)
|
||||
|
||||
|
||||
# Original implementation:
|
||||
# https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/mistral3/modeling_mistral3.py#L66
|
||||
@ -586,13 +664,15 @@ class Mistral3PatchMerger(torch.nn.Module):
|
||||
self.config = config
|
||||
|
||||
hidden_size = config.vision_config.hidden_size
|
||||
self._spatial_merge_size = config.spatial_merge_size
|
||||
self._spatial_merge_size = getattr(
|
||||
config, "spatial_merge_size", None) or getattr(
|
||||
config.vision_config, "spatial_merge_size")
|
||||
self._patch_size = config.vision_config.patch_size
|
||||
self.merging_layer = Linear(
|
||||
in_features=hidden_size * self._spatial_merge_size**2,
|
||||
out_features=hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
dtype=config.torch_dtype or model_config.torch_dtype,
|
||||
mapping=model_config.mapping,
|
||||
)
|
||||
|
||||
@ -640,7 +720,7 @@ class Mistral3MultiModalProjector(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
|
||||
dtype = config.torch_dtype
|
||||
dtype = config.torch_dtype or model_config.torch_dtype
|
||||
self.norm = RMSNorm(
|
||||
hidden_size=config.vision_config.hidden_size,
|
||||
# NOTE: the original implementation actually does not look at the config for this value.
|
||||
@ -650,21 +730,21 @@ class Mistral3MultiModalProjector(torch.nn.Module):
|
||||
)
|
||||
self.patch_merger = Mistral3PatchMerger(model_config)
|
||||
# We have hidden_size * the number of vision feature layers
|
||||
num_feature_layers = 1 if isinstance(config.vision_feature_layer,
|
||||
int) else len(
|
||||
config.vision_feature_layer)
|
||||
vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||
num_feature_layers = 1 if isinstance(vision_feature_layer,
|
||||
int) else len(vision_feature_layer)
|
||||
self.linear_1 = Linear(
|
||||
in_features=config.vision_config.hidden_size * num_feature_layers,
|
||||
out_features=config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
bias=getattr(config, "multimodal_projector_bias", None),
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping,
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.act = ACT2FN[getattr(config, "projector_hidden_act", "gelu")]
|
||||
self.linear_2 = Linear(
|
||||
in_features=config.text_config.hidden_size,
|
||||
out_features=config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
bias=getattr(config, "multimodal_projector_bias", None),
|
||||
dtype=dtype,
|
||||
mapping=model_config.mapping,
|
||||
)
|
||||
|
||||
70
tensorrt_llm/_torch/models/modeling_mistral_large3.py
Normal file
70
tensorrt_llm/_torch/models/modeling_mistral_large3.py
Normal file
@ -0,0 +1,70 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import MistralLarge3WeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3ForCausalLM
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_auto_model
|
||||
from tensorrt_llm._torch.modules.fused_moe import RenormalizeNaiveMoeRoutingMethod
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
|
||||
class Mistral3Gate(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((num_experts, hidden_size), dtype=dtype), requires_grad=False
|
||||
)
|
||||
self.top_k = top_k
|
||||
self.dtype = dtype
|
||||
self.routing_method = RenormalizeNaiveMoeRoutingMethod(top_k=self.top_k)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits: torch.Tensor = torch.ops.trtllm.cublas_mm(
|
||||
hidden_states, self.weight.t(), bias=None, out_dtype=self.dtype
|
||||
)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: List[Dict]):
|
||||
assert len(weights) == 1
|
||||
|
||||
self.weight.copy_(weights[0]["weight"][:])
|
||||
|
||||
|
||||
@register_auto_model("MistralLarge3ForCausalLM")
|
||||
class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__(model_config)
|
||||
self.weight_mapper = MistralLarge3WeightMapper()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
assert self.model_config is not None, "self.model_config is required"
|
||||
params_map = self.weight_mapper.mistral_llm_mapping.copy()
|
||||
quantization_weights_map: Dict[str, str] = {}
|
||||
if self.model_config.quant_config.quant_algo == QuantAlgo.NVFP4:
|
||||
quantization_weights_map = {
|
||||
"weight_packed": "weight",
|
||||
"input_global_scale": "input_scale",
|
||||
"weight_global_scale": "weight_scale_2",
|
||||
}
|
||||
elif self.model_config.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
|
||||
quantization_weights_map = {
|
||||
"weight_scale": "weight_scale_inv",
|
||||
}
|
||||
if quantization_weights_map:
|
||||
params_map.update(quantization_weights_map)
|
||||
weights = self.weight_mapper.rename_by_params_map(weights=weights, params_map=params_map)
|
||||
|
||||
super().load_weights(weights)
|
||||
@ -38,14 +38,22 @@ _CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDic
|
||||
|
||||
def load_pretrained_config(model_name_or_path: str,
|
||||
trust_remote_code: bool = False,
|
||||
checkpoint_format: str = None,
|
||||
**kwargs) -> transformers.PretrainedConfig:
|
||||
config_dict, _ = transformers.PretrainedConfig.get_config_dict(
|
||||
model_name_or_path, **kwargs)
|
||||
model_type = config_dict.get("model_type")
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
model_config = config_class.from_pretrained(model_name_or_path,
|
||||
**kwargs)
|
||||
elif checkpoint_format in ("mistral", "mistral_large_3"):
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
|
||||
MistralConfigLoader
|
||||
model_config = getattr(
|
||||
MistralConfigLoader().load(model_name_or_path).pretrained_config,
|
||||
"text_config")
|
||||
else:
|
||||
model_config = transformers.AutoConfig.from_pretrained(
|
||||
model_name_or_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
@ -600,6 +600,12 @@ def create_input_processor(
|
||||
logger.debug(
|
||||
f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back."
|
||||
)
|
||||
elif checkpoint_format in ("mistral", "mistral_large_3"):
|
||||
logger.debug(f"Detected checkpoint_format={checkpoint_format}.")
|
||||
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
|
||||
MistralConfigLoader
|
||||
model_config = MistralConfigLoader().load(model_path_or_dir)
|
||||
config = model_config.pretrained_config
|
||||
else:
|
||||
logger.debug(
|
||||
f"checkpoint_format={checkpoint_format}; skipping HF config load.")
|
||||
|
||||
@ -112,7 +112,8 @@ class OpenAIServer:
|
||||
from tensorrt_llm._torch.pyexecutor.config_utils import \
|
||||
load_pretrained_config
|
||||
self.model_config = load_pretrained_config(hf_tokenizer_path,
|
||||
trust_remote_code=trust_remote_code)
|
||||
trust_remote_code=trust_remote_code,
|
||||
checkpoint_format=getattr(self.llm.args, "checkpoint_format", None))
|
||||
except Exception:
|
||||
logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path)
|
||||
self.model_config = None
|
||||
|
||||
@ -281,3 +281,5 @@ bigcode/starcoder2-7b:
|
||||
- accuracy: 26.5
|
||||
bigcode/starcoder2-15b:
|
||||
- accuracy: 54.5
|
||||
mistral/Mistral-Large-3-675B:
|
||||
- accuracy: 90.83
|
||||
|
||||
@ -340,3 +340,5 @@ mistralai/Mistral-Nemo-12b-Base:
|
||||
- accuracy: 69.66
|
||||
- quant_algo: FP8
|
||||
accuracy: 69.66
|
||||
mistral/Mistral-Large-3-675B:
|
||||
- accuracy: 87.54
|
||||
|
||||
@ -4828,3 +4828,107 @@ class TestLlama3_1_8B_Instruct_RocketKV(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm,
|
||||
sampling_params=sampling_params,
|
||||
extra_evaluator_kwargs=extra_evaluator_kwargs)
|
||||
|
||||
|
||||
class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "mistral/Mistral-Large-3-675B"
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_mpi_world_size(4)
|
||||
@pytest.mark.skip_less_device_memory(183000)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend,eagle3",
|
||||
[
|
||||
(4, 1, 4, False, True, True, "TRTLLM", False),
|
||||
],
|
||||
ids=[
|
||||
"latency_moe_trtllm",
|
||||
],
|
||||
)
|
||||
def test_nvfp4_4gpus(self, tp_size, pp_size, ep_size, attention_dp,
|
||||
cuda_graph, overlap_scheduler, moe_backend, eagle3):
|
||||
|
||||
if moe_backend == "TRTLLM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE TRTLLM backend does not support SM version 120 or 121")
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
|
||||
enable_block_reuse=not eagle3)
|
||||
spec_config = None
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
|
||||
eagle3_one_model=True)
|
||||
with LLM(
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-NVFP4/",
|
||||
checkpoint_format="mistral",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config) as llm:
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skip_less_mpi_world_size(8)
|
||||
@pytest.mark.skip_less_device_memory(183000)
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler,moe_backend,eagle3",
|
||||
[
|
||||
(8, 1, 8, False, True, True, "DEEPGEMM", False),
|
||||
],
|
||||
ids=[
|
||||
"latency_moe_deepgemm",
|
||||
],
|
||||
)
|
||||
def test_fp8(self, tp_size, pp_size, ep_size, attention_dp, cuda_graph,
|
||||
overlap_scheduler, moe_backend, eagle3):
|
||||
|
||||
if moe_backend == "DEEPGEMM" and (get_sm_version() == 120
|
||||
or get_sm_version() == 121):
|
||||
pytest.skip(
|
||||
"MOE DEEPGEMM backend does not support SM version 120 or 121")
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
moe_config=MoeConfig(backend=moe_backend))
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
|
||||
enable_block_reuse=not eagle3)
|
||||
spec_config = None
|
||||
if eagle3:
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=2,
|
||||
speculative_model_dir=
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512-Eagle/",
|
||||
eagle3_one_model=True)
|
||||
with LLM(
|
||||
f"{llm_models_root()}/Mistral-Large-3-675B/Mistral-Large-3-675B-Instruct-2512/",
|
||||
checkpoint_format="mistral",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
moe_expert_parallel_size=ep_size,
|
||||
**pytorch_config,
|
||||
enable_attention_dp=attention_dp,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config) as llm:
|
||||
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -109,6 +109,7 @@ l0_dgx_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_fp8kv] TIMEOUT (180)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[latency] TIMEOUT (180)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus_chunked_prefill[baseline_fp8kv] TIMEOUT (180)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (90)
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
|
||||
@ -91,3 +91,4 @@ l0_gb200_multi_gpus:
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[enable_configurable_moe-dp4-trtllm-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus_online_eplb[enable_configurable_moe-fp8]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4_4gpus[latency_moe_trtllm_eagle3] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_fp8[latency_moe_deepgemm] TIMEOUT (90)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user