mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-8310][feat] Add Qwen3-VL-MoE (#9689)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
parent
dff77efa2a
commit
8ba8699f66
@ -27,7 +27,7 @@ nvidia-modelopt[torch]~=0.37.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-10.html#rel-25-10 uses 2.27.7
|
||||
nvidia-nccl-cu13==2.27.7
|
||||
nvidia-cuda-nvrtc
|
||||
transformers==4.56.0
|
||||
transformers==4.57.1
|
||||
prometheus_client
|
||||
prometheus_fastapi_instrumentator
|
||||
pydantic>=2.9.1
|
||||
@ -76,3 +76,4 @@ partial_json_parser
|
||||
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
|
||||
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
|
||||
mistral-common==1.8.6
|
||||
torchao>=0.14.1
|
||||
|
||||
@ -578,8 +578,9 @@ class PositionalEmbeddingParams:
|
||||
rope: Optional[RopeParams] = None
|
||||
is_neox: bool = True
|
||||
|
||||
# mRoPE params (currently, Qwen2/2.5-VL uses it)
|
||||
# mRoPE params
|
||||
mrope_section: Optional[List[int]] = None
|
||||
mrope_interleaved: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.type.is_deferred():
|
||||
|
||||
@ -28,6 +28,7 @@ from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
|
||||
from .modeling_qwen3 import Qwen3ForCausalLM
|
||||
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
|
||||
from .modeling_qwen3_next import Qwen3NextForCausalLM
|
||||
from .modeling_qwen3vl_moe import Qwen3MoeVLModel
|
||||
from .modeling_qwen_moe import Qwen2MoeForCausalLM
|
||||
from .modeling_seedoss import SeedOssForCausalLM
|
||||
from .modeling_siglip import SiglipVisionModel
|
||||
@ -71,6 +72,7 @@ __all__ = [
|
||||
"Qwen3ForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Qwen3NextForCausalLM",
|
||||
"Qwen3MoeVLModel",
|
||||
"GptOssForCausalLM",
|
||||
"SeedOssForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
|
||||
@ -0,0 +1,24 @@
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_mapper
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import MoE
|
||||
|
||||
|
||||
@register_mapper("HF", "Qwen3VLMoeForConditionalGeneration")
|
||||
class Qwen3VLMoeHfWeightMapper(Qwen3MoeHfWeightMapper):
|
||||
def handle_special_instance_module(
|
||||
self,
|
||||
module: nn.Module,
|
||||
module_name: str,
|
||||
module_weights: dict,
|
||||
allow_partial_loading: bool = False,
|
||||
) -> None:
|
||||
if isinstance(module, MoE):
|
||||
updated_module_weights = {}
|
||||
for weight_name, weight_value in module_weights.items():
|
||||
new_weight_name = weight_name.replace("scale_inv", "weight_scale")
|
||||
updated_module_weights[new_weight_name] = weight_value
|
||||
module.load_weights(
|
||||
weights=[updated_module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
@ -1,6 +1,5 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
@ -20,7 +19,8 @@ from tensorrt_llm._torch.models.checkpoints.mistral.weight_mapper import \
|
||||
from tensorrt_llm._torch.models.modeling_mistral_large3 import (
|
||||
Mistral3Gate, MistralLarge3ForCausalLM)
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import (
|
||||
find_input_mm_embeds, fuse_input_embeds, get_multimodal_embeddings)
|
||||
_MULTIMODAL_ENV_NAME, _is_disagg, find_input_mm_embeds, fuse_input_embeds,
|
||||
get_multimodal_embeddings)
|
||||
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
_load_weights_impl,
|
||||
@ -45,13 +45,6 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
_MULTIMODAL_ENV_NAME = "TLLM_MULTIMODAL_DISAGGREGATED"
|
||||
|
||||
|
||||
# Make this a runtime lookup rather than a module-wide constant for easier unit testing.
|
||||
def _is_disagg() -> bool:
|
||||
return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1"
|
||||
|
||||
|
||||
class MistralAttention(Attention):
|
||||
|
||||
@ -373,6 +366,7 @@ class Mistral3VLM(PreTrainedModel):
|
||||
)
|
||||
|
||||
config = model_config.pretrained_config
|
||||
self._supports_sdpa = True
|
||||
super().__init__(config)
|
||||
|
||||
vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||
|
||||
@ -17,7 +17,8 @@
|
||||
# and s2wrapper: https://github.com/bfshi/scaling_on_scales
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -29,6 +30,13 @@ from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
_MULTIMODAL_ENV_NAME = "TLLM_MULTIMODAL_DISAGGREGATED"
|
||||
|
||||
|
||||
# Make this a runtime lookup rather than a module-wide constant for easier unit testing.
|
||||
def _is_disagg() -> bool:
|
||||
return os.getenv(_MULTIMODAL_ENV_NAME, "0") == "1"
|
||||
|
||||
|
||||
def _get_uncached_multimodal_params(
|
||||
multimodal_params: List[MultimodalParams], ) -> List[MultimodalParams]:
|
||||
@ -67,17 +75,17 @@ def _cache_multimodal_embeddings(
|
||||
mostly for chunked prefill. It does not persist embeddings across different requests or sessions.
|
||||
"""
|
||||
# TODO: support multiple multimodal modalities per request
|
||||
assert len(
|
||||
embeddings
|
||||
) == 1, "Currently only support single mm_embeds (single modality) per request"
|
||||
if len(embeddings) > 1:
|
||||
raise ValueError("Multiple modalities caching is not supported yet.")
|
||||
mm_embed = embeddings[0]
|
||||
|
||||
# Collect embedding lengths for each parameter
|
||||
embed_lengths = [
|
||||
param.multimodal_runtime.total_mm_tokens_in_request -
|
||||
param.multimodal_runtime.total_special_tokens_in_request
|
||||
for param in multimodal_params if param.multimodal_runtime is not None
|
||||
]
|
||||
embed_lengths = []
|
||||
for param in multimodal_params:
|
||||
if param.multimodal_runtime is not None:
|
||||
embed_lengths.append(
|
||||
param.multimodal_runtime.total_mm_tokens_in_request -
|
||||
param.multimodal_runtime.total_special_tokens_in_request)
|
||||
|
||||
# Validate total length matches
|
||||
total_expected = sum(embed_lengths)
|
||||
@ -103,7 +111,10 @@ def _cache_multimodal_embeddings(
|
||||
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
encoder_forward_fn,
|
||||
encoder_forward_fn: Callable[
|
||||
[List[MultimodalParams]],
|
||||
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]],
|
||||
],
|
||||
multimodal_params: List[MultimodalParams],
|
||||
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
@ -117,12 +128,13 @@ def get_multimodal_embeddings(
|
||||
4. Gather all embeddings for the batch
|
||||
|
||||
Args:
|
||||
encoder_forward_fn: Callable that performs encoder forward pass
|
||||
Should accept List[MultimodalParams] and return List[torch.Tensor]
|
||||
multimodal_params: All multimodal parameters in the batch
|
||||
|
||||
encoder_forward_fn: Callable that performs encoder forward pass.
|
||||
Should accept List[MultimodalParams] and return List[torch.Tensor] or
|
||||
Tuple[List[torch.Tensor], Dict[str, Any]] for models with auxiliary outputs.
|
||||
multimodal_params: All multimodal parameters in the batch.
|
||||
encoder_kwargs: Optional kwargs to pass to encoder_forward_fn.
|
||||
Returns:
|
||||
List of multimodal embeddings for all multimodal params in the batch
|
||||
List of multimodal embeddings for all multimodal params in the batch.
|
||||
"""
|
||||
if not multimodal_params:
|
||||
return []
|
||||
@ -134,12 +146,13 @@ def get_multimodal_embeddings(
|
||||
# Step 2: Run encoder forward only on uncached parameters
|
||||
if uncached_multimodal_params:
|
||||
kwargs = encoder_kwargs or {}
|
||||
encoder_outputs = encoder_forward_fn(uncached_multimodal_params,
|
||||
**kwargs)
|
||||
encoder_embeddings = encoder_forward_fn(uncached_multimodal_params,
|
||||
**kwargs)
|
||||
|
||||
# TODO: support multiple multimodal modalities per request
|
||||
if len(encoder_outputs) > 1:
|
||||
return encoder_outputs
|
||||
if len(encoder_embeddings) > 1:
|
||||
logger.warning("Multiple modalities caching is not supported yet.")
|
||||
return encoder_embeddings
|
||||
|
||||
# Validate that multimodal_runtime has required attributes for caching
|
||||
if (not hasattr(uncached_multimodal_params[0], 'multimodal_runtime')
|
||||
@ -147,13 +160,13 @@ def get_multimodal_embeddings(
|
||||
or uncached_multimodal_params[0].multimodal_runtime.
|
||||
total_mm_tokens_in_request is None):
|
||||
logger.warning(
|
||||
"Multimodal runtime data missing or incomplete - recomputed all embeddings"
|
||||
"Multimodal runtime data missing or incomplete, will not cache embeddings."
|
||||
)
|
||||
return encoder_outputs
|
||||
return encoder_embeddings
|
||||
|
||||
# Step 3: Cache the computed embeddings to multimodal_data["multimodal_embedding"]
|
||||
_cache_multimodal_embeddings(uncached_multimodal_params,
|
||||
encoder_outputs)
|
||||
encoder_embeddings)
|
||||
|
||||
# Step 4: Gather all embeddings for the batch
|
||||
for param in multimodal_params:
|
||||
@ -301,8 +314,12 @@ def fuse_input_embeds(
|
||||
mm_token_ids: Optional[torch.IntTensor] = None,
|
||||
text_token_indices: Optional[torch.IntTensor] = None,
|
||||
mm_token_indices: Optional[torch.IntTensor] = None,
|
||||
extra_embeds: Optional[List[torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[Optional[torch.IntTensor], Optional[torch.FloatTensor]]:
|
||||
# TODO: make unified return type for all models
|
||||
) -> Union[Tuple[Optional[torch.IntTensor], Optional[torch.FloatTensor]],
|
||||
Tuple[Optional[torch.IntTensor], Optional[torch.FloatTensor],
|
||||
Optional[List[torch.FloatTensor]]]]:
|
||||
"""
|
||||
Fuse text and multimodal embeddings. input_ids is [text_total_length + mm_total_length] and mm_embed is [mm_total_length, hidden_dim]. We just need to fuse them into [text_total_length + mm_total_length, hidden_dim] by slice-and-assign to the corresponding entries.
|
||||
|
||||
@ -311,6 +328,7 @@ def fuse_input_embeds(
|
||||
input_ids: shape [text_total_length + mm_total_length], flattened from List[(text_length1 + mm_total_length1), ..., (text_lengthi + mm_total_lengthi)]. For LLM model, the requests are inflight batched together, but the input_ids are flattened with padding removed. By the slice condition < vocab_size, we can easily separate text / multimodal tokens and naturally batched the LLM embedding lookup
|
||||
mm_embeds: List[(mm_total_length1, hidden_dim), ..., (mm_total_lengthi, hidden_dim)].
|
||||
mm_token_ids: possible token ids for multimodal tokens, if known. If not known and set to None, it is assumed that the multimodal tokens are out-of-vocabulary tokens.
|
||||
extra_embeds: Optional list of extra embed tensors for models that support it (e.g., Qwen3-VL/Qwen3-MoE-VL).
|
||||
Returns:
|
||||
- If (1) JIT test run, (2) non-multimodal run, i.e. all text-only requests, either context or generation phase (3) multimodal run, all requests in generation phase --> there is no multimodal data, return only the input_ids
|
||||
- If (4) multimodal run, mixed batch of context and generation requests, each context request has a multimodal feature --> return only the fused input_embeds of shape [total length, hidden_dim]. For text tokens, LLM embedding layer has already run.
|
||||
@ -319,6 +337,8 @@ def fuse_input_embeds(
|
||||
- This function may involve host-device synchronization if indices are not provided and filtering is performed. See filter_mm_token_from_input_ids for details.
|
||||
"""
|
||||
if len(mm_embeds) == 0:
|
||||
if extra_embeds is not None and len(extra_embeds) > 0:
|
||||
return input_ids, None, extra_embeds
|
||||
return input_ids, None
|
||||
|
||||
mm_embed = torch.cat(mm_embeds, dim=0)
|
||||
@ -330,7 +350,6 @@ def fuse_input_embeds(
|
||||
input_ids,
|
||||
vocab_size=embedding_layer.num_embeddings,
|
||||
mm_token_ids=mm_token_ids)
|
||||
|
||||
if mm_token_indices.shape[0] != mm_embed.shape[0]:
|
||||
raise ValueError(
|
||||
f"Multimodal token count mismatch: found {len(mm_token_indices)} image tokens in input_ids "
|
||||
@ -343,11 +362,23 @@ def fuse_input_embeds(
|
||||
mm_embed.shape[-1],
|
||||
device=text_embed.device,
|
||||
dtype=text_embed.dtype)
|
||||
if extra_embeds is not None and len(extra_embeds) > 0:
|
||||
# only support single modality for deepstack features for now
|
||||
for i, extra_feature in enumerate(extra_embeds):
|
||||
extra_embed = torch.zeros(
|
||||
input_ids.shape[0],
|
||||
mm_embed.shape[-1],
|
||||
device=extra_feature.device,
|
||||
dtype=extra_feature.dtype,
|
||||
)
|
||||
extra_embed[mm_token_indices, :] = extra_feature
|
||||
extra_embeds[i] = extra_embed
|
||||
|
||||
input_embeds[text_token_indices, :] = text_embed
|
||||
input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype,
|
||||
device=input_embeds.device)
|
||||
|
||||
if extra_embeds is not None and len(extra_embeds) > 0:
|
||||
return None, cast(torch.FloatTensor, input_embeds), extra_embeds
|
||||
return None, cast(torch.FloatTensor, input_embeds)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -10,8 +9,8 @@ from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
|
||||
PreTrainedModel)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding,
|
||||
Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLMLP,
|
||||
Qwen2_5_VLVisionBlock, apply_rotary_pos_emb_vision)
|
||||
Qwen2_5_VisionTransformerPretrainedModel, Qwen2_5_VLVisionBlock,
|
||||
apply_rotary_pos_emb_vision)
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
|
||||
Qwen2VisionTransformerPretrainedModel
|
||||
|
||||
@ -21,8 +20,9 @@ from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.qwen2vl_weight_mapper import \
|
||||
Qwen2VLHfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg
|
||||
from tensorrt_llm._torch.modules.attention import Attention
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
@ -38,6 +38,7 @@ from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..attention_backend.utils import get_attention_backend
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.rotary_embedding import MRotaryEmbedding
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_multimodal_utils import (find_input_mm_embeds, fuse_input_embeds,
|
||||
@ -46,48 +47,9 @@ from .modeling_utils import (ModelConfig, QuantConfig, _load_weights_impl,
|
||||
filter_weights, register_auto_model,
|
||||
register_vision_encoder)
|
||||
|
||||
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
|
||||
PAD_INDEX = -100 # NOTE: refer to https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L269
|
||||
|
||||
|
||||
def process_weights(weights: Dict,
|
||||
prefix: str = "visual",
|
||||
weight_name_mapping: Dict[str, str] = None) -> Dict:
|
||||
"""
|
||||
Filter and transform weights in a single modular function.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of all model weights
|
||||
prefix: Prefix to filter weights by (default: "visual")
|
||||
weight_name_mapping: Optional mapping to transform weight names
|
||||
|
||||
Returns:
|
||||
Dictionary of processed weights ready for loading
|
||||
"""
|
||||
|
||||
# Filter weights by prefix (handles both direct and "model." prefixed keys)
|
||||
filtered_weights = {}
|
||||
for key, weight in weights.items():
|
||||
if key.startswith(prefix):
|
||||
filtered_weights[key] = weight
|
||||
elif key.startswith("model." + prefix):
|
||||
filtered_weights[key[len("model."):]] = weight
|
||||
|
||||
# Transform weight names if mapping provided
|
||||
if weight_name_mapping:
|
||||
transformed_weights = {}
|
||||
for key, weight in filtered_weights.items():
|
||||
new_key = key
|
||||
for old_suffix, new_suffix in weight_name_mapping.items():
|
||||
if key.endswith(old_suffix):
|
||||
new_key = key.replace(old_suffix, new_suffix)
|
||||
break
|
||||
transformed_weights[new_key] = weight
|
||||
return transformed_weights
|
||||
|
||||
return filtered_weights
|
||||
|
||||
|
||||
class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor,
|
||||
BaseMultimodalDummyInputsBuilder):
|
||||
|
||||
@ -310,7 +272,7 @@ class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor,
|
||||
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
|
||||
def _preprocess(self, text: Dict[str, any], mm_data: Dict[str, any],
|
||||
mm_processor_kwargs: Dict[str, Any]):
|
||||
images = mm_data.get("image")
|
||||
video_datas = mm_data.get("video")
|
||||
@ -323,8 +285,6 @@ class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor,
|
||||
do_rescale = False
|
||||
if videos and isinstance(videos[0][0], torch.Tensor):
|
||||
do_rescale = False
|
||||
# transformers=4.53.1 does not support GPU video tensors in Qwen2VL processor.
|
||||
videos = [[frame.to("cpu") for frame in video] for video in videos]
|
||||
return self.processor(text=[text],
|
||||
images=images,
|
||||
videos=videos,
|
||||
@ -346,7 +306,7 @@ class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor,
|
||||
image_grid_thw: torch.LongTensor,
|
||||
video_grid_thw: torch.LongTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
second_per_grid_ts: torch.Tensor = None) -> dict[str, torch.Tensor]:
|
||||
second_per_grid_ts: torch.Tensor = None) -> Dict[str, torch.Tensor]:
|
||||
mrope_position_ids, mrope_position_deltas = Qwen2VLInputProcessorBase.get_rope_index(
|
||||
self.config, input_ids, image_grid_thw, video_grid_thw,
|
||||
attention_mask, second_per_grid_ts)
|
||||
@ -437,6 +397,10 @@ class Qwen2VisionModelBase(nn.Module):
|
||||
def load_weights(self, weights: Dict):
|
||||
visual_weights = filter_weights("visual", weights)
|
||||
converted_weights = dict()
|
||||
if isinstance(self.visual, (Qwen2VisionTransformerPretrainedModel,
|
||||
Qwen2_5_VisionTransformerPretrainedModel)):
|
||||
self.visual.load_state_dict(visual_weights, strict=True)
|
||||
return
|
||||
|
||||
qkv_pattern = re.compile(r'(.*?)attn\.qkv\.(.*)')
|
||||
for name in visual_weights:
|
||||
@ -559,13 +523,13 @@ class Qwen2_5_VLVisionAttention(Attention):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]],
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# NOTE: Need separate Attention forward() for Qwen2.5-VL for multiple reasons
|
||||
# 1. We don't have the route for handing over position_embeddings to the Attention forward()
|
||||
# 2. Could not override the apply_rope() as we don't have the position_ids in the Vision Attention's rotary embedding.
|
||||
# (TODO: yechank-nvidia) Make OOTO path more modular and reusable for Attention's Rotary Embedding.
|
||||
# (TODO: yechank-nvidia) Make OOTB path more modular and reusable for Attention's Rotary Embedding.
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv, None, None
|
||||
@ -593,10 +557,26 @@ class Qwen2_5_VLVisionAttention(Attention):
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5_VLMLP(GatedMLP):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int):
|
||||
config = model_config.pretrained_config.vision_config
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=True,
|
||||
activation=F.silu,
|
||||
dtype=model_config.pretrained_config.torch_dtype,
|
||||
config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: Optional[int]):
|
||||
layer_idx: int):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config.vision_config
|
||||
self.norm1 = RMSNorm(hidden_size=config.hidden_size,
|
||||
@ -606,14 +586,15 @@ class Qwen2_5_VLVisionBlock(torch.nn.Module):
|
||||
eps=model_config.pretrained_config.rms_norm_eps,
|
||||
dtype=model_config.pretrained_config.torch_dtype)
|
||||
self.attn = Qwen2_5_VLVisionAttention(model_config, layer_idx)
|
||||
self.mlp = Qwen2_5_VLMLP(config, bias=True)
|
||||
self.mlp = Qwen2_5_VLMLP(model_config, layer_idx)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -621,6 +602,7 @@ class Qwen2_5_VLVisionBlock(torch.nn.Module):
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = residual + self.attn(
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
@ -650,21 +632,25 @@ class Qwen2_5_VLPatchMerger(torch.nn.Module):
|
||||
out_features=self.hidden_size,
|
||||
bias=True,
|
||||
dtype=model_config.pretrained_config.torch_dtype,
|
||||
mapping=model_config.mapping),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
allreduce_strategy=model_config.allreduce_strategy),
|
||||
torch.nn.GELU(),
|
||||
Linear(in_features=self.hidden_size,
|
||||
out_features=dim,
|
||||
bias=True,
|
||||
dtype=model_config.pretrained_config.torch_dtype,
|
||||
mapping=model_config.mapping),
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
allreduce_strategy=model_config.allreduce_strategy),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size)
|
||||
x = self.mlp(x)
|
||||
return x
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.ln_q(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2_5_VisionModel(torch.nn.Module):
|
||||
@ -740,7 +726,7 @@ class Qwen2_5_VisionModel(torch.nn.Module):
|
||||
return rotary_pos_emb
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
window_index: List[torch.Tensor] = []
|
||||
seq_lens = []
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
||||
@ -783,13 +769,12 @@ class Qwen2_5_VisionModel(torch.nn.Module):
|
||||
return window_index, seq_lens
|
||||
|
||||
def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata):
|
||||
# NOTE: The single prompt is divided into multiple seq_lens, so pretending have many batch_sizes.
|
||||
batch_size = len(seq_lens)
|
||||
batch_size = 1 # NOTE: Qwen2/2.5-VL concats all the pixel_values into a single tensor, so batch_size is 1
|
||||
prompt_lens = seq_lens
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int, pin_memory=True)
|
||||
request_ids = list(range(1, batch_size + 1))
|
||||
|
||||
attn_metadata.num_contexts = batch_size
|
||||
attn_metadata.num_contexts = len(seq_lens)
|
||||
attn_metadata.request_ids = request_ids
|
||||
attn_metadata.prompt_lens = prompt_lens
|
||||
attn_metadata.seq_lens = seq_lens
|
||||
@ -798,7 +783,7 @@ class Qwen2_5_VisionModel(torch.nn.Module):
|
||||
return attn_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor,
|
||||
def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor,
|
||||
**kwargs) -> torch.Tensor:
|
||||
window_index, window_seq_lens = self.get_window_index(grid_thw)
|
||||
seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
|
||||
@ -814,7 +799,7 @@ class Qwen2_5_VisionModel(torch.nn.Module):
|
||||
window_seq_lens, self.window_attn_metadata)
|
||||
|
||||
# From this point, pure GPU operation
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
seq_len, _ = hidden_states.size()
|
||||
hidden_states = hidden_states.reshape(
|
||||
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
@ -834,7 +819,6 @@ class Qwen2_5_VisionModel(torch.nn.Module):
|
||||
attn_metadata = full_attn_metadata
|
||||
else:
|
||||
attn_metadata = window_attn_metadata
|
||||
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -857,30 +841,27 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
self.original_arch = model_config.pretrained_config.architectures[0]
|
||||
|
||||
# NOTE: Setting disable_fuse_rope to True to do mrope fusion in the model engine by pre-computing rotary_cos_sin in the model engine
|
||||
disabble_fuse_rope = kwargs.get('disable_fuse_rope', False)
|
||||
model_config.pretrained_config.disable_fuse_rope = disabble_fuse_rope
|
||||
disable_fuse_rope = kwargs.get('disable_fuse_rope', False)
|
||||
model_config.pretrained_config.disable_fuse_rope = disable_fuse_rope
|
||||
model_config.pretrained_config.rope_scaling['type'] = 'mrope'
|
||||
config = model_config.pretrained_config
|
||||
|
||||
self._supports_sdpa = True
|
||||
super().__init__(config)
|
||||
|
||||
if not disabble_fuse_rope:
|
||||
self.init_mrope_embedding(model_config)
|
||||
|
||||
self.model_config = model_config
|
||||
self.config = model_config.pretrained_config
|
||||
|
||||
if model_config.attn_backend != 'TRTLLM':
|
||||
raise ValueError("Qwen2/2.5-VL only supports TRTLLM backend now")
|
||||
if not disabble_fuse_rope:
|
||||
if not disable_fuse_rope:
|
||||
self.init_mrope_embedding(model_config)
|
||||
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config.architectures = ["Qwen2ForCausalLM"]
|
||||
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
|
||||
|
||||
if not DISAGG:
|
||||
if not _is_disagg():
|
||||
mm_encoder_config = copy.deepcopy(model_config)
|
||||
self.mm_encoder = Qwen2VisionModelBase(
|
||||
mm_encoder_config, kwargs.get('vision_model_class', None))
|
||||
@ -977,21 +958,28 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embeds = []
|
||||
mrope_config = {}
|
||||
if len(multimodal_params) > 0:
|
||||
if not DISAGG:
|
||||
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, so we need to separate the mm_multimodal_params from the text-only prompts.
|
||||
mm_multimodal_params = [
|
||||
multimodal_param for multimodal_param in multimodal_params
|
||||
if multimodal_param.multimodal_data.get("image", {}).get(
|
||||
"pixel_values") is not None or multimodal_param.multimodal_data.
|
||||
get("video", {}).get("pixel_values_videos") is not None
|
||||
]
|
||||
if len(mm_multimodal_params) > 0:
|
||||
if not _is_disagg():
|
||||
mm_embeds = get_multimodal_embeddings(
|
||||
encoder_forward_fn=self.mm_encoder.forward,
|
||||
multimodal_params=multimodal_params[:num_context_requests])
|
||||
multimodal_params=mm_multimodal_params)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Qwen2VLModel does not support disaggregated inference yet. Please unset "
|
||||
f"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
|
||||
)
|
||||
mm_embeds = find_input_mm_embeds(
|
||||
mm_embeds, multimodal_params[:num_context_requests])
|
||||
if not self.model_config.pretrained_config.disable_fuse_rope:
|
||||
mrope_config = self.prepare_mrope_config(
|
||||
multimodal_params, num_context_requests)
|
||||
mm_embeds = find_input_mm_embeds(mm_embeds, mm_multimodal_params)
|
||||
|
||||
if not self.model_config.pretrained_config.disable_fuse_rope:
|
||||
mrope_config = self.prepare_mrope_config(multimodal_params,
|
||||
num_context_requests)
|
||||
|
||||
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
|
||||
input_ids, mm_embeds,
|
||||
@ -1038,9 +1026,8 @@ class Qwen2VLModel(Qwen2VLModelBase):
|
||||
]
|
||||
|
||||
def load_weights(self, weights, weight_mapper: BaseWeightMapper):
|
||||
if not DISAGG:
|
||||
vision_encoder_weights = process_weights(weights, "visual")
|
||||
self.mm_encoder.load_state_dict(vision_encoder_weights, strict=True)
|
||||
if not _is_disagg():
|
||||
self.mm_encoder.load_weights(weights)
|
||||
|
||||
self.llm.load_weights(weights, weight_mapper)
|
||||
|
||||
@ -1063,8 +1050,9 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
|
||||
**kwargs):
|
||||
kwargs['vision_model_class'] = Qwen2_5_VisionModel
|
||||
kwargs[
|
||||
'disable_fuse_rope'] = False # TODO: Make this ModelConfig's argument
|
||||
kwargs['disable_fuse_rope'] = kwargs.get(
|
||||
'disable_fuse_rope',
|
||||
False) # TODO: Make this ModelConfig's argument
|
||||
super().__init__(model_config, *args, **kwargs)
|
||||
|
||||
@property
|
||||
@ -1078,7 +1066,7 @@ class Qwen2_5_VLModel(Qwen2VLModelBase):
|
||||
if isinstance(weight_mapper, Qwen2VLHfWeightMapper):
|
||||
weights = weight_mapper.preprocess_weights(weights)
|
||||
|
||||
if not DISAGG:
|
||||
if not _is_disagg():
|
||||
self.mm_encoder.load_weights(weights)
|
||||
|
||||
self.llm.load_weights(weights)
|
||||
|
||||
@ -48,7 +48,11 @@ class Qwen3Attention(QKNormRoPEAttention):
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.from_string(pos_type),
|
||||
rope=RopeParams.from_config(config),
|
||||
)
|
||||
mrope_section=config.rope_scaling.get("mrope_section", None),
|
||||
mrope_interleaved=config.rope_scaling.get(
|
||||
"mrope_interleaved", False))
|
||||
if config.rope_scaling.get("mrope_interleaved", False):
|
||||
fuse_qk_norm_rope = False
|
||||
else:
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gpt_neox,
|
||||
@ -64,6 +68,7 @@ class Qwen3Attention(QKNormRoPEAttention):
|
||||
pos_embd_params=pos_embd_params,
|
||||
fuse_qk_norm_rope=fuse_qk_norm_rope,
|
||||
layer_idx=layer_idx,
|
||||
rope_fusion=not getattr(config, 'disable_fuse_rope', False),
|
||||
dtype=config.torch_dtype,
|
||||
dense_bias=getattr(config, "attention_bias", None),
|
||||
config=model_config,
|
||||
|
||||
@ -18,7 +18,7 @@ from ..modules.fused_moe import (BaseMoeRoutingMethod, CutlassFusedMoE,
|
||||
RenormalizeNaiveMoeRoutingMethod,
|
||||
RoutingMethodType, TRTLLMGenFusedMoE,
|
||||
create_moe, get_moe_cls)
|
||||
from ..modules.fused_moe.interface import MoE
|
||||
from ..modules.fused_moe.interface import MoE, MoEWeightLoadingMode
|
||||
from ..modules.linear import TensorParallelMode
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..speculative import SpecMetadata
|
||||
@ -114,6 +114,7 @@ class Qwen3MoE(nn.Module):
|
||||
moe_backend_cls=get_moe_cls(model_config),
|
||||
)
|
||||
|
||||
self.weight_loading_mode = MoEWeightLoadingMode.FUSED_GATE_UP_PROJ if config.model_type == "qwen3_vl_moe_text" else MoEWeightLoadingMode.VANILLA
|
||||
self.experts = create_moe(
|
||||
num_experts=self.num_experts,
|
||||
routing_method=self.gate.routing_method,
|
||||
@ -124,6 +125,7 @@ class Qwen3MoE(nn.Module):
|
||||
reduce_results=False,
|
||||
model_config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
weight_loading_mode=self.weight_loading_mode,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -221,6 +223,8 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
mrope_config: Optional[Dict[str, torch.Tensor]] = None,
|
||||
deepstack_embeds: Optional[List[torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
@ -236,6 +240,7 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
attn_metadata=attn_metadata,
|
||||
all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not self.disable_attn_allreduce),
|
||||
mrope_config=mrope_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -269,6 +274,10 @@ class Qwen3MoEDecoderLayer(DecoderLayer):
|
||||
do_finalize=do_finalize,
|
||||
)
|
||||
|
||||
if deepstack_embeds is not None and self.layer_idx in range(
|
||||
len(deepstack_embeds)):
|
||||
residual = residual + deepstack_embeds[self.layer_idx]
|
||||
|
||||
if self.fusion_config.POST_MOE_FUSION:
|
||||
if do_finalize:
|
||||
hidden_states, residual = self.allreduce(
|
||||
@ -365,6 +374,8 @@ class Qwen3MoEModel(DecoderModel):
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
mrope_config: Optional[Dict[str, torch.Tensor]] = None,
|
||||
deepstack_embeds: Optional[List[torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -379,11 +390,14 @@ class Qwen3MoEModel(DecoderModel):
|
||||
|
||||
residual = None
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states, residual = decoder_layer(position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata)
|
||||
hidden_states, residual = decoder_layer(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata,
|
||||
mrope_config=mrope_config,
|
||||
deepstack_embeds=deepstack_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
@ -320,9 +319,6 @@ class Qwen3NextConfig(PretrainedConfig):
|
||||
self.mlp_only_layers = mlp_only_layers
|
||||
|
||||
|
||||
AutoConfig.register("qwen3_next", Qwen3NextConfig)
|
||||
|
||||
|
||||
class Qwen3NextGate(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
992
tensorrt_llm/_torch/models/modeling_qwen3vl.py
Normal file
992
tensorrt_llm/_torch/models/modeling_qwen3vl.py
Normal file
@ -0,0 +1,992 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN as HF_ACT2FN
|
||||
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
|
||||
Qwen3VLVisionPatchEmbed as HFQwen3VLVisionPatchEmbed,
|
||||
)
|
||||
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
|
||||
Qwen3VLVisionRotaryEmbedding as HFQwen3VLVisionRotaryEmbedding,
|
||||
)
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..._utils import nvtx_range, nvtx_range_debug
|
||||
from ...inputs import (
|
||||
BaseMultimodalDummyInputsBuilder,
|
||||
BaseMultimodalInputProcessor,
|
||||
ExtraProcessedInputs,
|
||||
TextPrompt,
|
||||
)
|
||||
from ...inputs.multimodal import MultimodalParams
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..attention_backend.utils import get_attention_backend
|
||||
from ..modules.layer_norm import LayerNorm
|
||||
from ..modules.linear import Linear, TensorParallelMode
|
||||
from ..modules.mlp import MLP
|
||||
from ..modules.rotary_embedding import MRotaryEmbedding
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_multimodal_utils import (
|
||||
find_input_mm_embeds,
|
||||
fuse_input_embeds,
|
||||
get_multimodal_embeddings,
|
||||
)
|
||||
from .modeling_qwen2vl import Qwen2_5_VLVisionAttention
|
||||
from .modeling_utils import ModelConfig, QuantConfig, _load_weights_impl, filter_weights
|
||||
|
||||
|
||||
class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDummyInputsBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
trust_remote_code: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
self._dtype = self.config.text_config.dtype
|
||||
self._tokenizer = (
|
||||
tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(model_path)
|
||||
)
|
||||
self._model_path = model_path
|
||||
self._processor = AutoProcessor.from_pretrained(
|
||||
model_path, use_fast=True, trust_remote_code=trust_remote_code
|
||||
)
|
||||
self.tllm_multimodal_token_id = self.get_vocab_size() + 1
|
||||
# temporal patch size for video frames
|
||||
self.temporal_patch_size = getattr(self.config.vision_config, "temporal_patch_size", 1)
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> AutoTokenizer:
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def model_path(self) -> str:
|
||||
return self._model_path
|
||||
|
||||
@property
|
||||
def processor(self) -> AutoProcessor:
|
||||
return self._processor
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._dtype
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
"""Return the vocab size of the model."""
|
||||
return self.config.text_config.vocab_size
|
||||
|
||||
@classmethod
|
||||
def get_rope_index(
|
||||
cls,
|
||||
model_config: PretrainedConfig,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
|
||||
|
||||
# Since we use timestamps to separate videos, like <t1> <vision_start> <frame1> <vision_end> <t2>
|
||||
# <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
|
||||
if video_grid_thw is not None:
|
||||
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
|
||||
video_grid_thw[:, 0] = 1
|
||||
|
||||
spatial_merge_size = model_config.vision_config.spatial_merge_size
|
||||
image_token_id = model_config.image_token_id
|
||||
video_token_id = model_config.video_token_id
|
||||
vision_start_token_id = model_config.vision_start_token_id
|
||||
mrope_position_deltas = []
|
||||
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
||||
total_input_ids = input_ids
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(total_input_ids)
|
||||
position_ids = torch.ones(
|
||||
3,
|
||||
input_ids.shape[0],
|
||||
input_ids.shape[1],
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
image_index, video_index = 0, 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
image_nums, video_nums = 0, 0
|
||||
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
input_tokens = input_ids.tolist()
|
||||
llm_pos_ids_list: list = []
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t.item(),
|
||||
h.item() // spatial_merge_size,
|
||||
w.item() // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
# t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode
|
||||
# the temporal information for videos)
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
||||
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
||||
mrope_position_deltas = torch.tensor(
|
||||
mrope_position_deltas, device=input_ids.device
|
||||
).unsqueeze(1)
|
||||
return position_ids, mrope_position_deltas
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
||||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
||||
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
||||
else:
|
||||
position_ids = (
|
||||
torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
.view(1, 1, -1)
|
||||
.expand(3, input_ids.shape[0], -1)
|
||||
)
|
||||
mrope_position_deltas = torch.zeros(
|
||||
[input_ids.shape[0], 1],
|
||||
device=input_ids.device,
|
||||
dtype=input_ids.dtype,
|
||||
)
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
def _preprocess(
|
||||
self, text: Dict[str, Any], mm_data: Dict[str, Any], mm_processor_kwargs: Dict[str, Any]
|
||||
):
|
||||
images = mm_data.get("image")
|
||||
video_datas = mm_data.get("video")
|
||||
if video_datas is not None:
|
||||
videos = [video_data.frames for video_data in video_datas]
|
||||
else:
|
||||
videos = None
|
||||
do_rescale = True
|
||||
if images and isinstance(images[0], torch.Tensor):
|
||||
do_rescale = False
|
||||
if videos and isinstance(videos[0][0], torch.Tensor):
|
||||
do_rescale = False
|
||||
return self.processor(
|
||||
text=[text],
|
||||
images=images,
|
||||
videos=videos,
|
||||
padding=True,
|
||||
do_rescale=do_rescale,
|
||||
return_tensors="pt",
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
def _postprocess(self, input_ids: torch.IntTensor) -> torch.IntTensor:
|
||||
masks = (input_ids == self.config.image_token_id) | (
|
||||
input_ids == self.config.video_token_id
|
||||
)
|
||||
input_ids[masks] = self.tllm_multimodal_token_id
|
||||
return input_ids
|
||||
|
||||
def get_mrope_config(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
image_grid_thw: torch.LongTensor,
|
||||
video_grid_thw: torch.LongTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
mrope_position_ids, mrope_position_deltas = Qwen3VLInputProcessorBase.get_rope_index(
|
||||
self.config, input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
|
||||
mrope_config = {}
|
||||
mrope_config["mrope_position_ids"] = mrope_position_ids.to("cpu").clone()
|
||||
mrope_config["mrope_position_deltas"] = (
|
||||
mrope_position_deltas.to("cpu").to(torch.int32).clone()
|
||||
)
|
||||
|
||||
return mrope_config
|
||||
|
||||
@nvtx_range("Qwen3VLInputProcessorBase forward()")
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self,
|
||||
inputs: TextPrompt,
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data, mm_processor_kwargs = (
|
||||
inputs.get("prompt"),
|
||||
inputs.get("multi_modal_data", {}),
|
||||
inputs.get("mm_processor_kwargs", {}),
|
||||
)
|
||||
with nvtx_range_debug("transformers input preprocess"):
|
||||
processed_inputs = self._preprocess(text_prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
multimodal_data = {}
|
||||
pixel_values = processed_inputs.get("pixel_values", None)
|
||||
if pixel_values is not None:
|
||||
multimodal_data["image"] = {
|
||||
"pixel_values": pixel_values.to(self.dtype),
|
||||
"image_grid_thw": processed_inputs.get("image_grid_thw"),
|
||||
}
|
||||
|
||||
pixel_values_videos = processed_inputs.get("pixel_values_videos", None)
|
||||
if pixel_values_videos is not None:
|
||||
multimodal_data["video"] = {
|
||||
"pixel_values_videos": pixel_values_videos.to(self.dtype),
|
||||
"video_grid_thw": processed_inputs.get("video_grid_thw"),
|
||||
}
|
||||
|
||||
# NOTE: Even on the text-only prompts, we still need 'mrope_position_ids'.
|
||||
mrope_config = self.get_mrope_config(
|
||||
processed_inputs["input_ids"],
|
||||
processed_inputs.get("image_grid_thw", None),
|
||||
processed_inputs.get("video_grid_thw", None),
|
||||
processed_inputs.get("attention_mask", None),
|
||||
)
|
||||
multimodal_data["mrope_config"] = mrope_config
|
||||
|
||||
fused_input_ids = processed_inputs["input_ids"][0]
|
||||
if mm_data:
|
||||
fused_input_ids = self._postprocess(fused_input_ids)
|
||||
|
||||
return fused_input_ids.to(torch.int32).tolist(), {
|
||||
"multimodal_data": multimodal_data,
|
||||
}
|
||||
|
||||
|
||||
class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention):
|
||||
def __init__(self, model_config, layer_idx):
|
||||
model_config.pretrained_config.max_position_embeddings = (
|
||||
model_config.pretrained_config.text_config.max_position_embeddings
|
||||
)
|
||||
model_config.pretrained_config.vision_config.torch_dtype = (
|
||||
model_config.pretrained_config.text_config.dtype
|
||||
)
|
||||
super().__init__(model_config, layer_idx)
|
||||
|
||||
|
||||
class Qwen3VLVisionMLP(MLP):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
|
||||
config = model_config.pretrained_config.vision_config
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=True,
|
||||
activation=HF_ACT2FN[config.hidden_act],
|
||||
dtype=model_config.pretrained_config.text_config.dtype,
|
||||
config=model_config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3VLVisionBlock(torch.nn.Module):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config.vision_config
|
||||
|
||||
self.norm1 = LayerNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=model_config.pretrained_config.text_config.rms_norm_eps,
|
||||
dtype=model_config.pretrained_config.text_config.dtype,
|
||||
)
|
||||
self.norm2 = LayerNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=model_config.pretrained_config.text_config.rms_norm_eps,
|
||||
dtype=model_config.pretrained_config.text_config.dtype,
|
||||
)
|
||||
self.attn = Qwen3VLVisionAttention(model_config, layer_idx)
|
||||
self.mlp = Qwen3VLVisionMLP(model_config, layer_idx)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = residual + self.attn(
|
||||
hidden_states=hidden_states,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = residual + self.mlp(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen3VLVisionPatchMerger(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_config: ModelConfig[PretrainedConfig], use_postshuffle_norm: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config.vision_config
|
||||
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
||||
self.use_postshuffle_norm = use_postshuffle_norm
|
||||
self.norm = LayerNorm(
|
||||
hidden_size=self.hidden_size if use_postshuffle_norm else config.hidden_size,
|
||||
eps=model_config.pretrained_config.text_config.rms_norm_eps,
|
||||
dtype=model_config.pretrained_config.text_config.dtype,
|
||||
)
|
||||
self.linear_fc1 = Linear(
|
||||
in_features=self.hidden_size,
|
||||
out_features=self.hidden_size,
|
||||
bias=True,
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
allreduce_strategy=model_config.allreduce_strategy,
|
||||
)
|
||||
self.act_fn = nn.GELU()
|
||||
self.linear_fc2 = Linear(
|
||||
in_features=self.hidden_size,
|
||||
out_features=config.out_hidden_size,
|
||||
bias=True,
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
allreduce_strategy=model_config.allreduce_strategy,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_postshuffle_norm:
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
|
||||
hidden_states = self.norm(hidden_states).view(-1, self.hidden_size)
|
||||
hidden_states = self.linear_fc1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen3VisionModel(torch.nn.Module):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
self.config = self.model_config.pretrained_config.vision_config
|
||||
|
||||
self.spatial_merge_size = self.config.spatial_merge_size
|
||||
self.patch_size = self.config.patch_size
|
||||
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||
|
||||
self.patch_embed = HFQwen3VLVisionPatchEmbed(
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
self.pos_embed = nn.Embedding(self.config.num_position_embeddings, self.config.hidden_size)
|
||||
self.num_grid_per_side = int(self.config.num_position_embeddings**0.5)
|
||||
|
||||
head_dim = self.config.hidden_size // self.config.num_heads
|
||||
self.rotary_pos_emb = HFQwen3VLVisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen3VLVisionBlock(model_config, layer_idx=layer_idx)
|
||||
for layer_idx in range(self.config.depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen3VLVisionPatchMerger(
|
||||
model_config=model_config,
|
||||
use_postshuffle_norm=False,
|
||||
)
|
||||
self.deepstack_visual_indexes = self.config.deepstack_visual_indexes
|
||||
self.deepstack_merger_list = nn.ModuleList(
|
||||
[
|
||||
Qwen3VLVisionPatchMerger(
|
||||
model_config=model_config,
|
||||
use_postshuffle_norm=True,
|
||||
)
|
||||
for _ in range(len(self.deepstack_visual_indexes))
|
||||
]
|
||||
)
|
||||
self.metadata_cls = get_attention_backend(self.model_config.attn_backend).Metadata
|
||||
|
||||
self.attn_metadata = self.metadata_cls(
|
||||
max_num_requests=8192, # TODO: Make this dynamic
|
||||
max_num_tokens=8192, # TODO: Make this dynamic
|
||||
kv_cache_manager=None,
|
||||
)
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
merge_size = self.spatial_merge_size
|
||||
|
||||
max_hw = int(grid_thw[:, 1:].max().item())
|
||||
freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
|
||||
device = freq_table.device
|
||||
|
||||
total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw:
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
|
||||
block_rows = torch.arange(merged_h, device=device) # block row indices
|
||||
block_cols = torch.arange(merged_w, device=device) # block col indices
|
||||
intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
|
||||
intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
|
||||
|
||||
# Compute full-resolution positions
|
||||
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||
|
||||
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
|
||||
num_tokens = coords.shape[0]
|
||||
pos_ids[offset : offset + num_tokens] = coords
|
||||
offset += num_tokens
|
||||
|
||||
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
||||
embeddings = embeddings.flatten(1)
|
||||
return embeddings
|
||||
|
||||
def fast_pos_embed_interpolate(self, grid_thw):
|
||||
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
|
||||
|
||||
idx_list = [[] for _ in range(4)]
|
||||
weight_list = [[] for _ in range(4)]
|
||||
|
||||
for t, h, w in zip(grid_ts, grid_hs, grid_ws):
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
|
||||
|
||||
h_idxs_floor = h_idxs.int()
|
||||
w_idxs_floor = w_idxs.int()
|
||||
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
|
||||
dh = h_idxs - h_idxs_floor
|
||||
dw = w_idxs - w_idxs_floor
|
||||
|
||||
base_h = h_idxs_floor * self.num_grid_per_side
|
||||
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||
|
||||
indices = [
|
||||
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||
]
|
||||
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
|
||||
for i in range(4):
|
||||
idx_list[i].extend(indices[i].tolist())
|
||||
weight_list[i].extend(weights[i].tolist())
|
||||
|
||||
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
|
||||
weight_tensor = torch.tensor(
|
||||
weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
|
||||
)
|
||||
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
|
||||
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
|
||||
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
|
||||
patch_pos_embeds_permute = []
|
||||
merge_size = self.config.spatial_merge_size
|
||||
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||
pos_embed = pos_embed.repeat(t, 1)
|
||||
pos_embed = (
|
||||
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5)
|
||||
.flatten(0, 4)
|
||||
)
|
||||
patch_pos_embeds_permute.append(pos_embed)
|
||||
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
|
||||
return patch_pos_embeds
|
||||
|
||||
def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata):
|
||||
# NOTE: The single prompt is divided into multiple seq_lens, so pretending have many batch_sizes.
|
||||
batch_size = len(seq_lens)
|
||||
prompt_lens = seq_lens
|
||||
seq_lens = torch.tensor(seq_lens, dtype=torch.int, pin_memory=True)
|
||||
request_ids = list(range(1, batch_size + 1))
|
||||
|
||||
attn_metadata.num_contexts = batch_size
|
||||
attn_metadata.request_ids = request_ids
|
||||
attn_metadata.prompt_lens = prompt_lens
|
||||
attn_metadata.seq_lens = seq_lens
|
||||
attn_metadata.max_seq_len = seq_lens.max().item()
|
||||
attn_metadata.prepare()
|
||||
return attn_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist()
|
||||
attn_metadata = self.prepare_attn_metadata(seq_lens, self.attn_metadata)
|
||||
|
||||
# Getting positional embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
# From this point, pure GPU operation
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
seq_len, _ = hidden_states.size()
|
||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_feature = self.deepstack_merger_list[
|
||||
self.deepstack_visual_indexes.index(layer_num)
|
||||
](hidden_states)
|
||||
deepstack_feature_lists.append(deepstack_feature)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
|
||||
return hidden_states, deepstack_feature_lists
|
||||
|
||||
|
||||
class Qwen3VisionModelBase(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
model_class: Union[type[PreTrainedModel], type[torch.nn.Module]],
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
self.model_dtype = self.model_config.pretrained_config.text_config.dtype
|
||||
|
||||
# NOTE: Re-setting QuantConfig to exclude vision encoder weights from quantization load.
|
||||
self.model_config.quant_config = QuantConfig(
|
||||
kv_cache_quant_algo=self.model_config.quant_config.kv_cache_quant_algo
|
||||
)
|
||||
|
||||
self.visual = model_class(self.model_config).to(self.model_dtype)
|
||||
|
||||
self.post_config()
|
||||
|
||||
def post_config(self):
|
||||
self.config = self.model_config.pretrained_config.vision_config
|
||||
|
||||
def load_weights(self, weights: Dict[str, torch.Tensor]):
|
||||
visual_weights = filter_weights("model.visual", weights)
|
||||
converted_weights = {}
|
||||
|
||||
qkv_pattern = re.compile(r"(.*?)attn\.qkv\.(.*)")
|
||||
for name in visual_weights:
|
||||
# Handle with weights and bias for vision transformer's qkv projection.
|
||||
match = qkv_pattern.match(name)
|
||||
if match:
|
||||
prefix, suffix = match.groups()
|
||||
q_name = f"{prefix}attn.q_proj.{suffix}"
|
||||
k_name = f"{prefix}attn.k_proj.{suffix}"
|
||||
v_name = f"{prefix}attn.v_proj.{suffix}"
|
||||
dim_shape = visual_weights[name].shape[0] // 3
|
||||
converted_weights[q_name] = visual_weights[name][:dim_shape]
|
||||
converted_weights[k_name] = visual_weights[name][dim_shape : 2 * dim_shape]
|
||||
converted_weights[v_name] = visual_weights[name][2 * dim_shape :]
|
||||
else:
|
||||
converted_weights[name] = visual_weights[name]
|
||||
pattern_mapping = {
|
||||
r"(.*?)attn.proj.(.*)": r"\1attn.o_proj.\2",
|
||||
r"(.*?)mlp.linear_fc1.(.*)": r"\1mlp.up_proj.\2",
|
||||
r"(.*?)mlp.linear_fc2.(.*)": r"\1mlp.down_proj.\2",
|
||||
}
|
||||
self.visual.config.num_attention_heads = self.visual.config.num_heads
|
||||
_load_weights_impl(self.visual, converted_weights, params_map=pattern_mapping)
|
||||
|
||||
def _parse_and_batch_multimodal_data(
|
||||
self, multimodal_params: List[MultimodalParams]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, List[Any]]]:
|
||||
pixel_values_list = []
|
||||
pixel_values_videos_list = []
|
||||
image_grid_thw_list = []
|
||||
video_grid_thw_list = []
|
||||
|
||||
for multimodal_param in multimodal_params:
|
||||
multimodal_data = multimodal_param.multimodal_data
|
||||
# Process images if present
|
||||
if multimodal_data.get("image") is not None:
|
||||
pixel_values_list.append(multimodal_data["image"]["pixel_values"])
|
||||
image_grid_thw_list.append(multimodal_data["image"]["image_grid_thw"])
|
||||
|
||||
# Process videos if present
|
||||
if multimodal_data.get("video") is not None:
|
||||
pixel_values_videos_list.append(multimodal_data["video"]["pixel_values_videos"])
|
||||
video_grid_thw_list.append(multimodal_data["video"]["video_grid_thw"])
|
||||
|
||||
# Concatenate tensors
|
||||
mm_content_dict = {}
|
||||
if pixel_values_list:
|
||||
mm_content_dict["pixel_values"] = (
|
||||
torch.cat(pixel_values_list, dim=0)
|
||||
if len(pixel_values_list) > 1
|
||||
else pixel_values_list[0]
|
||||
)
|
||||
if pixel_values_videos_list:
|
||||
mm_content_dict["pixel_values_videos"] = (
|
||||
torch.cat(pixel_values_videos_list, dim=0)
|
||||
if len(pixel_values_videos_list) > 1
|
||||
else pixel_values_videos_list[0]
|
||||
)
|
||||
|
||||
# Prepare extra data
|
||||
mm_extra_data = {}
|
||||
if image_grid_thw_list:
|
||||
mm_extra_data["image_grid_thw"] = (
|
||||
torch.cat(image_grid_thw_list, dim=0)
|
||||
if len(image_grid_thw_list) > 1
|
||||
else image_grid_thw_list[0]
|
||||
)
|
||||
if video_grid_thw_list:
|
||||
mm_extra_data["video_grid_thw"] = (
|
||||
torch.cat(video_grid_thw_list, dim=0)
|
||||
if len(video_grid_thw_list) > 1
|
||||
else video_grid_thw_list[0]
|
||||
)
|
||||
|
||||
return mm_content_dict, mm_extra_data
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, multimodal_params: List[MultimodalParams]) -> List[torch.Tensor]:
|
||||
mm_content_data, mm_extra_data = self._parse_and_batch_multimodal_data(multimodal_params)
|
||||
pixel_values = mm_content_data.get("pixel_values", None)
|
||||
pixel_values_videos = mm_content_data.get("pixel_values_videos", None)
|
||||
|
||||
if pixel_values is not None and pixel_values_videos is not None:
|
||||
raise ValueError("Currently only support single modality per request")
|
||||
|
||||
image_grid_thw = mm_extra_data.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_extra_data.get("video_grid_thw", None)
|
||||
|
||||
embeds = []
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.to(self.model_dtype)
|
||||
image_embeds, deepstack_image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
)
|
||||
# NOTE: We concatenate deepstack_embeds to mm_embeds
|
||||
# The shape will be [seq_len, hidden_dim * (num_deepstack_layers + 1)]
|
||||
mixed_image_embeds = torch.cat([image_embeds] + deepstack_image_embeds, dim=1)
|
||||
embeds.append(mixed_image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.to(self.model_dtype)
|
||||
video_embeds, deepstack_video_embeds = self.visual(
|
||||
pixel_values_videos, grid_thw=video_grid_thw
|
||||
)
|
||||
# NOTE: We concatenate deepstack_embeds to mm_embeds
|
||||
# The shape will be [seq_len, hidden_dim * (num_deepstack_layers + 1)]
|
||||
mixed_video_embeds = torch.cat([video_embeds] + deepstack_video_embeds, dim=1)
|
||||
embeds.append(mixed_video_embeds)
|
||||
return embeds
|
||||
|
||||
|
||||
class Qwen3VLModelBase(PreTrainedModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.original_arch = model_config.pretrained_config.architectures[0]
|
||||
|
||||
disable_fuse_rope = kwargs.get("disable_fuse_rope", False)
|
||||
model_config.pretrained_config.text_config.disable_fuse_rope = disable_fuse_rope
|
||||
model_config.pretrained_config.text_config.rope_scaling["type"] = "mrope"
|
||||
config = model_config.pretrained_config
|
||||
|
||||
self._supports_sdpa = True
|
||||
self._supports_flash_attn = True
|
||||
super().__init__(config)
|
||||
if not disable_fuse_rope:
|
||||
self.init_mrope_embedding(model_config)
|
||||
|
||||
self.model_config = model_config
|
||||
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config = config.text_config
|
||||
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
|
||||
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
|
||||
|
||||
if not _is_disagg():
|
||||
self.mm_encoder = Qwen3VisionModelBase(
|
||||
model_config, kwargs.get("vision_model_class", None)
|
||||
).eval()
|
||||
|
||||
self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes")
|
||||
self.deepstack_num_level = (
|
||||
len(config.vision_config.deepstack_visual_indexes) if self.use_deepstack else 0
|
||||
)
|
||||
|
||||
self.post_config()
|
||||
|
||||
def post_config(self):
|
||||
# use llm.config as config for pytorch model engine
|
||||
self.model_config.pretrained_config = self.llm.config
|
||||
self.config = self.model_config.pretrained_config
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
return self.llm.infer_max_seq_len()
|
||||
|
||||
def init_mrope_embedding(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
config = model_config.pretrained_config.text_config
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.from_string(config.rope_scaling["type"]),
|
||||
rope=RopeParams.from_config(config),
|
||||
mrope_section=config.rope_scaling.get("mrope_section", None),
|
||||
mrope_interleaved=config.rope_scaling.get("mrope_interleaved", False),
|
||||
)
|
||||
self.rotary_emb = MRotaryEmbedding(
|
||||
pos_embd_params.rope,
|
||||
head_dim=config.hidden_size // config.num_attention_heads,
|
||||
is_neox=pos_embd_params.is_neox,
|
||||
mrope_section=pos_embd_params.mrope_section,
|
||||
mrope_interleaved=pos_embd_params.mrope_interleaved,
|
||||
).to("cuda")
|
||||
self.mrope_position_ids_padding_cuda = torch.zeros(
|
||||
(
|
||||
3,
|
||||
1,
|
||||
config.max_position_embeddings,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@nvtx_range("Qwen3-VL prepare_mrope_config")
|
||||
def prepare_mrope_config(
|
||||
self, multimodal_params: List[MultimodalParams], num_context_requests: int
|
||||
):
|
||||
mrope_config = {}
|
||||
mrope_rotary_cos_sin = []
|
||||
mrope_position_deltas = []
|
||||
for multimodal_param in multimodal_params[:num_context_requests]:
|
||||
if multimodal_param.multimodal_data.get("mrope_config") is not None:
|
||||
with nvtx_range("Qwen3-VL get_cos_sin"):
|
||||
if (
|
||||
multimodal_param.multimodal_data["mrope_config"].get("mrope_position_ids")
|
||||
is not None
|
||||
):
|
||||
mrope_position_ids = multimodal_param.multimodal_data["mrope_config"][
|
||||
"mrope_position_ids"
|
||||
]
|
||||
|
||||
self.mrope_position_ids_padding_cuda[
|
||||
:, :, : mrope_position_ids.shape[-1]
|
||||
] = mrope_position_ids
|
||||
self.mrope_position_ids_padding_cuda[
|
||||
:, :, mrope_position_ids.shape[-1] :
|
||||
] = 0
|
||||
cos, sin = self.rotary_emb.get_cos_sin(self.mrope_position_ids_padding_cuda)
|
||||
concat_cos_sin = torch.stack((cos, sin), dim=-1)
|
||||
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1)
|
||||
mrope_rotary_cos_sin.append(concat_cos_sin)
|
||||
|
||||
for multimodal_param in multimodal_params[num_context_requests:]:
|
||||
if multimodal_param.multimodal_data.get("mrope_config") is not None:
|
||||
if (
|
||||
multimodal_param.multimodal_data["mrope_config"].get("mrope_position_deltas")
|
||||
is not None
|
||||
):
|
||||
mrope_position_deltas.append(
|
||||
multimodal_param.multimodal_data["mrope_config"]["mrope_position_deltas"]
|
||||
)
|
||||
|
||||
with nvtx_range("Qwen3-VL concat mrope_rotary_cos_sin"):
|
||||
if mrope_rotary_cos_sin:
|
||||
mrope_config["mrope_rotary_cos_sin"] = torch.cat(mrope_rotary_cos_sin, dim=0)
|
||||
with nvtx_range("Qwen3-VL concat mrope_position_deltas"):
|
||||
if mrope_position_deltas:
|
||||
mrope_config["mrope_position_deltas"] = torch.cat(mrope_position_deltas, dim=0)
|
||||
|
||||
return mrope_config
|
||||
|
||||
def split_mm_embeds(self, mm_embed, deepstack_num_level):
|
||||
num_elements = mm_embed.shape[1] // (deepstack_num_level + 1)
|
||||
mm_embed_chunks = torch.split(mm_embed, [num_elements] * (deepstack_num_level + 1), dim=1)
|
||||
return mm_embed_chunks[0], list(mm_embed_chunks[1:])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.IntTensor] = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
VLM forward logic with inflight batching support.
|
||||
"""
|
||||
num_context_requests, num_generation_requests = (
|
||||
attn_metadata.num_contexts,
|
||||
attn_metadata.num_generations,
|
||||
)
|
||||
logger.debug(
|
||||
f"num_context_requests: {num_context_requests}, num_generation_requests: {num_generation_requests}"
|
||||
)
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embeds = []
|
||||
mrope_config = {}
|
||||
deepstack_embeds = []
|
||||
|
||||
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts,
|
||||
# so we need to separate the mm_multimodal_params from the text-only prompts.
|
||||
mm_multimodal_params = [
|
||||
multimodal_param
|
||||
for multimodal_param in multimodal_params
|
||||
if multimodal_param.multimodal_data.get("image", {}).get("pixel_values") is not None
|
||||
or multimodal_param.multimodal_data.get("video", {}).get("pixel_values_videos")
|
||||
is not None
|
||||
]
|
||||
if len(mm_multimodal_params) > 0:
|
||||
if not _is_disagg():
|
||||
mm_embeds = get_multimodal_embeddings(
|
||||
encoder_forward_fn=self.mm_encoder.forward,
|
||||
multimodal_params=mm_multimodal_params,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Qwen3VLModel does not support disaggregated inference yet. Please unset "
|
||||
"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
|
||||
)
|
||||
mm_embeds = find_input_mm_embeds(mm_embeds, mm_multimodal_params)
|
||||
|
||||
if self.use_deepstack:
|
||||
for i, mm_embed in enumerate(mm_embeds):
|
||||
mm_embed, deepstack_embed = self.split_mm_embeds(
|
||||
mm_embed, self.deepstack_num_level
|
||||
)
|
||||
mm_embeds[i] = mm_embed
|
||||
deepstack_embeds.extend(deepstack_embed)
|
||||
|
||||
if not self.model_config.pretrained_config.disable_fuse_rope:
|
||||
mrope_config = self.prepare_mrope_config(multimodal_params, num_context_requests)
|
||||
|
||||
result = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens,
|
||||
input_ids,
|
||||
mm_embeds,
|
||||
extra_embeds=deepstack_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
if len(deepstack_embeds) > 0:
|
||||
input_ids, input_embeds, deepstack_embeds = result
|
||||
else:
|
||||
input_ids, input_embeds = result
|
||||
|
||||
output_prob = self.llm.forward(
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=input_embeds,
|
||||
return_context_logits=return_context_logits,
|
||||
deepstack_embeds=deepstack_embeds,
|
||||
mrope_config=mrope_config,
|
||||
)
|
||||
logger.debug(f"output shape: {output_prob.shape}")
|
||||
return output_prob
|
||||
64
tensorrt_llm/_torch/models/modeling_qwen3vl_moe.py
Normal file
64
tensorrt_llm/_torch/models/modeling_qwen3vl_moe.py
Normal file
@ -0,0 +1,64 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg
|
||||
|
||||
from ...inputs import (
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement,
|
||||
register_input_processor,
|
||||
)
|
||||
from .checkpoints.base_weight_mapper import BaseWeightMapper
|
||||
from .checkpoints.hf.qwen3vl_moe_weight_mapper import Qwen3VLMoeHfWeightMapper
|
||||
from .modeling_qwen3vl import (
|
||||
Qwen3VisionModel,
|
||||
Qwen3VisionModelBase,
|
||||
Qwen3VLInputProcessorBase,
|
||||
Qwen3VLModelBase,
|
||||
)
|
||||
from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder
|
||||
|
||||
|
||||
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
|
||||
@register_auto_model("Qwen3VLMoeForConditionalGeneration")
|
||||
@register_input_processor(
|
||||
Qwen3VLInputProcessorBase,
|
||||
model_type="qwen3_vl_moe",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "<|vision_start|><|image_pad|><|vision_end|>",
|
||||
"video": "<|vision_start|><|video_pad|><|vision_end|>",
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
),
|
||||
)
|
||||
class Qwen3MoeVLModel(Qwen3VLModelBase):
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs):
|
||||
# NOTE: HF implementation.
|
||||
kwargs["vision_model_class"] = Qwen3VisionModel
|
||||
kwargs["disable_fuse_rope"] = kwargs.get(
|
||||
"disable_fuse_rope", False
|
||||
) # TODO: Make this ModelConfig's argument
|
||||
super().__init__(model_config, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def multimodal_data_device_paths(self) -> List[str]:
|
||||
return [
|
||||
"image.pixel_values",
|
||||
"video.pixel_values_videos",
|
||||
"multimodal_embedding",
|
||||
]
|
||||
|
||||
def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper):
|
||||
if not _is_disagg():
|
||||
self.mm_encoder.load_weights(weights)
|
||||
|
||||
weight_mapper = Qwen3VLMoeHfWeightMapper()
|
||||
weight_mapper.init_model_and_config(self.llm, self.model_config)
|
||||
filtered_weights = {k: v for k, v in weights.items() if not k.startswith("model.visual.")}
|
||||
params_map = {
|
||||
r"^model\.language_model\.(.*)$": r"model.\1",
|
||||
}
|
||||
self.llm.load_weights(filtered_weights, weight_mapper, params_map=params_map)
|
||||
@ -672,10 +672,12 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
def load_weights(self,
|
||||
weights: Dict,
|
||||
weight_mapper: Optional[BaseWeightMapper] = None,
|
||||
params_map: Optional[Dict[str, str]] = None,
|
||||
allow_partial_loading: bool = False):
|
||||
super().load_weights(weights=weights,
|
||||
weight_mapper=weight_mapper,
|
||||
skip_modules=["draft_model"],
|
||||
params_map=params_map,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
|
||||
def load_draft_weights(self,
|
||||
|
||||
@ -561,6 +561,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
weights: Dict,
|
||||
weight_mapper: Optional["BaseWeightMapper"] = None,
|
||||
skip_modules: List[str] = [],
|
||||
params_map: Optional[Dict[str, str]] = None,
|
||||
allow_partial_loading: bool = False):
|
||||
# TODO smor- this solution is a temporary solution to load weights while we are still using
|
||||
# the old checkpoint format loading process. Once checkpoint format is unified
|
||||
@ -570,6 +571,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
_load_weights_impl(self,
|
||||
weights,
|
||||
skip_modules,
|
||||
params_map=params_map,
|
||||
preload_weight_modules=preload_weight_modules,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
else:
|
||||
@ -577,6 +579,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
weights,
|
||||
weight_mapper,
|
||||
skip_modules,
|
||||
params_map=params_map,
|
||||
preload_weight_modules=preload_weight_modules,
|
||||
allow_partial_loading=allow_partial_loading)
|
||||
|
||||
|
||||
@ -324,7 +324,7 @@ class Attention(nn.Module):
|
||||
head_dim=self.head_dim,
|
||||
is_neox=self.pos_embd_params.is_neox,
|
||||
mrope_section=self.pos_embd_params.mrope_section,
|
||||
)
|
||||
mrope_interleaved=self.pos_embd_params.mrope_interleaved)
|
||||
else:
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.pos_embd_params.rope,
|
||||
|
||||
@ -160,6 +160,7 @@ class QKNormRoPEAttention(Attention):
|
||||
attn_output_gate: Optional[bool] = None,
|
||||
is_qk_norm: bool = True,
|
||||
reduce_output: bool = True,
|
||||
rope_fusion: bool = True,
|
||||
):
|
||||
self.pretrained_config = config.pretrained_config
|
||||
|
||||
@ -170,7 +171,8 @@ class QKNormRoPEAttention(Attention):
|
||||
|
||||
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb
|
||||
# will be skipped in the overridden apply_rope.
|
||||
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope and not attn_output_gate and not use_gemma_rms_norm
|
||||
rope_fusion &= (not self.fuse_qk_norm_rope and not skip_rope
|
||||
and not attn_output_gate and not use_gemma_rms_norm)
|
||||
self.is_qk_norm = is_qk_norm
|
||||
assert not (fuse_qk_norm_rope and skip_rope
|
||||
), "Fusing qk norm and skipping rope is not supported"
|
||||
|
||||
@ -136,9 +136,22 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
head_dim: int,
|
||||
mrope_section: List[int],
|
||||
is_neox: bool = True,
|
||||
mrope_interleaved: bool = False,
|
||||
):
|
||||
super().__init__(rope_params, head_dim=head_dim, is_neox=is_neox)
|
||||
self.mrope_section = mrope_section
|
||||
self.mrope_interleaved = mrope_interleaved
|
||||
|
||||
def apply_interleaved_rope(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# referenced from https://github.com/vllm-project/vllm/blob/aeb82b1930454498fccc7e91f7c4e0f360cf658a/vllm/model_executor/layers/rotary_embedding/mrope.py#L191
|
||||
x_t = x[0].clone()
|
||||
x_t[...,
|
||||
1:self.mrope_section[1] * 3:3] = x[1, ...,
|
||||
1:self.mrope_section[1] * 3:3]
|
||||
x_t[...,
|
||||
2:self.mrope_section[2] * 3:3] = x[2, ...,
|
||||
2:self.mrope_section[2] * 3:3]
|
||||
return x_t
|
||||
|
||||
def get_cos_sin(
|
||||
self,
|
||||
@ -146,16 +159,20 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
if position_ids.ndim == 3:
|
||||
cos_sin = self.rotary_cos_sin[position_ids.view(3, -1)]
|
||||
cos, sin = cos_sin[:, :, 0, :], cos_sin[:, :, 1, :]
|
||||
cos = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
if self.mrope_interleaved:
|
||||
cos = self.apply_interleaved_rope(cos)
|
||||
sin = self.apply_interleaved_rope(sin)
|
||||
else:
|
||||
cos = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
sin = torch.cat([
|
||||
m[i]
|
||||
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
# Fallback to the original RoPE where position_ids is 2D for dummy requests
|
||||
cos_sin = self.rotary_cos_sin[position_ids.view(-1)]
|
||||
|
||||
@ -319,7 +319,7 @@ class CUDAGraphRunner:
|
||||
}
|
||||
if self.config.use_mrope:
|
||||
sliced_static_tensors["position_ids"] = self.shared_static_tensors[
|
||||
"position_ids"][:, :, :num_tokens_for_capture],
|
||||
"position_ids"][:, :, :num_tokens_for_capture]
|
||||
sliced_static_tensors[
|
||||
"multimodal_params"] = self.shared_static_tensors[
|
||||
"multimodal_params"][:batch_size * self.max_beam_width]
|
||||
|
||||
@ -590,7 +590,7 @@ def build_llava_engine(args):
|
||||
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
|
||||
args.model_path, dtype=torch.float16)
|
||||
wrapper = LlavaOnevisionVisionWrapper(
|
||||
model.vision_tower.vision_model.to(args.device),
|
||||
model.vision_tower.to(args.device),
|
||||
model.multi_modal_projector.to(args.device), model.config)
|
||||
|
||||
export_onnx(wrapper, image, f'{args.output_dir}/onnx')
|
||||
|
||||
@ -19,3 +19,5 @@ nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16:
|
||||
- accuracy: 26.67
|
||||
microsoft/Phi-4-multimodal-instruct:
|
||||
- accuracy: 53.67
|
||||
Qwen/Qwen3-VL-30B-A3B-Instruct:
|
||||
- accuracy: 55.33
|
||||
|
||||
@ -245,3 +245,21 @@ class TestGemma3_27BInstruct(LlmapiAccuracyTestHarness):
|
||||
) as llm:
|
||||
task = MMMU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=self.sampling_params)
|
||||
|
||||
|
||||
class TestQwen3VL_MOE(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-VL-30B-A3B-Instruct"
|
||||
MAX_NUM_TOKENS = 16384
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>"
|
||||
)
|
||||
|
||||
def test_auto_dtype(self):
|
||||
with LLM(
|
||||
self.MODEL_PATH,
|
||||
max_num_tokens=self.MAX_NUM_TOKENS,
|
||||
) as llm:
|
||||
task = MMMU(self.MODEL_NAME)
|
||||
task.evaluate(llm, sampling_params=self.sampling_params)
|
||||
|
||||
@ -656,6 +656,7 @@ accuracy/test_llm_api_pytorch_multimodal.py::TestVILA1_5_3B::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestNemotron_Nano_12B_V2_VL::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestPhi4MMFusedVisionLora::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestGemma3_27BInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen3VL_MOE::test_auto_dtype
|
||||
|
||||
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
|
||||
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]
|
||||
|
||||
@ -21,6 +21,7 @@ l0_l40s:
|
||||
- unittest/_torch/modeling -k "modeling_phi4mm"
|
||||
- unittest/_torch/modeling/test_modeling_llava_next.py::TestLlavaNext::test_all
|
||||
- unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all
|
||||
- unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all
|
||||
- test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _model_test_utils import get_small_model_config
|
||||
from build_and_run_ad import ExperimentConfig
|
||||
@ -9,6 +10,10 @@ from tensorrt_llm._torch.auto_deploy.utils._graph import move_to_device
|
||||
|
||||
|
||||
def test_build_run_llama4_vlm():
|
||||
pytest.skip(
|
||||
"Skipping test_build_run_llm4_vlm because Llama4 is giving an error on upgrading transformers version to 4.57.1"
|
||||
"https://nvbugspro.nvidia.com/bug/5732942"
|
||||
)
|
||||
atol = 1e-3
|
||||
rtol = 1e-3
|
||||
|
||||
|
||||
@ -201,6 +201,19 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs):
|
||||
],
|
||||
)
|
||||
def test_build_ad(model_hub_id: str, llm_extra_args: dict):
|
||||
if (
|
||||
model_hub_id == "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
and llm_extra_args.get("mode") != "transformers"
|
||||
):
|
||||
pytest.skip(
|
||||
"Mixtral-8x7B-Instruct-v0.1 is giving an error on upgrading transformers version to 4.57.1"
|
||||
"https://nvbugspro.nvidia.com/bug/5732942"
|
||||
)
|
||||
if model_hub_id == "Qwen/Qwen3-30B-A3B" and llm_extra_args.get("mode") != "transformers":
|
||||
pytest.skip(
|
||||
"Qwen3-30B-A3B is giving an error on upgrading transformers version to 4.57.1"
|
||||
"https://nvbugspro.nvidia.com/bug/5732942"
|
||||
)
|
||||
experiment_config = get_small_model_config(model_hub_id, **llm_extra_args)
|
||||
experiment_config["args"]["runtime"] = "demollm" # Default runtime set to demollm
|
||||
experiment_config["args"]["world_size"] = 0 # Default world_size set to 0
|
||||
|
||||
@ -271,7 +271,10 @@ class TestLlama4MinLatency(unittest.TestCase):
|
||||
"The transformers between 4.55.0 and 4.56.1 have accuracy "
|
||||
"issues for Llama4. See: "
|
||||
"https://github.com/huggingface/transformers/pull/40609")
|
||||
|
||||
elif transformers.__version__ >= "4.57.1":
|
||||
self.skipTest(
|
||||
"Bumping transformers version to 4.57.1 has accuracy issues for Llama4. See: "
|
||||
"http://nvbugs/5732958")
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG)
|
||||
# 17B * sizeof(float16) plus some extra for activations
|
||||
|
||||
@ -185,6 +185,12 @@ class TestModelingMultimodal(unittest.TestCase, ABC):
|
||||
else:
|
||||
model.load_weights(hf_model_state_dict)
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, "post_load_weights") and not getattr(
|
||||
module, "_weights_removed", False
|
||||
):
|
||||
module.post_load_weights()
|
||||
|
||||
return model, model_config
|
||||
|
||||
def create_hf_model(self, pretrained_config: PretrainedConfig) -> PreTrainedModel:
|
||||
@ -457,7 +463,7 @@ class TestModelingMultimodal(unittest.TestCase, ABC):
|
||||
"attn_metadata"
|
||||
].create_cuda_graph_metadata(1)
|
||||
|
||||
# Prepare metadata before capture (like in working Qwen2.5-VL test)
|
||||
# Prepare metadata before capture
|
||||
trtllm_inputs["attn_metadata"].prepare()
|
||||
|
||||
key = (1, 0, False)
|
||||
|
||||
@ -187,7 +187,7 @@ class TestQwen2_5_VL(TestModelingMultimodal):
|
||||
return self.trtllm_model.forward(**trtllm_inputs)
|
||||
else:
|
||||
# NOTE: Qwen2.5-VL model uses mrope
|
||||
graph_runner = create_mock_cuda_graph_runner(1, True)
|
||||
graph_runner = create_mock_cuda_graph_runner(1, use_mrope=True)
|
||||
trtllm_inputs["attn_metadata"] = trtllm_inputs[
|
||||
"attn_metadata"].create_cuda_graph_metadata(1)
|
||||
|
||||
@ -232,13 +232,6 @@ class TestQwen2_5_VL(TestModelingMultimodal):
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False),
|
||||
|
||||
# ==== Disable fuse rope scenarios ====
|
||||
TestQwen2_5_VLScenario(modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=True,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False),
|
||||
|
||||
# ==== Chunked Prefill Scenarios ====
|
||||
TestQwen2_5_VLScenario(modality="image",
|
||||
use_cuda_graph=False,
|
||||
@ -252,6 +245,13 @@ class TestQwen2_5_VL(TestModelingMultimodal):
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=True),
|
||||
|
||||
# ==== Disable fuse rope scenarios ====
|
||||
TestQwen2_5_VLScenario(modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=True,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False),
|
||||
]
|
||||
return scenarios
|
||||
|
||||
|
||||
283
tests/unittest/_torch/modeling/test_modeling_qwen3vl_moe.py
Normal file
283
tests/unittest/_torch/modeling/test_modeling_qwen3vl_moe.py
Normal file
@ -0,0 +1,283 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from _torch.helpers import create_mock_cuda_graph_runner
|
||||
from test_modeling_multimodal import MultimodalScenario, TestModelingMultimodal
|
||||
from transformers import Qwen3VLMoeConfig
|
||||
from transformers import Qwen3VLMoeForConditionalGeneration as HFQwen3VLMoeForConditionalGeneration
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.qwen3vl_moe_weight_mapper import (
|
||||
Qwen3VLMoeHfWeightMapper,
|
||||
)
|
||||
from tensorrt_llm._torch.models.modeling_qwen3vl_moe import Qwen3MoeVLModel
|
||||
|
||||
QWEN3_VL_30B_A3B_CONFIG = {
|
||||
"architectures": ["Qwen3VLMoeForConditionalGeneration"],
|
||||
"image_token_id": 151655,
|
||||
"model_type": "qwen3_vl_moe",
|
||||
"text_config": {
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"decoder_sparse_step": 1,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2048,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 6144,
|
||||
"max_position_embeddings": 262144,
|
||||
"mlp_only_layers": [],
|
||||
"model_type": "qwen3_vl_moe_text",
|
||||
"moe_intermediate_size": 768,
|
||||
"norm_topk_prob": True,
|
||||
"num_attention_heads": 32,
|
||||
"num_experts": 128,
|
||||
"num_experts_per_tok": 8,
|
||||
"num_hidden_layers": 2, # NOTE: Only 2 layer for testing, 48 layers for full model
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
"tie_word_embeddings": False,
|
||||
"transformers_version": "4.57.0.dev0",
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"deepstack_visual_indexes": [8, 16, 24],
|
||||
"depth": 27,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "qwen3_vl_moe",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 2048,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"_attn_implementation": "flash_attention_2",
|
||||
"_name_or_path": str(os.path.join(llm_models_root(), "Qwen3", "Qwen3-VL-30B-A3B-Instruct")),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class TestQwen3VLMoeScenario(MultimodalScenario):
|
||||
disable_fuse_rope: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Generate a human-readable string representation of the scenario."""
|
||||
features = []
|
||||
features.append(f"modality:{self.modality.lower()}")
|
||||
if self.use_cuda_graph:
|
||||
features.append("cuda_graph")
|
||||
if self.disable_fuse_rope:
|
||||
features.append("no_fuse_rope")
|
||||
if self.chunked_prefill:
|
||||
features.append("chunked_prefill")
|
||||
if self.kv_cache_reuse:
|
||||
features.append("kv_cache_reuse")
|
||||
return "-".join(features)
|
||||
|
||||
|
||||
class TestQwen3VLMoe(TestModelingMultimodal):
|
||||
def get_model_config(self):
|
||||
"""Return the model configuration dictionary."""
|
||||
return QWEN3_VL_30B_A3B_CONFIG
|
||||
|
||||
def get_trtllm_model_class(self):
|
||||
return Qwen3MoeVLModel
|
||||
|
||||
def get_hf_model_class(self):
|
||||
return HFQwen3VLMoeForConditionalGeneration
|
||||
|
||||
def get_weight_mapper_class(self):
|
||||
return Qwen3VLMoeHfWeightMapper
|
||||
|
||||
def get_model_type(self):
|
||||
return "qwen3_vl_moe"
|
||||
|
||||
def get_model_config_class(self):
|
||||
return Qwen3VLMoeConfig
|
||||
|
||||
def get_trtllm_inputs(
|
||||
self,
|
||||
input_ids,
|
||||
multimodal_params_list,
|
||||
is_gen: bool = False,
|
||||
num_cached_tokens_per_seq: List[int] = None,
|
||||
):
|
||||
trtllm_inputs = super().get_trtllm_inputs(
|
||||
input_ids, multimodal_params_list, is_gen, num_cached_tokens_per_seq
|
||||
)
|
||||
|
||||
if is_gen:
|
||||
mrope_gen_position_ids = []
|
||||
for multimodal_param in multimodal_params_list:
|
||||
mrope_gen_position_ids.append(
|
||||
multimodal_param.multimodal_data["mrope_config"]["mrope_position_deltas"]
|
||||
)
|
||||
mrope_gen_position_ids = torch.cat(mrope_gen_position_ids, dim=-1).to(self.device)
|
||||
trtllm_inputs["position_ids"] = (
|
||||
(trtllm_inputs["position_ids"] + mrope_gen_position_ids).expand(3, -1, 1).cuda()
|
||||
)
|
||||
gen_multimodal_params_list = []
|
||||
for multimodal_param in multimodal_params_list:
|
||||
multimodal_param.strip_for_generation()
|
||||
multimodal_param.to_device(
|
||||
"multimodal_data",
|
||||
self.device,
|
||||
pin_memory=True,
|
||||
target_keywords=["mrope_config.mrope_position_deltas"],
|
||||
)
|
||||
gen_multimodal_params_list.append(multimodal_param)
|
||||
trtllm_inputs["multimodal_params"] = gen_multimodal_params_list
|
||||
else:
|
||||
# Mrope position ids
|
||||
mrope_position_ids = []
|
||||
for multimodal_param in multimodal_params_list:
|
||||
mrope_position_ids.append(
|
||||
multimodal_param.multimodal_data["mrope_config"]["mrope_position_ids"]
|
||||
)
|
||||
position_ids = torch.cat(mrope_position_ids, dim=-1)
|
||||
position_ids = position_ids.cuda()
|
||||
trtllm_inputs["position_ids"] = position_ids
|
||||
|
||||
return trtllm_inputs
|
||||
|
||||
def init_kv_cache_manager(self, scenario: TestQwen3VLMoeScenario):
|
||||
# NOTE: Exactly the same as the parent class method,
|
||||
# but with the mrope flag set to True for Qwen2.5-VL model.
|
||||
cache_config = self.get_kv_cache_config(scenario)
|
||||
tokens_per_block = cache_config["tokens_per_block"]
|
||||
max_seq_len = cache_config["max_seq_len"]
|
||||
batch_size = cache_config["batch_size"]
|
||||
|
||||
num_blocks = (max_seq_len + tokens_per_block - 1) // tokens_per_block
|
||||
|
||||
self.kv_cache_manager = self.get_kv_cache_manager(
|
||||
dtype=self.model_config.pretrained_config.torch_dtype,
|
||||
config=self.model_config.pretrained_config,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
batch_size=batch_size,
|
||||
num_blocks=num_blocks,
|
||||
)
|
||||
|
||||
self.kv_cache_manager.add_dummy_requests(
|
||||
request_ids=[1],
|
||||
token_nums=[max_seq_len],
|
||||
# NOTE: Qwen2.5-VL model uses mrope
|
||||
use_mrope=True,
|
||||
)
|
||||
|
||||
def run_trtllm_forward(self, trtllm_inputs, use_cuda_graph: bool = False):
|
||||
# NOTE: Exactly the same as the parent class method,
|
||||
# but with the mrope flag set to True for Qwen2.5-VL model.
|
||||
if not use_cuda_graph:
|
||||
trtllm_inputs["attn_metadata"].prepare()
|
||||
return self.trtllm_model.forward(**trtllm_inputs)
|
||||
else:
|
||||
# NOTE: Qwen2.5-VL model uses mrope
|
||||
graph_runner = create_mock_cuda_graph_runner(1, True)
|
||||
trtllm_inputs["attn_metadata"] = trtllm_inputs[
|
||||
"attn_metadata"
|
||||
].create_cuda_graph_metadata(1)
|
||||
|
||||
# Prepare metadata before capture (like in working Qwen2.5-VL test)
|
||||
trtllm_inputs["attn_metadata"].prepare()
|
||||
|
||||
key = (1, 0, False)
|
||||
graph_runner.capture(
|
||||
key=key,
|
||||
forward_fn=lambda inputs: self.trtllm_model.forward(**inputs),
|
||||
initial_inputs=trtllm_inputs,
|
||||
)
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated in prepare().
|
||||
trtllm_inputs["attn_metadata"].prepare()
|
||||
logits = graph_runner.replay(key=key, current_inputs=trtllm_inputs)
|
||||
return logits.clone()
|
||||
|
||||
def get_scenarios(self) -> List[TestQwen3VLMoeScenario]:
|
||||
scenarios = [
|
||||
# ==== Modality Sanity Checks ====
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="video",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="multiple_image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
# ==== CUDA Graph Scenarios ====
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="image",
|
||||
use_cuda_graph=True,
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
# ==== Chunked Prefill Scenarios ====
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=True,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
# ==== KV Cache Reuse Scenarios ====
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=False,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=True,
|
||||
),
|
||||
# ==== Disable fuse rope scenarios ====
|
||||
TestQwen3VLMoeScenario(
|
||||
modality="image",
|
||||
use_cuda_graph=False,
|
||||
disable_fuse_rope=True,
|
||||
chunked_prefill=False,
|
||||
kv_cache_reuse=False,
|
||||
),
|
||||
]
|
||||
return scenarios
|
||||
|
||||
def setup_scenario(self, scenario: TestQwen3VLMoeScenario):
|
||||
super().setup_scenario(scenario)
|
||||
if scenario.disable_fuse_rope:
|
||||
self.trtllm_model, self.model_config = self.create_trtllm_model(
|
||||
load_weights=True,
|
||||
hf_model_state_dict=self.hf_model.state_dict(),
|
||||
disable_fuse_rope=True,
|
||||
)
|
||||
@ -106,7 +106,8 @@ class TestSiglipVisionModel(unittest.TestCase):
|
||||
attn_backend=backend,
|
||||
)
|
||||
|
||||
tllm_model = SiglipVisionModel(model_config).to(dtype).to(device)
|
||||
tllm_model = SiglipVisionModel(
|
||||
model_config, use_post_layernorm=True).to(dtype).to(device)
|
||||
tllm_model.load_weights(hf_model.state_dict())
|
||||
|
||||
# Prepare inputs - create random pixel values for images
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
regex
|
||||
fire
|
||||
tritonclient[all]
|
||||
transformers==4.56.0
|
||||
transformers==4.57.1
|
||||
pandas
|
||||
tabulate
|
||||
flash_attn
|
||||
torchao>=0.14.1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user