[None][feat] support Qwen3-VL dense model in pytorch backend (#9060)

Signed-off-by: Nekofish-L <liuxiangyang@mail.ustc.edu.cn>
This commit is contained in:
Necofish 2025-12-31 16:54:26 +08:00 committed by GitHub
parent 827d12caaf
commit 73870ae4ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 360 additions and 23 deletions

View File

@ -28,6 +28,7 @@ from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
from .modeling_qwen3 import Qwen3ForCausalLM
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
from .modeling_qwen3_next import Qwen3NextForCausalLM
from .modeling_qwen3vl import Qwen3VLModel
from .modeling_qwen3vl_moe import Qwen3MoeVLModel
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_seedoss import SeedOssForCausalLM
@ -76,6 +77,7 @@ __all__ = [
"GptOssForCausalLM",
"SeedOssForCausalLM",
"Glm4MoeForCausalLM",
"Qwen3VLModel",
]
if transformers.__version__ >= "4.45.1":

View File

@ -10,6 +10,7 @@ from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper
from .hf.qwen2vl_weight_mapper import Qwen2VLHfWeightMapper
from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
from .hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper
from .hf.weight_loader import HfWeightLoader
from .hf.weight_mapper import HfWeightMapper
from .mistral.checkpoint_loader import (MistralCheckpointLoader,
@ -19,23 +20,12 @@ from .mistral.weight_mapper import (MistralLarge3WeightMapper,
MistralWeightMapper)
__all__ = [
"HfConfigLoader",
"HfWeightLoader",
"HfWeightMapper",
"MistralConfigLoader",
"MistralWeightMapper",
"MistralCheckpointLoader",
"BaseCheckpointLoader",
"HfCheckpointLoader",
"NemotronHHfWeightMapper",
"Gemma3HfWeightMapper",
"MixtralHfWeightMapper",
"Llama4HfWeightMapper",
"Qwen2MoeHfWeightMapper",
"Qwen3MoeHfWeightMapper",
"Qwen2VLHfWeightMapper",
"Qwen3NextHfWeightMapper",
"LlavaNextHfWeightMapper",
"MistralLarge3CheckpointLoader",
"MistralLarge3WeightMapper",
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper", "MistralConfigLoader",
"MistralWeightMapper", "MistralCheckpointLoader", "BaseCheckpointLoader",
"HfCheckpointLoader", "NemotronHHfWeightMapper", "Gemma3HfWeightMapper",
"MixtralHfWeightMapper", "Llama4HfWeightMapper", "Qwen2MoeHfWeightMapper",
"Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
"Qwen3NextHfWeightMapper", "LlavaNextHfWeightMapper",
"MistralLarge3CheckpointLoader", "MistralLarge3WeightMapper",
"Qwen3VLHfWeightMapper"
]

View File

@ -0,0 +1,8 @@
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper
@register_mapper("HF", "Qwen3VLForConditionalGeneration")
class Qwen3VLHfWeightMapper(HfWeightMapper):
def preprocess_weights(self, weights: dict) -> dict:
return weights

View File

@ -121,6 +121,8 @@ class Qwen3DecoderLayer(DecoderLayer):
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
spec_metadata: Optional[SpecMetadata] = None,
mrope_config: Optional[dict] = None,
deepstack_embeds: Optional[list[torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
@ -137,6 +139,7 @@ class Qwen3DecoderLayer(DecoderLayer):
attn_metadata=attn_metadata,
all_reduce_params=AllReduceParams(
enable_allreduce=not self.disable_allreduce),
mrope_config=mrope_config,
**kwargs,
)
@ -150,6 +153,9 @@ class Qwen3DecoderLayer(DecoderLayer):
enable_allreduce=not self.disable_allreduce),
cutlass_min_latency_mode=False,
)
if deepstack_embeds is not None and self.layer_idx in range(
len(deepstack_embeds)):
residual = residual + deepstack_embeds[self.layer_idx]
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
@ -191,6 +197,9 @@ class Qwen3Model(DecoderModel):
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
mrope_config: Optional[dict] = None,
# args for deepstack
deepstack_embeds: Optional[list[torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -211,8 +220,8 @@ class Qwen3Model(DecoderModel):
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
)
mrope_config=mrope_config,
deepstack_embeds=deepstack_embeds)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

View File

@ -21,7 +21,10 @@ from ...inputs import (
BaseMultimodalDummyInputsBuilder,
BaseMultimodalInputProcessor,
ExtraProcessedInputs,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement,
TextPrompt,
register_input_processor,
)
from ...inputs.multimodal import MultimodalParams
from ...logger import logger
@ -33,6 +36,8 @@ from ..modules.layer_norm import LayerNorm
from ..modules.linear import Linear, TensorParallelMode
from ..modules.mlp import MLP
from ..modules.rotary_embedding import MRotaryEmbedding
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .checkpoints.hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper
from .modeling_auto import AutoModelForCausalLM
from .modeling_multimodal_utils import (
find_input_mm_embeds,
@ -40,7 +45,14 @@ from .modeling_multimodal_utils import (
get_multimodal_embeddings,
)
from .modeling_qwen2vl import Qwen2_5_VLVisionAttention
from .modeling_utils import ModelConfig, QuantConfig, _load_weights_impl, filter_weights
from .modeling_utils import (
ModelConfig,
QuantConfig,
_load_weights_impl,
filter_weights,
register_auto_model,
register_vision_encoder,
)
class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDummyInputsBuilder):
@ -807,7 +819,12 @@ class Qwen3VLModelBase(PreTrainedModel):
llm_model_config = copy.deepcopy(model_config)
llm_model_config.pretrained_config = config.text_config
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
if self.original_arch == "Qwen3VLForConditionalGeneration":
llm_model_config.pretrained_config.architectures = ["Qwen3ForCausalLM"]
elif self.original_arch == "Qwen3VLMoeForConditionalGeneration":
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
else:
raise ValueError(f"Unsupported architecture: {self.original_arch}")
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
if not _is_disagg():
@ -990,3 +1007,42 @@ class Qwen3VLModelBase(PreTrainedModel):
)
logger.debug(f"output shape: {output_prob.shape}")
return output_prob
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
@register_auto_model("Qwen3VLForConditionalGeneration")
@register_input_processor(
Qwen3VLInputProcessorBase,
model_type="qwen3_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|vision_start|><|image_pad|><|vision_end|>",
"video": "<|vision_start|><|video_pad|><|vision_end|>",
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
),
)
class Qwen3VLModel(Qwen3VLModelBase):
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs):
# NOTE: HF implementation.
kwargs["vision_model_class"] = Qwen3VisionModel
kwargs["disable_fuse_rope"] = kwargs.get(
"disable_fuse_rope", False
) # TODO: Make this ModelConfig's argument
super().__init__(model_config, *args, **kwargs)
@property
def multimodal_data_device_paths(self) -> List[str]:
return ["image.pixel_values", "video.pixel_values_videos", "multimodal_embedding"]
def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper):
if not _is_disagg():
self.mm_encoder.load_weights(weights)
weight_mapper = Qwen3VLHfWeightMapper()
weight_mapper.init_model_and_config(self.llm, self.model_config)
filtered_weights = {k: v for k, v in weights.items() if not k.startswith("model.visual.")}
params_map = {
r"^model\.language_model\.(.*)$": r"model.\1",
}
self.llm.load_weights(filtered_weights, weight_mapper, params_map=params_map)

