support hotwords for FunASR model (#39674)

Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
This commit is contained in:
AllenDou
2026-04-22 17:25:06 +08:00
committed by GitHub
parent ed6d30377d
commit 9047288b68
4 changed files with 50 additions and 4 deletions
@@ -27,7 +27,12 @@ from vllm.assets.audio import AudioAsset
def sync_openai(
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3
audio_path: str,
client: OpenAI,
model: str,
*,
repetition_penalty: float = 1.3,
hotwords: str = None,
):
"""
Perform synchronous transcription using OpenAI-compatible API.
@@ -43,12 +48,15 @@ def sync_openai(
extra_body=dict(
seed=4419,
repetition_penalty=repetition_penalty,
hotwords=hotwords,
),
)
print("transcription result [sync]:", transcription.text)
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
async def stream_openai_response(
audio_path: str, client: AsyncOpenAI, model: str, hotwords: str = None
):
"""
Perform asynchronous transcription using OpenAI-compatible API.
"""
@@ -64,6 +72,7 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: st
extra_body=dict(
seed=420,
top_p=0.6,
hotwords=hotwords,
),
stream=True,
)
@@ -136,6 +145,7 @@ def main(args):
client=client,
model=model,
repetition_penalty=args.repetition_penalty,
hotwords=args.hotwords,
)
# Run the asynchronous function
@@ -146,7 +156,10 @@ def main(args):
)
asyncio.run(
stream_openai_response(
args.audio_path if args.audio_path else winning_call, client, model
args.audio_path if args.audio_path else winning_call,
client,
model,
hotwords=args.hotwords,
)
)
else:
@@ -174,5 +187,11 @@ if __name__ == "__main__":
default=1.3,
help="repetition penalty",
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="hotwords",
)
args = parser.parse_args()
main(args)
+6
View File
@@ -35,6 +35,12 @@ class SpeechToTextParams:
language: str | None = None
"""ISO 639-1 language code (validated / auto-detected)."""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
task_type: str = "transcribe"
"""``"transcribe"`` or ``"translate"``."""
@@ -78,6 +78,12 @@ class TranscriptionRequest(OpenAIBaseModel):
will improve accuracy and latency.
"""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
@@ -205,6 +211,7 @@ class TranscriptionRequest(OpenAIBaseModel):
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
hotwords=self.hotwords,
)
def to_beam_search_params(
@@ -481,6 +488,12 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy.
"""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
to_language: str | None = None
"""The language of the input audio we translate to.
@@ -522,6 +535,7 @@ class TranslationRequest(OpenAIBaseModel):
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
hotwords=self.hotwords,
)
def to_beam_search_params(
+8 -1
View File
@@ -881,13 +881,20 @@ class FunASRForConditionalGeneration(
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
hotwords = stt_params.hotwords
if language is None:
raise ValueError(
"Language must be specified when creating the funasr prompt"
)
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
if hotwords is not None:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n热词列表:[{}]\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n".format( # noqa: E501
hotwords
)
else:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
prompt = {
"prompt": funasr_prompt,
"multi_modal_data": {