mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[https://nvbugs/5835925][fix] Add EPD disagg support for Qwen3 VL MoE (#10962)
* Why? Trying to instantiate a `MultimodalEncoder` for a Qwen3 VL MoE model would fail during weight loading. * What? This commit fixes the bug, alongside: - explicit, intentional support for EPD for Qwen3 VL MoE. - extends EPD unit tests for Qwen3 VL MoE, albeit with dummy weights. - unit tests for the weight mapper fixes. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
0ead17bb85
commit
abb8106c01
@ -16,9 +16,6 @@ class Qwen3MoeHfWeightMapper(Qwen2MoeHfWeightMapper):
|
||||
DecoderModelForCausalLM],
|
||||
config: ModelConfig):
|
||||
super().init_model_and_config(model, config)
|
||||
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
|
||||
model.config, 'num_key_value_heads'
|
||||
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
|
||||
|
||||
def should_skip_module(self, module_name: str) -> bool:
|
||||
if module_name.startswith("draft_model"):
|
||||
@ -49,3 +46,11 @@ class Qwen3MoeHfWeightMapper(Qwen2MoeHfWeightMapper):
|
||||
return processed_weights
|
||||
|
||||
return weights
|
||||
|
||||
@property
|
||||
def _num_kv_heads(self) -> int:
|
||||
num_kv_heads = self._model.config.num_key_value_heads if hasattr(
|
||||
self._model.config, 'num_key_value_heads'
|
||||
) and self._model.config.num_key_value_heads is not None else self._model.config.num_attention_heads
|
||||
|
||||
return num_kv_heads
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
from torch import nn
|
||||
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
|
||||
Qwen3VLMoeTextConfig,
|
||||
Qwen3VLMoeVisionConfig,
|
||||
)
|
||||
|
||||
from tensorrt_llm._torch.models.checkpoints.hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
|
||||
from tensorrt_llm._torch.models.modeling_utils import register_mapper
|
||||
@ -22,3 +26,20 @@ class Qwen3VLMoeHfWeightMapper(Qwen3MoeHfWeightMapper):
|
||||
module.load_weights(
|
||||
weights=[updated_module_weights], allow_partial_loading=allow_partial_loading
|
||||
)
|
||||
|
||||
@property
|
||||
def _num_kv_heads(self) -> int:
|
||||
config = self._model.config
|
||||
if isinstance(config, Qwen3VLMoeTextConfig):
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = config.num_attention_heads
|
||||
elif isinstance(config, Qwen3VLMoeVisionConfig):
|
||||
num_kv_heads = config.num_heads
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected `Qwen3VLMoeTextConfig` or `Qwen3VLMoeVisionConfig`, "
|
||||
f"got {type(config).__name__}"
|
||||
)
|
||||
|
||||
return num_kv_heads
|
||||
|
||||
@ -9,6 +9,7 @@ from ...inputs import (
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement,
|
||||
register_input_processor,
|
||||
support_multimodal_disaggregated,
|
||||
)
|
||||
from .checkpoints.base_weight_mapper import BaseWeightMapper
|
||||
from .checkpoints.hf.qwen3vl_moe_weight_mapper import Qwen3VLMoeHfWeightMapper
|
||||
@ -21,6 +22,14 @@ from .modeling_qwen3vl import (
|
||||
from .modeling_utils import ModelConfig, register_auto_model, register_vision_encoder
|
||||
|
||||
|
||||
# NOTE: this is technically not strictly necessary, since the underlying mechanism for registering
|
||||
# support is tacked onto the input processor class (`Qwen3VLInputProcessorBase`). Given that
|
||||
# the `Qwen3VLModel` (defined via the import of `modeling_qwen3vl.py` in this file) has that
|
||||
# decorator applied to it, and uses the same input processor class, we get it "for free" here.
|
||||
# However, we keep it here to explicitly signify intent that this is supported. This also shields
|
||||
# it from e.g. the input processor classes becoming specialized between `Qwen3VLModel` and the
|
||||
# below MoE class.
|
||||
@support_multimodal_disaggregated
|
||||
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
|
||||
@register_auto_model("Qwen3VLMoeForConditionalGeneration")
|
||||
@register_input_processor(
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -11,9 +12,10 @@ from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import MultimodalEncoder
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.inputs import default_multimodal_input_loader
|
||||
from tensorrt_llm.llmapi import (CacheTransceiverConfig, DisaggregatedParams,
|
||||
KvCacheConfig)
|
||||
KvCacheConfig, MoeConfig)
|
||||
from tensorrt_llm.llmapi.llm import LLM, SamplingParams
|
||||
|
||||
test_data_root = Path(
|
||||
@ -27,6 +29,67 @@ example_images = [
|
||||
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
|
||||
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
|
||||
_QWEN_3_VL_DIR = llm_models_root() / "Qwen3" / "Qwen3-VL-2B-Instruct"
|
||||
_QWEN_3_VL_30B_A3B_FP8_DIR = llm_models_root(
|
||||
) / "Qwen3" / "Qwen3-VL-30B-A3B-Instruct-FP8"
|
||||
|
||||
_FAKE_QWEN3_VL_30B_A3B_FP8_SENTINEL = "qwen3_vl_30b_a3b_fp8_fake"
|
||||
_FAKE_CHECKPOINT_MARKER = ".tllm_fake_checkpoint"
|
||||
|
||||
|
||||
# Unlike the other models, we cannot fit a multimodal encoder + 2 copies of the LLM on a single
|
||||
# H100 GPU in CI. We therefore resort to creating a slimmed down version of the model with less
|
||||
# layers.
|
||||
def _get_fake_qwen3_vl_30b_a3b_config() -> dict:
|
||||
config_path = _QWEN_3_VL_30B_A3B_FP8_DIR / "config.json"
|
||||
if not config_path.exists():
|
||||
pytest.skip(f"Qwen3-VL-30B-A3B config not found: {config_path}")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
config = copy.deepcopy(config)
|
||||
config["text_config"]["num_hidden_layers"] = 2
|
||||
return config
|
||||
|
||||
|
||||
def _create_fake_qwen3_vl_30b_a3b_fp8_dir(
|
||||
tmp_path_factory: pytest.TempPathFactory,
|
||||
assets_dir: Path,
|
||||
) -> Path:
|
||||
if not assets_dir.exists():
|
||||
pytest.skip(f"Base model dir not found: {assets_dir}")
|
||||
|
||||
fake_dir = tmp_path_factory.mktemp("qwen3_vl_30b_a3b_fp8_fake")
|
||||
|
||||
for item in assets_dir.iterdir():
|
||||
if item.name == "config.json":
|
||||
continue
|
||||
target = fake_dir / item.name
|
||||
if target.exists():
|
||||
continue
|
||||
os.symlink(item, target, target_is_directory=item.is_dir())
|
||||
|
||||
config_path = fake_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(_get_fake_qwen3_vl_30b_a3b_config(), f, indent=2)
|
||||
|
||||
(fake_dir /
|
||||
_FAKE_CHECKPOINT_MARKER).write_text("Synthetic checkpoint for CI tests.\n")
|
||||
return fake_dir
|
||||
|
||||
|
||||
def _get_fake_checkpoint_kwargs(model_dir: Path) -> dict:
|
||||
if (model_dir / _FAKE_CHECKPOINT_MARKER).exists():
|
||||
return {"load_format": "dummy"}
|
||||
return {}
|
||||
|
||||
|
||||
def _is_fake_checkpoint(model_dir: Path) -> bool:
|
||||
return (model_dir / _FAKE_CHECKPOINT_MARKER).exists()
|
||||
|
||||
|
||||
def _get_moe_config_for_blackwell() -> MoeConfig:
|
||||
if get_sm_version() >= 100:
|
||||
return MoeConfig(backend="DEEPGEMM")
|
||||
return MoeConfig()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -67,10 +130,12 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
event_buffer_max_size=1024, # Enable KV cache events
|
||||
)
|
||||
moe_config = _get_moe_config_for_blackwell()
|
||||
|
||||
llm = LLM(model=encoder_model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
moe_config=moe_config,
|
||||
max_batch_size=1)
|
||||
|
||||
inputs = _load_inputs(llm, prompts, media)
|
||||
@ -100,10 +165,20 @@ def test_kv_event_mm_keys_with_reuse(prompts, expected_num_duplicates):
|
||||
f"got {num_duplicates}. Offsets: {mm_keys_offsets}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[_LLAVA_DIR, _QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR],
|
||||
ids=["llava_7b", "qwen2.5_3b", "qwen3_2b"])
|
||||
def model_dir(request) -> Path:
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[
|
||||
pytest.param(_LLAVA_DIR, id="llava_7b"),
|
||||
pytest.param(_QWEN_2_5_VL_DIR, id="qwen2.5_3b"),
|
||||
pytest.param(_QWEN_3_VL_DIR, id="qwen3_2b"),
|
||||
pytest.param(_FAKE_QWEN3_VL_30B_A3B_FP8_SENTINEL,
|
||||
id="qwen3_30b_a3b_fp8"),
|
||||
],
|
||||
)
|
||||
def model_dir(request, tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
if request.param == _FAKE_QWEN3_VL_30B_A3B_FP8_SENTINEL:
|
||||
return _create_fake_qwen3_vl_30b_a3b_fp8_dir(tmp_path_factory,
|
||||
_QWEN_3_VL_DIR)
|
||||
return request.param
|
||||
|
||||
|
||||
@ -125,14 +200,18 @@ def llms(model_dir: Path,
|
||||
free_gpu_memory_fraction=free_gpu_memory_fraction,
|
||||
)
|
||||
|
||||
load_kwargs = _get_fake_checkpoint_kwargs(model_dir)
|
||||
moe_config = _get_moe_config_for_blackwell()
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
moe_config=moe_config,
|
||||
trust_remote_code=True,
|
||||
cache_transceiver_config=cache_transceiver_cfg,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
max_batch_size=1, # fix batch size to reduce non-determinism in tests
|
||||
**load_kwargs,
|
||||
)
|
||||
with llm:
|
||||
if pd_disagg:
|
||||
@ -140,8 +219,10 @@ def llms(model_dir: Path,
|
||||
model=model_dir,
|
||||
backend='pytorch',
|
||||
kv_cache_config=kv_cache_config,
|
||||
moe_config=moe_config,
|
||||
trust_remote_code=True,
|
||||
cache_transceiver_config=cache_transceiver_cfg,
|
||||
**load_kwargs,
|
||||
)
|
||||
with llm_decode:
|
||||
yield (llm, llm_decode)
|
||||
@ -252,7 +333,9 @@ def test_single_image_chat(
|
||||
|
||||
# Prepare inputs for llm (pass mm_embeddings)
|
||||
# Process multimodal data using encoder (pass mm_embeddings)
|
||||
encoder = MultimodalEncoder(model=model_dir, max_batch_size=max_batch_size)
|
||||
encoder = MultimodalEncoder(model=model_dir,
|
||||
max_batch_size=max_batch_size,
|
||||
**_get_fake_checkpoint_kwargs(model_dir))
|
||||
with encoder:
|
||||
encoder_outputs = encoder.generate(inputs)
|
||||
|
||||
@ -393,13 +476,15 @@ def test_multi_request_batch_chat(
|
||||
embeddings alongside the prompt ("multi_modal_embeddings"), as well as the embedding
|
||||
handling within default_multimodal_input_loader.
|
||||
"""
|
||||
if use_mm_embeddings and model_dir in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]:
|
||||
if use_mm_embeddings and (model_dir in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]
|
||||
or _is_fake_checkpoint(model_dir)):
|
||||
pytest.skip("Qwen does not implement attach_multimodal_embeddings")
|
||||
|
||||
# Qwen2.5/3 VL's vision encoder seems to output different embeddings based on this value.
|
||||
# The test only passes with this set to 1.
|
||||
encoder_max_batch_size = (1 if model_dir
|
||||
in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR] else 3)
|
||||
encoder_max_batch_size = (1 if
|
||||
model_dir in [_QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]
|
||||
or _is_fake_checkpoint(model_dir) else 3)
|
||||
|
||||
llm, llm_decode = llms
|
||||
if llm_decode is not None:
|
||||
@ -430,7 +515,8 @@ def test_multi_request_batch_chat(
|
||||
) > 0, f"Reference generation has no output text for input {i}"
|
||||
|
||||
encoder = MultimodalEncoder(model=model_dir,
|
||||
max_batch_size=encoder_max_batch_size)
|
||||
max_batch_size=encoder_max_batch_size,
|
||||
**_get_fake_checkpoint_kwargs(model_dir))
|
||||
with encoder:
|
||||
# Encoder path
|
||||
encoder_outputs = encoder.generate(inputs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user