3 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] 61ecb098d6 Improve error handling and logging for AudioMediaIO compatibility
- Add warnings to inform users which compatibility mode is being used
- Handle both AttributeError and ImportError for better coverage
- Add __init__ method to inherited class for consistency
- Provide clear diagnostic messages when patching fails

Co-authored-by: donglixp <1070872+donglixp@users.noreply.github.com>
2026-01-29 02:24:53 +00:00
copilot-swe-agent[bot] b4cd7c479f Fix vLLM AudioMediaIO compatibility issue
Add try-except blocks to handle both old and new vLLM versions where AudioMediaIO may not exist or may have been moved. This fixes the AttributeError when using newer vLLM versions.

- Handle missing AudioMediaIO by creating standalone implementation
- Add fallback for utils module patching
- Maintain backward compatibility with older vLLM versions

Co-authored-by: donglixp <1070872+donglixp@users.noreply.github.com>
2026-01-29 02:22:47 +00:00
copilot-swe-agent[bot] 11dd7420ec Initial plan 2026-01-29 02:19:04 +00:00
17 changed files with 401 additions and 3524 deletions
+3 -11
View File
@@ -3,13 +3,11 @@
## 🎙️ VibeVoice: Open-Source Frontier Voice AI
[![Project Page](https://img.shields.io/badge/Project-Page-blue?logo=githubpages)](https://microsoft.github.io/VibeVoice)
[![Hugging Face](https://img.shields.io/badge/HuggingFace-Collection-orange?logo=huggingface)](https://huggingface.co/collections/microsoft/vibevoice-68a2ef24a875c44be47b034f)
[![TTS Report](https://img.shields.io/badge/TTS-Report-red?logo=arxiv)](https://openreview.net/pdf?id=FihSkzyxdv)
[![TTS Report](https://img.shields.io/badge/TTS-Report-red?logo=arxiv)](https://arxiv.org/pdf/2508.19205)
[![ASR Report](https://img.shields.io/badge/ASR-Report-yellow?logo=arxiv)](https://arxiv.org/pdf/2601.18184)
[![Colab](https://img.shields.io/badge/StreamingTTS-Colab-green?logo=googlecolab)](https://colab.research.google.com/github/microsoft/VibeVoice/blob/main/demo/VibeVoice_colab.ipynb)
[![ASR Playground](https://img.shields.io/badge/ASR-Playground-6F42C1?logo=gradio)](https://aka.ms/vibevoice-asr)
[![microsoft%2FVibeVoice | Trendshift](https://trendshift.io/api/badge/repositories/15465)](https://trendshift.io/repositories/15465)
</div>
@@ -24,13 +22,7 @@
<h3>📰 News</h3>
<strong>2026-03-29: 🎉 VibeVoice-ASR is being adopted by the open-source community! <a href="https://vibingjustspeakit.github.io/Vibing/">Vibing</a>, a voice-powered input method, is now built on top of VibeVoice-ASR. Download: [macOS](https://github.com/VibingJustSpeakIt/Vibing/releases/download/v0.1.0/Vibing-v0.1.0-mac.dmg) | [Windows](https://github.com/VibingJustSpeakIt/Vibing/releases/download/v0.1.0/Vibing-v0.1.0-windows.exe)</strong>
https://github.com/user-attachments/assets/db0bb23f-ae06-4135-a66a-1ff1669f4f84
<strong>2026-03-06: 🚀 VibeVoice ASR is now part of a <a href="https://github.com/huggingface/transformers/releases/tag/v5.3.0">Transformers release</a>! You can now use our speech recognition model directly through the Hugging Face Transformers library for seamless integration into your projects.</strong>
<strong>2026-01-21:</strong> 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. Try it in [Playground](https://aka.ms/vibevoice-asr).
<strong>2026-01-21: 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. Try it in [Playground](https://aka.ms/vibevoice-asr)</strong>.
- ⭐️ VibeVoice-ASR is natively multilingual, supporting over 50 languages — check the [supported languages](docs/vibevoice-asr.md#language-distribution) for details.
- 🔥 The VibeVoice-ASR [finetuning code](finetuning-asr/README.md) is now available!
- ⚡️ **vLLM inference** is now supported for faster inference; see [vllm-asr](docs/vibevoice-vllm-asr.md) for more details.
@@ -44,7 +36,7 @@ https://github.com/user-attachments/assets/db0bb23f-ae06-4135-a66a-1ff1669f4f84
2025-09-05: VibeVoice is an open-source research framework intended to advance collaboration in the speech synthesis community. After release, we discovered instances where the tool was used in ways inconsistent with the stated intent. Since responsible use of AI is one of Microsofts guiding principles, we have removed the VibeVoice-TTS code from this repository.
2025-08-25: 📣 We open-sourced <a href="docs/vibevoice-tts.md"><strong>VibeVoice-TTS</strong></a>, a long-form multi-speaker text-to-speech model that can synthesize speech up to 90 minutes long with up to 4 distinct speakers. — accepted as an [Oral](https://openreview.net/forum?id=FihSkzyxdv) at ICLR 2026! 🔥
2025-08-25: 📣 We open-sourced <a href="docs/vibevoice-tts.md"><strong>VibeVoice-TTS</strong></a>, a long-form multi-speaker text-to-speech model that can synthesize speech up to 90 minutes long with up to 4 distinct speakers.
</div>
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -66,7 +66,7 @@
"print(\"✅ Cloned VibeVoice repository\")\n",
"\n",
"# Install project dependencies\n",
"!uv pip --quiet install --system -e /content/VibeVoice[streamingtts]\n",
"!uv pip --quiet install --system -e /content/VibeVoice[tts]\n",
"!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared && chmod +x cloudflared\n",
"print(\"✅ Installed dependencies\")\n",
"\n",
+1 -1
View File
@@ -142,7 +142,7 @@ class StreamingTTSService:
if name and name in self.voice_presets:
return name
default_key = "en-Carter_man"
default_key = "en-WHTest_man"
if default_key in self.voice_presets:
return default_key
+1 -1
View File
@@ -97,7 +97,7 @@ sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulim
git clone https://github.com/microsoft/VibeVoice.git
cd VibeVoice/
pip install -e .[streamingtts]
pip install -e .
```
+3 -96
View File
@@ -10,7 +10,6 @@ Deploy VibeVoice ASR model as a high-performance API service using [vLLM](https:
- **📡 OpenAI-Compatible API**: Standard `/v1/chat/completions` endpoint with streaming support
- **🎵 Long Audio Support**: Process up to 60+ minutes of audio in a single request
- **🔌 Plugin Architecture**: No vLLM source code modification required - just install and run
- **⚡ Data Parallel (DP)**: Run independent model replicas across multiple GPUs with automatic load balancing behind a single port
## 🛠️ Installation
@@ -32,87 +31,10 @@ docker run -d --gpus all --name vibevoice-vllm \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
vllm/vllm-openai:latest \
-c "python3 /app/vllm_plugin/scripts/start_server.py"
```
## ⚡ Multi-GPU Deployment
The launcher supports two types of GPU parallelism via `--tp` and `--dp` flags:
| Flag | Name | What it does |
|------|------|-------------|
| `--tp N` | Tensor Parallel | Splits **one model** across N GPUs (for models too large for a single GPU) |
| `--dp N` | Data Parallel | Runs **N independent replicas**, one per GPU, with automatic load balancing behind a single port |
### Data Parallel (Recommended for scaling throughput)
Run N independent replicas on N GPUs with automatic load balancing behind a single port.
When `--dp N` is specified (N > 1), the launcher automatically starts N independent vLLM
processes behind an nginx reverse proxy (2×N workers) for optimal throughput:
```bash
docker run -d --gpus '"device=0,1,2,3"' --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 4"
```
Run on all 8 GPUs:
```bash
docker run -d --gpus all --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 8"
```
### Tensor Parallel
Split a single model across 2 GPUs (useful if GPU memory is limited):
```bash
docker run -d --gpus '"device=0,1"' --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --tp 2"
```
### Hybrid (DP × TP)
Combine both — e.g., 2 replicas, each split across 2 GPUs (4 GPUs total):
```bash
docker run -d --gpus '"device=0,1,2,3"' --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 2 --tp 2"
```
> **Note**: Total GPUs required = `dp × tp`. Make sure to expose enough GPU devices in the Docker `--gpus` flag.
3. View logs
```bash
docker logs -f vibevoice-vllm
@@ -130,25 +52,10 @@ docker logs -f vibevoice-vllm
Once the vLLM server is running, test it with the provided script:
```bash
# Basic transcription
# Run the test (use container path /app/...)
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
# With hotwords for better recognition of specific terms
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
```
```bash
# With auto-recovery from repetition loops (for long audio)
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
# Auto-recover with hotwords
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
```
> **Note**:
> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing.
> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names.
> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing.
### Environment Variables
-6
View File
@@ -38,12 +38,6 @@ dependencies = [
"requests",
]
[project.optional-dependencies]
streamingtts = [
"transformers==4.51.3",
]
[project.entry-points."vllm.general_plugins"]
vibevoice = "vllm_plugin:register_vibevoice"
+1
View File
@@ -15,6 +15,7 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
from .model import VibeVoiceForCausalLM
from .inputs import vibevoice_audio_input_mapper
def register_vibevoice():
+241 -113
View File
@@ -5,11 +5,17 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
integration for speech-to-text inference.
"""
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
import json
import math
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from io import BytesIO
import tempfile
import base64
@@ -23,12 +29,32 @@ import base64
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
def _suffix_from_media_type(media_type: str | None) -> str:
if not media_type:
return ".bin"
mt = media_type.lower().strip()
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
return ".wav"
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
return ".mp3"
if mt in ("audio/flac",):
return ".flac"
if mt in ("audio/ogg", "audio/opus"):
return ".ogg"
if mt in ("audio/mp4", "audio/m4a"):
return ".m4a"
if mt in ("video/mp4",):
return ".mp4"
return ".bin"
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg.
Returns:
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
"""
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
normalizer = AudioNormalizer()
audio = normalizer(audio)
@@ -46,53 +72,91 @@ def _ffmpeg_load_file(filepath) -> tuple[np.ndarray, int]:
return audio, sr
# Register FFmpeg-based audio loader
try:
# Try new location (vLLM >= 0.6.x)
from vllm.multimodal.media.audio import AudioMediaIO as _OriginalAudioMediaIO
except ImportError:
# Fall back to old location (vLLM < 0.6.x)
import vllm.multimodal.audio as _vllm_audio_module
import warnings
# Handle both old and new vLLM versions
# In newer versions, AudioMediaIO may not exist or may have been moved
try:
_OriginalAudioMediaIO = _vllm_audio_module.AudioMediaIO
class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
def __init__(self, **kwargs):
# Call parent constructor if it exists
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data)
return _ffmpeg_load_bytes(data, media_type=None)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data))
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
# Replace globally
try:
# For new vLLM versions
import vllm.multimodal.media.audio as _vllm_audio_module
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
except ImportError:
# For old vLLM versions
import vllm.multimodal.audio as _vllm_audio_module
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
# Also patch in utils module where it's imported
try:
import vllm.multimodal.utils as _vllm_utils_module
_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
except (ImportError, AttributeError):
# AudioMediaIO might not be imported in utils in newer versions
except (ImportError, AttributeError) as e:
warnings.warn(f"Could not patch AudioMediaIO in vllm.multimodal.utils: {e}", UserWarning)
except (AttributeError, ImportError) as e:
# AudioMediaIO doesn't exist in this vLLM version
# Define our own standalone implementation
warnings.warn(
f"AudioMediaIO not found in vllm.multimodal.audio ({e}). "
"Using standalone FFmpeg-based implementation for compatibility.",
UserWarning
)
class _PatchedAudioMediaIO:
"""Standalone AudioMediaIO implementation using FFmpeg for audio decoding.
This is used when vLLM doesn't provide AudioMediaIO or it's been moved/removed.
"""
def __init__(self, **kwargs):
pass
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data, media_type=None)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
# Try to register it in the module if possible
try:
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
except (AttributeError, TypeError) as e:
warnings.warn(
f"Could not register AudioMediaIO in vllm.multimodal.audio: {e}. "
"Audio loading will use FFmpeg functions directly.",
UserWarning
)
# ============================================================================
from transformers import BatchFeature
from transformers import Qwen2Config, BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.config import VllmConfig, ModelConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import MultiModalDataParser
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
from vllm.inputs import PromptType
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
maybe_prefix,
@@ -107,17 +171,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
try:
# Try new location (vLLM >= 0.6.x)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
except ImportError:
# Fall back to old location (vLLM < 0.6.x)
try:
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
except ImportError:
# If neither location works, try individual imports
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.multimodal.processing.inputs import ProcessorInputs
# Import VibeVoice components
from vibevoice.modular.modular_vibevoice_tokenizer import (
@@ -554,88 +608,30 @@ class VibeVoiceProcessingInfo(BaseProcessingInfo):
return tokens
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""Return the maximum number of audio tokens per item.
This tells vLLM's scheduler the upper bound so that
``encoder_compute_budget`` is large enough for any audio length
the model can handle, preventing the silent scheduling deadlock
described in docs/max_num_batched_tokens_issue.md.
Formula: audio_tokens = ceil(audio_samples / compress_ratio) + 3
where +3 accounts for speech_start, speech_end, and newline tokens.
The max audio samples is bounded by seq_len (the model's context
window cannot hold more tokens than that).
"""
hf_config = self.get_hf_config()
def _cfg(key: str, default):
if isinstance(hf_config, dict):
return hf_config.get(key, default)
return getattr(hf_config, key, default)
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
sample_rate = int(_cfg("target_sample_rate", 24000))
# Upper bound: 61-minute audio at 24 kHz
max_audio_samples = 61 * 60 * sample_rate # 88,464,000
max_audio_tokens = int(np.ceil(max_audio_samples / compress_ratio)) + 3
# Cannot exceed the model's context window
max_audio_tokens = min(max_audio_tokens, seq_len)
return {"audio": max_audio_tokens}
return {"audio": None}
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
"""
Build dummy inputs for multimodal profiling.
vLLM uses dummy inputs to:
1. Measure peak GPU activation memory → determine KV cache capacity
2. Warm up CUDA graphs
The dummy audio length must be consistent with ``get_mm_max_tokens_per_item``
so that the memory estimate covers the worst-case (longest audio) scenario.
Dummy text uses the raw <|AUDIO|> token(s). vLLM's processing pipeline will
expand each <|AUDIO|> via `VibeVoiceMultiModalProcessor._get_prompt_updates`
into the full ASR format:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
where N is derived from audio length / compress_ratio.
"""
def _get_max_audio_samples(self, seq_len: int) -> int:
"""Compute maximum audio samples consistent with ``get_mm_max_tokens_per_item``.
Uses the same formula: max_tokens = min(ceil(61min * sr / ratio) + 3, seq_len),
then converts back to samples.
"""
hf_config = self.info.get_hf_config()
def _cfg(key: str, default):
if isinstance(hf_config, dict):
return hf_config.get(key, default)
return getattr(hf_config, key, default)
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
sample_rate = int(_cfg("target_sample_rate", 24000))
# Upper bound: 61-minute audio at 24 kHz
max_hour_samples = 61 * 60 * sample_rate # 88,464,000
max_tokens_from_audio = int(np.ceil(max_hour_samples / compress_ratio)) + 3
# Cannot exceed model context window
max_tokens = min(max_tokens_from_audio, seq_len)
# Convert tokens back to samples
return max_tokens * compress_ratio
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
if num_audios <= 0:
return ""
# Get the audio token from our token info helper
token_info = self.info.get_audio_token_info()
audio_token = token_info["audio_token"]
# Return ONLY the audio tokens - the HF processor adds bos/eos
return audio_token * num_audios
def get_dummy_mm_data(
@@ -644,23 +640,16 @@ class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo
mm_counts: Mapping[str, int],
mm_options: Mapping[str, Any] | None = None,
) -> Dict[str, Any]:
"""Generate dummy audio data for profiling.
"""Generate dummy audio data for profiling."""
feature_extractor = self.info.get_feature_extractor()
The audio length is derived from ``seq_len`` so that profiling
accurately measures memory for the longest audio the model can handle.
Supports ``AudioDummyOptions.length`` override for faster startup.
"""
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
max_audio_len = self._get_max_audio_samples(seq_len)
audio_overrides = mm_options.get("audio") if mm_options else None
# Generate dummy audio as numpy arrays (what the HF processor expects)
return {
"audio": self._get_dummy_audios(
length=max_audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
"audio": [np.zeros(audio_len, dtype=np.float32) for _ in range(num_audios)]
}
def get_dummy_processor_inputs(
@@ -934,6 +923,17 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with a causal language model for text generation.
"""
# SupportsTranscription interface
supports_transcription: ClassVar[Literal[True]] = True
supports_transcription_only: ClassVar[bool] = False
supports_segment_timestamp: ClassVar[bool] = False
# Supported languages (Chinese as primary target)
supported_languages: ClassVar[Mapping[str, str]] = {
"zh": "Chinese",
"en": "English",
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""Return the placeholder string format for a given modality.
@@ -947,11 +947,113 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return "<|AUDIO|>"
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the prompt for the ASR model.
Generates a chat-formatted prompt for speech-to-text transcription
with JSON output format.
"""
# If user provides custom prompt, use it
if request_prompt:
return request_prompt
# Calculate audio duration for the prompt
# Audio should be at 24kHz, so duration = len(audio) / 24000
duration = len(audio) / 24000 if audio is not None else 10.0
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
user_suffix = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
user_content = "<|AUDIO|>\n" + user_suffix
prompt = (
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{user_content}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return prompt
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model."""
return SpeechToTextConfig(
language=None, # Auto-detect or use request language
task_type=task_type,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
"""Estimate number of audio tokens from duration.
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
Note: _get_prompt_updates actually generates:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
So total prompt tokens = N + 3, but this returns N (the embedding count).
"""
sampling_rate = 24000
compress_ratio = 3200
samples = int(audio_duration_s * sampling_rate)
num_tokens = int(np.ceil(samples / compress_ratio))
return num_tokens
@classmethod
def get_other_languages(cls) -> Mapping[str, str]:
"""Get languages from Whisper map not natively supported."""
# Import LANGUAGES from vllm
try:
from vllm.transformers_utils.tokenizer import LANGUAGES
except ImportError:
# Fallback to empty dict if import fails
return {}
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
@classmethod
def validate_language(cls, language: str | None) -> str | None:
"""Validate the language code."""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
print(f"Warning: Language {language!r} is not natively supported")
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. "
f"Supported: {list(cls.supported_languages.keys())}"
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
# Keep a copy of the resolved model path for any custom weight-loading
# logic (e.g., loading audio encoder weights in fp32 directly from
# safetensors shards).
self._model_path = vllm_config.model_config.model
self.audio_encoder = VibeVoiceAudioEncoder(config)
# Pass decoder_config to the language model initialization
@@ -1048,54 +1150,76 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Process each audio through the VibeVoice encoder
embeddings = []
# Get model device for tensor placement.
# Get model device and dtype for alignment
try:
device = next(self.audio_encoder.parameters()).device
dtype = next(self.audio_encoder.parameters()).dtype
except StopIteration:
# Fallback if no parameters (shouldn't happen)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Handle both stacked tensor and list of tensors
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
if isinstance(raw_audio, torch.Tensor):
if raw_audio.dim() == 3:
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
elif raw_audio.dim() == 2:
# Shape: [batch_size, seq_len]
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i] for i in range(num_audios)]
else:
# Single 1D tensor
audio_list = [raw_audio]
elif isinstance(raw_audio, (list, tuple)):
audio_list = list(raw_audio)
else:
# Single tensor
audio_list = [raw_audio]
for i, audio_tensor in enumerate(audio_list):
try:
if isinstance(audio_tensor, list):
audio_tensor = torch.stack(audio_tensor)
# Ensure tensor
if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor)
# Let vLLM handle dtype (bfloat16 by default)
audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length
if raw_audio_lengths and i < len(raw_audio_lengths):
actual_len = int(raw_audio_lengths[i])
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
# Truncate from the last dimension (sequence length)
audio_tensor = audio_tensor[..., :actual_len]
if audio_tensor.numel() < 160:
# Skip if audio is too short (< 1 frame)
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
continue
# Encode audio through VibeVoice encoder
audio_embeds = self.audio_encoder(
audio_tensor,
use_streaming=use_streaming_flag,
segment_duration_s=streaming_segment_duration,
)
# audio_embeds shape: [1, seq_len, hidden_size]
# We need to return it as a single embedding tensor per audio
final_embed = audio_embeds.squeeze(0)
embeddings.append(final_embed)
except Exception as e:
# Log error but don't crash - this helps debug profiling issues
print(f"[VibeVoice] Error encoding audio {i}: {e}")
import traceback
traceback.print_exc()
# Return empty embedding to avoid crash
continue
return tuple(embeddings)
@@ -1220,6 +1344,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Intermediate tensors for pipeline parallelism.
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
"""
try:
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None:
@@ -1246,6 +1371,9 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
return hidden_states
except Exception as e:
raise
# Alias for training checkpoint compatibility
VibeVoiceForASRTraining = VibeVoiceForCausalLM
File diff suppressed because it is too large Load Diff
+16 -298
View File
@@ -9,21 +9,14 @@ One-click deployment script that handles:
4. Generating tokenizer files
5. Starting vLLM server
For DP > 1, launches N independent vLLM processes behind an nginx
reverse proxy for optimal throughput (avoids single-process HTTP
bottleneck of vLLM's built-in DP coordinator).
Usage:
python3 start_server.py [--model MODEL_ID] [--port PORT]
"""
import argparse
import os
import signal
import subprocess
import sys
import textwrap
import time
def run_command(cmd: list[str], description: str, shell: bool = False) -> None:
@@ -84,268 +77,45 @@ def generate_tokenizer(model_path: str) -> None:
)
def _build_vllm_cmd(model_path: str, port: int,
tensor_parallel_size: int = 1,
data_parallel_size: int = 1,
max_num_seqs: int = 64,
max_model_len: int = 65536,
gpu_memory_utilization: float = 0.8) -> list[str]:
"""Build the vllm serve command."""
return [
def start_vllm_server(model_path: str, port: int) -> None:
"""Start vLLM server (replaces current process)."""
print(f"\n{'='*60}")
print(f" Starting vLLM server on port {port}")
print(f"{'='*60}\n")
vllm_cmd = [
"vllm", "serve", model_path,
"--served-model-name", "vibevoice",
"--trust-remote-code",
"--dtype", "bfloat16",
"--max-num-seqs", str(max_num_seqs),
"--max-model-len", str(max_model_len),
"--gpu-memory-utilization", str(gpu_memory_utilization),
"--max-num-seqs", "64",
"--max-model-len", "65536",
"--max-num-batched-tokens", "32768",
"--gpu-memory-utilization", "0.8",
"--enforce-eager",
"--no-enable-prefix-caching",
"--enable-chunked-prefill",
"--chat-template-content-format", "openai",
"--tensor-parallel-size", str(tensor_parallel_size),
"--data-parallel-size", str(data_parallel_size),
"--tensor-parallel-size", "1",
"--allowed-local-media-path", "/app",
"--port", str(port),
]
def start_vllm_server(model_path: str, port: int,
tensor_parallel_size: int = 1,
data_parallel_size: int = 1,
max_num_seqs: int = 64,
max_model_len: int = 65536,
gpu_memory_utilization: float = 0.8) -> None:
"""Start a single vLLM server (replaces current process)."""
print(f"\n{'='*60}")
print(f" Starting vLLM server on port {port}")
print(f" Tensor Parallel (TP): {tensor_parallel_size}")
print(f" Data Parallel (DP): {data_parallel_size}")
print(f" Max Num Seqs: {max_num_seqs}")
print(f" Max Model Len: {max_model_len}")
print(f" GPU Mem Utilization: {gpu_memory_utilization}")
print(f"{'='*60}\n")
vllm_cmd = _build_vllm_cmd(
model_path, port,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
)
os.execvp("vllm", vllm_cmd)
def _install_nginx() -> None:
"""Install nginx if not already available."""
if subprocess.run(["which", "nginx"], capture_output=True).returncode != 0:
run_command(["apt-get", "update"], "Updating package list for nginx")
run_command(
["apt-get", "install", "-y", "nginx"],
"Installing nginx for load balancing"
)
def _write_nginx_config(frontend_port: int, backend_ports: list[int],
num_workers: int = 0) -> str:
"""Write nginx config for round-robin load balancing.
Args:
num_workers: Number of nginx worker processes. 0 = auto (2 × num backends).
"""
if num_workers <= 0:
num_workers = len(backend_ports) * 2
backends = "\n".join(f" server 127.0.0.1:{p};" for p in backend_ports)
config = textwrap.dedent(f"""\
worker_processes {num_workers};
worker_rlimit_nofile 65536;
error_log /dev/stderr warn;
pid /tmp/nginx.pid;
events {{
worker_connections 8192;
}}
http {{
access_log off;
upstream vllm_backends {{
least_conn;
{backends}
}}
server {{
listen {frontend_port};
client_max_body_size 200m;
client_body_buffer_size 10m;
proxy_buffering on;
proxy_buffer_size 64k;
proxy_buffers 16 64k;
location / {{
proxy_pass http://vllm_backends;
proxy_read_timeout 600s;
proxy_connect_timeout 10s;
proxy_send_timeout 600s;
proxy_http_version 1.1;
proxy_set_header Connection "";
}}
}}
}}
""")
config_path = "/tmp/nginx_vllm.conf"
with open(config_path, "w") as f:
f.write(config)
return config_path
def start_dp_server(model_path: str, frontend_port: int,
data_parallel_size: int,
tensor_parallel_size: int = 1,
max_num_seqs: int = 64,
max_model_len: int = 65536,
gpu_memory_utilization: float = 0.8) -> None:
"""Start multiple vLLM workers behind nginx for data parallelism.
Launches N independent vLLM processes (one per GPU group) on internal
ports, with an nginx reverse proxy on the frontend port for load
balancing. This avoids the single-process HTTP bottleneck of vLLM's
built-in DP coordinator when handling large audio payloads.
"""
import torch
num_gpus = torch.cuda.device_count()
gpus_per_replica = tensor_parallel_size
total_gpus_needed = data_parallel_size * gpus_per_replica
assert num_gpus >= total_gpus_needed, (
f"Need {total_gpus_needed} GPUs (dp={data_parallel_size} × tp={tensor_parallel_size}) "
f"but only {num_gpus} available"
)
# Auto-tune per-worker env vars based on dp size
ffmpeg_concurrency = max(
64, int(os.environ.get("VIBEVOICE_FFMPEG_MAX_CONCURRENCY", "64"))
)
media_threads = max(
8, int(os.environ.get("VLLM_MEDIA_LOADING_THREAD_COUNT", "8"))
)
_install_nginx()
# Assign internal ports: frontend_port + 100, +101, ...
backend_ports = [frontend_port + 100 + i for i in range(data_parallel_size)]
print(f"\n{'='*60}")
print(f" Starting DP server with nginx load balancing")
print(f" Frontend port: {frontend_port} (nginx)")
print(f" Backend ports: {backend_ports}")
print(f" Data Parallel: {data_parallel_size}")
print(f" Tensor Parallel: {tensor_parallel_size}")
print(f" GPUs per replica: {gpus_per_replica}")
print(f" Max Num Seqs: {max_num_seqs}")
print(f" Max Model Len: {max_model_len}")
print(f" FFmpeg concurrency (per worker): {ffmpeg_concurrency}")
print(f" Media loading threads (per worker): {media_threads}")
print(f"{'='*60}\n")
# Write nginx config
nginx_conf = _write_nginx_config(frontend_port, backend_ports)
# Launch vLLM workers
workers: list[subprocess.Popen] = []
for rank in range(data_parallel_size):
gpu_start = rank * gpus_per_replica
gpu_ids = ",".join(str(gpu_start + j) for j in range(gpus_per_replica))
port = backend_ports[rank]
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = gpu_ids
env["VIBEVOICE_FFMPEG_MAX_CONCURRENCY"] = str(ffmpeg_concurrency)
env["VLLM_MEDIA_LOADING_THREAD_COUNT"] = str(media_threads)
vllm_cmd = _build_vllm_cmd(
model_path, port,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=1, # Each worker is dp=1
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
)
print(f" Launching worker rank={rank} on GPU(s) {gpu_ids}, port {port}")
proc = subprocess.Popen(vllm_cmd, env=env)
workers.append(proc)
# Start nginx
print(f"\n Starting nginx on port {frontend_port} ...")
nginx_proc = subprocess.Popen(
["nginx", "-c", nginx_conf, "-g", "daemon off;"]
)
# Wait for all backends to be ready
print(" Waiting for all backends to be ready ...")
import urllib.request
for port in backend_ports:
url = f"http://127.0.0.1:{port}/v1/models"
for attempt in range(600): # up to 10 minutes
try:
urllib.request.urlopen(url, timeout=2)
print(f" ✅ Backend on port {port} is ready")
break
except Exception:
time.sleep(1)
else:
print(f" ❌ Backend on port {port} failed to start")
print(f"\n{'='*60}")
print(f" ✅ VibeVoice DP server ready on port {frontend_port}")
print(f" {data_parallel_size} replicas behind nginx load balancer")
print(f"{'='*60}\n")
# Handle shutdown: forward signals to all children
def _shutdown(signum, frame):
print("\nShutting down ...")
nginx_proc.terminate()
for w in workers:
w.terminate()
for w in workers:
w.wait(timeout=10)
nginx_proc.wait(timeout=5)
sys.exit(0)
signal.signal(signal.SIGTERM, _shutdown)
signal.signal(signal.SIGINT, _shutdown)
# Wait for any child to exit (indicates a failure)
while True:
for i, w in enumerate(workers):
ret = w.poll()
if ret is not None:
print(f" ❌ Worker {i} exited with code {ret}")
_shutdown(None, None)
if nginx_proc.poll() is not None:
print(f" ❌ nginx exited with code {nginx_proc.returncode}")
_shutdown(None, None)
time.sleep(1)
def main():
parser = argparse.ArgumentParser(
description="VibeVoice vLLM ASR Server - One-Click Deployment",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Start with default settings (single GPU)
# Start with default settings
python3 start_server.py
# Use custom port
python3 start_server.py --port 8080
# Data parallel: 4 replicas on 4 GPUs (nginx load balancing)
python3 start_server.py --dp 4
# Tensor parallel: split model across 2 GPUs
python3 start_server.py --tp 2
# Skip dependency installation (if already installed)
python3 start_server.py --skip-deps
"""
@@ -371,41 +141,6 @@ Examples:
action="store_true",
help="Skip generating tokenizer files"
)
parser.add_argument(
"--tp", "--tensor-parallel-size",
type=int,
default=1,
dest="tensor_parallel_size",
help="Tensor parallel size: split one model across N GPUs (default: 1)"
)
parser.add_argument(
"--dp", "--data-parallel-size",
type=int,
default=1,
dest="data_parallel_size",
help="Data parallel size: run N independent model replicas for load balancing (default: 1)"
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=64,
dest="max_num_seqs",
help="Maximum number of sequences per batch (default: 64)"
)
parser.add_argument(
"--max-model-len",
type=int,
default=65536,
dest="max_model_len",
help="Maximum model context length (default: 65536)"
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.8,
dest="gpu_memory_utilization",
help="GPU memory utilization fraction (default: 0.8)"
)
args = parser.parse_args()
print("\n" + "="*60)
@@ -426,25 +161,8 @@ Examples:
if not args.skip_tokenizer:
generate_tokenizer(model_path)
# Step 5: Start server
if args.data_parallel_size > 1:
start_dp_server(
model_path, args.port,
data_parallel_size=args.data_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.max_num_seqs,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
)
else:
start_vllm_server(
model_path, args.port,
tensor_parallel_size=args.tensor_parallel_size,
data_parallel_size=1,
max_num_seqs=args.max_num_seqs,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
)
# Step 5: Start vLLM server
start_vllm_server(model_path, args.port)
if __name__ == "__main__":
+85 -101
View File
@@ -1,23 +1,14 @@
#!/usr/bin/env python3
"""
Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.
This script tests ASR transcription via the vLLM OpenAI-compatible API.
By default, it runs standard transcription without hotwords.
Optionally, you can provide hotwords (context_info) to improve recognition
of domain-specific content like proper nouns, technical terms, and speaker names.
Hotwords are embedded in the prompt as "with extra info: {hotwords}".
Test VibeVoice vLLM API with Streaming (Real-time output).
Usage:
python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
python test_api.py [audio_path] [--url URL]
Examples:
# Standard transcription (no hotwords)
python3 test_api.py audio.wav
# With hotwords for better recognition of specific terms
python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
python test_api.py # Use default audio
python test_api.py /path/to/audio.wav # Specify audio file
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
"""
import requests
import json
@@ -30,38 +21,38 @@ import argparse
def _guess_mime_type(path: str) -> str:
"""Guess MIME type from file extension."""
ext = os.path.splitext(path)[1].lower()
mime_map = {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".m4a": "audio/mp4",
".mp4": "video/mp4",
".flac": "audio/flac",
".ogg": "audio/ogg",
".opus": "audio/ogg",
}
return mime_map.get(ext, "application/octet-stream")
if ext == ".wav":
return "audio/wav"
if ext in (".mp3",):
return "audio/mpeg"
if ext in (".m4a",):
return "audio/mp4"
if ext in (".mp4", ".m4v", ".mov", ".webm"):
return "video/mp4"
if ext in (".flac",):
return "audio/flac"
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
def _get_duration_seconds_ffprobe(path: str) -> float:
"""Get audio duration using ffprobe."""
cmd = [
"ffprobe", "-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
path,
]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out)
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _extract_audio_from_video(video_path: str) -> str:
"""
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
@@ -83,40 +74,26 @@ def _extract_audio_from_video(video_path: str) -> str:
return audio_path
def test_transcription_with_hotwords(
audio_path: str,
context_info: str = None,
base_url: str = "http://localhost:8000",
):
"""
Test ASR transcription with customized hotwords.
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
This helps the model recognize domain-specific terms more accurately.
Args:
audio_path: Path to the audio file
context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
base_url: vLLM server URL
"""
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"):
"""Test ASR transcription with streaming output."""
print(f"=" * 70)
print(f"Testing Customized Hotwords Support")
print(f"=" * 70)
print(f"Input file: {audio_path}")
print(f"Hotwords: {context_info or '(none)'}")
print()
print(f"Loading audio from: {audio_path}")
# Handle video files: extract audio first
temp_audio_path = None
actual_audio_path = audio_path
if _is_video_file(audio_path):
print(f"🎬 Detected video file, extracting audio...")
print(f"Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}")
print(f"Audio extracted to: {temp_audio_path}")
# Load audio
try:
duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds")
@@ -129,30 +106,16 @@ def test_transcription_with_hotwords(
except Exception as e:
print(f"Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return
# Build the request
url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
# Hotwords are embedded as "with extra info: {hotwords}" in the prompt
if context_info and context_info.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{context_info}'")
else:
prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}"
@@ -176,19 +139,20 @@ def test_transcription_with_hotwords(
"temperature": 0.0,
"stream": True,
"top_p": 1.0,
"repetition_penalty": 1.0,
}
print(f"\n{'=' * 70}")
print(f"Sending request to {url}")
print(f"{'=' * 70}")
print(f"\nSending request to {url} (Streaming Mode)...")
print(f"Prompt: {prompt_text}")
print("-" * 60)
t0 = time.time()
try:
response = requests.post(url, json=payload, stream=True, timeout=12000)
if response.status_code == 200:
print("\nResponse received. Streaming content:\n")
print("-" * 50)
print("Response received. Streaming content:\n")
printed = ""
for line in response.iter_lines():
@@ -198,72 +162,92 @@ def test_transcription_with_hotwords(
if decoded_line.startswith("data: "):
json_str = decoded_line[6:]
if json_str.strip() == "[DONE]":
print("\n" + "-" * 50)
print("✅ [Finished]")
print("\n\n[Finished]")
break
try:
data = json.loads(json_str)
delta = data['choices'][0]['delta']
content = delta.get('content', '')
if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only print the newly-added part to avoid repeats.
if content.startswith(printed):
to_print = content[len(printed):]
else:
to_print = content
if to_print:
print(to_print, end='', flush=True)
printed += to_print
except json.JSONDecodeError:
pass
else:
print(f"Error: {response.status_code}")
print(f"Error: {response.status_code}")
print(response.text)
except requests.exceptions.Timeout:
print("Request timed out!")
print("\nRequest timed out!")
except Exception as e:
print(f"Error: {e}")
print(f"\nError: {e}")
elapsed = time.time() - t0
print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"📊 RTF (Real-Time Factor): {elapsed / duration:.2f}x")
print(f"{'=' * 70}")
print(f"\n{'-'*60}")
print(f"Total time elapsed: {time.time() - t0:.2f}s")
# Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
print(f"Cleaned up temp file: {temp_audio_path}")
def main():
parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with Customized Hotwords"
description="Test VibeVoice vLLM API with streaming output"
)
parser.add_argument(
"audio_path",
nargs="?",
default=None,
help="Path to audio file (wav, mp3, flac, etc.) or video file"
)
parser.add_argument(
"--url",
default="http://localhost:8000",
help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
help="vLLM server base URL (default: http://localhost:8000)"
)
args = parser.parse_args()
# Run test
test_transcription_with_hotwords(
audio_path=args.audio_path,
context_info=args.hotwords,
base_url=args.url,
)
# Find default audio if not specified
audio_path = args.audio_path
if audio_path is None:
# Try to find a sample audio in common locations
possible_paths = [
# In VibeVoice demo folder
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
# Relative to current directory
"demo/voices/en-Carter_man.wav",
"demo/voices/zh-Anchen_man_bgm.wav",
]
for path in possible_paths:
if os.path.exists(path):
audio_path = path
break
if audio_path is None:
print("Error: No audio file specified and no default audio found.")
print("Usage: python test_api.py <audio_path>")
sys.exit(1)
if not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}")
sys.exit(1)
test_transcription(audio_path, args.url)
if __name__ == "__main__":
-638
View File
@@ -1,638 +0,0 @@
#!/usr/bin/env python3
"""
Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery.
This script tests ASR transcription with automatic recovery from repetition loops.
Supports optional hotwords to improve recognition of domain-specific terms.
Features:
- Streaming output with real-time repetition detection
- Auto-recovery when model enters repetition loops
- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}")
- Video file support (auto-extracts audio)
Recovery Strategy:
1. First attempt: greedy decoding (temperature=0, top_p=1.0)
2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95
3. Max 3 retries, truncate to last complete segment boundary
Usage:
python test_api_auto_recover.py <audio_path> [output_path] [--url URL] [--hotwords "word1,word2"] [--debug]
Examples:
# Basic usage
python3 test_api_auto_recover.py audio.wav
# With hotwords
python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice"
# Save result to file
python3 test_api_auto_recover.py audio.wav result.txt
# Debug mode (show recovery info)
python3 test_api_auto_recover.py audio.wav --debug
"""
import requests
import json
import base64
import time
import sys
import os
import subprocess
import re
import argparse
from collections import Counter
def _guess_mime_type(path: str) -> str:
ext = os.path.splitext(path)[1].lower()
if ext == ".wav":
return "audio/wav"
if ext in (".mp3",):
return "audio/mpeg"
if ext in (".m4a",):
return "audio/mp4"
if ext in (".mp4", ".m4v", ".mov", ".webm"):
return "video/mp4"
if ext in (".flac",):
return "audio/flac"
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
def _get_duration_seconds_ffprobe(path: str) -> float:
cmd = [
"ffprobe", "-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
path,
]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out)
def _extract_audio_from_video(video_path: str) -> str:
"""
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
Returns the path to the extracted audio file.
"""
import tempfile
# Create temp file with .mp3 extension
fd, audio_path = tempfile.mkstemp(suffix=".mp3")
os.close(fd)
cmd = [
"ffmpeg", "-y", "-i", video_path,
"-vn", # No video
"-acodec", "libmp3lame",
"-q:a", "2", # High quality
audio_path
]
subprocess.run(cmd, check=True, capture_output=True)
return audio_path
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _find_last_segment_boundary(text: str) -> int:
"""
Find the position after the last complete segment boundary (},).
Returns -1 if no complete segment found.
"""
# Find last "}, " or "}," pattern (segment separator)
pos = text.rfind("},")
if pos != -1:
return pos + 2 # Include the },
return -1
def _find_safe_print_boundary(text: str, max_pos: int) -> int:
"""
Find the last complete segment boundary before max_pos.
Returns 0 if no complete segment found before max_pos.
"""
search_text = text[:max_pos]
pos = search_text.rfind("},")
if pos != -1:
return pos + 2 # Include the },
return 0
class RepetitionDetector:
"""Detect repetition patterns in streaming text."""
def __init__(self,
min_pattern_len: int = 10, # Minimum chars for a pattern
min_repeats: int = 3, # Minimum repetitions to trigger
window_size: int = 500): # Window to check for patterns
self.min_pattern_len = min_pattern_len
self.min_repeats = min_repeats
self.window_size = window_size
self.text = ""
def add_text(self, new_text: str):
"""Add new text and return (is_looping, good_text_end_pos)."""
self.text += new_text
return self._check_repetition()
def _check_repetition(self):
"""Check if the recent text contains repetition loops."""
if len(self.text) < self.min_pattern_len * self.min_repeats:
return False, len(self.text)
# Check the recent window
window = self.text[-self.window_size:] if len(self.text) > self.window_size else self.text
# Method 1: Check for repeated substrings
for pattern_len in range(self.min_pattern_len, len(window) // self.min_repeats + 1):
# Get the last pattern_len characters as potential pattern
pattern = window[-pattern_len:]
# Count how many times this pattern appears at the end
count = 0
pos = len(window)
while pos >= pattern_len:
if window[pos - pattern_len:pos] == pattern:
count += 1
pos -= pattern_len
else:
break
if count >= self.min_repeats:
# Found repetition! Calculate where the good text ends
repetition_start = len(self.text) - (count * pattern_len)
# Keep one instance of the pattern (or none if it's garbage)
good_end = repetition_start + pattern_len if self._is_meaningful(pattern) else repetition_start
return True, good_end
# Method 2: Check for repeated short phrases (like "you're not, you're not")
# Look for patterns like "X, X, X" or "X X X"
words = window.split()
if len(words) >= self.min_repeats * 2:
# Check last N words for repetition
for phrase_len in range(2, 6): # 2-5 word phrases
if len(words) < phrase_len * self.min_repeats:
continue
phrase = " ".join(words[-phrase_len:])
count = 0
idx = len(words)
while idx >= phrase_len:
candidate = " ".join(words[idx - phrase_len:idx])
if candidate == phrase:
count += 1
idx -= phrase_len
else:
break
if count >= self.min_repeats:
# Calculate position in original text
repeated_text = (phrase + " ") * count
good_end = len(self.text) - len(repeated_text.rstrip()) + len(phrase)
return True, max(0, good_end)
return False, len(self.text)
def _is_meaningful(self, pattern: str) -> bool:
"""Check if pattern is meaningful content (not just garbage)."""
# Filter out patterns that are just punctuation, spaces, or very repetitive
clean = pattern.strip()
if not clean:
return False
if len(set(clean)) < 3: # Too few unique characters
return False
return True
def get_good_text(self, end_pos: int) -> str:
"""Get text up to the specified position."""
return self.text[:end_pos]
def reset(self, keep_text: str = ""):
"""Reset detector, optionally keeping some text."""
self.text = keep_text
def stream_with_recovery(
url: str,
base_messages: list,
audio_data_url: str,
prompt_text: str,
max_tokens: int = 32768,
max_retries: int = 3,
timeout: int = 12000,
debug: bool = False,
):
"""
Stream transcription with automatic recovery from repetition loops.
Args:
url: API endpoint
base_messages: Base messages (system + user with audio)
audio_data_url: The audio data URL for the request
prompt_text: The text prompt
max_tokens: Maximum tokens to generate
max_retries: Maximum recovery attempts (default 3)
timeout: Request timeout
debug: If True, show recovery debug info to stderr
Recovery strategy:
- First attempt: temperature=0.0, top_p=1.0 (greedy)
- Recovery: temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
- If has complete segments: use assistant prefix
- If no complete segments: restart from scratch
- Max 3 retries, if all fail output error message
Returns:
Final transcription text
"""
import sys as _sys
def _log(msg):
"""Log to stderr only if debug."""
if debug:
print(msg, file=_sys.stderr)
detector = RepetitionDetector(
min_pattern_len=10, # At least 10 chars for a pattern
min_repeats=10, # Must repeat 10+ times
window_size=400, # Check last 400 chars (can detect 10-40 char patterns repeated 10 times)
)
accumulated_text = ""
retry_count = 0
user_safe_printed_len = 0 # Track how much we've safely shown to user (at segment boundaries)
is_recovery = False # Whether we're in recovery mode
while retry_count <= max_retries:
# Build request payload
messages = list(base_messages) # Copy base messages
# If we have accumulated text from previous attempt, add it as partial assistant response
if accumulated_text:
# Add the good content as a partial assistant message
# vLLM will continue from here
messages.append({
"role": "assistant",
"content": accumulated_text
})
# Set sampling parameters based on recovery state
if is_recovery:
# Recovery: increase temperature each retry to break loops
recovery_temp = 0.1 + 0.1 * retry_count # 0.2, 0.3, 0.4 for retry 1, 2, 3
payload = {
"model": "vibevoice",
"messages": messages,
"max_tokens": max_tokens,
"temperature": recovery_temp,
"top_p": 0.95,
"stream": True,
}
if accumulated_text:
_log(f"[RECOVERY #{retry_count}] Continuing from {len(accumulated_text)} chars with temp={recovery_temp}, top_p=0.95")
else:
_log(f"[RECOVERY #{retry_count}] Restarting from scratch with temp={recovery_temp}, top_p=0.95")
else:
# First attempt: greedy decoding
payload = {
"model": "vibevoice",
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.0,
"top_p": 1.0,
"stream": True,
}
try:
response = requests.post(url, json=payload, stream=True, timeout=timeout)
if response.status_code != 200:
_log(f"[ERROR] {response.status_code} - {response.text[:500]}")
return accumulated_text
new_text = ""
printed = "" # Track what we've already received to handle vLLM duplicates
for line in response.iter_lines():
if not line:
continue
decoded_line = line.decode('utf-8')
if not decoded_line.startswith("data: "):
continue
json_str = decoded_line[6:]
if json_str.strip() == "[DONE]":
# Successfully finished without loops
full_result = accumulated_text + new_text
# Print any remaining content that wasn't printed yet
if len(full_result) > user_safe_printed_len:
remaining = full_result[user_safe_printed_len:]
print(remaining, end='', flush=True)
print() # Final newline
return full_result
try:
data = json.loads(json_str)
delta = data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only track the newly-added part.
if content.startswith(printed):
to_add = content[len(printed):]
else:
to_add = content
if to_add:
printed += to_add
new_text += to_add
# When continuing from prefix, model may add "[" or "[{" at start
# or repeat the ending "}, " from prefix
# We need to handle these to maintain valid JSON array format
if accumulated_text and new_text:
stripped = new_text.lstrip()
# Case 1: Model added "[{" - remove the "["
if stripped.startswith("[{"):
new_text = stripped[1:]
_log("[STRIPPED leading '[' from continuation]")
# Case 2: Model added just "[" - remove it
elif stripped.startswith("["):
new_text = stripped[1:]
_log("[STRIPPED leading '[' from continuation]")
# Case 3: Model repeated "}," from prefix ending
elif stripped.startswith("},"):
new_text = stripped[2:]
_log("[STRIPPED leading '},' from continuation]")
# Case 4: Model repeated "}" from prefix ending
elif stripped.startswith("}") and not stripped.startswith("}]"):
new_text = stripped[1:]
_log("[STRIPPED leading '}' from continuation]")
# Fix malformed JSON: {"2.99,... -> {"Start":2.99,...
# This happens when model skips "Start": key
import re
malformed = re.match(r'^\{"(\d+\.?\d*),', new_text)
if malformed:
time_val = malformed.group(1)
new_text = '{"Start":' + time_val + ',' + new_text[malformed.end():]
_log(f"[FIXED malformed JSON: added Start key]")
# Check for repetition in the combined text
full_text = accumulated_text + new_text
detector.text = full_text
is_looping, good_end = detector._check_repetition()
if is_looping:
_log(f"[LOOP DETECTED at char {good_end}]")
# Use what user has already seen as prefix for retry
# user_safe_printed_len is always at a segment boundary
if user_safe_printed_len > 0:
accumulated_text = full_text[:user_safe_printed_len]
_log(f"[RETRY from user-visible content at {user_safe_printed_len}]")
else:
# No complete segment shown to user yet - restart from scratch
accumulated_text = ""
_log(f"[NO CONTENT SHOWN TO USER - restart from scratch]")
detector.reset(accumulated_text)
is_recovery = True
if debug:
print("\n[...recovering...]", end='', flush=True, file=sys.stderr)
retry_count += 1
if retry_count > max_retries:
_log(f"[MAX RETRIES REACHED]")
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
return None
# Break inner loop to retry
break
else:
# No loop detected - stream content to user
# Only print up to (full_text_len - window_size) at segment boundaries
# This ensures user never sees content that might be rolled back
safe_end = max(0, len(full_text) - detector.window_size)
safe_boundary = _find_safe_print_boundary(full_text, safe_end)
if safe_boundary > user_safe_printed_len:
# Print new safe content
to_print = full_text[user_safe_printed_len:safe_boundary]
print(to_print, end='', flush=True)
user_safe_printed_len = safe_boundary
except json.JSONDecodeError:
continue
else:
# Loop completed without break (no repetition detected)
full_result = accumulated_text + new_text
# Print any remaining content that wasn't printed yet
if len(full_result) > user_safe_printed_len:
remaining = full_result[user_safe_printed_len:]
print(remaining, end='', flush=True)
print() # Final newline
return full_result
except requests.exceptions.Timeout:
_log("[TIMEOUT]")
print()
return accumulated_text
except Exception as e:
_log(f"[ERROR: {e}]")
print()
return accumulated_text
# All retries exhausted
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
return None
def test_transcription_with_recovery(
audio_path: str,
output_path: str = None,
base_url: str = "http://localhost:8000",
hotwords: str = None,
debug: bool = False,
):
"""
Test ASR transcription with auto-recovery from repetition loops.
Args:
audio_path: Path to the audio file
output_path: Optional path to save transcription result
base_url: vLLM server URL
hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
debug: Show recovery debug info
"""
print(f"=" * 70)
print(f"Testing with Auto-Recovery")
print(f"=" * 70)
print(f"Input file: {audio_path}")
print(f"Hotwords: {hotwords or '(none)'}")
print()
# Handle video files: extract audio first
temp_audio_path = None
actual_audio_path = audio_path
if _is_video_file(audio_path):
print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path
print(f"✅ Audio extracted to: {temp_audio_path}")
# Load audio
try:
duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds")
with open(actual_audio_path, "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
print(f"Audio size: {len(audio_bytes)} bytes")
except Exception as e:
print(f"❌ Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return
url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
if hotwords and hotwords.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{hotwords}'")
else:
prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}"
# Base messages (without assistant continuation)
base_messages = [
{
"role": "system",
"content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
},
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": data_url}},
{"type": "text", "text": prompt_text}
]
}
]
print(f"\n{'=' * 70}")
print(f"Sending request to {url}")
print(f"{'=' * 70}")
t0 = time.time()
print("\n✅ Response received. Streaming content:\n")
print("-" * 50)
result = stream_with_recovery(
url=url,
base_messages=base_messages,
audio_data_url=data_url,
prompt_text=prompt_text,
max_tokens=32768,
max_retries=3,
debug=debug,
)
elapsed = time.time() - t0
print("-" * 50)
print("✅ [Finished]")
print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"{'=' * 70}")
if result is None:
print("❌ Transcription failed")
return
print(f"📄 Final output length: {len(result)} chars")
# Optionally save result
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
f.write(result)
print(f"💾 Result saved to: {output_path}")
# Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main():
parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with auto-recovery from repetition loops"
)
parser.add_argument(
"audio_path",
help="Path to audio file (wav, mp3, flac, etc.) or video file"
)
parser.add_argument(
"output_path",
nargs="?",
default=None,
help="Optional path to save transcription result"
)
parser.add_argument(
"--url",
default="http://localhost:8000",
help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
)
parser.add_argument(
"--debug",
action="store_true",
help="Show recovery debug info"
)
args = parser.parse_args()
# Run test
test_transcription_with_recovery(
audio_path=args.audio_path,
output_path=args.output_path,
base_url=args.url,
hotwords=args.hotwords,
debug=args.debug,
)
if __name__ == "__main__":
main()