mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Enable AttentionDP on Qwen3-VL and fix test (#10435)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
parent
1c69aad850
commit
7295af68ba
@ -502,8 +502,10 @@ class Qwen2VisionModelBase(nn.Module):
|
||||
|
||||
class Qwen2_5_VLVisionAttention(Attention):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int) -> None:
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int,
|
||||
reduce_output: bool = True) -> None:
|
||||
|
||||
config = model_config.pretrained_config.vision_config
|
||||
super().__init__(
|
||||
@ -518,6 +520,7 @@ class Qwen2_5_VLVisionAttention(Attention):
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
reduce_output=reduce_output,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@ -15,6 +15,7 @@ from transformers.models.qwen3_vl.modeling_qwen3_vl import (
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..._utils import nvtx_range, nvtx_range_debug
|
||||
from ...inputs import (
|
||||
@ -439,7 +440,13 @@ class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention):
|
||||
model_config.pretrained_config.vision_config.torch_dtype = (
|
||||
model_config.pretrained_config.text_config.dtype
|
||||
)
|
||||
super().__init__(model_config, layer_idx)
|
||||
super().__init__(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
reduce_output=(
|
||||
not model_config.mapping.enable_attention_dp and model_config.mapping.tp_size > 1
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLVisionMLP(MLP):
|
||||
@ -453,12 +460,14 @@ class Qwen3VLVisionMLP(MLP):
|
||||
dtype=model_config.pretrained_config.text_config.dtype,
|
||||
config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLVisionBlock(torch.nn.Module):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
config = model_config.pretrained_config.vision_config
|
||||
|
||||
self.norm1 = LayerNorm(
|
||||
@ -510,11 +519,29 @@ class Qwen3VLVisionPatchMerger(torch.nn.Module):
|
||||
eps=model_config.pretrained_config.text_config.rms_norm_eps,
|
||||
dtype=model_config.pretrained_config.text_config.dtype,
|
||||
)
|
||||
|
||||
self.mapping = model_config.mapping
|
||||
overridden_tp_size = 1 if model_config.mapping.enable_attention_dp else None
|
||||
if overridden_tp_size is not None:
|
||||
assert self.mapping.tp_size % overridden_tp_size == 0
|
||||
tp_size = overridden_tp_size
|
||||
# "Misuse" pp_size here to perform all-reduce within smaller groups
|
||||
pp_size = self.mapping.pp_size * self.mapping.tp_size // overridden_tp_size
|
||||
mapping = Mapping(
|
||||
world_size=tp_size * pp_size,
|
||||
rank=self.mapping.rank,
|
||||
gpus_per_node=self.mapping.gpus_per_node,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
)
|
||||
else:
|
||||
mapping = self.mapping
|
||||
|
||||
self.linear_fc1 = Linear(
|
||||
in_features=self.hidden_size,
|
||||
out_features=self.hidden_size,
|
||||
bias=True,
|
||||
mapping=model_config.mapping,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
allreduce_strategy=model_config.allreduce_strategy,
|
||||
)
|
||||
@ -523,7 +550,7 @@ class Qwen3VLVisionPatchMerger(torch.nn.Module):
|
||||
in_features=self.hidden_size,
|
||||
out_features=config.out_hidden_size,
|
||||
bias=True,
|
||||
mapping=model_config.mapping,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
allreduce_strategy=model_config.allreduce_strategy,
|
||||
)
|
||||
@ -705,8 +732,8 @@ class Qwen3VisionModel(torch.nn.Module):
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist()
|
||||
attn_metadata = self.prepare_attn_metadata(seq_lens, self.attn_metadata)
|
||||
|
||||
@ -714,7 +741,7 @@ class Qwen3VisionModel(torch.nn.Module):
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# From this point, pure GPU operation
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
seq_len, _ = hidden_states.size()
|
||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ from typing import Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..model_config import ModelConfig
|
||||
from ..peft.lora.layer import LoraLayer, LoraModuleType
|
||||
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
||||
@ -20,7 +22,8 @@ class MLP(nn.Module):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
config: Optional[ModelConfig] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
reduce_output: bool = True):
|
||||
reduce_output: bool = True,
|
||||
overridden_tp_size: Optional[int] = None):
|
||||
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
@ -29,6 +32,22 @@ class MLP(nn.Module):
|
||||
self.activation = activation
|
||||
|
||||
config = config or ModelConfig()
|
||||
self.mapping = config.mapping
|
||||
if overridden_tp_size is not None:
|
||||
assert config.mapping.tp_size % overridden_tp_size == 0
|
||||
tp_size = overridden_tp_size
|
||||
# "Misuse" pp_size here to perform all-reduce within smaller groups
|
||||
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
|
||||
mapping = Mapping(
|
||||
world_size=tp_size * pp_size,
|
||||
rank=self.mapping.rank,
|
||||
gpus_per_node=self.mapping.gpus_per_node,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size,
|
||||
)
|
||||
else:
|
||||
mapping = config.mapping
|
||||
|
||||
self.up_lora = LoraLayer(
|
||||
[LoraModuleType.MLP_H_TO_4H],
|
||||
[self.intermediate_size // config.mapping.tp_size])
|
||||
@ -38,7 +57,7 @@ class MLP(nn.Module):
|
||||
self.intermediate_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=config.mapping,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.VANILLA),
|
||||
@ -55,7 +74,7 @@ class MLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
mapping=config.mapping,
|
||||
mapping=mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
|
||||
@ -27,3 +27,5 @@ Qwen/Qwen3-VL-30B-A3B-Instruct:
|
||||
mistral/Mistral-Large-3-675B:
|
||||
# Mistral Large 3 675B only supports single image input, so accuracy is lower.
|
||||
- accuracy: 47
|
||||
Qwen/Qwen3-VL-8B-Instruct:
|
||||
- accuracy: 55.11
|
||||
|
||||
@ -327,3 +327,21 @@ class TestMistralLarge3_675B(LlmapiAccuracyTestHarness):
|
||||
) as llm:
|
||||
task = MMMU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=self.sampling_params)
|
||||
|
||||
|
||||
class TestQwen3VL(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen3-VL-8B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-VL-8B-Instruct"
|
||||
MAX_NUM_TOKENS = 16384
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>"
|
||||
)
|
||||
|
||||
def test_auto_dtype(self):
|
||||
with LLM(
|
||||
self.MODEL_PATH,
|
||||
max_num_tokens=self.MAX_NUM_TOKENS,
|
||||
) as llm:
|
||||
task = MMMU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=self.sampling_params)
|
||||
|
||||
@ -14,6 +14,7 @@ l0_l40s:
|
||||
backend: pytorch
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
# Multimodal modeling tests
|
||||
- unittest/_torch/modeling -k "modeling_mllama"
|
||||
- unittest/_torch/modeling -k "modeling_siglip"
|
||||
- unittest/_torch/modeling -k "modeling_vila"
|
||||
@ -22,6 +23,7 @@ l0_l40s:
|
||||
- unittest/_torch/modeling/test_modeling_llava_next.py::TestLlavaNext::test_all
|
||||
- unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all
|
||||
- unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all
|
||||
- unittest/_torch/modeling/test_modeling_qwen3vl.py::TestQwen3VL::test_all
|
||||
- test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image]
|
||||
|
||||
@ -237,12 +237,6 @@ class TestQwen3VL(TestModelingMultimodal):
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
# ==== Disable fuse rope scenarios ====
|
||||
# TestQwen3VLScenario(modality="image",
|
||||
# use_cuda_graph=False,
|
||||
# disable_fuse_rope=True,
|
||||
# chunked_prefill=False,
|
||||
# kv_cache_reuse=False),
|
||||
# ==== Chunked Prefill Scenarios ====
|
||||
TestQwen3VLScenario(
|
||||
modality="image",
|
||||
@ -259,6 +253,14 @@ class TestQwen3VL(TestModelingMultimodal):
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=True,
|
||||
),
|
||||
# ==== Disable fuse rope scenarios ====
|
||||
TestQwen3VLScenario(
|
||||
modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=True,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
]
|
||||
return scenarios
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user