[https://nvbugs/5835925][fix] Add EPD disagg support for Qwen3 VL MoE (#10962)

* Why?

Trying to instantiate a `MultimodalEncoder` for a Qwen3 VL MoE model
would fail during weight loading.

* What?

This commit fixes the bug, alongside:
- explicit, intentional support for EPD for Qwen3 VL MoE.
- extends EPD unit tests for Qwen3 VL MoE, albeit with dummy weights.
- unit tests for the weight mapper fixes.

Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
William Zhang 2026-01-27 16:50:50 -08:00 committed by Yanchao Lu
parent 0ead17bb85
commit abb8106c01
4 changed files with 134 additions and 13 deletions

View File

@ -16,9 +16,6 @@ class Qwen3MoeHfWeightMapper(Qwen2MoeHfWeightMapper):
DecoderModelForCausalLM],
config: ModelConfig):
super().init_model_and_config(model, config)
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
model.config, 'num_key_value_heads'
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
def should_skip_module(self, module_name: str) -> bool:
if module_name.startswith("draft_model"):
@ -49,3 +46,11 @@ class Qwen3MoeHfWeightMapper(Qwen2MoeHfWeightMapper):
return processed_weights
return weights
@property
def _num_kv_heads(self) -> int:
num_kv_heads = self._model.config.num_key_value_heads if hasattr(
self._model.config, 'num_key_value_heads'
) and self._model.config.num_key_value_heads is not None else self._model.config.num_attention_heads
return num_kv_heads

View File

@ -1,4 +1,8 @@
from torch import nn
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
Qwen3VLMoeTextConfig,
Qwen3VLMoeVisionConfig,
)
from tensorrt_llm._torch.models.checkpoints.hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper
@ -22,3 +26,20 @@ class Qwen3VLMoeHfWeightMapper(Qwen3MoeHfWeightMapper):
module.load_weights(
weights=[updated_module_weights], allow_partial_loading=allow_partial_loading
)
@property
def _num_kv_heads(self) -> int:
config = self._model.config
if isinstance(config, Qwen3VLMoeTextConfig):
num_kv_heads = getattr(config, "num_key_value_heads", None)
if num_kv_heads is None:
num_kv_heads = config.num_attention_heads
elif isinstance(config, Qwen3VLMoeVisionConfig):
num_kv_heads = config.num_heads
else:
raise TypeError(
"Expected `Qwen3VLMoeTextConfig` or `Qwen3VLMoeVisionConfig`, "
f"got {type(config).__name__}"
)
return num_kv_heads

View File

