[None][fix] Enable AttentionDP on Qwen3-VL and fix test (#10435)

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
Yechan Kim 2026-01-10 00:13:26 +09:00 committed by GitHub
parent 1c69aad850
commit 7295af68ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 90 additions and 17 deletions

View File

@ -502,8 +502,10 @@ class Qwen2VisionModelBase(nn.Module):
class Qwen2_5_VLVisionAttention(Attention):
def __init__(self, model_config: ModelConfig[PretrainedConfig],
layer_idx: int) -> None:
def __init__(self,
model_config: ModelConfig[PretrainedConfig],
layer_idx: int,
reduce_output: bool = True) -> None:
config = model_config.pretrained_config.vision_config
super().__init__(
@ -518,6 +520,7 @@ class Qwen2_5_VLVisionAttention(Attention):
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
reduce_output=reduce_output,
)
def forward(

View File

@ -15,6 +15,7 @@ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.mapping import Mapping
from ..._utils import nvtx_range, nvtx_range_debug
from ...inputs import (
@ -439,7 +440,13 @@ class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention):
model_config.pretrained_config.vision_config.torch_dtype = (
model_config.pretrained_config.text_config.dtype
)
super().__init__(model_config, layer_idx)
super().__init__(
model_config,
layer_idx=layer_idx,
reduce_output=(
not model_config.mapping.enable_attention_dp and model_config.mapping.tp_size > 1
),
)
class Qwen3VLVisionMLP(MLP):
@ -453,12 +460,14 @@ class Qwen3VLVisionMLP(MLP):
dtype=model_config.pretrained_config.text_config.dtype,
config=model_config,
layer_idx=layer_idx,
overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None,
)
class Qwen3VLVisionBlock(torch.nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config.vision_config
self.norm1 = LayerNorm(
@ -510,11 +519,29 @@ class Qwen3VLVisionPatchMerger(torch.nn.Module):
eps=model_config.pretrained_config.text_config.rms_norm_eps,
dtype=model_config.pretrained_config.text_config.dtype,
)
self.mapping = model_config.mapping
overridden_tp_size = 1 if model_config.mapping.enable_attention_dp else None
if overridden_tp_size is not None:
assert self.mapping.tp_size % overridden_tp_size == 0
tp_size = overridden_tp_size
# "Misuse" pp_size here to perform all-reduce within smaller groups
pp_size = self.mapping.pp_size * self.mapping.tp_size // overridden_tp_size
mapping = Mapping(
world_size=tp_size * pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
tp_size=tp_size,
pp_size=pp_size,
)
else:
mapping = self.mapping
self.linear_fc1 = Linear(
in_features=self.hidden_size,
out_features=self.hidden_size,
bias=True,
mapping=model_config.mapping,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
allreduce_strategy=model_config.allreduce_strategy,
)
@ -523,7 +550,7 @@ class Qwen3VLVisionPatchMerger(torch.nn.Module):
in_features=self.hidden_size,
out_features=config.out_hidden_size,
bias=True,
mapping=model_config.mapping,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
allreduce_strategy=model_config.allreduce_strategy,
)
@ -705,8 +732,8 @@ class Qwen3VisionModel(torch.nn.Module):
@torch.inference_mode()
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
) -> torch.Tensor:
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist()
attn_metadata = self.prepare_attn_metadata(seq_lens, self.attn_metadata)
@ -714,7 +741,7 @@ class Qwen3VisionModel(torch.nn.Module):
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# From this point, pure GPU operation
hidden_states = self.patch_embed(hidden_states)
hidden_states = self.patch_embed(pixel_values)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len, -1)

View File

@ -4,6 +4,8 @@ from typing import Optional
import torch
from torch import nn
from tensorrt_llm.mapping import Mapping
from ..model_config import ModelConfig
from ..peft.lora.layer import LoraLayer, LoraModuleType
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
@ -20,7 +22,8 @@ class MLP(nn.Module):
dtype: Optional[torch.dtype] = None,
config: Optional[ModelConfig] = None,
layer_idx: Optional[int] = None,
reduce_output: bool = True):
reduce_output: bool = True,
overridden_tp_size: Optional[int] = None):
super().__init__()
self.layer_idx = layer_idx
@ -29,6 +32,22 @@ class MLP(nn.Module):
self.activation = activation
config = config or ModelConfig()
self.mapping = config.mapping
if overridden_tp_size is not None:
assert config.mapping.tp_size % overridden_tp_size == 0
tp_size = overridden_tp_size
# "Misuse" pp_size here to perform all-reduce within smaller groups
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
mapping = Mapping(
world_size=tp_size * pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node,
tp_size=tp_size,
pp_size=pp_size,
)
else:
mapping = config.mapping
self.up_lora = LoraLayer(
[LoraModuleType.MLP_H_TO_4H],
[self.intermediate_size // config.mapping.tp_size])
@ -38,7 +57,7 @@ class MLP(nn.Module):
self.intermediate_size,
bias=bias,
dtype=dtype,
mapping=config.mapping,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.VANILLA),
@ -55,7 +74,7 @@ class MLP(nn.Module):
self.hidden_size,
bias=bias,
dtype=dtype,
mapping=config.mapping,
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,

View File

@ -27,3 +27,5 @@ Qwen/Qwen3-VL-30B-A3B-Instruct:
mistral/Mistral-Large-3-675B:
# Mistral Large 3 675B only supports single image input, so accuracy is lower.
- accuracy: 47
Qwen/Qwen3-VL-8B-Instruct:
- accuracy: 55.11

View File

@ -327,3 +327,21 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
) as llm:
task = MMMU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=self.sampling_params)
class TestQwen3VL(LlmapiAccuracyTestHarness):
MODEL_NAME = "Qwen/Qwen3-VL-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-VL-8B-Instruct"
MAX_NUM_TOKENS = 16384
sampling_params = SamplingParams(
max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>"
)
def test_auto_dtype(self):
with LLM(
self.MODEL_PATH,
max_num_tokens=self.MAX_NUM_TOKENS,
) as llm:
task = MMMU(self.MODEL_NAME)
task.evaluate(llm, sampling_params=self.sampling_params)

View File

@ -14,6 +14,7 @@ l0_l40s:
backend: pytorch
tests:
# ------------- PyTorch tests ---------------
# Multimodal modeling tests
- unittest/_torch/modeling -k "modeling_mllama"
- unittest/_torch/modeling -k "modeling_siglip"
- unittest/_torch/modeling -k "modeling_vila"
@ -22,6 +23,7 @@ l0_l40s:
- unittest/_torch/modeling/test_modeling_llava_next.py::TestLlavaNext::test_all
- unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all
- unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all
- unittest/_torch/modeling/test_modeling_qwen3vl.py::TestQwen3VL::test_all
- test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image]

View File

@ -237,12 +237,6 @@ class TestQwen3VL(TestModelingMultimodal):
chunked_prefill=False,
kv_cache_reuse=False,
),
# ==== Disable fuse rope scenarios ====
# TestQwen3VLScenario(modality="image",
# use_cuda_graph=False,
# disable_fuse_rope=True,
# chunked_prefill=False,
# kv_cache_reuse=False),
# ==== Chunked Prefill Scenarios ====
TestQwen3VLScenario(
modality="image",
@ -259,6 +253,14 @@ class TestQwen3VL(TestModelingMultimodal):
chunked_prefill=False,
kv_cache_reuse=True,
),
# ==== Disable fuse rope scenarios ====
TestQwen3VLScenario(
modality="image",
use_cuda_graph=False,
disable_fuse_rope=True,
chunked_prefill=False,
kv_cache_reuse=False,
),
]
return scenarios