View File

@ -0,0 +1,272 @@
import os
from dataclasses import dataclass
from typing import List
import torch
from _torch.helpers import create_mock_cuda_graph_runner
from test_modeling_multimodal import MultimodalScenario, TestModelingMultimodal
from transformers import Qwen3VLConfig
from transformers import Qwen3VLForConditionalGeneration as HFQwen3VLForConditionalLM
from utils.llm_data import llm_models_root
from tensorrt_llm._torch.models.checkpoints.hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper
from tensorrt_llm._torch.models.modeling_qwen3vl import Qwen3VLModel
QWEN3_VL_8B_CONFIG = {
"architectures": ["Qwen3VLForConditionalGeneration"],
"image_token_id": 151655,
"model_type": "qwen3_vl",
"text_config": {
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 32,
"num_hidden_layers": 4,
# NOTE: Only 4 layers for testing, 36 layers for full model.
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"use_cache": True,
"vocab_size": 151936,
},
"tie_word_embeddings": False,
"transformers_version": "4.57.0.dev0",
"video_token_id": 151656,
"vision_config": {
"deepstack_visual_indexes": [8, 16, 24],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
"vision_end_token_id": 151653,
"vision_start_token_id": 151652,
"_attn_implementation": "flash_attention_2",
"_name_or_path": str(os.path.join(llm_models_root(), "Qwen3", "Qwen3-VL-8B-Instruct")),
}
@dataclass(repr=False)
class TestQwen3VLScenario(MultimodalScenario):
disable_fuse_rope: bool = False
def __repr__(self) -> str:
"""Generate a human-readable string representation of the scenario."""
features = []
features.append(f"modality:{self.modality.lower()}")
if self.use_cuda_graph:
features.append("cuda_graph")
if self.disable_fuse_rope:
features.append("no_fuse_rope")
if self.chunked_prefill:
features.append("chunked_prefill")
if self.kv_cache_reuse:
features.append("kv_cache_reuse")
return "-".join(features)
class TestQwen3VL(TestModelingMultimodal):
def get_model_config(self):
"""Return the model configuration dictionary."""
return QWEN3_VL_8B_CONFIG
def get_trtllm_model_class(self):
return Qwen3VLModel
def get_hf_model_class(self):
return HFQwen3VLForConditionalLM
def get_weight_mapper_class(self):
return Qwen3VLHfWeightMapper
def get_model_type(self):
return "qwen3_vl"
def get_model_config_class(self):
return Qwen3VLConfig
def get_trtllm_inputs(
self,
input_ids,
multimodal_params_list,
is_gen: bool = False,
num_cached_tokens_per_seq: List[int] = None,
):
trtllm_inputs = super().get_trtllm_inputs(
input_ids, multimodal_params_list, is_gen, num_cached_tokens_per_seq
)
if is_gen:
mrope_gen_position_ids = []
for multimodal_param in multimodal_params_list:
mrope_gen_position_ids.append(
multimodal_param.multimodal_data["mrope_config"]["mrope_position_deltas"]
)
mrope_gen_position_ids = torch.cat(mrope_gen_position_ids, dim=-1).to(self.device)
trtllm_inputs["position_ids"] = (
(trtllm_inputs["position_ids"] + mrope_gen_position_ids).expand(3, -1, 1).cuda()
)
gen_multimodal_params_list = []
for multimodal_param in multimodal_params_list:
multimodal_param.strip_for_generation()
multimodal_param.to_device(
"multimodal_data",
self.device,
pin_memory=True,
target_keywords=["mrope_config.mrope_position_deltas"],
)
gen_multimodal_params_list.append(multimodal_param)
trtllm_inputs["multimodal_params"] = gen_multimodal_params_list
else:
# Mrope position ids
mrope_position_ids = []
for multimodal_param in multimodal_params_list:
mrope_position_ids.append(
multimodal_param.multimodal_data["mrope_config"]["mrope_position_ids"]
)
position_ids = torch.cat(mrope_position_ids, dim=-1)
position_ids = position_ids.cuda()
trtllm_inputs["position_ids"] = position_ids
return trtllm_inputs
def init_kv_cache_manager(self, scenario: TestQwen3VLScenario):
"""NOTE: Exactly the same as the parent class method, but with the mrope flag set to True for Qwen3-VL model."""
cache_config = self.get_kv_cache_config(scenario)
tokens_per_block = cache_config["tokens_per_block"]
max_seq_len = cache_config["max_seq_len"]
batch_size = cache_config["batch_size"]
num_blocks = (max_seq_len + tokens_per_block - 1) // tokens_per_block
self.kv_cache_manager = self.get_kv_cache_manager(
dtype=self.model_config.pretrained_config.torch_dtype,
config=self.model_config.pretrained_config,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
batch_size=batch_size,
num_blocks=num_blocks,
)
self.kv_cache_manager.add_dummy_requests(
request_ids=[1],
token_nums=[max_seq_len],
# NOTE: Qwen3-VL model uses mrope
use_mrope=True,
)
def run_trtllm_forward(self, trtllm_inputs, use_cuda_graph: bool = False):
"""NOTE: Exactly the same as the parent class method, but with the mrope flag set to True for Qwen3-VL model."""
if not use_cuda_graph:
trtllm_inputs["attn_metadata"].prepare()
return self.trtllm_model.forward(**trtllm_inputs)
else:
# NOTE: Qwen3-VL model uses mrope
graph_runner = create_mock_cuda_graph_runner(1, True)
trtllm_inputs["attn_metadata"] = trtllm_inputs[
"attn_metadata"
].create_cuda_graph_metadata(1)
# Prepare metadata before capture (like in working Qwen2.5-VL test)
trtllm_inputs["attn_metadata"].prepare()
key = (1, 0, False)
graph_runner.capture(
key=key,
forward_fn=lambda inputs: self.trtllm_model.forward(**inputs),
initial_inputs=trtllm_inputs,
)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated in prepare().
trtllm_inputs["attn_metadata"].prepare()
logits = graph_runner.replay(key=key, current_inputs=trtllm_inputs)
return logits.clone()
def get_scenarios(self) -> List[TestQwen3VLScenario]:
scenarios = [
# ==== Modality Sanity Checks ====
TestQwen3VLScenario(
modality="image",
use_cuda_graph=False,
disable_fuse_rope=False,
chunked_prefill=False,
kv_cache_reuse=False,
),
TestQwen3VLScenario(
modality="video",
use_cuda_graph=False,
disable_fuse_rope=False,
chunked_prefill=False,
kv_cache_reuse=False,
),
TestQwen3VLScenario(
modality="multiple_image",
use_cuda_graph=False,
disable_fuse_rope=False,
chunked_prefill=False,
kv_cache_reuse=False,
),
# ==== CUDA Graph Scenarios ====
TestQwen3VLScenario(
modality="image",
use_cuda_graph=True,
disable_fuse_rope=False,
chunked_prefill=False,
kv_cache_reuse=False,
),
# ==== Disable fuse rope scenarios ====
# TestQwen3VLScenario(modality="image",
# use_cuda_graph=False,
# disable_fuse_rope=True,
# chunked_prefill=False,
# kv_cache_reuse=False),
# ==== Chunked Prefill Scenarios ====
TestQwen3VLScenario(
modality="image",
use_cuda_graph=False,
disable_fuse_rope=False,
chunked_prefill=True,
kv_cache_reuse=False,
),
# ==== KV Cache Reuse Scenarios ====
TestQwen3VLScenario(
modality="image",
use_cuda_graph=False,
disable_fuse_rope=False,
chunked_prefill=False,
kv_cache_reuse=True,
),
]
return scenarios
def setup_scenario(self, scenario: TestQwen3VLScenario):
super().setup_scenario(scenario)
if scenario.disable_fuse_rope:
self.trtllm_model, self.model_config = self.create_trtllm_model(
load_weights=True,
hf_model_state_dict=self.hf_model.state_dict(),
disable_fuse_rope=True,
)