[Frontend] Consolidate Speech to Text entrypoints. (#42370)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
wang.yuqi
2026-05-12 15:06:57 +08:00
committed by GitHub
parent 8517cdaf90
commit d37e25ffbe
26 changed files with 681 additions and 607 deletions
+2 -2
View File
@@ -31,8 +31,8 @@
/vllm/entrypoints/cli @hmellor @mgoin @DarkLight1337 @russellb
/vllm/entrypoints/mcp @heheda12345
/vllm/entrypoints/openai @aarnphm @chaunceyjiang @DarkLight1337 @russellb
/vllm/entrypoints/openai/realtime @njhill
/vllm/entrypoints/openai/speech_to_text @NickLucche
/vllm/entrypoints/speech_to_text/realtime @njhill
/vllm/entrypoints/speech_to_text @NickLucche
/vllm/entrypoints/pooling @noooop
/vllm/entrypoints/sagemaker @DarkLight1337
/vllm/entrypoints/serve @njhill
@@ -7,8 +7,8 @@ from unittest.mock import AsyncMock, Mock
import pytest
from vllm.entrypoints.openai.speech_to_text.protocol import TranscriptionResponse
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
from vllm.entrypoints.speech_to_text.base.serving import OpenAISpeechToText
from vllm.entrypoints.speech_to_text.transcription.protocol import TranscriptionResponse
async def _never_finishes():
@@ -24,12 +24,14 @@ from vllm.entrypoints.openai.engine.protocol import (
RequestResponseMetadata,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text.protocol import TranscriptionRequest
from vllm.entrypoints.openai.speech_to_text.serving import OpenAIServingTranscription
from vllm.entrypoints.openai.speech_to_text.speech_to_text import (
from vllm.entrypoints.speech_to_text.base.serving import (
OpenAISpeechToText,
asr_inter_chunk_separator,
)
from vllm.entrypoints.speech_to_text.transcription.protocol import TranscriptionRequest
from vllm.entrypoints.speech_to_text.transcription.serving import (
OpenAIServingTranscription,
)
from vllm.model_executor.models.interfaces import SupportsTranscription
from vllm.outputs import CompletionOutput, RequestOutput
+10 -24
View File
@@ -233,19 +233,12 @@ def build_app(
attach_render_router(app)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.speech_to_text.api_router import (
attach_router as register_speech_to_text_api_router,
if "transcription" in supported_tasks or "realtime" in supported_tasks:
from vllm.entrypoints.speech_to_text.factories import (
register_speech_to_text_api_routers,
)
register_speech_to_text_api_router(app)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import (
attach_router as register_realtime_api_router,
)
register_realtime_api_router(app)
register_speech_to_text_api_routers(app, supported_tasks)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling.factories import register_pooling_api_routers
@@ -284,11 +277,11 @@ def build_app(
if "realtime" in supported_tasks:
# Add WebSocket metrics middleware
from vllm.entrypoints.openai.realtime.metrics import (
WebSocketMetricsMiddleware,
from vllm.entrypoints.speech_to_text.factories import (
add_websocket_metrics_middleware,
)
app.add_middleware(WebSocketMetricsMiddleware)
add_websocket_metrics_middleware(app)
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning(
@@ -421,20 +414,13 @@ async def init_app_state(
await init_generative_scoring_state(engine_client, state, args, request_logger)
if "transcription" in supported_tasks:
from vllm.entrypoints.openai.speech_to_text.api_router import (
init_transcription_state,
)
if "transcription" in supported_tasks or "realtime" in supported_tasks:
from vllm.entrypoints.speech_to_text.factories import init_speech_to_text_state
init_transcription_state(
init_speech_to_text_state(
engine_client, state, args, request_logger, supported_tasks
)
if "realtime" in supported_tasks:
from vllm.entrypoints.openai.realtime.api_router import init_realtime_state
init_realtime_state(engine_client, state, args, request_logger, supported_tasks)
if any(task in POOLING_TASKS for task in supported_tasks):
from vllm.entrypoints.pooling.factories import init_pooling_state
+5 -5
View File
@@ -39,11 +39,6 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
@@ -51,6 +46,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.speech_to_text.transcription.protocol import (
TranscriptionRequest,
TranscriptionResponse,
)
from vllm.entrypoints.speech_to_text.translation.protocol import TranslationRequest
from vllm.entrypoints.utils import create_error_response
from vllm.inputs import EngineInput, PromptType
from vllm.logger import init_logger
+10 -8
View File
@@ -41,14 +41,6 @@ from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
OpenAIBaseModel,
)
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose,
TranslationRequest,
TranslationResponse,
TranslationResponseVerbose,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingRequest,
EmbeddingResponse,
@@ -59,6 +51,16 @@ from vllm.entrypoints.pooling.scoring.protocol import (
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.speech_to_text.transcription.protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose,
)
from vllm.entrypoints.speech_to_text.translation.protocol import (
TranslationRequest,
TranslationResponse,
TranslationResponseVerbose,
)
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
@@ -1,148 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, FastAPI, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionRequest,
TranscriptionResponseVariant,
TranslationRequest,
TranslationResponseVariant,
)
from vllm.entrypoints.openai.speech_to_text.serving import (
OpenAIServingTranscription,
OpenAIServingTranslation,
)
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
logger = init_logger(__name__)
router = APIRouter()
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def translation(request: Request) -> OpenAIServingTranslation:
return request.app.state.openai_serving_translation
@router.post(
"/v1/audio/transcriptions",
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_transcriptions(
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
handler = transcription(raw_request)
if handler is None:
raise NotImplementedError("The model does not support Transcriptions API")
audio_data = await request.file.read()
generator = await handler.create_transcription(audio_data, request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TranscriptionResponseVariant):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/audio/translations",
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_translations(
request: Annotated[TranslationRequest, Form()], raw_request: Request
):
handler = translation(raw_request)
if handler is None:
raise NotImplementedError("The model does not support Translations API")
audio_data = await request.file.read()
generator = await handler.create_translation(audio_data, request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TranslationResponseVariant):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
app.include_router(router)
def init_transcription_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_transcription = (
OpenAIServingTranscription(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
state.openai_serving_translation = (
OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "transcription" in supported_tasks
else None
)
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, TypeAlias
import torch
## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
_LONG_INFO = torch.iinfo(torch.long)
@@ -24,18 +24,6 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionSegment,
TranscriptionStreamResponse,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationSegment,
TranslationStreamResponse,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EncoderDecoderInput, EngineInput
@@ -51,6 +39,21 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import get_tokenizer
from vllm.utils.async_utils import merge_async_iterators
from ..transcription.protocol import (
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionSegment,
TranscriptionStreamResponse,
)
from ..translation.protocol import (
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationSegment,
TranslationStreamResponse,
)
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import FastAPI
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
def register_speech_to_text_api_routers(
app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
):
if "realtime" in supported_tasks:
from .realtime.api_router import router as realtime_router
app.include_router(realtime_router)
if "transcription" in supported_tasks:
from .transcription.api_router import router as transcription_router
app.include_router(transcription_router)
from .translation.api_router import router as translation_router
app.include_router(translation_router)
def add_websocket_metrics_middleware(app: FastAPI):
from .realtime.metrics import WebSocketMetricsMiddleware
app.add_middleware(WebSocketMetricsMiddleware)
def init_speech_to_text_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
if "transcription" in supported_tasks:
from .transcription.serving import OpenAIServingTranscription
state.openai_serving_transcription = OpenAIServingTranscription(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
from .translation.serving import OpenAIServingTranslation
state.openai_serving_translation = OpenAIServingTranslation(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
enable_force_include_usage=args.enable_force_include_usage,
)
if "realtime" in supported_tasks:
from .realtime.serving import OpenAIServingRealtime
state.openai_serving_realtime = OpenAIServingRealtime(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
)
@@ -1,26 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from fastapi import APIRouter, FastAPI, WebSocket
from fastapi import APIRouter, WebSocket
from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.logger import init_logger
from .connection import RealtimeConnection
logger = init_logger(__name__)
if TYPE_CHECKING:
from argparse import Namespace
from starlette.datastructures import State
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.tasks import SupportedTask
else:
RequestLogger = object
router = APIRouter()
@@ -48,27 +37,3 @@ async def realtime_endpoint(websocket: WebSocket):
connection = RealtimeConnection(websocket, serving)
await connection.handle_connection()
def attach_router(app: FastAPI):
"""Attach the realtime router to the FastAPI app."""
app.include_router(router)
logger.info("Realtime API router attached")
def init_realtime_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
state.openai_serving_realtime = (
OpenAIServingRealtime(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
)
if "realtime" in supported_tasks
else None
)
@@ -14,7 +14,10 @@ from starlette.websockets import WebSocketDisconnect
from vllm import envs
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.realtime.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from .protocol import (
ErrorEvent,
InputAudioBufferAppend,
InputAudioBufferCommit,
@@ -22,9 +25,7 @@ from vllm.entrypoints.openai.realtime.protocol import (
TranscriptionDelta,
TranscriptionDone,
)
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from .serving import OpenAIServingRealtime
logger = init_logger(__name__)
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Annotated
from fastapi import APIRouter, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
from .protocol import TranscriptionRequest, TranscriptionResponseVariant
from .serving import OpenAIServingTranscription
logger = init_logger(__name__)
router = APIRouter()
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
@router.post(
"/v1/audio/transcriptions",
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_transcriptions(
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
handler = transcription(raw_request)
if handler is None:
raise NotImplementedError("The model does not support Transcriptions API")
audio_data = await request.file.read()
generator = await handler.create_transcription(audio_data, request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TranscriptionResponseVariant):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@@ -6,7 +6,6 @@ import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Literal, TypeAlias
import torch
from fastapi import HTTPException, UploadFile
from pydantic import (
Field,
@@ -28,13 +27,14 @@ from vllm.sampling_params import (
)
from vllm.utils import random_uuid
from ..base.protocol import _LONG_INFO, AudioResponseFormat
if TYPE_CHECKING:
import numpy as np
from vllm.config import ModelConfig, SpeechToTextConfig
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
@@ -52,10 +52,6 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
usage: UsageInfo | None = Field(default=None)
## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
class TranscriptionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranscription
@@ -407,276 +403,3 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
TranscriptionResponseVariant: TypeAlias = (
TranscriptionResponse | TranscriptionResponseVerbose
)
class TranslationResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: str | None = None
stop_reason: int | str | None = None
class TranslationStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
object: Literal["translation.chunk"] = "translation.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranslationResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
class TranslationRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranslation
file: UploadFile
"""
The audio file object (not file name) to translate, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: str | None = None
"""ID of the model to use.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
use_beam_search: bool = False
"""Whether or not beam search should be used."""
n: int = 1
"""The number of beams to be used in beam search."""
length_penalty: float = 1.0
"""Length penalty to be used for beam search."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
# --8<-- [end:translation-sampling-params]
# --8<-- [start:translation-extra-params]
language: str | None = None
"""The language of the input audio we translate from.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy.
"""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
to_language: str | None = None
"""The language of the input audio we translate to.
Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""
stream: bool | None = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def build_stt_params(
self,
audio: "np.ndarray",
stt_config: "SpeechToTextConfig",
model_config: "ModelConfig",
task_type: str,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=audio,
stt_config=stt_config,
model_config=model_config,
language=self.language,
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
hotwords=self.hotwords,
)
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
# NOTE: Temp 0 is a different fallback than completions
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self, default_max_tokens: int, default_sampling_params: dict | None = None
) -> SamplingParams:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
return SamplingParams.from_optional(
temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
# Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data
# Translation response objects
class TranslationResponse(OpenAIBaseModel):
text: str
"""The translated text."""
class TranslationWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranslationSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float | None = None
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: list[int]
"""Array of token IDs for the text content."""
class TranslationResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The translated text."""
segments: list[TranslationSegment] | None = None
"""Segments of the translated text and their corresponding details."""
words: list[TranslationWord] | None = None
"""Extracted words and their corresponding timestamps."""
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
@@ -11,21 +11,17 @@ from vllm.entrypoints.openai.engine.protocol import (
RequestResponseMetadata,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.speech_to_text.protocol import (
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from ..base.serving import OpenAISpeechToText
from .protocol import (
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseStreamChoice,
TranscriptionResponseVerbose,
TranscriptionStreamResponse,
TranslationRequest,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationStreamResponse,
)
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
logger = init_logger(__name__)
@@ -101,76 +97,3 @@ class OpenAIServingTranscription(OpenAISpeechToText):
)
async for chunk in generator:
yield chunk
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
enable_force_include_usage=enable_force_include_usage,
)
async def create_translation(
self,
audio_data: bytes,
request: TranslationRequest,
raw_request: Request | None = None,
) -> (
TranslationResponse
| TranslationResponseVerbose
| AsyncGenerator[str, None]
| ErrorResponse
):
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=(
TranslationResponseVerbose
if request.response_format == "verbose_json"
else TranslationResponse
),
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self,
request: TranslationRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
separator: str,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
separator=separator,
)
async for chunk in generator:
yield chunk
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Annotated
from fastapi import APIRouter, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
from .protocol import TranslationRequest, TranslationResponseVariant
from .serving import OpenAIServingTranslation
logger = init_logger(__name__)
router = APIRouter()
def translation(request: Request) -> OpenAIServingTranslation:
return request.app.state.openai_serving_translation
@router.post(
"/v1/audio/translations",
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_translations(
request: Annotated[TranslationRequest, Form()], raw_request: Request
):
handler = translation(raw_request)
if handler is None:
raise NotImplementedError("The model does not support Translations API")
audio_data = await request.file.read()
generator = await handler.create_translation(audio_data, request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, TranslationResponseVariant):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@@ -0,0 +1,308 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TYPE_CHECKING, Literal, TypeAlias
from fastapi import UploadFile
from pydantic import (
Field,
model_validator,
)
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
OpenAIBaseModel,
UsageInfo,
)
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
SamplingParams,
)
from vllm.utils import random_uuid
from ..base.protocol import _LONG_INFO, AudioResponseFormat
if TYPE_CHECKING:
import numpy as np
from vllm.config import ModelConfig, SpeechToTextConfig
logger = init_logger(__name__)
class TranslationResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: str | None = None
stop_reason: int | str | None = None
class TranslationStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
object: Literal["translation.chunk"] = "translation.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranslationResponseStreamChoice]
usage: UsageInfo | None = Field(default=None)
class TranslationRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranslation
file: UploadFile
"""
The audio file object (not file name) to translate, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: str | None = None
"""ID of the model to use.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
use_beam_search: bool = False
"""Whether or not beam search should be used."""
n: int = 1
"""The number of beams to be used in beam search."""
length_penalty: float = 1.0
"""Length penalty to be used for beam search."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
# --8<-- [end:translation-sampling-params]
# --8<-- [start:translation-extra-params]
language: str | None = None
"""The language of the input audio we translate from.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy.
"""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
to_language: str | None = None
"""The language of the input audio we translate to.
Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""
stream: bool | None = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def build_stt_params(
self,
audio: "np.ndarray",
stt_config: "SpeechToTextConfig",
model_config: "ModelConfig",
task_type: str,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=audio,
stt_config=stt_config,
model_config=model_config,
language=self.language,
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
hotwords=self.hotwords,
)
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
# NOTE: Temp 0 is a different fallback than completions
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self, default_max_tokens: int, default_sampling_params: dict | None = None
) -> SamplingParams:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
)
return SamplingParams.from_optional(
temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
# Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data
# Translation response objects
class TranslationResponse(OpenAIBaseModel):
text: str
"""The translated text."""
class TranslationWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranslationSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float | None = None
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: list[int]
"""Array of token IDs for the text content."""
class TranslationResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The translated text."""
segments: list[TranslationSegment] | None = None
"""Segments of the translated text and their corresponding details."""
words: list[TranslationWord] | None = None
"""Extracted words and their corresponding timestamps."""
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import AsyncGenerator
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from ..base.serving import OpenAISpeechToText
from .protocol import (
TranslationRequest,
TranslationResponse,
TranslationResponseStreamChoice,
TranslationResponseVerbose,
TranslationStreamResponse,
)
logger = init_logger(__name__)
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate",
enable_force_include_usage=enable_force_include_usage,
)
async def create_translation(
self,
audio_data: bytes,
request: TranslationRequest,
raw_request: Request | None = None,
) -> (
TranslationResponse
| TranslationResponseVerbose
| AsyncGenerator[str, None]
| ErrorResponse
):
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=(
TranslationResponseVerbose
if request.response_format == "verbose_json"
else TranslationResponse
),
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self,
request: TranslationRequest,
result_generator: list[AsyncGenerator[RequestOutput, None]],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
separator: str,
) -> AsyncGenerator[str, None]:
generator = self._speech_to_text_stream_generator(
request=request,
list_result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
separator=separator,
)
async for chunk in generator:
yield chunk