mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
support hotwords for FunASR model (#39674)
Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com> Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
This commit is contained in:
@@ -27,7 +27,12 @@ from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
def sync_openai(
|
||||
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3
|
||||
audio_path: str,
|
||||
client: OpenAI,
|
||||
model: str,
|
||||
*,
|
||||
repetition_penalty: float = 1.3,
|
||||
hotwords: str = None,
|
||||
):
|
||||
"""
|
||||
Perform synchronous transcription using OpenAI-compatible API.
|
||||
@@ -43,12 +48,15 @@ def sync_openai(
|
||||
extra_body=dict(
|
||||
seed=4419,
|
||||
repetition_penalty=repetition_penalty,
|
||||
hotwords=hotwords,
|
||||
),
|
||||
)
|
||||
print("transcription result [sync]:", transcription.text)
|
||||
|
||||
|
||||
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
|
||||
async def stream_openai_response(
|
||||
audio_path: str, client: AsyncOpenAI, model: str, hotwords: str = None
|
||||
):
|
||||
"""
|
||||
Perform asynchronous transcription using OpenAI-compatible API.
|
||||
"""
|
||||
@@ -64,6 +72,7 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: st
|
||||
extra_body=dict(
|
||||
seed=420,
|
||||
top_p=0.6,
|
||||
hotwords=hotwords,
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
@@ -136,6 +145,7 @@ def main(args):
|
||||
client=client,
|
||||
model=model,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
hotwords=args.hotwords,
|
||||
)
|
||||
|
||||
# Run the asynchronous function
|
||||
@@ -146,7 +156,10 @@ def main(args):
|
||||
)
|
||||
asyncio.run(
|
||||
stream_openai_response(
|
||||
args.audio_path if args.audio_path else winning_call, client, model
|
||||
args.audio_path if args.audio_path else winning_call,
|
||||
client,
|
||||
model,
|
||||
hotwords=args.hotwords,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -174,5 +187,11 @@ if __name__ == "__main__":
|
||||
default=1.3,
|
||||
help="repetition penalty",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hotwords",
|
||||
type=str,
|
||||
default=None,
|
||||
help="hotwords",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -35,6 +35,12 @@ class SpeechToTextParams:
|
||||
language: str | None = None
|
||||
"""ISO 639-1 language code (validated / auto-detected)."""
|
||||
|
||||
hotwords: str | None = None
|
||||
"""
|
||||
hotwords refers to a list of important words or phrases that the model
|
||||
should pay extra attention to during transcription.
|
||||
"""
|
||||
|
||||
task_type: str = "transcribe"
|
||||
"""``"transcribe"`` or ``"translate"``."""
|
||||
|
||||
|
||||
@@ -78,6 +78,12 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
will improve accuracy and latency.
|
||||
"""
|
||||
|
||||
hotwords: str | None = None
|
||||
"""
|
||||
hotwords refers to a list of important words or phrases that the model
|
||||
should pay extra attention to during transcription.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
@@ -205,6 +211,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
task_type=task_type,
|
||||
request_prompt=self.prompt,
|
||||
to_language=self.to_language,
|
||||
hotwords=self.hotwords,
|
||||
)
|
||||
|
||||
def to_beam_search_params(
|
||||
@@ -481,6 +488,12 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
hotwords: str | None = None
|
||||
"""
|
||||
hotwords refers to a list of important words or phrases that the model
|
||||
should pay extra attention to during transcription.
|
||||
"""
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the input audio we translate to.
|
||||
|
||||
@@ -522,6 +535,7 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
task_type=task_type,
|
||||
request_prompt=self.prompt,
|
||||
to_language=self.to_language,
|
||||
hotwords=self.hotwords,
|
||||
)
|
||||
|
||||
def to_beam_search_params(
|
||||
|
||||
@@ -881,13 +881,20 @@ class FunASRForConditionalGeneration(
|
||||
audio = stt_params.audio
|
||||
stt_config = stt_params.stt_config
|
||||
language = stt_params.language
|
||||
hotwords = stt_params.hotwords
|
||||
|
||||
if language is None:
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the funasr prompt"
|
||||
)
|
||||
|
||||
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
||||
if hotwords is not None:
|
||||
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n热词列表:[{}]\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n".format( # noqa: E501
|
||||
hotwords
|
||||
)
|
||||
else:
|
||||
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
||||
|
||||
prompt = {
|
||||
"prompt": funasr_prompt,
|
||||
"multi_modal_data": {
|
||||
|
||||
Reference in New Issue
Block a user