@ -9,6 +9,7 @@ from ...inputs import (
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement,
register_input_processor,
support_multimodal_disaggregated,
)
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .checkpoints.hf.qwen3vl_moe_weight_mapper import Qwen3VLMoeHfWeightMapper
@ -21,6 +22,14 @@ from .modeling_qwen3vl import (
from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder
# NOTE: this is technically not strictly necessary, since the underlying mechanism for registering
# support is tacked onto the input processor class (`Qwen3VLInputProcessorBase`). Given that
# the `Qwen3VLModel` (defined via the import of `modeling_qwen3vl.py` in this file) has that
# decorator applied to it, and uses the same input processor class, we get it "for free" here.
# However, we keep it here to explicitly signify intent that this is supported. This also shields
# it from e.g. the input processor classes becoming specialized between `Qwen3VLModel` and the
# below MoE class.
@support_multimodal_disaggregated
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
@register_auto_model("Qwen3VLMoeForConditionalGeneration")
@register_input_processor(

View File

@ -1,3 +1,4 @@
import copy
import json
import os
import time
@ -11,9 +12,10 @@ from utils.llm_data import llm_models_root
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.llmapi import (CacheTransceiverConfig, DisaggregatedParams,
KvCacheConfig)
KvCacheConfig, MoeConfig)
from tensorrt_llm.llmapi.llm import LLM, SamplingParams
test_data_root = Path(
@ -27,6 +29,67 @@ example_images = [
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
_QWEN_3_VL_DIR = llm_models_root() / "Qwen3" / "Qwen3-VL-2B-Instruct"
_QWEN_3_VL_30B_A3B_FP8_DIR = llm_models_root(
) / "Qwen3" / "Qwen3-VL-30B-A3B-Instruct-FP8"
_FAKE_QWEN3_VL_30B_A3B_FP8_SENTINEL = "qwen3_vl_30b_a3b_fp8_fake"
_FAKE_CHECKPOINT_MARKER = ".tllm_fake_checkpoint"
# Unlike the other models, we cannot fit a multimodal encoder + 2 copies of the LLM on a single
# H100 GPU in CI. We therefore resort to creating a slimmed down version of the model with less
# layers.
def _get_fake_qwen3_vl_30b_a3b_config() -> dict:
config_path = _QWEN_3_VL_30B_A3B_FP8_DIR / "config.json"
if not config_path.exists():
pytest.skip(f"Qwen3-VL-30B-A3B config not found: {config_path}")
with open(config_path, "r") as f:
config = json.load(f)
config = copy.deepcopy(config)
config["text_config"]["num_hidden_layers"] = 2
return config
def _create_fake_qwen3_vl_30b_a3b_fp8_dir(
tmp_path_factory: pytest.TempPathFactory,
assets_dir: Path,
) -> Path:
if not assets_dir.exists():
pytest.skip(f"Base model dir not found: {assets_dir}")
fake_dir = tmp_path_factory.mktemp("qwen3_vl_30b_a3b_fp8_fake")
for item in assets_dir.iterdir():
if item.name == "config.json":
continue
target = fake_dir / item.name
if target.exists():
continue
os.symlink(item, target, target_is_directory=item.is_dir())
config_path = fake_dir / "config.json"
with open(config_path, "w") as f:
json.dump(_get_fake_qwen3_vl_30b_a3b_config(), f, indent=2)
(fake_dir /
_FAKE_CHECKPOINT_MARKER).write_text("Synthetic checkpoint for CI tests.\n")
return fake_dir
def _get_fake_checkpoint_kwargs(model_dir: Path) -> dict:
if (model_dir / _FAKE_CHECKPOINT_MARKER).exists():
return {"load_format": "dummy"}
return {}
def _is_fake_checkpoint(model_dir: Path) -> bool:
return (model_dir / _FAKE_CHECKPOINT_MARKER).exists()
def _get_moe_config_for_blackwell() -> MoeConfig:
if get_sm_version() >= 100:
return MoeConfig(backend="DEEPGEMM")
return MoeConfig()
@pytest.mark.parametrize(
@ -67,10 +130,12 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
free_gpu_memory_fraction=free_gpu_memory_fraction,
event_buffer_max_size=1024, # Enable KV cache events
)
moe_config = _get_moe_config_for_blackwell()
llm = LLM(model=encoder_model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
moe_config=moe_config,
max_batch_size=1)
inputs = _load_inputs(llm, prompts, media)
@ -100,10 +165,20 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
f"got {num_duplicates}. Offsets: {mm_keys_offsets}")
@pytest.fixture(scope="module",
params=[_LLAVA_DIR, _QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR],
ids=["llava_7b", "qwen2.5_3b", "qwen3_2b"])
def model_dir(request) -> Path:
@pytest.fixture(
scope="module",
params=[
pytest.param(_LLAVA_DIR, id="llava_7b"),
pytest.param(_QWEN_2_5_VL_DIR, id="qwen2.5_3b"),
pytest.param(_QWEN_3_VL_DIR, id="qwen3_2b"),
pytest.param(_FAKE_QWEN3_VL_30B_A3B_FP8_SENTINEL,
id="qwen3_30b_a3b_fp8"),
],
)
def model_dir(request, tmp_path_factory: pytest.TempPathFactory) -> Path:
if request.param == _FAKE_QWEN3_VL_30B_A3B_FP8_SENTINEL:
return _create_fake_qwen3_vl_30b_a3b_fp8_dir(tmp_path_factory,
_QWEN_3_VL_DIR)
return request.param
@ -125,14 +200,18 @@ def llms(model_dir: Path,
free_gpu_memory_fraction=free_gpu_memory_fraction,
)
load_kwargs = _get_fake_checkpoint_kwargs(model_dir)
moe_config = _get_moe_config_for_blackwell()
llm = LLM(
model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
moe_config=moe_config,
trust_remote_code=True,
cache_transceiver_config=cache_transceiver_cfg,
disable_overlap_scheduler=disable_overlap_scheduler,
max_batch_size=1, # fix batch size to reduce non-determinism in tests
**load_kwargs,
)
with llm:
if pd_disagg:
@ -140,8 +219,10 @@ def llms(model_dir: Path,
model=model_dir,
backend='pytorch',
kv_cache_config=kv_cache_config,
moe_config=moe_config,
trust_remote_code=True,
cache_transceiver_config=cache_transceiver_cfg,
**load_kwargs,
)
with llm_decode:
yield (llm, llm_decode)
@ -252,7 +333,9 @@ def test_single_image_chat(
# Prepare inputs for llm (pass mm_embeddings)
# Process multimodal data using encoder (pass mm_embeddings)
encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size)
encoder = MultimodalEncoder(model=model_dir,
max_batch_size=max_batch_size,
**_get_fake_checkpoint_kwargs(model_dir))
with encoder:
encoder_outputs = encoder.generate(inputs)
@ -393,13 +476,15 @@ def test_multi_request_batch_chat(
embeddings alongside the prompt ("multi_modal_embeddings"), as well as the embedding
handling within default_multimodal_input_loader.
"""
if use_mm_embeddings and model_dir in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]:
if use_mm_embeddings and (model_dir in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]
or _is_fake_checkpoint(model_dir)):
pytest.skip("Qwen does not implement attach_multimodal_embeddings")
# Qwen2.5/3 VL's vision encoder seems to output different embeddings based on this value.
# The test only passes with this set to 1.
encoder_max_batch_size = (1 if model_dir
in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR] else 3)
encoder_max_batch_size = (1 if
model_dir in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]
or _is_fake_checkpoint(model_dir) else 3)
llm, llm_decode = llms
if llm_decode is not None:
@ -430,7 +515,8 @@ def test_multi_request_batch_chat(
) > 0, f"Reference generation has no output text for input {i}"
encoder = MultimodalEncoder(model=model_dir,
max_batch_size=encoder_max_batch_size)
max_batch_size=encoder_max_batch_size,
**_get_fake_checkpoint_kwargs(model_dir))
with encoder:
# Encoder path
encoder_outputs = encoder.generate(inputs)