mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Frontend] Consolidate Speech to Text entrypoints. (#42370)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
+2
-2
@@ -31,8 +31,8 @@
|
||||
/vllm/entrypoints/cli @hmellor @mgoin @DarkLight1337 @russellb
|
||||
/vllm/entrypoints/mcp @heheda12345
|
||||
/vllm/entrypoints/openai @aarnphm @chaunceyjiang @DarkLight1337 @russellb
|
||||
/vllm/entrypoints/openai/realtime @njhill
|
||||
/vllm/entrypoints/openai/speech_to_text @NickLucche
|
||||
/vllm/entrypoints/speech_to_text/realtime @njhill
|
||||
/vllm/entrypoints/speech_to_text @NickLucche
|
||||
/vllm/entrypoints/pooling @noooop
|
||||
/vllm/entrypoints/sagemaker @DarkLight1337
|
||||
/vllm/entrypoints/serve @njhill
|
||||
|
||||
@@ -7,8 +7,8 @@ from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import TranscriptionResponse
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
|
||||
from vllm.entrypoints.speech_to_text.base.serving import OpenAISpeechToText
|
||||
from vllm.entrypoints.speech_to_text.transcription.protocol import TranscriptionResponse
|
||||
|
||||
|
||||
async def _never_finishes():
|
||||
|
||||
+5
-3
@@ -24,12 +24,14 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import TranscriptionRequest
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import OpenAIServingTranscription
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import (
|
||||
from vllm.entrypoints.speech_to_text.base.serving import (
|
||||
OpenAISpeechToText,
|
||||
asr_inter_chunk_separator,
|
||||
)
|
||||
from vllm.entrypoints.speech_to_text.transcription.protocol import TranscriptionRequest
|
||||
from vllm.entrypoints.speech_to_text.transcription.serving import (
|
||||
OpenAIServingTranscription,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsTranscription
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
|
||||
|
||||
@@ -233,19 +233,12 @@ def build_app(
|
||||
|
||||
attach_render_router(app)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from vllm.entrypoints.openai.speech_to_text.api_router import (
|
||||
attach_router as register_speech_to_text_api_router,
|
||||
if "transcription" in supported_tasks or "realtime" in supported_tasks:
|
||||
from vllm.entrypoints.speech_to_text.factories import (
|
||||
register_speech_to_text_api_routers,
|
||||
)
|
||||
|
||||
register_speech_to_text_api_router(app)
|
||||
|
||||
if "realtime" in supported_tasks:
|
||||
from vllm.entrypoints.openai.realtime.api_router import (
|
||||
attach_router as register_realtime_api_router,
|
||||
)
|
||||
|
||||
register_realtime_api_router(app)
|
||||
register_speech_to_text_api_routers(app, supported_tasks)
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling.factories import register_pooling_api_routers
|
||||
@@ -284,11 +277,11 @@ def build_app(
|
||||
|
||||
if "realtime" in supported_tasks:
|
||||
# Add WebSocket metrics middleware
|
||||
from vllm.entrypoints.openai.realtime.metrics import (
|
||||
WebSocketMetricsMiddleware,
|
||||
from vllm.entrypoints.speech_to_text.factories import (
|
||||
add_websocket_metrics_middleware,
|
||||
)
|
||||
|
||||
app.add_middleware(WebSocketMetricsMiddleware)
|
||||
add_websocket_metrics_middleware(app)
|
||||
|
||||
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
||||
logger.warning(
|
||||
@@ -421,20 +414,13 @@ async def init_app_state(
|
||||
|
||||
await init_generative_scoring_state(engine_client, state, args, request_logger)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from vllm.entrypoints.openai.speech_to_text.api_router import (
|
||||
init_transcription_state,
|
||||
)
|
||||
if "transcription" in supported_tasks or "realtime" in supported_tasks:
|
||||
from vllm.entrypoints.speech_to_text.factories import init_speech_to_text_state
|
||||
|
||||
init_transcription_state(
|
||||
init_speech_to_text_state(
|
||||
engine_client, state, args, request_logger, supported_tasks
|
||||
)
|
||||
|
||||
if "realtime" in supported_tasks:
|
||||
from vllm.entrypoints.openai.realtime.api_router import init_realtime_state
|
||||
|
||||
init_realtime_state(engine_client, state, args, request_logger, supported_tasks)
|
||||
|
||||
if any(task in POOLING_TASKS for task in supported_tasks):
|
||||
from vllm.entrypoints.pooling.factories import init_pooling_state
|
||||
|
||||
|
||||
@@ -39,11 +39,6 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
@@ -51,6 +46,11 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.speech_to_text.transcription.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
)
|
||||
from vllm.entrypoints.speech_to_text.translation.protocol import TranslationRequest
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.inputs import EngineInput, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -41,14 +41,6 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
OpenAIBaseModel,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
@@ -59,6 +51,16 @@ from vllm.entrypoints.pooling.scoring.protocol import (
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.speech_to_text.transcription.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
)
|
||||
from vllm.entrypoints.speech_to_text.translation.protocol import (
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
)
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
TranslationResponseVariant,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(
|
||||
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
||||
):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Transcriptions API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
|
||||
generator = await handler.create_transcription(audio_data, request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/translations",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(
|
||||
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
||||
):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Translations API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
|
||||
generator = await handler.create_translation(audio_data, request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranslationResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def init_transcription_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
## Protocols for Audio
|
||||
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
+15
-12
@@ -24,18 +24,6 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionSegment,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EncoderDecoderInput, EngineInput
|
||||
@@ -51,6 +39,21 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
from ..transcription.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionSegment,
|
||||
TranscriptionStreamResponse,
|
||||
)
|
||||
from ..translation.protocol import (
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
SpeechToTextResponseVerbose: TypeAlias = (
|
||||
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
|
||||
def register_speech_to_text_api_routers(
|
||||
app: FastAPI,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
if "realtime" in supported_tasks:
|
||||
from .realtime.api_router import router as realtime_router
|
||||
|
||||
app.include_router(realtime_router)
|
||||
|
||||
if "transcription" in supported_tasks:
|
||||
from .transcription.api_router import router as transcription_router
|
||||
|
||||
app.include_router(transcription_router)
|
||||
|
||||
from .translation.api_router import router as translation_router
|
||||
|
||||
app.include_router(translation_router)
|
||||
|
||||
|
||||
def add_websocket_metrics_middleware(app: FastAPI):
|
||||
from .realtime.metrics import WebSocketMetricsMiddleware
|
||||
|
||||
app.add_middleware(WebSocketMetricsMiddleware)
|
||||
|
||||
|
||||
def init_speech_to_text_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
if "transcription" in supported_tasks:
|
||||
from .transcription.serving import OpenAIServingTranscription
|
||||
|
||||
state.openai_serving_transcription = OpenAIServingTranscription(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
|
||||
from .translation.serving import OpenAIServingTranslation
|
||||
|
||||
state.openai_serving_translation = OpenAIServingTranslation(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
|
||||
if "realtime" in supported_tasks:
|
||||
from .realtime.serving import OpenAIServingRealtime
|
||||
|
||||
state.openai_serving_realtime = OpenAIServingRealtime(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
+3
-38
@@ -1,26 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, FastAPI, WebSocket
|
||||
from fastapi import APIRouter, WebSocket
|
||||
|
||||
from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
|
||||
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .connection import RealtimeConnection
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -48,27 +37,3 @@ async def realtime_endpoint(websocket: WebSocket):
|
||||
|
||||
connection = RealtimeConnection(websocket, serving)
|
||||
await connection.handle_connection()
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Attach the realtime router to the FastAPI app."""
|
||||
app.include_router(router)
|
||||
logger.info("Realtime API router attached")
|
||||
|
||||
|
||||
def init_realtime_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_realtime = (
|
||||
OpenAIServingRealtime(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
if "realtime" in supported_tasks
|
||||
else None
|
||||
)
|
||||
+5
-4
@@ -14,7 +14,10 @@ from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.realtime.protocol import (
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .protocol import (
|
||||
ErrorEvent,
|
||||
InputAudioBufferAppend,
|
||||
InputAudioBufferCommit,
|
||||
@@ -22,9 +25,7 @@ from vllm.entrypoints.openai.realtime.protocol import (
|
||||
TranscriptionDelta,
|
||||
TranscriptionDone,
|
||||
)
|
||||
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from .serving import OpenAIServingRealtime
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .protocol import TranscriptionRequest, TranscriptionResponseVariant
|
||||
from .serving import OpenAIServingTranscription
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(
|
||||
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
||||
):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Transcriptions API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
|
||||
generator = await handler.create_transcription(audio_data, request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
+2
-279
@@ -6,7 +6,6 @@ import time
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from pydantic import (
|
||||
Field,
|
||||
@@ -28,13 +27,14 @@ from vllm.sampling_params import (
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
from ..base.protocol import _LONG_INFO, AudioResponseFormat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||
@@ -52,10 +52,6 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
## Protocols for Audio
|
||||
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
|
||||
|
||||
|
||||
class TranscriptionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
@@ -407,276 +403,3 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
TranscriptionResponseVariant: TypeAlias = (
|
||||
TranscriptionResponse | TranscriptionResponseVerbose
|
||||
)
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
|
||||
class TranslationStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||
object: Literal["translation.chunk"] = "translation.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranslationResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
class TranslationRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to translate, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str | None = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
use_beam_search: bool = False
|
||||
"""Whether or not beam search should be used."""
|
||||
|
||||
n: int = 1
|
||||
"""The number of beams to be used in beam search."""
|
||||
|
||||
length_penalty: float = 1.0
|
||||
"""Length penalty to be used for beam search."""
|
||||
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
# --8<-- [end:translation-sampling-params]
|
||||
|
||||
# --8<-- [start:translation-extra-params]
|
||||
language: str | None = None
|
||||
"""The language of the input audio we translate from.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
hotwords: str | None = None
|
||||
"""
|
||||
hotwords refers to a list of important words or phrases that the model
|
||||
should pay extra attention to during transcription.
|
||||
"""
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the input audio we translate to.
|
||||
|
||||
Please note that this is not supported by all models, refer to the specific
|
||||
model documentation for more details.
|
||||
For instance, Whisper only supports `to_language=en`.
|
||||
"""
|
||||
|
||||
stream: bool | None = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:translation-extra-params]
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def build_stt_params(
|
||||
self,
|
||||
audio: "np.ndarray",
|
||||
stt_config: "SpeechToTextConfig",
|
||||
model_config: "ModelConfig",
|
||||
task_type: str,
|
||||
) -> SpeechToTextParams:
|
||||
return SpeechToTextParams(
|
||||
audio=audio,
|
||||
stt_config=stt_config,
|
||||
model_config=model_config,
|
||||
language=self.language,
|
||||
task_type=task_type,
|
||||
request_prompt=self.prompt,
|
||||
to_language=self.to_language,
|
||||
hotwords=self.hotwords,
|
||||
)
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> BeamSearchParams:
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
max_tokens = default_max_tokens
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# NOTE: Temp 0 is a different fallback than completions
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Translation response objects
|
||||
class TranslationResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
|
||||
class TranslationWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
segments: list[TranslationSegment] | None = None
|
||||
"""Segments of the translated text and their corresponding details."""
|
||||
|
||||
words: list[TranslationWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||
+5
-82
@@ -11,21 +11,17 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
from ..base.serving import OpenAISpeechToText
|
||||
from .protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -101,76 +97,3 @@ class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_translation(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranslationRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=(
|
||||
TranslationResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranslationResponse
|
||||
),
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self,
|
||||
request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
separator: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
separator=separator,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .protocol import TranslationRequest, TranslationResponseVariant
|
||||
from .serving import OpenAIServingTranslation
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/translations",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(
|
||||
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
||||
):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Translations API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
|
||||
generator = await handler.create_translation(audio_data, request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranslationResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Literal, TypeAlias
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.config.speech_to_text import SpeechToTextParams
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
OpenAIBaseModel,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
from ..base.protocol import _LONG_INFO, AudioResponseFormat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
|
||||
class TranslationStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||
object: Literal["translation.chunk"] = "translation.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranslationResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
class TranslationRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to translate, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str | None = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
use_beam_search: bool = False
|
||||
"""Whether or not beam search should be used."""
|
||||
|
||||
n: int = 1
|
||||
"""The number of beams to be used in beam search."""
|
||||
|
||||
length_penalty: float = 1.0
|
||||
"""Length penalty to be used for beam search."""
|
||||
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
# --8<-- [end:translation-sampling-params]
|
||||
|
||||
# --8<-- [start:translation-extra-params]
|
||||
language: str | None = None
|
||||
"""The language of the input audio we translate from.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
hotwords: str | None = None
|
||||
"""
|
||||
hotwords refers to a list of important words or phrases that the model
|
||||
should pay extra attention to during transcription.
|
||||
"""
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the input audio we translate to.
|
||||
|
||||
Please note that this is not supported by all models, refer to the specific
|
||||
model documentation for more details.
|
||||
For instance, Whisper only supports `to_language=en`.
|
||||
"""
|
||||
|
||||
stream: bool | None = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:translation-extra-params]
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def build_stt_params(
|
||||
self,
|
||||
audio: "np.ndarray",
|
||||
stt_config: "SpeechToTextConfig",
|
||||
model_config: "ModelConfig",
|
||||
task_type: str,
|
||||
) -> SpeechToTextParams:
|
||||
return SpeechToTextParams(
|
||||
audio=audio,
|
||||
stt_config=stt_config,
|
||||
model_config=model_config,
|
||||
language=self.language,
|
||||
task_type=task_type,
|
||||
request_prompt=self.prompt,
|
||||
to_language=self.to_language,
|
||||
hotwords=self.hotwords,
|
||||
)
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: dict | None = None,
|
||||
) -> BeamSearchParams:
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
max_tokens = default_max_tokens
|
||||
n = self.n if self.n is not None else 1
|
||||
|
||||
# NOTE: Temp 0 is a different fallback than completions
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Translation response objects
|
||||
class TranslationResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
|
||||
class TranslationWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
segments: list[TranslationSegment] | None = None
|
||||
"""Segments of the translated text and their corresponding details."""
|
||||
|
||||
words: list[TranslationWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
from ..base.serving import OpenAISpeechToText
|
||||
from .protocol import (
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_translation(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranslationRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=(
|
||||
TranslationResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranslationResponse
|
||||
),
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self,
|
||||
request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
separator: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
separator=separator,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
Reference in New Issue
Block a user