[Model] Support Hy3 preview (#40681)

Signed-off-by: stevenkuang <stevenkuang@tencent.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
stevenkuang
2026-04-23 22:08:26 +08:00
committed by GitHub
parent 424033f4fc
commit d0009ddb0b
16 changed files with 2696 additions and 0 deletions
+1
View File
@@ -419,6 +419,7 @@ th {
| `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2` | ✅︎ | ✅︎ |
| `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ |
| `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ |
| `HYV3ForCausalLM` | HY3 | `tencent/Hy3-preview-Base`, `tencent/Hy3-preview` | ✅︎ | ✅︎ |
| `HyperCLOVAXForCausalLM` | HyperCLOVAX-SEED-Think-14B | `naver-hyperclovax/HyperCLOVAX-SEED-Think-14B` | ✅︎ | ✅︎ |
| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ |
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
+5
View File
@@ -324,6 +324,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"HunYuanMoEV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-A13B-Instruct", trust_remote_code=True
),
"HYV3ForCausalLM": _HfExamplesInfo("tencent/Hy3-preview", trust_remote_code=True),
"HyperCLOVAXForCausalLM": _HfExamplesInfo(
"naver-hyperclovax/HyperCLOVAX-SEED-Think-14B",
trust_remote_code=True,
@@ -1516,6 +1517,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online=False,
min_transformers_version="5.1.0",
),
"HYV3MTPModel": _HfExamplesInfo(
"tencent/Hy3-preview",
speculative_model="tencent/Hy3-preview",
),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
trust_remote_code=True,
@@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.tokenizers import get_tokenizer
parser_name = "hy_v3"
MODEL = "tencent/Hy3-preview"
@pytest.fixture(scope="module")
def hy_v3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
WITH_THINK = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "high",
}
WITH_THINK_STREAM = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "high",
}
WITHOUT_THINK = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "no_think",
}
WITHOUT_THINK_STREAM = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
"reasoning_effort": "no_think",
}
WITH_REASONING_EFFORT_NONE = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
WITH_REASONING_EFFORT_NONE_STREAM = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
"reasoning_effort": "high",
}
MULTILINE_REASONING = {
"output": "This is a reasoning\nsection</think>This is the rest\nThat",
"reasoning": "This is a reasoning\nsection",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
"reasoning_effort": "high",
}
ONLY_OPEN_TAG = {
"output": "This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
"reasoning_effort": "high",
}
ONLY_OPEN_TAG_STREAM = {
"output": "This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
"reasoning_effort": "high",
}
TEST_CASES = [
pytest.param(
False,
WITH_THINK,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
id="with_think_stream",
),
pytest.param(
False,
WITHOUT_THINK,
id="without_think",
),
pytest.param(
True,
WITHOUT_THINK_STREAM,
id="without_think_stream",
),
pytest.param(
False,
WITH_REASONING_EFFORT_NONE,
id="with_reasoning_effort_none",
),
pytest.param(
True,
WITH_REASONING_EFFORT_NONE_STREAM,
id="with_reasoning_effort_none_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_stream",
),
pytest.param(
False,
MULTILINE_REASONING,
id="multiline_reasoning",
),
pytest.param(
True,
MULTILINE_REASONING,
id="multiline_reasoning_stream",
),
pytest.param(
False,
ONLY_OPEN_TAG,
id="only_open_tag",
),
pytest.param(
True,
ONLY_OPEN_TAG_STREAM,
id="only_open_tag_stream",
),
]
STILL_REASONING_PROMPT = """<hy_begin▁of▁sentence>
You are a helpful assistant.
<reasoning_mode>reasoning_effort:high<hy_User>
What is the capital of France?<hy_Assistant>
<think>The user is asking for the capital of"""
DONE_REASONING_PROMPT = """<hy_begin▁of▁sentence>
You are a helpful assistant.
<reasoning_mode>reasoning_effort:high<hy_User>
What is the capital of France?<hy_Assistant>
<think>The user is asking for the capital of France.</think>
The capital of France is Paris."""
MULTI_TURN_STILL_REASONING_PROMPT = """<hy_begin▁of▁sentence>
You are a helpful assistant.
<reasoning_mode>reasoning_effort:high<hy_User>
What is the capital of France?<hy_Assistant
><think></think>The capital of France is Paris.<eos:6124c78e>
<hy_User>What about Chile?<hy_Assistant>
<think>The user is asking for the capital of"""
MULTI_TURN_DONE_REASONING_PROMPT = """<hy_begin▁of▁sentence>
You are a helpful assistant.
<reasoning_mode>reasoning_effort:high<hy_User>
What is the capital of France?<hy_Assistant
><think></think>The capital of France is Paris.<eos:6124c78e>
<hy_User>What about Chile?<hy_Assistant>
<think>The user is asking for the capital of Chile.</think>
The capital of Chile is Santiago."""
REASONING_END_TEST_CASES = [
pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"),
pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"),
pytest.param(
MULTI_TURN_STILL_REASONING_PROMPT, False, id="multi_turn_still_reasoning"
),
pytest.param(
MULTI_TURN_DONE_REASONING_PROMPT, True, id="multi_turn_done_reasoning"
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
hy_v3_tokenizer,
):
output = hy_v3_tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = [
hy_v3_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser_kwargs = {}
if "reasoning_effort" in param_dict:
parser_kwargs["chat_template_kwargs"] = {
"reasoning_effort": param_dict["reasoning_effort"]
}
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
hy_v3_tokenizer,
**parser_kwargs,
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
output_ids = hy_v3_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES)
def test_is_reasoning_end_full_prompt(
prompt: str, is_reasoning_end: bool, hy_v3_tokenizer
):
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
hy_v3_tokenizer,
chat_template_kwargs={"reasoning_effort": "high"},
)
tokens = hy_v3_tokenizer.tokenize(prompt)
token_ids = hy_v3_tokenizer.convert_tokens_to_ids(tokens)
check_is_reasoning_end = parser.is_reasoning_end(token_ids)
assert check_is_reasoning_end == is_reasoning_end
@@ -0,0 +1,274 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""Tests for the HYV3 tool call parser."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
FunctionDefinition,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.hy_v3_tool_parser import HYV3ToolParser
parser_name = "hy_v3"
MODEL = "tencent/Hy3-preview"
@pytest.fixture(scope="module")
def hy_v3_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def hy_v3_tool_parser(hy_v3_tokenizer):
return HYV3ToolParser(hy_v3_tokenizer)
@pytest.fixture
def mock_request() -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = [
ChatCompletionToolsParam(
function=FunctionDefinition(name="get_current_date", parameters={}),
),
ChatCompletionToolsParam(
function=FunctionDefinition(
name="get_weather",
parameters={
"type": "object",
"properties": {
"city": {"type": "string"},
"date": {"type": "string"},
},
},
),
),
]
request.tool_choice = "auto"
return request
class TestHYV3ExtractToolCalls:
def test_no_tool_call(self, hy_v3_tool_parser, mock_request):
out = "This is a plain response."
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert not r.tools_called
assert r.content == out
def test_zero_arg_inline(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls><tool_call>get_current_date<tool_sep></tool_call></tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.tool_calls[0].function.name == "get_current_date"
assert json.loads(r.tool_calls[0].function.arguments) == {}
assert r.content is None
def test_zero_arg_newline(self, hy_v3_tool_parser, mock_request):
out = "<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.tool_calls[0].function.name == "get_current_date"
def test_args_same_line(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls><tool_call>get_weather<tool_sep><arg_key>city</arg_key><arg_value>Beijing"
"</arg_value><arg_key>date</arg_key><arg_value>2026-03-30</arg_value></tool_call></tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert json.loads(r.tool_calls[0].function.arguments) == {
"city": "Beijing",
"date": "2026-03-30",
}
def test_args_with_newlines(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls>\n<tool_call>get_weather<tool_sep>\n<arg_key>city</arg_key>\n<arg_value>Beijing"
"</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2026-03-30</arg_value>\n</tool_call>\n</tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert json.loads(r.tool_calls[0].function.arguments) == {
"city": "Beijing",
"date": "2026-03-30",
}
def test_content_before(self, hy_v3_tool_parser, mock_request):
out = "Checking.<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.content == "Checking."
def test_multiple(self, hy_v3_tool_parser, mock_request):
out = (
"<tool_calls>\n<tool_call>get_weather<tool_sep>\n<arg_key>city</arg_key>\n<arg_value>Beijing"
"</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2026-03-30</arg_value>\n</tool_call>\n"
"<tool_call>get_weather<tool_sep>\n<arg_key>city</arg_key>\n<arg_value>Hangzhou</arg_value>\n"
"<arg_key>date</arg_key>\n<arg_value>2026-03-30</arg_value>\n</tool_call>\n</tool_calls>"
)
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert len(r.tool_calls) == 2
def test_empty_content_none(self, hy_v3_tool_parser, mock_request):
out = "<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
r = hy_v3_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.content is None
def _simulate_streaming(
parser: HYV3ToolParser,
deltas: list[str],
request: ChatCompletionRequest,
) -> list[DeltaMessage | None]:
results: list[DeltaMessage | None] = []
previous_text = ""
previous_token_ids: list[int] = []
vocab = parser.vocab
for delta_text in deltas:
current_text = previous_text + delta_text
delta_token_ids = [tid for tok, tid in vocab.items() if tok in delta_text]
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request,
)
results.append(result)
previous_text = current_text
previous_token_ids = current_token_ids
return results
def _collect_streaming_tool_calls(results: list[DeltaMessage | None]) -> list[dict]:
tool_calls: dict[int, dict] = {}
for result in results:
if result is None or not result.tool_calls:
continue
for tc in result.tool_calls:
idx = tc.index
if idx not in tool_calls:
tool_calls[idx] = {
"name": tc.function.name or "",
"arguments": tc.function.arguments or "",
}
else:
if tc.function.name:
tool_calls[idx]["name"] += tc.function.name
if tc.function.arguments:
tool_calls[idx]["arguments"] += tc.function.arguments
return [tool_calls[i] for i in sorted(tool_calls.keys())]
def _collect_streaming_content(results: list[DeltaMessage | None]) -> str:
parts = []
for result in results:
if result is not None and result.content:
parts.append(result.content)
return "".join(parts)
class TestHYV3ExtractToolCallsStreaming:
def test_no_tool_call_streaming(self, hy_v3_tool_parser, mock_request):
deltas = ["This is ", "a plain ", "response."]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
content = _collect_streaming_content(results)
assert content == "This is a plain response."
assert len(_collect_streaming_tool_calls(results)) == 0
def test_zero_arg_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"<tool_calls>",
"\n<tool_call>",
"get_current_date",
"<tool_sep>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1
assert tc[0]["name"] == "get_current_date"
assert json.loads(tc[0]["arguments"]) == {}
def test_args_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"<tool_calls>",
"\n<tool_call>",
"get_weather",
"<tool_sep>",
"\n<arg_key>city</arg_key>",
"\n<arg_value>Beijing</arg_value>",
"\n<arg_key>date</arg_key>",
"\n<arg_value>2026-03-30</arg_value>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1 and tc[0]["name"] == "get_weather"
assert json.loads(tc[0]["arguments"]) == {
"city": "Beijing",
"date": "2026-03-30",
}
def test_content_before_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"Checking.",
"<tool_calls>",
"\n<tool_call>",
"get_current_date",
"<tool_sep>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
assert "Checking." in _collect_streaming_content(results)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1 and tc[0]["name"] == "get_current_date"
def test_multiple_streaming(self, hy_v3_tool_parser, mock_request):
deltas = [
"<tool_calls>",
"\n<tool_call>",
"get_weather",
"<tool_sep>",
"\n<arg_key>city</arg_key>",
"\n<arg_value>Beijing</arg_value>",
"\n<arg_key>date</arg_key>",
"\n<arg_value>2026-03-30</arg_value>",
"\n</tool_call>",
"\n<tool_call>",
"get_weather",
"<tool_sep>",
"\n<arg_key>city</arg_key>",
"\n<arg_value>Hangzhou</arg_value>",
"\n<arg_key>date</arg_key>",
"\n<arg_value>2026-03-30</arg_value>",
"\n</tool_call>",
"\n</tool_calls>",
]
results = _simulate_streaming(hy_v3_tool_parser, deltas, mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 2
assert json.loads(tc[0]["arguments"])["city"] == "Beijing"
assert json.loads(tc[1]["arguments"])["city"] == "Hangzhou"
def test_all_in_one_delta_streaming(self, hy_v3_tool_parser, mock_request):
out = "<tool_calls>\n<tool_call>get_current_date<tool_sep>\n</tool_call>\n</tool_calls>"
results = _simulate_streaming(hy_v3_tool_parser, [out], mock_request)
tc = _collect_streaming_tool_calls(results)
assert len(tc) == 1 and tc[0]["name"] == "get_current_date"
assert json.loads(tc[0]["arguments"]) == {}
+8
View File
@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"mtp",
"pangu_ultra_moe_mtp",
"step3p5_mtp",
"hy_v3_mtp",
]
NgramGPUTypes = Literal["ngram_gpu"]
DFlashModelTypes = Literal["dflash"]
@@ -364,6 +365,13 @@ class SpeculativeConfig:
if initial_architecture == "MistralLarge3ForCausalLM":
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
if hf_config.model_type == "hy_v3":
hf_config.model_type = "hy_v3_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update(
{"n_predict": n_predict, "architectures": ["HYV3MTPModel"]}
)
return hf_config
def __post_init__(self):
@@ -1562,6 +1562,11 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
# NemotronH format: .mixer.{k,v}_proj.{k,v}_scale ->
# .mixer.attn.{k,v}_scale
(r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"),
# HYV3 format: .self_attn.q.scale -> .self_attn.attn.q_scale
(r"\.self_attn\.q\.scale$", r".self_attn.attn.q_scale"),
# HYV3 format: .self_attn.{k,v}_cache.scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.([kv])_cache\.scale$", r".self_attn.attn.\1_scale"),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([qkv])_scale$", r".attn.\1_scale"),
(r"\.([qkv])_zero_point$", r".attn.\1_zero_point"),
@@ -1576,6 +1581,9 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
".k_zero_point",
".v_zero_point",
".q_zero_point",
".q.scale",
".k_cache.scale",
".v_cache.scale",
)
):
import regex as re
+707
View File
@@ -0,0 +1,707 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# coding=utf-8
# Copyright 2026 The HY team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HY model compatible with HuggingFace weights."""
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
get_pp_group,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.hy_v3 import HYV3Config
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class HYV3FeedForward(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
expert_gate: torch.nn.Linear | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
out = self.act_fn(gate_up)
out, _ = self.down_proj(out)
return out
class HYV3MoEFused(nn.Module):
def __init__(
self,
config: HYV3Config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}."
)
top_k = config.num_experts_per_tok
intermediate_size = config.expert_hidden_dim
router_scaling_factor = getattr(config, "router_scaling_factor", 1.0)
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
self.gate = GateLinear(
config.hidden_size,
config.num_experts,
bias=False,
out_dtype=torch.float32,
params_dtype=torch.float32,
prefix=f"{prefix}.gate",
)
if config.num_shared_experts > 0:
self.shared_mlp = HYV3FeedForward(
hidden_size=config.hidden_size,
intermediate_size=config.expert_hidden_dim * config.num_shared_experts,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}",
reduce_results=False,
)
else:
self.shared_mlp = None
self.expert_bias = nn.Parameter(torch.empty(config.num_experts))
scoring_func = "sigmoid"
e_score_correction_bias = self.expert_bias
self.experts = FusedMoE(
num_experts=self.n_routed_experts,
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
renormalize=config.route_norm,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
scoring_func=scoring_func,
use_grouped_topk=True,
num_expert_group=1,
topk_group=1,
routed_scaling_factor=router_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
n_shared_experts=config.num_shared_experts,
shared_experts=self.shared_mlp,
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
return final_hidden_states.view(orig_shape)
class HYV3Attention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 8192,
head_dim: int | None = None,
rms_norm_eps: float = 1e-5,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
dual_chunk_attention_config: dict[str, Any] | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if hasattr(config, "head_dim") and config.head_dim:
self.head_dim = config.head_dim
else:
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.use_qk_norm = getattr(config, "qk_norm", False)
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
quant_config=quant_config,
bias=None,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
if self.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim, rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output_shape = None
if self.use_qk_norm:
q_by_head = q.view(
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, output_shape)
attn_output = attn_output.view(q.shape[0], -1)
output, _ = self.o_proj(attn_output)
return output
class HYV3DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
layer_idx = int(prefix.split(".")[-1])
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = HYV3Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
if not hasattr(config, "first_k_dense_replace"):
raise ValueError("first_k_dense_replace not exist,please check config")
if layer_idx < config.first_k_dense_replace:
self.mlp = HYV3FeedForward(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.block_type = "feedforward"
else:
self.mlp = HYV3MoEFused(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
self.block_type = "moe"
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
idx: int = -1,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class HYV3Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.vocab_size = config.vocab_size
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: HYV3DecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
# Set MoE hyperparameters
self.expert_weights = []
self.num_expert_groups = 1
self.moe_layers = []
example_layer = None
for layer in self.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, HYV3DecoderLayer)
if layer.block_type == "moe":
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
self.num_moe_layers = 0
raise RuntimeError("No MoE layer found in model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_logical_experts = getattr(example_layer, "n_logical_experts", None)
self.num_physical_experts = getattr(example_layer, "n_physical_experts", None)
self.num_local_physical_experts = getattr(
example_layer, "n_local_physical_experts", None
)
self.num_routed_experts = getattr(example_layer, "n_routed_experts", None)
self.num_redundant_experts = getattr(example_layer, "n_redundant_experts", None)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.layers:
if isinstance(layer.mlp, HYV3MoEFused):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual = layer(positions, hidden_states, residual, idx=idx)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
expert_params_mapping = self.get_expert_mapping()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
is_found = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
is_found = True
break
if is_found:
continue
if name.endswith(".bias") and name not in params_dict:
continue
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if "router.gate." in name:
name = name.replace("router.", "")
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def get_spec_layer_idx_from_weight_name(
config: PretrainedConfig, weight_name: str
) -> int | None:
# HYV3MTP is enabled only when num_nextn_predict_layers is greater than 1
if (
hasattr(config, "num_nextn_predict_layers")
and config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
class HYV3ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.model = HYV3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
def _filter_weights(weights):
for name, weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
yield name, weight
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(_filter_weights(weights))
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
+470
View File
@@ -0,0 +1,470 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# coding=utf-8
# Copyright 2026 The HY team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HY V3 MTP model compatible with HuggingFace weights."""
from collections.abc import Iterable
import regex as re
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from .hy_v3 import HYV3DecoderLayer, get_spec_layer_idx_from_weight_name
from .utils import is_pp_missing_parameter, maybe_prefix
def _is_moe(config: PretrainedConfig) -> bool:
return bool(
getattr(config, "num_experts", None)
and (
(isinstance(config.num_experts, int) and config.num_experts > 1)
or (isinstance(config.num_experts, list) and max(config.num_experts) > 1)
)
)
def _get_cla_factor(config: PretrainedConfig) -> int:
if not getattr(config, "use_cla", False):
return 1
return getattr(config, "cla_share_factor", 1)
class HYV3SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
class HYV3MultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = HYV3SharedHead(config=config, quant_config=quant_config)
self.mtp_block = HYV3DecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
# Final layernorm applied after transformer block, before logits
# projection (matches HF HYV3MTPDecoderLayer.final_layernorm)
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
# HYV3DecoderLayer returns (hidden_states, residual)
hidden_states, residual = self.mtp_block(
positions=positions, hidden_states=hidden_states, residual=None
)
hidden_states = residual + hidden_states
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class HYV3MultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict(
{
str(idx): HYV3MultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
)
return logits
class HYV3MTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.model = HYV3MultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput | None:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def _split_qkv_weight(self, qkv: torch.Tensor):
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(
self.config, "num_key_value_heads", self.config.num_attention_heads
)
num_key_value_groups = num_attention_heads // num_kv_heads
hidden_size = self.config.hidden_size
if hasattr(self.config, "head_dim"):
attention_head_dim = self.config.head_dim
elif hasattr(self.config, "attention_head_dim"):
attention_head_dim = self.config.attention_head_dim
else:
attention_head_dim = self.config.hidden_size // num_attention_heads
qkv = qkv.reshape(
num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size
)
q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1)
q = q.reshape(-1, hidden_size)
k = k.reshape(-1, hidden_size)
v = v.reshape(-1, hidden_size)
return torch.concat((q, k, v))
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
cla_factor = _get_cla_factor(self.config)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(
self.config, "num_key_value_heads", self.config.num_attention_heads
)
split_params_mapping = [
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
(
".qkv_proj",
".qkv_proj",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
self._split_qkv_weight,
),
]
if _is_moe(self.config):
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = {}
params_dict = dict(self.named_parameters())
# V3 shared weights mapping:
# - embed_tokens: from main model's model.embed_tokens.weight
# - lm_head: from main model's lm_head.weight → MTP shared_head.head
# (HF infer_mtp uses head_weight=self.lm_head.weight, not the
# checkpoint's model.layers.<N>.shared_head.weight)
# - No norm mapping (V3 MTP has no intermediate norm before lm_head)
mtp_start = self.config.num_hidden_layers
v3_shared_weights = {
"model.embed_tokens.weight": "model.embed_tokens.weight",
"lm_head.weight": f"model.layers.{mtp_start}.shared_head.head.weight",
}
for name, loaded_weight in weights:
# Intercept shared weights before any other processing
if name in v3_shared_weights:
target_name = v3_shared_weights[name]
if target_name in params_dict:
param = params_dict[target_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
continue
if "rotary_emb.inv_freq" in name:
continue
if "gate_proj_bias" in name:
name = name.replace("gate_proj_bias", "gate_proj.bias")
if "up_proj_bias" in name:
name = name.replace("up_proj_bias", "up_proj.bias")
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
# Skip weights that _rewrite_spec_layer_name marked for skipping
if name == "__skip__":
continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
is_found = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
if weight_name == ".q_proj":
match = re.search(r"layers\.\d+", name)
if match:
layer_id = int(match.group(0).split(".")[-1])
if cla_factor > 1 and layer_id % cla_factor != 0:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
is_found = True
break
if is_found:
continue
for param_name, weight_name, den, split_param, func in split_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
assert loaded_weight.shape[0] % den == 0
units = loaded_weight.shape[0] // den
param = params_dict[name]
weight_loader = param.weight_loader
offset = 0
for shard_id, num in split_param:
new_offset = offset + num * units
if func:
weight_loader(
param, func(loaded_weight)[offset:new_offset], shard_id
)
else:
weight_loader(param, loaded_weight[offset:new_offset], shard_id)
offset = new_offset
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
if is_pp_missing_parameter(name, self):
continue
if "mlp.gate.wg." in name:
name = name.replace("wg.", "")
# V3 checkpoint: mlp.router.gate -> mlp.gate
if "mlp.router.gate." in name:
name = name.replace("router.gate.", "gate.")
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""Rewrite spec layer weight names to match vLLM module structure."""
# Skip embed_tokens (doesn't exist in V3 MTP checkpoint under spec
# layer) and shared_head (we use main model's lm_head instead)
if f"model.layers.{spec_layer}.embed_tokens" in name:
return "__skip__"
if f"model.layers.{spec_layer}.shared_head" in name:
return "__skip__"
spec_layer_weight_names = ["enorm", "hnorm", "eh_proj", "final_layernorm"]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# Transformer block weights go under .mtp_block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
return name
+2
View File
@@ -133,6 +133,7 @@ _TEXT_GENERATION_MODELS = {
"Grok1ForCausalLM": ("grok1", "GrokForCausalLM"),
"HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"),
"HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"),
"HYV3ForCausalLM": ("hy_v3", "HYV3ForCausalLM"),
"HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),
"HCXVisionV2ForCausalLM": ("hyperclovax_vision_v2", "HCXVisionV2ForCausalLM"),
"HyperCLOVAXForCausalLM": ("hyperclovax", "HyperCLOVAXForCausalLM"),
@@ -599,6 +600,7 @@ _SPECULATIVE_DECODING_MODELS = {
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
"Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"),
"Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"),
"HYV3MTPModel": ("hy_v3_mtp", "HYV3MTP"),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
+4
View File
@@ -56,6 +56,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"hunyuan_a13b_reasoning_parser",
"HunyuanA13BReasoningParser",
),
"hy_v3": (
"hy_v3_reasoning_parser",
"HYV3ReasoningParser",
),
"kimi_k2": (
"kimi_k2_reasoning_parser",
"KimiK2ReasoningParser",
+137
View File
@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class HYV3ReasoningParser(BaseThinkingReasoningParser):
"""
HYV3 parser that delegates to either HYV3ReasoningParser or
IdentityReasoningParser based on `reasoning_effort`.
The HYV3 model uses <think>...</think> tokens to denote reasoning text.
This parser extracts the reasoning content from the model output.
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# First, If there is reasoning_effort in chat_kwargs,
# prioritize using chat_kwargs.reasoning_effort.
# If it's not present, use the "reasoning_effort" field
# at the outer level of the chat message.
# Otherwise, If both are empty, assign "no_think".
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
reasoning_effort = chat_kwargs.pop("reasoning_effort", "no_think")
logger.debug("reasoning_effort for choosing parser: %s", reasoning_effort)
self._identity_parser: IdentityReasoningParser | None
if reasoning_effort == "no_think":
self._identity_parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
else:
self._identity_parser = None
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<think>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "</think>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end(input_ids)
return super().is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end_streaming(
input_ids, delta_ids
)
return super().is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self._identity_parser is not None:
return self._identity_parser.extract_content_ids(input_ids)
return super().extract_content_ids(input_ids)
def extract_reasoning(
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]:
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning(model_output, request)
return super().extract_reasoning(model_output, request)
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
ret = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if (
ret is not None
and self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids
):
if self.end_token_id in delta_token_ids:
# end token in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# end token in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no end token in previous or delta, reasoning content continues
return DeltaMessage(reasoning=delta_text)
return ret
+4
View File
@@ -66,6 +66,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"hy_v3": (
"hy_v3_tool_parser",
"HYV3ToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
+645
View File
@@ -0,0 +1,645 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
class HYV3ToolParser(ToolParser):
_TYPE_ALIASES: dict[str, str] = {
"str": "string",
"text": "string",
"varchar": "string",
"char": "string",
"enum": "string",
"bool": "boolean",
"binary": "boolean",
"int": "integer",
"float": "number",
"double": "number",
"list": "array",
"dict": "object",
"map": "object",
}
# Prefix-based wildcard matching for non-standard type names.
# Following the same approach as
# qwen3coder_tool_parser._convert_param_value which uses
# param_type.startswith("int"), startswith("uint"), etc.
_INTEGER_PREFIXES: tuple[str, ...] = (
"int",
"uint",
"long",
"short",
"unsigned",
)
_NUMBER_PREFIXES: tuple[str, ...] = ("num", "float")
@staticmethod
def _normalize_type(raw_type: str) -> str:
"""Map non-standard type aliases to JSON Schema standard names.
First performs exact lookup in _TYPE_ALIASES. On miss, falls back
to prefix-based matching using startswith()
- int*/uint*/long*/short*/unsigned* → "integer"
- num*/float* → "number"
"""
exact = HYV3ToolParser._TYPE_ALIASES.get(raw_type)
if exact is not None:
return exact
lower = raw_type.lower()
if any(lower.startswith(p) for p in HYV3ToolParser._INTEGER_PREFIXES):
return "integer"
if any(lower.startswith(p) for p in HYV3ToolParser._NUMBER_PREFIXES):
return "number"
return raw_type
@staticmethod
def _get_arg_schema(
function_name: str,
arg_key: str,
tools: list[ChatCompletionToolsParam] | None,
) -> dict:
"""Look up a specific argument's property schema from the tools list."""
if tools is None:
return {}
for tool in tools:
if tool.function.name == function_name:
if tool.function.parameters is None:
return {}
return tool.function.parameters.get("properties", {}).get(arg_key, {})
logger.warning("No tool named '%s'.", function_name)
return {}
@staticmethod
def _get_schema_options(arg_schema: dict) -> list[dict]:
"""Normalize any property schema into a list of sub-schemas.
- has type (single type) → return [arg_schema]
- anyOf → return the anyOf list
- oneOf → return the oneOf list
- fallback → [{"type": "string"}]
Note: single ``type`` has the highest priority.
"""
if "type" in arg_schema:
return [arg_schema]
if "anyOf" in arg_schema:
return arg_schema["anyOf"]
if "oneOf" in arg_schema:
return arg_schema["oneOf"]
return [{"type": "string"}]
@staticmethod
def _get_types(arg_schema: dict) -> set[str]:
"""Extract normalized, non-null type set from a property schema."""
schemas = HYV3ToolParser._get_schema_options(arg_schema)
return {
HYV3ToolParser._normalize_type(s.get("type", "string")) for s in schemas
} - {"null"}
@staticmethod
def _is_only_string_type(
function_name: str,
arg_key: str,
tools: list[ChatCompletionToolsParam] | None,
) -> bool:
"""Return True if the parameter's type set is exactly {"string"}.
Only pure string types get partial value streaming; compound types
like anyOf(string | array) do not, since the partial value might
end up being a JSON array or object.
"""
arg_schema = HYV3ToolParser._get_arg_schema(function_name, arg_key, tools)
types = HYV3ToolParser._get_types(arg_schema)
return types == {"string"}
@staticmethod
def _try_parse_bool(value: str) -> bool | None:
"""Try to parse a string as bool; return None on failure."""
lower = value.lower()
if lower == "true":
return True
elif lower == "false":
return False
return None
@staticmethod
def _try_parse_int(value: str) -> int | None:
"""Try to parse a string as int; return None on failure."""
try:
return int(value)
except (ValueError, TypeError):
return None
@staticmethod
def _try_parse_wildcard_number(value: str) -> int | float | None:
"""Try to parse a string as a number (int or float).
Decision rule: if the string contains '.' or 'e'/'E' (scientific
notation), parse as float; otherwise parse as int.
Examples:
"5" → int(5)
"5.0" → float(5.0)
"5.3" → float(5.3)
"1e3" → float(1000.0)
"-3" → int(-3)
Return None on failure.
"""
try:
if "." in value or "e" in value or "E" in value:
return float(value)
return int(value)
except (ValueError, TypeError):
return None
@staticmethod
def _deserialize(value: str) -> Any:
"""Deserialize a string value using json.loads then ast.literal_eval."""
try:
return json.loads(value)
except Exception:
pass
try:
return ast.literal_eval(value)
except Exception:
pass
return value
@staticmethod
def _parse_value(
value: str,
function_name: str,
arg_key: str,
tools: list[ChatCompletionToolsParam] | None,
) -> Any:
"""Unified argument value parser with anyOf/oneOf support.
Fallthrough chain:
bool → int → number(wildcard_number)
→ json.loads for array/object
→ string → _deserialize
"""
arg_schema = HYV3ToolParser._get_arg_schema(function_name, arg_key, tools)
types = HYV3ToolParser._get_types(arg_schema)
# 1. Try bool
if "boolean" in types:
result_bool = HYV3ToolParser._try_parse_bool(value)
if result_bool is not None:
return result_bool
# 2. Try int
if "integer" in types:
result_int = HYV3ToolParser._try_parse_int(value)
if result_int is not None:
return result_int
# 3. Try number (wildcard_number: int if no '.'/e/E, float otherwise)
if "number" in types:
result_number = HYV3ToolParser._try_parse_wildcard_number(value)
if result_number is not None:
return result_number
# 4. Try json.loads (covers array/object and other unlisted types)
if types - {"string", "boolean", "integer", "number"}:
try:
return json.loads(value)
except (json.JSONDecodeError, ValueError):
pass
# 5. String fallback
if "string" in types:
return value
# 6. Final fallback
return HYV3ToolParser._deserialize(value)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
# Streaming state: send tool name first, then return arguments at once
self._streaming_tool_name: str | None = None # tool name being streamed
# State fields for incremental argument streaming
self._completed_args: dict = {} # closed {key: parsed_value}
self._current_arg_key: str | None = None # key being collected
self._current_arg_is_string: bool = False # is current arg pure string?
self._streamed_json_len: int = 0 # bytes of JSON already sent
self.tool_calls_start_token: str = "<tool_calls>"
self.tool_calls_end_token: str = "</tool_calls>"
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_sep_token: str = "<tool_sep>"
self.arg_key_start_token: str = "<arg_key>"
self.arg_key_end_token: str = "</arg_key>"
self.arg_value_start_token: str = "<arg_value>"
self.arg_value_end_token: str = "</arg_value>"
self.tool_call_regex = re.compile(
rf"{self.tool_call_start_token}(.*?){self.tool_sep_token}"
rf"(.*?){self.tool_call_end_token}",
re.DOTALL,
)
self.tool_call_portion_regex = re.compile(
rf"{self.tool_call_start_token}(.*?){self.tool_sep_token}(.*)", re.DOTALL
)
self.func_args_regex = re.compile(
rf"{self.arg_key_start_token}(.*?){self.arg_key_end_token}\s*"
rf"{self.arg_value_start_token}(.*?){self.arg_value_end_token}",
re.DOTALL,
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self._buffer = ""
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"HYV3 Tool parser could not locate tool call "
"start/end tokens in the tokenizer!"
)
def _extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> list[ToolCall]:
try:
function_call_tuples = []
# start_token{name}sep_token{args}end_token...
function_calls = self.tool_call_regex.findall(model_output)
if function_calls:
function_call_tuples.extend(function_calls)
remaining = model_output.split(self.tool_call_end_token)[-1]
function_calls = self.tool_call_portion_regex.findall(remaining)
function_call_tuples += function_calls
else:
function_calls = self.tool_call_portion_regex.findall(model_output)
if function_calls:
function_call_tuples.extend(function_calls)
tool_calls = []
for match in function_call_tuples:
function_name, function_args = match
function_name = function_name.strip()
function_args = function_args.strip()
arg_pairs = self.func_args_regex.findall(function_args)
arg_dict = {}
for key, value in arg_pairs:
parsed_value = HYV3ToolParser._parse_value(
value, function_name, key, request.tools
)
arg_dict[key] = parsed_value
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(arg_dict, ensure_ascii=False),
),
)
)
return tool_calls
except Exception:
logger.exception("Error in extracting tool call from response.")
return []
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
tool_calls = self._extract_tool_calls(model_output, request)
s_index = model_output.find(self.tool_calls_start_token)
content = model_output[:s_index] if s_index != -1 else model_output
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _reset_streaming_tool_state(self):
"""Reset the streaming state for a single tool call."""
self._streaming_tool_name = None
self._completed_args = {}
self._current_arg_key = None
self._current_arg_is_string = False
self._streamed_json_len = 0
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# Check whether current tokens contain the tool_calls start token
if self.tool_calls_start_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
# Encountered tool_calls start tag; extract preceding content and buffer
if self.tool_calls_start_token in delta_text:
text_parts = delta_text.split(self.tool_calls_start_token)
self._buffer += text_parts[-1]
if text_parts[0]:
return DeltaMessage(content=text_parts[0])
# Don't return None; continue processing buffer for complete content
else:
self._buffer += delta_text
# Encountered finish, extract valid arguments
if (
current_text.find(self.tool_call_end_token + self.tool_calls_end_token)
!= -1
and self._buffer.find(self.tool_call_end_token) == -1
):
self._buffer += self.tool_call_end_token + self.tool_calls_end_token
cur_text = self._buffer
# Haven't encountered tool_call start tag yet; keep buffering
start_idx = cur_text.find(self.tool_call_start_token)
if start_idx == -1 and self._streaming_tool_name is None:
self._buffer = ""
return None
# === Phase 1: Detect tool name (send when tool_sep_token is seen) ===
name_delta: DeltaMessage | None = None
if self._streaming_tool_name is None:
sep_idx = cur_text.find(self.tool_sep_token)
if sep_idx == -1:
# tool_sep not yet seen; keep buffering from tool_call_start
self._buffer = cur_text[start_idx:]
return None
# Extract tool name: between tool_call_start_token and tool_sep_token
name_start = start_idx + len(self.tool_call_start_token)
tool_name = cur_text[name_start:sep_idx].strip()
self._streaming_tool_name = tool_name
# Update buffer: keep only content after tool_sep (i.e. the args portion)
self._buffer = cur_text[sep_idx + len(self.tool_sep_token) :]
# Increment tool_id and send a chunk containing only the name
self.current_tool_id += 1
self._current_tool_call_id = make_tool_call_id()
name_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=self._current_tool_call_id,
type="function",
function=DeltaFunctionCall(
name=tool_name,
),
)
]
)
# Check if buffer already has complete arguments (all-in-one-delta)
if self.tool_call_end_token not in self._buffer:
return name_delta
# Buffer already has a complete tool call; continue to phase 2 below
# === Phase 2: Incremental argument streaming ===
return self._extract_streaming_incremental(name_delta, request)
def _make_args_delta(self, argument_diff: str) -> DeltaMessage:
"""Build a DeltaMessage containing only an arguments diff."""
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=argument_diff),
)
]
)
def _extract_streaming_incremental(
self,
name_delta: DeltaMessage | None,
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Incremental phase-2: scan tags in buffer, emit JSON diffs.
Strategy:
- Track completed args and emit each one as a JSON fragment.
- For string-typed args, stream the value character-by-character.
- Withhold the closing ``}`` until ``</tool_call>`` is seen.
We build JSON manually via fragments rather than using json.dumps
with a cursor, because json.dumps of partial-vs-full string values
produces incompatible prefixes (e.g. ``""}`` vs ``"Hello"}``).
"""
buf = self._buffer
is_complete = self.tool_call_end_token in buf
if is_complete:
end_idx = buf.find(self.tool_call_end_token)
args_text = buf[:end_idx]
remaining = buf[end_idx + len(self.tool_call_end_token) :]
else:
args_text = buf
remaining = ""
# --- scan all fully closed kv pairs ---
arg_pairs = self.func_args_regex.findall(args_text)
for key, value in arg_pairs:
key = key.strip()
if key not in self._completed_args:
parsed_value = HYV3ToolParser._parse_value(
value, self._streaming_tool_name or "", key, request.tools
)
self._completed_args[key] = parsed_value
# --- detect partial (unclosed) kv at the tail ---
last_closed_end = 0
for m in self.func_args_regex.finditer(args_text):
last_closed_end = m.end()
tail = args_text[last_closed_end:]
partial_key: str | None = None
partial_value: str | None = None
ak_start = tail.find(self.arg_key_start_token)
if ak_start != -1:
ak_end = tail.find(
self.arg_key_end_token,
ak_start + len(self.arg_key_start_token),
)
if ak_end != -1:
partial_key = tail[
ak_start + len(self.arg_key_start_token) : ak_end
].strip()
self._current_arg_key = partial_key
self._current_arg_is_string = HYV3ToolParser._is_only_string_type(
self._streaming_tool_name or "",
partial_key,
request.tools,
)
av_start = tail.find(self.arg_value_start_token, ak_end)
if av_start != -1:
val_content_start = av_start + len(self.arg_value_start_token)
if self._current_arg_is_string:
partial_value = tail[val_content_start:]
else:
# key not yet closed
self._current_arg_key = None
self._current_arg_is_string = False
# --- build the current JSON snapshot as a string ---
# We construct JSON manually so we can precisely control
# what gets sent incrementally.
snapshot_parts: list[str] = []
for k, v in self._completed_args.items():
k_json = json.dumps(k, ensure_ascii=False)
v_json = json.dumps(v, ensure_ascii=False)
snapshot_parts.append(f"{k_json}: {v_json}")
if partial_key is not None and partial_value is not None:
k_json = json.dumps(partial_key, ensure_ascii=False)
# For string partial value, we build the JSON string
# WITHOUT the closing quote, so the prefix stays stable
# as the value grows. The closing `"` and `}` will be
# sent when the value or tool_call closes.
escaped_val = (
partial_value.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("\n", "\\n")
.replace("\r", "\\r")
.replace("\t", "\\t")
)
# Note: no closing " here it's appended only on close
snapshot_parts.append(f'{k_json}: "{escaped_val}')
snapshot = "{" + ", ".join(snapshot_parts) + "}"
# --- compute diff ---
argument_diff: str | None = None
if is_complete:
# Tool call finished send everything remaining.
# Build final snapshot with proper JSON (all values closed).
final_args = dict(self._completed_args)
final_json = json.dumps(final_args, ensure_ascii=False)
if self._streamed_json_len < len(final_json):
argument_diff = final_json[self._streamed_json_len :]
self._streamed_json_len = len(final_json)
# Record into prev_tool_call_arr
self.prev_tool_call_arr.append(
{
"name": self._streaming_tool_name,
"arguments": final_args,
}
)
self.streamed_args_for_tool.append(final_json)
self._reset_streaming_tool_state()
self._buffer = remaining
else:
# Still in progress withhold the tail.
# For open strings: snapshot ends with ...partial_val}
# we withhold "}" (1 char) the missing closing " will
# be sent when the value closes.
# For no open string: snapshot ends with ...value"}
# we withhold "}" (1 char).
end = len(snapshot) - 1 # exclude trailing "}"
if end > self._streamed_json_len:
argument_diff = snapshot[self._streamed_json_len : end]
self._streamed_json_len = end
# --- construct return DeltaMessage ---
if name_delta is not None and argument_diff:
nd_func = name_delta.tool_calls[0].function
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=self._current_tool_call_id,
type="function",
function=DeltaFunctionCall(
name=nd_func.name if nd_func else None,
arguments=argument_diff,
),
)
]
)
elif name_delta is not None:
return name_delta
elif argument_diff:
return self._make_args_delta(argument_diff)
else:
return None
+1
View File
@@ -94,6 +94,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
funaudiochat="FunAudioChatConfig",
granite4_vision="Granite4VisionConfig",
hunyuan_vl="HunYuanVLConfig",
hy_v3="HYV3Config",
isaac="IsaacConfig",
kimi_k2="DeepseekV3Config", # Kimi K2 uses same architecture as DeepSeek V3
kimi_linear="KimiLinearConfig",
@@ -36,6 +36,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl",
"HYV3Config": "vllm.transformers_utils.configs.hy_v3",
"HyperCLOVAXConfig": "vllm.transformers_utils.configs.hyperclovax",
"IsaacConfig": "vllm.transformers_utils.configs.isaac",
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
@@ -97,6 +98,7 @@ __all__ = [
"HunYuanVLConfig",
"HunYuanVLTextConfig",
"HunYuanVLVisionConfig",
"HYV3Config",
"HyperCLOVAXConfig",
"IsaacConfig",
"RWConfig",
+185
View File
@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers.configuration_utils import PretrainedConfig
class HYV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`HYV3Model`].
It is used to instantiate a HYV3 model (HY V3 MoE language model) according to
the specified arguments.
Configuration objects inherit from [`PretrainedConfig`] and can be used to
control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
vocab_size (`int`, *optional*, defaults to 120832):
Vocabulary size of the model.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 13312):
Dimension of the dense FFN intermediate representations.
num_hidden_layers (`int`, *optional*, defaults to 80):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer.
num_key_value_heads (`int`, *optional*, defaults to 8):
Number of key-value heads for grouped-query attention.
head_dim (`int`, *optional*, defaults to 128):
Dimension per attention head.
hidden_act (`str`, *optional*, defaults to `"silu"`):
Activation function used in FFN layers.
max_position_embeddings (`int`, *optional*, defaults to 131072):
Maximum sequence length supported by the model.
initializer_range (`float`, *optional*, defaults to 0.006):
Standard deviation of the truncated normal initializer for weight
initialization.
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
Epsilon for RMS normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether to use KV cache for decoding.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning-of-sequence token id.
eos_token_id (`int` or `List[int]`, *optional*):
End-of-sequence token id(s).
rope_parameters (`dict`, *optional*):
The parameters of the RoPE embeddings.
qk_norm (`bool`, *optional*, defaults to `True`):
Whether to apply RMSNorm to query and key states before attention.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie input and output embedding weights.
enable_attention_fp32_softmax (`bool`, *optional*, defaults to `False`):
Whether to upcast attention softmax to float32. Note: the eager attention
path always computes softmax in float32 regardless of this setting; this
flag is reserved for future use with custom attention backends.
enable_lm_head_fp32 (`bool`, *optional*, defaults to `True`):
Whether to upcast the LM head computation to float32.
num_experts (`int`, *optional*, defaults to 192):
Total number of MoE experts.
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of experts selected per token (top-k routing).
num_shared_experts (`int`, *optional*, defaults to 1):
Number of always-active shared experts combined into a single MLP.
expert_hidden_dim (`int`, *optional*, defaults to 1536):
Intermediate dimension of each individual MoE expert.
moe_router_enable_expert_bias (`bool`, *optional*, defaults to `True`):
Whether to use per-expert load-balancing bias in the router.
moe_router_use_sigmoid (`bool`, *optional*, defaults to `True`):
Whether to use sigmoid (instead of softmax) for router scoring.
route_norm (`bool`, *optional*, defaults to `True`):
Whether to normalize routing scores when using sigmoid routing.
router_scaling_factor (`float`, *optional*):
Optional multiplicative scaling factor applied to routing scores.
use_grouped_mm (`bool`, *optional*, defaults to `False`):
Whether to use grouped GEMM for expert computation (not yet implemented).
enable_moe_fp32_combine (`bool`, *optional*, defaults to `False`):
Whether to accumulate expert outputs in float32.
first_k_dense_replace (`int`, *optional*, defaults to 1):
Number of initial decoder layers that use a dense FFN instead of MoE.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether to output router logits from each MoE layer. Useful for computing
auxiliary load-balancing loss during training. Disabled by default to avoid
the memory overhead of storing per-layer router tensors during inference.
Example:
```python
>>> from transformers import HYV3Config, HYV3Model
>>> config = HYV3Config()
>>> model = HYV3Model(config)
```
"""
model_type = "hy_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=120832,
hidden_size=4096,
intermediate_size=13312,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.006,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
rope_parameters: dict[str, Any] | None = None,
qk_norm=True,
tie_word_embeddings=False,
enable_attention_fp32_softmax=False,
enable_lm_head_fp32=True,
# MoE specific
num_experts=192,
num_experts_per_tok=8,
num_shared_experts=1,
expert_hidden_dim=1536,
moe_router_enable_expert_bias=True,
moe_router_use_sigmoid=True,
route_norm=True,
router_scaling_factor=None,
use_grouped_mm=False,
enable_moe_fp32_combine=False,
# Dense/MoE layer control
first_k_dense_replace=1,
output_router_logits=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
rope_theta = kwargs.pop("rope_theta", 11158840.0)
if rope_parameters is None:
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rope_parameters = rope_parameters
self.qk_norm = qk_norm
self.tie_word_embeddings = tie_word_embeddings
self.enable_lm_head_fp32 = enable_lm_head_fp32
self.enable_attention_fp32_softmax = enable_attention_fp32_softmax
# MoE specific
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.num_shared_experts = num_shared_experts
self.expert_hidden_dim = expert_hidden_dim
self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
self.moe_router_use_sigmoid = moe_router_use_sigmoid
self.route_norm = route_norm
self.use_grouped_mm = use_grouped_mm
self.router_scaling_factor = router_scaling_factor
self.enable_moe_fp32_combine = enable_moe_fp32_combine
# Dense/MoE layer control
self.first_k_dense_replace = first_k_dense_replace
self.output_router_logits = output_router_logits
if eos_token_id is not None and isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)