[Refactor] Drop direct dependency on librosa (#39079)

Signed-off-by: Nick Cao <ncao@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Nick Cao
2026-04-18 02:55:38 -04:00
committed by GitHub
parent 87518c3027
commit 153ba7f0f3
13 changed files with 40 additions and 44 deletions
+3 -3
View File
@@ -193,7 +193,7 @@ Provide a fast duration→token estimate to improve streaming usage statistics:
The API server takes care of basic audio I/O and optional chunking before building prompts:
- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`.
- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `AudioResampler`.
- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`.
- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words.
@@ -206,8 +206,8 @@ Relevant server logic:
async def _preprocess_speech_to_text(...):
language = self.model_cls.validate_language(request.language)
...
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
y, sr = load_audio(bytes_, sr=self.asr_config.sample_rate)
duration = get_audio_duration(y=y, sr=sr)
do_split_audio = (self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
+3 -3
View File
@@ -300,12 +300,12 @@ Full example: [examples/offline_inference/audio_language.py](../../examples/offl
Speech-to-text models like Whisper have a maximum audio length they can process (typically 30 seconds). For longer audio files, vLLM provides a utility to intelligently split audio into chunks at quiet points to minimize cutting through speech.
```python
import librosa
from vllm import LLM, SamplingParams
from vllm.multimodal.audio import split_audio
from vllm.multimodal.media.audio import load_audio
# Load long audio file
audio, sr = librosa.load("long_audio.wav", sr=16000)
audio, sr = load_audio("long_audio.wav", sr=16000)
# Split into chunks at low-energy (quiet) regions
chunks = split_audio(
@@ -832,7 +832,7 @@ Then, you can use the OpenAI client as follows:
base_url=openai_api_base,
)
# Any format supported by librosa is supported
# Any format supported by soundfile/PyAV is supported
audio_url = AudioAsset("winning_call").url
audio_base64 = encode_base64_content_from_url(audio_url)
@@ -267,7 +267,7 @@ def run_audio(model: str, max_completion_tokens: int) -> None:
{
"type": "input_audio",
"input_audio": {
# Any format supported by librosa is supported
# Any format supported by soundfile/PyAV is supported
"data": audio_base64,
"format": "wav",
},
@@ -292,7 +292,7 @@ def run_audio(model: str, max_completion_tokens: int) -> None:
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
# Any format supported by soundfile/PyAV is supported
"url": audio_url
},
},
@@ -316,7 +316,7 @@ def run_audio(model: str, max_completion_tokens: int) -> None:
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
# Any format supported by soundfile/PyAV is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
},
},
@@ -12,7 +12,6 @@ model, for example:
Requirements:
- vllm with audio support
- websockets
- librosa
- numpy
The script:
@@ -26,12 +25,12 @@ import argparse
import asyncio
import json
import librosa
import numpy as np
import pybase64 as base64
import websockets
from vllm.assets.audio import AudioAsset
from vllm.multimodal.media.audio import load_audio
def audio_to_pcm16_base64(audio_path: str) -> str:
@@ -39,7 +38,7 @@ def audio_to_pcm16_base64(audio_path: str) -> str:
Load an audio file and convert it to base64-encoded PCM16 @ 16kHz.
"""
# Load audio and resample to 16kHz mono
audio, _ = librosa.load(audio_path, sr=16000, mono=True)
audio, _ = load_audio(audio_path, sr=16000, mono=True)
# Convert to PCM16
pcm16 = (audio * 32767).astype(np.int16)
# Encode as base64
@@ -13,7 +13,6 @@ import io
import time
from statistics import mean, median
import librosa
import pytest
import soundfile
import torch
@@ -21,6 +20,7 @@ from datasets import load_dataset
from evaluate import load
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from vllm.multimodal.audio import get_audio_duration
from vllm.tokenizers import get_tokenizer
from ....models.registry import HF_EXAMPLE_MODELS
@@ -84,7 +84,7 @@ async def process_dataset(model, client, data, concurrent_request):
trust_remote_code=model_info.trust_remote_code,
)
# Warmup call as the first `librosa.load` server-side is quite slow.
# Warmup call as the first `load_audio` server-side is quite slow.
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
_ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "")
@@ -118,7 +118,7 @@ def print_performance_metrics(results, total_time):
def add_duration(sample):
y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000
sample["duration_ms"] = get_audio_duration(y=y, sr=sr) * 1000
return sample
@@ -5,7 +5,6 @@ import asyncio
import json
import warnings
import librosa
import numpy as np
import pybase64 as base64
import pytest
@@ -14,6 +13,7 @@ import websockets
from tests.entrypoints.openai.conftest import add_attention_backend
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
from vllm.assets.audio import AudioAsset
from vllm.multimodal.media.audio import load_audio
# Increase engine iteration timeout for ROCm where first-use JIT compilation
# can exceed the default 60s, causing a silent deadlock in feed_tokens.
@@ -56,7 +56,7 @@ async def send_event(ws, event: dict) -> None:
def mary_had_lamb_audio_chunks() -> list[str]:
"""Audio split into ~1 second chunks for streaming."""
path = AudioAsset("mary_had_lamb").get_local_path()
audio, _ = librosa.load(str(path), sr=16000, mono=True)
audio, _ = load_audio(str(path), sr=16000, mono=True)
# Split into ~0.1 second chunks (1600 samples at 16kHz)
chunk_size = 1600
@@ -6,7 +6,6 @@ import asyncio
import io
import json
import librosa
import numpy as np
import openai
import pytest
@@ -14,6 +13,7 @@ import pytest_asyncio
import soundfile as sf
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.media.audio import load_audio
from vllm.platforms import current_platform
MODEL_NAME = "openai/whisper-large-v3-turbo"
@@ -134,7 +134,7 @@ async def test_bad_requests(mary_had_lamb, whisper_client):
@pytest.mark.asyncio
async def test_long_audio_request(mary_had_lamb, whisper_client):
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
audio, sr = load_audio(mary_had_lamb)
# Add small silence after each audio for repeatability in the split process
audio = np.pad(audio, (0, 1600))
repeated_audio = np.tile(audio, 10)
@@ -7,7 +7,6 @@ import io
import json
import httpx
import librosa
import numpy as np
import openai
import pytest
@@ -17,6 +16,7 @@ import soundfile as sf
from tests.entrypoints.openai.conftest import add_attention_backend
from tests.utils import RemoteOpenAIServer
from vllm.logger import init_logger
from vllm.multimodal.media.audio import load_audio
logger = init_logger(__name__)
@@ -264,7 +264,7 @@ async def test_long_audio_request(foscolo, client_and_model):
if model_name == "google/gemma-3n-E2B-it":
pytest.skip("Gemma3n does not support long audio requests")
foscolo.seek(0)
audio, sr = librosa.load(foscolo)
audio, sr = load_audio(foscolo)
repeated_audio = np.tile(audio, 2)
# Repeated audio to buffer
buffer = io.BytesIO()
@@ -4,7 +4,6 @@
import os
from collections.abc import Sequence
import librosa
import pytest
import regex as re
from huggingface_hub import snapshot_download
@@ -14,6 +13,7 @@ from vllm.assets.image import ImageAsset
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.multimodal.media.audio import load_audio
from ....conftest import (
IMAGE_ASSETS,
@@ -290,7 +290,7 @@ def test_vision_speech_models(
num_logprobs: int,
) -> None:
# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None)
audio = load_audio(speech_question, sr=None)
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
inputs_vision_speech = [
@@ -4,11 +4,11 @@
from collections.abc import Sequence
from typing import Any
import librosa
import pytest
from transformers import AutoModelForSpeechSeq2Seq
from vllm.assets.audio import AudioAsset
from vllm.multimodal.audio import AudioResampler
from vllm.platforms import current_platform
from ....conftest import HfRunner, PromptAudioInput, VllmRunner
@@ -93,13 +93,12 @@ def run_test(
def resampled_assets() -> list[tuple[Any, int]]:
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
sampled_assets = []
resampler = AudioResampler(target_sr=WHISPER_SAMPLE_RATE)
for asset in audio_assets:
audio, orig_sr = asset.audio_and_sample_rate
# Resample to Whisper's expected sample rate (16kHz)
if orig_sr != WHISPER_SAMPLE_RATE:
audio = librosa.resample(
audio, orig_sr=orig_sr, target_sr=WHISPER_SAMPLE_RATE
)
audio = resampler.resample(audio, orig_sr=orig_sr)
sampled_assets.append(
(audio, WHISPER_SAMPLE_RATE),
)
+2 -2
View File
@@ -3,12 +3,12 @@
from pathlib import Path
from unittest.mock import patch
import librosa
import numpy as np
import pybase64 as base64
import pytest
from vllm.multimodal.media import AudioMediaIO
from vllm.multimodal.media.audio import load_audio
from ...conftest import AudioTestAssets
@@ -73,6 +73,6 @@ def test_audio_media_io_from_video(video_assets):
video_path = video_assets[0].video_path
with open(video_path, "rb") as f:
audio, sr = audio_io.load_bytes(f.read())
audio_ref, sr_ref = librosa.load(video_path, sr=None)
audio_ref, sr_ref = load_audio(video_path, sr=None)
assert sr == sr_ref
np.testing.assert_allclose(audio_ref, audio, atol=1e-4)
+3 -3
View File
@@ -29,9 +29,9 @@ except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`,
# soundfile being the main audio loading backend. Used to validate if an audio
# loading error is due to a server error vs a client error (invalid audio file).
# 0 = sf_error(NULL) race condition: when multiple threads fail sf_open_virtual
# concurrently, one thread may clear the global error before another reads it,
# producing code=0 ("Garbled error message from libsndfile" in soundfile).
@@ -4,11 +4,11 @@ import logging
import math
import random
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchaudio.functional import melscale_fbanks
from transformers import AutoFeatureExtractor, AutoProcessor, BatchFeature
from transformers.feature_extraction_sequence_utils import (
SequenceFeatureExtractor,
@@ -129,17 +129,15 @@ class FilterbankFeatures(nn.Module):
self.pad_min_duration = 0.0
self.pad_direction = "both"
filterbanks = torch.tensor(
librosa.filters.mel(
sr=sample_rate,
n_fft=self.n_fft,
n_mels=nfilt,
fmin=lowfreq,
fmax=highfreq,
norm=mel_norm,
),
dtype=torch.float,
).unsqueeze(0)
filterbanks = melscale_fbanks(
n_freqs=self.n_fft // 2 + 1,
f_min=lowfreq,
f_max=highfreq,
n_mels=nfilt,
sample_rate=sample_rate,
norm=mel_norm,
mel_scale="slaney",
).T.unsqueeze(0)
self.register_buffer("fb", filterbanks)
# Calculate maximum sequence length