27 Commits

Author SHA1 Message Date
Zhiliang Peng 3c976491d4 Update README with new TTS report and ICLR oral acceptance
Updated TTS report link and added conference acceptance note.
2026-03-31 12:24:50 +08:00
Jianwei Yu c766f12e23 docs: add Vibing download links
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-29 03:28:59 +00:00
Jianwei Yu 8f133837dc docs: add Vibing demo video to news section
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-29 02:33:10 +00:00
Jianwei Yu 0857b6d59f docs: fix news bold formatting
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-29 01:48:20 +00:00
Jianwei Yu c8371b6bb6 docs: add Vibing voice input adoption news
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-29 01:43:59 +00:00
Jianwei Yu b691f99191 docs: add Trendshift #1 trending badge to README
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-28 17:07:00 +00:00
Jianwei Yu 5cd81bb497 fix: restore sequential encoder (batch encoder causes OOM)
Batch encoder across multiple requests caused GPU OOM when vLLM
scheduler sends many audio items at once. The encoder intermediates
(~700MB per 69s audio) compete with KV cache for GPU memory.

Sequential encoding is stable and proven correct. The encoder
(267ms per request) is not the primary throughput bottleneck when
encoder cache is enabled (default).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-27 18:48:06 +00:00
Jianwei Yu cd945395d4 feat: set nginx workers to 2×dp for optimal HTTP throughput
Nginx worker_processes now defaults to 2×N (where N is the number of DP
replicas) instead of 'auto'. This ensures enough HTTP handler processes
to fully saturate all GPU backends under heavy concurrent load.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-27 09:16:05 +00:00
Jianwei Yu e6b65abb9b fix: auto-tune per-worker env vars in DP mode
Pass VIBEVOICE_FFMPEG_MAX_CONCURRENCY and VLLM_MEDIA_LOADING_THREAD_COUNT
to each worker subprocess so they inherit the correct settings regardless
of how the container is launched (--skip-deps or not).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-27 07:57:49 +00:00
Jianwei Yu 3817f74d46 feat: nginx-based data parallel for optimal ASR throughput
When --dp N is specified (N > 1), the launcher now starts N independent
vLLM processes behind an nginx reverse proxy instead of using vLLM's
built-in DP coordinator. This avoids the single-process HTTP bottleneck
when handling large base64 audio payloads, achieving near-linear scaling
(7.2x with 8 GPUs at 4096 concurrent requests).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-27 07:43:32 +00:00
JianweiYu 9634518ca4 Add data parallel (DP) support to vLLM server launcher
- Add --dp/--data-parallel-size flag for running independent model replicas
  across multiple GPUs with automatic load balancing behind a single port
- Add --tp/--tensor-parallel-size flag (previously hardcoded to 1)
- Update docs/vibevoice-vllm-asr.md with multi-GPU deployment guide
  covering DP, TP, and hybrid (DP × TP) configurations

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-24 11:53:31 +00:00
JianweiYu 09ca114fa3 Add Gradio ASR demo with video support and demo audio/video files
- Add gradio_asr_demo_api_video.py: Gradio web UI supporting audio/video upload,
  streaming output, hotwords, and Cloudflare tunnel
- Add demo/asr_demo/: demo audio and video files for the Gradio interface

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-22 06:11:51 +00:00
Zhiliang Peng 4c419978c9 Merge pull request #255 from sd983527/main
Add news about VibeVoice ASR Transformers integration
2026-03-06 14:08:47 +08:00
Yan Xia 7e73beec97 Add news about VibeVoice ASR Transformers integration
- Added announcement that VibeVoice ASR is now part of Transformers v5.3.0 release
- Linked to the official Hugging Face Transformers release page
- Positioned as the latest news item with today's date

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 13:32:21 +08:00
Li Dong 7ef9dbe300 Merge pull request #247 from Damon-Salvetore/fix/vllm-version-compat
fix: vllm-version-stable
2026-02-28 11:12:24 +08:00
Damon-Salvetore 165e17e5ed fix: vllm-version-stable 2026-02-25 07:30:43 +00:00
Jianwei Yu 1807b858d4 Merge pull request #236 from Damon-Salvetore/main
fix backend
2026-02-10 00:07:05 +08:00
YingboHAO a4add8e52f fix backend 2026-02-08 09:58:19 +00:00
Jianwei Yu ce3d40c78f Merge pull request #233 from Damon-Salvetore/main
Add hot words support
2026-02-07 12:32:03 +08:00
YingboHAO 0508c3e86f fix 2026-02-06 14:38:16 +00:00
YingboHAO 7761242bf3 fix 2026-02-06 05:52:48 +00:00
YingboHAO bb54f78d0e feat: add hotwords support for vLLM ASR 2026-02-04 10:33:20 +00:00
YaoyaoChang 0aa8cb4c64 fx default speaker 2026-02-03 00:35:04 -08:00
YaoyaoChang e43c1e2cdb streaming use transformers==4.51.3 2026-02-03 00:30:52 -08:00
Jianwei Yu e16491d65e Merge pull request #228 from Damon-Salvetore/vllm-1
[Fix] Resolve occasional infinite loops during vLLM inference
2026-02-03 10:38:40 +08:00
YingboHAO e26f1c263f 1 2026-02-02 13:50:27 +00:00
YingboHAO 0055161273 Add test_api_auto_recover.py and test audio files 2026-02-02 13:49:01 +00:00
17 changed files with 3524 additions and 401 deletions
+11 -3
View File
@@ -3,11 +3,13 @@
## 🎙️ VibeVoice: Open-Source Frontier Voice AI
[![Project Page](https://img.shields.io/badge/Project-Page-blue?logo=githubpages)](https://microsoft.github.io/VibeVoice)
[![Hugging Face](https://img.shields.io/badge/HuggingFace-Collection-orange?logo=huggingface)](https://huggingface.co/collections/microsoft/vibevoice-68a2ef24a875c44be47b034f)
[![TTS Report](https://img.shields.io/badge/TTS-Report-red?logo=arxiv)](https://arxiv.org/pdf/2508.19205)
[![TTS Report](https://img.shields.io/badge/TTS-Report-red?logo=arxiv)](https://openreview.net/pdf?id=FihSkzyxdv)
[![ASR Report](https://img.shields.io/badge/ASR-Report-yellow?logo=arxiv)](https://arxiv.org/pdf/2601.18184)
[![Colab](https://img.shields.io/badge/StreamingTTS-Colab-green?logo=googlecolab)](https://colab.research.google.com/github/microsoft/VibeVoice/blob/main/demo/VibeVoice_colab.ipynb)
[![ASR Playground](https://img.shields.io/badge/ASR-Playground-6F42C1?logo=gradio)](https://aka.ms/vibevoice-asr)
[![microsoft%2FVibeVoice | Trendshift](https://trendshift.io/api/badge/repositories/15465)](https://trendshift.io/repositories/15465)
</div>
@@ -22,7 +24,13 @@
<h3>📰 News</h3>
<strong>2026-01-21: 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. Try it in [Playground](https://aka.ms/vibevoice-asr)</strong>.
<strong>2026-03-29: 🎉 VibeVoice-ASR is being adopted by the open-source community! <a href="https://vibingjustspeakit.github.io/Vibing/">Vibing</a>, a voice-powered input method, is now built on top of VibeVoice-ASR. Download: [macOS](https://github.com/VibingJustSpeakIt/Vibing/releases/download/v0.1.0/Vibing-v0.1.0-mac.dmg) | [Windows](https://github.com/VibingJustSpeakIt/Vibing/releases/download/v0.1.0/Vibing-v0.1.0-windows.exe)</strong>
https://github.com/user-attachments/assets/db0bb23f-ae06-4135-a66a-1ff1669f4f84
<strong>2026-03-06: 🚀 VibeVoice ASR is now part of a <a href="https://github.com/huggingface/transformers/releases/tag/v5.3.0">Transformers release</a>! You can now use our speech recognition model directly through the Hugging Face Transformers library for seamless integration into your projects.</strong>
<strong>2026-01-21:</strong> 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. Try it in [Playground](https://aka.ms/vibevoice-asr).
- ⭐️ VibeVoice-ASR is natively multilingual, supporting over 50 languages — check the [supported languages](docs/vibevoice-asr.md#language-distribution) for details.
- 🔥 The VibeVoice-ASR [finetuning code](finetuning-asr/README.md) is now available!
- ⚡️ **vLLM inference** is now supported for faster inference; see [vllm-asr](docs/vibevoice-vllm-asr.md) for more details.
@@ -36,7 +44,7 @@
2025-09-05: VibeVoice is an open-source research framework intended to advance collaboration in the speech synthesis community. After release, we discovered instances where the tool was used in ways inconsistent with the stated intent. Since responsible use of AI is one of Microsofts guiding principles, we have removed the VibeVoice-TTS code from this repository.
2025-08-25: 📣 We open-sourced <a href="docs/vibevoice-tts.md"><strong>VibeVoice-TTS</strong></a>, a long-form multi-speaker text-to-speech model that can synthesize speech up to 90 minutes long with up to 4 distinct speakers.
2025-08-25: 📣 We open-sourced <a href="docs/vibevoice-tts.md"><strong>VibeVoice-TTS</strong></a>, a long-form multi-speaker text-to-speech model that can synthesize speech up to 90 minutes long with up to 4 distinct speakers. — accepted as an [Oral](https://openreview.net/forum?id=FihSkzyxdv) at ICLR 2026! 🔥
</div>
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -66,7 +66,7 @@
"print(\"✅ Cloned VibeVoice repository\")\n",
"\n",
"# Install project dependencies\n",
"!uv pip --quiet install --system -e /content/VibeVoice[tts]\n",
"!uv pip --quiet install --system -e /content/VibeVoice[streamingtts]\n",
"!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared && chmod +x cloudflared\n",
"print(\"✅ Installed dependencies\")\n",
"\n",
+1 -1
View File
@@ -142,7 +142,7 @@ class StreamingTTSService:
if name and name in self.voice_presets:
return name
default_key = "en-WHTest_man"
default_key = "en-Carter_man"
if default_key in self.voice_presets:
return default_key
+1 -1
View File
@@ -97,7 +97,7 @@ sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulim
git clone https://github.com/microsoft/VibeVoice.git
cd VibeVoice/
pip install -e .
pip install -e .[streamingtts]
```
+96 -3
View File
@@ -10,6 +10,7 @@ Deploy VibeVoice ASR model as a high-performance API service using [vLLM](https:
- **📡 OpenAI-Compatible API**: Standard `/v1/chat/completions` endpoint with streaming support
- **🎵 Long Audio Support**: Process up to 60+ minutes of audio in a single request
- **🔌 Plugin Architecture**: No vLLM source code modification required - just install and run
- **⚡ Data Parallel (DP)**: Run independent model replicas across multiple GPUs with automatic load balancing behind a single port
## 🛠️ Installation
@@ -31,10 +32,87 @@ docker run -d --gpus all --name vibevoice-vllm \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:latest \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py"
```
## ⚡ Multi-GPU Deployment
The launcher supports two types of GPU parallelism via `--tp` and `--dp` flags:
| Flag | Name | What it does |
|------|------|-------------|
| `--tp N` | Tensor Parallel | Splits **one model** across N GPUs (for models too large for a single GPU) |
| `--dp N` | Data Parallel | Runs **N independent replicas**, one per GPU, with automatic load balancing behind a single port |
### Data Parallel (Recommended for scaling throughput)
Run N independent replicas on N GPUs with automatic load balancing behind a single port.
When `--dp N` is specified (N > 1), the launcher automatically starts N independent vLLM
processes behind an nginx reverse proxy (2×N workers) for optimal throughput:
```bash
docker run -d --gpus '"device=0,1,2,3"' --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 4"
```
Run on all 8 GPUs:
```bash
docker run -d --gpus all --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 8"
```
### Tensor Parallel
Split a single model across 2 GPUs (useful if GPU memory is limited):
```bash
docker run -d --gpus '"device=0,1"' --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --tp 2"
```
### Hybrid (DP × TP)
Combine both — e.g., 2 replicas, each split across 2 GPUs (4 GPUs total):
```bash
docker run -d --gpus '"device=0,1,2,3"' --name vibevoice-vllm \
--ipc=host \
-p 8000:8000 \
-v $(pwd):/app \
-w /app \
--entrypoint bash \
vllm/vllm-openai:v0.14.1 \
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 2 --tp 2"
```
> **Note**: Total GPUs required = `dp × tp`. Make sure to expose enough GPU devices in the Docker `--gpus` flag.
3. View logs
```bash
docker logs -f vibevoice-vllm
@@ -52,10 +130,25 @@ docker logs -f vibevoice-vllm
Once the vLLM server is running, test it with the provided script:
```bash
# Run the test (use container path /app/...)
# Basic transcription
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
# With hotwords for better recognition of specific terms
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
```
> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing.
```bash
# With auto-recovery from repetition loops (for long audio)
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
# Auto-recover with hotwords
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
```
> **Note**:
> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing.
> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names.
### Environment Variables
+6
View File
@@ -38,6 +38,12 @@ dependencies = [
"requests",
]
[project.optional-dependencies]
streamingtts = [
"transformers==4.51.3",
]
[project.entry-points."vllm.general_plugins"]
vibevoice = "vllm_plugin:register_vibevoice"
-1
View File
@@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
from .model import VibeVoiceForCausalLM
from .inputs import vibevoice_audio_input_mapper
def register_vibevoice():
+156 -284
View File
@@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
integration for speech-to-text inference.
"""
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
import json
import math
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from io import BytesIO
import tempfile
import base64
@@ -29,32 +23,12 @@ import base64
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
def _suffix_from_media_type(media_type: str | None) -> str:
if not media_type:
return ".bin"
mt = media_type.lower().strip()
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
return ".wav"
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
return ".mp3"
if mt in ("audio/flac",):
return ".flac"
if mt in ("audio/ogg", "audio/opus"):
return ".ogg"
if mt in ("audio/mp4", "audio/m4a"):
return ".m4a"
if mt in ("video/mp4",):
return ".mp4"
return ".bin"
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg.
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
Returns:
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
"""
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
normalizer = AudioNormalizer()
audio = normalizer(audio)
@@ -72,91 +46,53 @@ def _ffmpeg_load_file(filepath) -> tuple[np.ndarray, int]:
return audio, sr
# Register FFmpeg-based audio loader
import vllm.multimodal.audio as _vllm_audio_module
import warnings
# Handle both old and new vLLM versions
# In newer versions, AudioMediaIO may not exist or may have been moved
try:
# Try new location (vLLM >= 0.6.x)
from vllm.multimodal.media.audio import AudioMediaIO as _OriginalAudioMediaIO
except ImportError:
# Fall back to old location (vLLM < 0.6.x)
import vllm.multimodal.audio as _vllm_audio_module
_OriginalAudioMediaIO = _vllm_audio_module.AudioMediaIO
class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
def __init__(self, **kwargs):
# Call parent constructor if it exists
if hasattr(super(), '__init__'):
super().__init__(**kwargs)
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data, media_type=None)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data)
# Replace globally
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data))
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
# Replace globally
try:
# For new vLLM versions
import vllm.multimodal.media.audio as _vllm_audio_module
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
# Also patch in utils module where it's imported
try:
import vllm.multimodal.utils as _vllm_utils_module
_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
except (ImportError, AttributeError) as e:
warnings.warn(f"Could not patch AudioMediaIO in vllm.multimodal.utils: {e}", UserWarning)
except (AttributeError, ImportError) as e:
# AudioMediaIO doesn't exist in this vLLM version
# Define our own standalone implementation
warnings.warn(
f"AudioMediaIO not found in vllm.multimodal.audio ({e}). "
"Using standalone FFmpeg-based implementation for compatibility.",
UserWarning
)
class _PatchedAudioMediaIO:
"""Standalone AudioMediaIO implementation using FFmpeg for audio decoding.
This is used when vLLM doesn't provide AudioMediaIO or it's been moved/removed.
"""
def __init__(self, **kwargs):
pass
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data, media_type=None)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
# Try to register it in the module if possible
try:
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
except (AttributeError, TypeError) as e:
warnings.warn(
f"Could not register AudioMediaIO in vllm.multimodal.audio: {e}. "
"Audio loading will use FFmpeg functions directly.",
UserWarning
)
except ImportError:
# For old vLLM versions
import vllm.multimodal.audio as _vllm_audio_module
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
# Also patch in utils module where it's imported
try:
import vllm.multimodal.utils as _vllm_utils_module
_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
except (ImportError, AttributeError):
# AudioMediaIO might not be imported in utils in newer versions
pass
# ============================================================================
from transformers import Qwen2Config, BatchFeature
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.config import VllmConfig, ModelConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import MultiModalDataParser
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
from vllm.inputs import PromptType
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
maybe_prefix,
@@ -171,7 +107,17 @@ from vllm.multimodal.processing import (
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
try:
# Try new location (vLLM >= 0.6.x)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
except ImportError:
# Fall back to old location (vLLM < 0.6.x)
try:
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
except ImportError:
# If neither location works, try individual imports
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.multimodal.processing.inputs import ProcessorInputs
# Import VibeVoice components
from vibevoice.modular.modular_vibevoice_tokenizer import (
@@ -608,30 +554,88 @@ class VibeVoiceProcessingInfo(BaseProcessingInfo):
return tokens
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
return {"audio": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""Return the maximum number of audio tokens per item.
This tells vLLM's scheduler the upper bound so that
``encoder_compute_budget`` is large enough for any audio length
the model can handle, preventing the silent scheduling deadlock
described in docs/max_num_batched_tokens_issue.md.
Formula: audio_tokens = ceil(audio_samples / compress_ratio) + 3
where +3 accounts for speech_start, speech_end, and newline tokens.
The max audio samples is bounded by seq_len (the model's context
window cannot hold more tokens than that).
"""
hf_config = self.get_hf_config()
def _cfg(key: str, default):
if isinstance(hf_config, dict):
return hf_config.get(key, default)
return getattr(hf_config, key, default)
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
sample_rate = int(_cfg("target_sample_rate", 24000))
# Upper bound: 61-minute audio at 24 kHz
max_audio_samples = 61 * 60 * sample_rate # 88,464,000
max_audio_tokens = int(np.ceil(max_audio_samples / compress_ratio)) + 3
# Cannot exceed the model's context window
max_audio_tokens = min(max_audio_tokens, seq_len)
return {"audio": max_audio_tokens}
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
"""
Build dummy inputs for multimodal profiling.
Dummy text uses the raw <|AUDIO|> token(s). vLLM's processing pipeline will
expand each <|AUDIO|> via `VibeVoiceMultiModalProcessor._get_prompt_updates`
into the full ASR format:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
where N is derived from audio length / compress_ratio.
vLLM uses dummy inputs to:
1. Measure peak GPU activation memory → determine KV cache capacity
2. Warm up CUDA graphs
The dummy audio length must be consistent with ``get_mm_max_tokens_per_item``
so that the memory estimate covers the worst-case (longest audio) scenario.
"""
def _get_max_audio_samples(self, seq_len: int) -> int:
"""Compute maximum audio samples consistent with ``get_mm_max_tokens_per_item``.
Uses the same formula: max_tokens = min(ceil(61min * sr / ratio) + 3, seq_len),
then converts back to samples.
"""
hf_config = self.info.get_hf_config()
def _cfg(key: str, default):
if isinstance(hf_config, dict):
return hf_config.get(key, default)
return getattr(hf_config, key, default)
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
sample_rate = int(_cfg("target_sample_rate", 24000))
# Upper bound: 61-minute audio at 24 kHz
max_hour_samples = 61 * 60 * sample_rate # 88,464,000
max_tokens_from_audio = int(np.ceil(max_hour_samples / compress_ratio)) + 3
# Cannot exceed model context window
max_tokens = min(max_tokens_from_audio, seq_len)
# Convert tokens back to samples
return max_tokens * compress_ratio
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
if num_audios <= 0:
return ""
# Get the audio token from our token info helper
token_info = self.info.get_audio_token_info()
audio_token = token_info["audio_token"]
# Return ONLY the audio tokens - the HF processor adds bos/eos
return audio_token * num_audios
def get_dummy_mm_data(
@@ -640,16 +644,23 @@ class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo
mm_counts: Mapping[str, int],
mm_options: Mapping[str, Any] | None = None,
) -> Dict[str, Any]:
"""Generate dummy audio data for profiling."""
feature_extractor = self.info.get_feature_extractor()
"""Generate dummy audio data for profiling.
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
The audio length is derived from ``seq_len`` so that profiling
accurately measures memory for the longest audio the model can handle.
Supports ``AudioDummyOptions.length`` override for faster startup.
"""
num_audios = mm_counts.get("audio", 0)
# Generate dummy audio as numpy arrays (what the HF processor expects)
max_audio_len = self._get_max_audio_samples(seq_len)
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": [np.zeros(audio_len, dtype=np.float32) for _ in range(num_audios)]
"audio": self._get_dummy_audios(
length=max_audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
def get_dummy_processor_inputs(
@@ -923,17 +934,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with a causal language model for text generation.
"""
# SupportsTranscription interface
supports_transcription: ClassVar[Literal[True]] = True
supports_transcription_only: ClassVar[bool] = False
supports_segment_timestamp: ClassVar[bool] = False
# Supported languages (Chinese as primary target)
supported_languages: ClassVar[Mapping[str, str]] = {
"zh": "Chinese",
"en": "English",
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""Return the placeholder string format for a given modality.
@@ -947,112 +947,10 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return "<|AUDIO|>"
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the prompt for the ASR model.
Generates a chat-formatted prompt for speech-to-text transcription
with JSON output format.
"""
# If user provides custom prompt, use it
if request_prompt:
return request_prompt
# Calculate audio duration for the prompt
# Audio should be at 24kHz, so duration = len(audio) / 24000
duration = len(audio) / 24000 if audio is not None else 10.0
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
user_suffix = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
user_content = "<|AUDIO|>\n" + user_suffix
prompt = (
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{user_content}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return prompt
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model."""
return SpeechToTextConfig(
language=None, # Auto-detect or use request language
task_type=task_type,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
"""Estimate number of audio tokens from duration.
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
Note: _get_prompt_updates actually generates:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
So total prompt tokens = N + 3, but this returns N (the embedding count).
"""
sampling_rate = 24000
compress_ratio = 3200
samples = int(audio_duration_s * sampling_rate)
num_tokens = int(np.ceil(samples / compress_ratio))
return num_tokens
@classmethod
def get_other_languages(cls) -> Mapping[str, str]:
"""Get languages from Whisper map not natively supported."""
# Import LANGUAGES from vllm
try:
from vllm.transformers_utils.tokenizer import LANGUAGES
except ImportError:
# Fallback to empty dict if import fails
return {}
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
@classmethod
def validate_language(cls, language: str | None) -> str | None:
"""Validate the language code."""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
print(f"Warning: Language {language!r} is not natively supported")
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. "
f"Supported: {list(cls.supported_languages.keys())}"
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
# Keep a copy of the resolved model path for any custom weight-loading
# logic (e.g., loading audio encoder weights in fp32 directly from
# safetensors shards).
self._model_path = vllm_config.model_config.model
self.audio_encoder = VibeVoiceAudioEncoder(config)
@@ -1150,76 +1048,54 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Process each audio through the VibeVoice encoder
embeddings = []
# Get model device and dtype for alignment
# Get model device for tensor placement.
try:
device = next(self.audio_encoder.parameters()).device
dtype = next(self.audio_encoder.parameters()).dtype
except StopIteration:
# Fallback if no parameters (shouldn't happen)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Handle both stacked tensor and list of tensors
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
if isinstance(raw_audio, torch.Tensor):
if raw_audio.dim() == 3:
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
elif raw_audio.dim() == 2:
# Shape: [batch_size, seq_len]
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i] for i in range(num_audios)]
else:
# Single 1D tensor
audio_list = [raw_audio]
elif isinstance(raw_audio, (list, tuple)):
audio_list = list(raw_audio)
else:
# Single tensor
audio_list = [raw_audio]
for i, audio_tensor in enumerate(audio_list):
try:
if isinstance(audio_tensor, list):
audio_tensor = torch.stack(audio_tensor)
# Ensure tensor
if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor)
# Let vLLM handle dtype (bfloat16 by default)
audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length
if raw_audio_lengths and i < len(raw_audio_lengths):
actual_len = int(raw_audio_lengths[i])
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
# Truncate from the last dimension (sequence length)
audio_tensor = audio_tensor[..., :actual_len]
# Skip if audio is too short (< 1 frame)
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
if audio_tensor.numel() < 160:
continue
# Encode audio through VibeVoice encoder
audio_embeds = self.audio_encoder(
audio_tensor,
use_streaming=use_streaming_flag,
segment_duration_s=streaming_segment_duration,
)
# audio_embeds shape: [1, seq_len, hidden_size]
# We need to return it as a single embedding tensor per audio
final_embed = audio_embeds.squeeze(0)
embeddings.append(final_embed)
except Exception as e:
# Log error but don't crash - this helps debug profiling issues
print(f"[VibeVoice] Error encoding audio {i}: {e}")
import traceback
traceback.print_exc()
# Return empty embedding to avoid crash
continue
return tuple(embeddings)
@@ -1344,35 +1220,31 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Intermediate tensors for pipeline parallelism.
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
"""
try:
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None:
# Compute embeddings from input_ids
inputs_embeds = self.get_input_embeddings()(input_ids)
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
if intermediate_tensors is not None:
inputs_embeds = None
# Get the inner model - handle both wrapped and direct language models
language_model = self.language_model
if hasattr(language_model, "language_model"):
language_model = language_model.language_model
# Call the language model's model (Qwen2Model)
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
hidden_states = language_model.model(
input_ids=None, # Always None when we have inputs_embeds
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds
)
return hidden_states
except Exception as e:
raise
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None:
# Compute embeddings from input_ids
inputs_embeds = self.get_input_embeddings()(input_ids)
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
if intermediate_tensors is not None:
inputs_embeds = None
# Get the inner model - handle both wrapped and direct language models
language_model = self.language_model
if hasattr(language_model, "language_model"):
language_model = language_model.language_model
# Call the language model's model (Qwen2Model)
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
hidden_states = language_model.model(
input_ids=None, # Always None when we have inputs_embeds
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds
)
return hidden_states
# Alias for training checkpoint compatibility
File diff suppressed because it is too large Load Diff
+298 -16
View File
@@ -9,14 +9,21 @@ One-click deployment script that handles:
4. Generating tokenizer files
5. Starting vLLM server
For DP > 1, launches N independent vLLM processes behind an nginx
reverse proxy for optimal throughput (avoids single-process HTTP
bottleneck of vLLM's built-in DP coordinator).
Usage:
python3 start_server.py [--model MODEL_ID] [--port PORT]
"""
import argparse
import os
import signal
import subprocess
import sys
import textwrap
import time
def run_command(cmd: list[str], description: str, shell: bool = False) -> None:
@@ -77,45 +84,268 @@ def generate_tokenizer(model_path: str) -> None:
)
def start_vllm_server(model_path: str, port: int) -> None:
"""Start vLLM server (replaces current process)."""
print(f"\n{'='*60}")
print(f" Starting vLLM server on port {port}")
print(f"{'='*60}\n")
vllm_cmd = [
def _build_vllm_cmd(model_path: str, port: int,
tensor_parallel_size: int = 1,
data_parallel_size: int = 1,
max_num_seqs: int = 64,
max_model_len: int = 65536,
gpu_memory_utilization: float = 0.8) -> list[str]:
"""Build the vllm serve command."""
return [
"vllm", "serve", model_path,
"--served-model-name", "vibevoice",
"--trust-remote-code",
"--dtype", "bfloat16",
"--max-num-seqs", "64",
"--max-model-len", "65536",
"--max-num-batched-tokens", "32768",
"--gpu-memory-utilization", "0.8",
"--enforce-eager",
"--max-num-seqs", str(max_num_seqs),
"--max-model-len", str(max_model_len),
"--gpu-memory-utilization", str(gpu_memory_utilization),
"--no-enable-prefix-caching",
"--enable-chunked-prefill",
"--chat-template-content-format", "openai",
"--tensor-parallel-size", "1",
"--tensor-parallel-size", str(tensor_parallel_size),
"--data-parallel-size", str(data_parallel_size),
"--allowed-local-media-path", "/app",
"--port", str(port),
]
def start_vllm_server(model_path: str, port: int,
tensor_parallel_size: int = 1,
data_parallel_size: int = 1,
max_num_seqs: int = 64,
max_model_len: int = 65536,
gpu_memory_utilization: float = 0.8) -> None:
"""Start a single vLLM server (replaces current process)."""
print(f"\n{'='*60}")
print(f" Starting vLLM server on port {port}")
print(f" Tensor Parallel (TP): {tensor_parallel_size}")
print(f" Data Parallel (DP): {data_parallel_size}")
print(f" Max Num Seqs: {max_num_seqs}")
print(f" Max Model Len: {max_model_len}")
print(f" GPU Mem Utilization: {gpu_memory_utilization}")
print(f"{'='*60}\n")
vllm_cmd = _build_vllm_cmd(
model_path, port,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
)
os.execvp("vllm", vllm_cmd)
def _install_nginx() -> None:
"""Install nginx if not already available."""
if subprocess.run(["which", "nginx"], capture_output=True).returncode != 0:
run_command(["apt-get", "update"], "Updating package list for nginx")
run_command(
["apt-get", "install", "-y", "nginx"],
"Installing nginx for load balancing"
)
def _write_nginx_config(frontend_port: int, backend_ports: list[int],
num_workers: int = 0) -> str:
"""Write nginx config for round-robin load balancing.
Args:
num_workers: Number of nginx worker processes. 0 = auto (2 × num backends).
"""
if num_workers <= 0:
num_workers = len(backend_ports) * 2
backends = "\n".join(f" server 127.0.0.1:{p};" for p in backend_ports)
config = textwrap.dedent(f"""\
worker_processes {num_workers};
worker_rlimit_nofile 65536;
error_log /dev/stderr warn;
pid /tmp/nginx.pid;
events {{
worker_connections 8192;
}}
http {{
access_log off;
upstream vllm_backends {{
least_conn;
{backends}
}}
server {{
listen {frontend_port};
client_max_body_size 200m;
client_body_buffer_size 10m;
proxy_buffering on;
proxy_buffer_size 64k;
proxy_buffers 16 64k;
location / {{
proxy_pass http://vllm_backends;
proxy_read_timeout 600s;
proxy_connect_timeout 10s;
proxy_send_timeout 600s;
proxy_http_version 1.1;
proxy_set_header Connection "";
}}
}}
}}
""")
config_path = "/tmp/nginx_vllm.conf"
with open(config_path, "w") as f:
f.write(config)
return config_path
def start_dp_server(model_path: str, frontend_port: int,
data_parallel_size: int,
tensor_parallel_size: int = 1,
max_num_seqs: int = 64,
max_model_len: int = 65536,
gpu_memory_utilization: float = 0.8) -> None:
"""Start multiple vLLM workers behind nginx for data parallelism.
Launches N independent vLLM processes (one per GPU group) on internal
ports, with an nginx reverse proxy on the frontend port for load
balancing. This avoids the single-process HTTP bottleneck of vLLM's
built-in DP coordinator when handling large audio payloads.
"""
import torch
num_gpus = torch.cuda.device_count()
gpus_per_replica = tensor_parallel_size
total_gpus_needed = data_parallel_size * gpus_per_replica
assert num_gpus >= total_gpus_needed, (
f"Need {total_gpus_needed} GPUs (dp={data_parallel_size} × tp={tensor_parallel_size}) "
f"but only {num_gpus} available"
)
# Auto-tune per-worker env vars based on dp size
ffmpeg_concurrency = max(
64, int(os.environ.get("VIBEVOICE_FFMPEG_MAX_CONCURRENCY", "64"))
)
media_threads = max(
8, int(os.environ.get("VLLM_MEDIA_LOADING_THREAD_COUNT", "8"))
)
_install_nginx()
# Assign internal ports: frontend_port + 100, +101, ...
backend_ports = [frontend_port + 100 + i for i in range(data_parallel_size)]
print(f"\n{'='*60}")
print(f" Starting DP server with nginx load balancing")
print(f" Frontend port: {frontend_port} (nginx)")
print(f" Backend ports: {backend_ports}")
print(f" Data Parallel: {data_parallel_size}")
print(f" Tensor Parallel: {tensor_parallel_size}")
print(f" GPUs per replica: {gpus_per_replica}")
print(f" Max Num Seqs: {max_num_seqs}")
print(f" Max Model Len: {max_model_len}")
print(f" FFmpeg concurrency (per worker): {ffmpeg_concurrency}")
print(f" Media loading threads (per worker): {media_threads}")
print(f"{'='*60}\n")
# Write nginx config
nginx_conf = _write_nginx_config(frontend_port, backend_ports)
# Launch vLLM workers
workers: list[subprocess.Popen] = []
for rank in range(data_parallel_size):
gpu_start = rank * gpus_per_replica
gpu_ids = ",".join(str(gpu_start + j) for j in range(gpus_per_replica))
port = backend_ports[rank]
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = gpu_ids
env["VIBEVOICE_FFMPEG_MAX_CONCURRENCY"] = str(ffmpeg_concurrency)
env["VLLM_MEDIA_LOADING_THREAD_COUNT"] = str(media_threads)
vllm_cmd = _build_vllm_cmd(
model_path, port,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=1, # Each worker is dp=1
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
)
print(f" Launching worker rank={rank} on GPU(s) {gpu_ids}, port {port}")
proc = subprocess.Popen(vllm_cmd, env=env)
workers.append(proc)
# Start nginx
print(f"\n Starting nginx on port {frontend_port} ...")
nginx_proc = subprocess.Popen(
["nginx", "-c", nginx_conf, "-g", "daemon off;"]
)
# Wait for all backends to be ready
print(" Waiting for all backends to be ready ...")
import urllib.request
for port in backend_ports:
url = f"http://127.0.0.1:{port}/v1/models"
for attempt in range(600): # up to 10 minutes
try:
urllib.request.urlopen(url, timeout=2)
print(f" ✅ Backend on port {port} is ready")
break
except Exception:
time.sleep(1)
else:
print(f" ❌ Backend on port {port} failed to start")
print(f"\n{'='*60}")
print(f" ✅ VibeVoice DP server ready on port {frontend_port}")
print(f" {data_parallel_size} replicas behind nginx load balancer")
print(f"{'='*60}\n")
# Handle shutdown: forward signals to all children
def _shutdown(signum, frame):
print("\nShutting down ...")
nginx_proc.terminate()
for w in workers:
w.terminate()
for w in workers:
w.wait(timeout=10)
nginx_proc.wait(timeout=5)
sys.exit(0)
signal.signal(signal.SIGTERM, _shutdown)
signal.signal(signal.SIGINT, _shutdown)
# Wait for any child to exit (indicates a failure)
while True:
for i, w in enumerate(workers):
ret = w.poll()
if ret is not None:
print(f" ❌ Worker {i} exited with code {ret}")
_shutdown(None, None)
if nginx_proc.poll() is not None:
print(f" ❌ nginx exited with code {nginx_proc.returncode}")
_shutdown(None, None)
time.sleep(1)
def main():
parser = argparse.ArgumentParser(
description="VibeVoice vLLM ASR Server - One-Click Deployment",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Start with default settings
# Start with default settings (single GPU)
python3 start_server.py
# Use custom port
python3 start_server.py --port 8080
# Data parallel: 4 replicas on 4 GPUs (nginx load balancing)
python3 start_server.py --dp 4
# Tensor parallel: split model across 2 GPUs
python3 start_server.py --tp 2
# Skip dependency installation (if already installed)
python3 start_server.py --skip-deps
"""
@@ -141,6 +371,41 @@ Examples:
action="store_true",
help="Skip generating tokenizer files"
)
parser.add_argument(
"--tp", "--tensor-parallel-size",
type=int,
default=1,
dest="tensor_parallel_size",
help="Tensor parallel size: split one model across N GPUs (default: 1)"
)
parser.add_argument(
"--dp", "--data-parallel-size",
type=int,
default=1,
dest="data_parallel_size",
help="Data parallel size: run N independent model replicas for load balancing (default: 1)"
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=64,
dest="max_num_seqs",
help="Maximum number of sequences per batch (default: 64)"
)
parser.add_argument(
"--max-model-len",
type=int,
default=65536,
dest="max_model_len",
help="Maximum model context length (default: 65536)"
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.8,
dest="gpu_memory_utilization",
help="GPU memory utilization fraction (default: 0.8)"
)
args = parser.parse_args()
print("\n" + "="*60)
@@ -161,8 +426,25 @@ Examples:
if not args.skip_tokenizer:
generate_tokenizer(model_path)
# Step 5: Start vLLM server
start_vllm_server(model_path, args.port)
# Step 5: Start server
if args.data_parallel_size > 1:
start_dp_server(
model_path, args.port,
data_parallel_size=args.data_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.max_num_seqs,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
)
else:
start_vllm_server(
model_path, args.port,
tensor_parallel_size=args.tensor_parallel_size,
data_parallel_size=1,
max_num_seqs=args.max_num_seqs,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
)
if __name__ == "__main__":
+107 -91
View File
@@ -1,14 +1,23 @@
#!/usr/bin/env python3
"""
Test VibeVoice vLLM API with Streaming (Real-time output).
Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.
This script tests ASR transcription via the vLLM OpenAI-compatible API.
By default, it runs standard transcription without hotwords.
Optionally, you can provide hotwords (context_info) to improve recognition
of domain-specific content like proper nouns, technical terms, and speaker names.
Hotwords are embedded in the prompt as "with extra info: {hotwords}".
Usage:
python test_api.py [audio_path] [--url URL]
python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
Examples:
python test_api.py # Use default audio
python test_api.py /path/to/audio.wav # Specify audio file
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
# Standard transcription (no hotwords)
python3 test_api.py audio.wav
# With hotwords for better recognition of specific terms
python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
"""
import requests
import json
@@ -21,38 +30,38 @@ import argparse
def _guess_mime_type(path: str) -> str:
"""Guess MIME type from file extension."""
ext = os.path.splitext(path)[1].lower()
if ext == ".wav":
return "audio/wav"
if ext in (".mp3",):
return "audio/mpeg"
if ext in (".m4a",):
return "audio/mp4"
if ext in (".mp4", ".m4v", ".mov", ".webm"):
return "video/mp4"
if ext in (".flac",):
return "audio/flac"
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
mime_map = {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".m4a": "audio/mp4",
".mp4": "video/mp4",
".flac": "audio/flac",
".ogg": "audio/ogg",
".opus": "audio/ogg",
}
return mime_map.get(ext, "application/octet-stream")
def _get_duration_seconds_ffprobe(path: str) -> float:
"""Get audio duration using ffprobe."""
cmd = [
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
"ffprobe", "-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
path,
]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out)
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _extract_audio_from_video(video_path: str) -> str:
"""
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
@@ -74,26 +83,40 @@ def _extract_audio_from_video(video_path: str) -> str:
return audio_path
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"):
"""Test ASR transcription with streaming output."""
def test_transcription_with_hotwords(
audio_path: str,
context_info: str = None,
base_url: str = "http://localhost:8000",
):
"""
Test ASR transcription with customized hotwords.
print(f"Loading audio from: {audio_path}")
Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
This helps the model recognize domain-specific terms more accurately.
Args:
audio_path: Path to the audio file
context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
base_url: vLLM server URL
"""
print(f"=" * 70)
print(f"Testing Customized Hotwords Support")
print(f"=" * 70)
print(f"Input file: {audio_path}")
print(f"Hotwords: {context_info or '(none)'}")
print()
# Handle video files: extract audio first
temp_audio_path = None
actual_audio_path = audio_path
if _is_video_file(audio_path):
print(f"Detected video file, extracting audio...")
print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}")
print(f"Audio extracted to: {temp_audio_path}")
# Load audio
try:
duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds")
@@ -106,16 +129,30 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
except Exception as e:
print(f"Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return
# Build the request
url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
# Build prompt with optional hotwords
# Hotwords are embedded as "with extra info: {hotwords}" in the prompt
if context_info and context_info.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{context_info}'")
else:
prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}"
@@ -139,20 +176,19 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
"temperature": 0.0,
"stream": True,
"top_p": 1.0,
"repetition_penalty": 1.0,
}
print(f"\nSending request to {url} (Streaming Mode)...")
print(f"Prompt: {prompt_text}")
print("-" * 60)
print(f"\n{'=' * 70}")
print(f"Sending request to {url}")
print(f"{'=' * 70}")
t0 = time.time()
try:
response = requests.post(url, json=payload, stream=True, timeout=12000)
if response.status_code == 200:
print("Response received. Streaming content:\n")
print("\nResponse received. Streaming content:\n")
print("-" * 50)
printed = ""
for line in response.iter_lines():
@@ -162,92 +198,72 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
if decoded_line.startswith("data: "):
json_str = decoded_line[6:]
if json_str.strip() == "[DONE]":
print("\n\n[Finished]")
print("\n" + "-" * 50)
print("✅ [Finished]")
break
try:
data = json.loads(json_str)
delta = data['choices'][0]['delta']
content = delta.get('content', '')
if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only print the newly-added part to avoid repeats.
if content.startswith(printed):
to_print = content[len(printed):]
else:
to_print = content
if to_print:
print(to_print, end='', flush=True)
printed += to_print
except json.JSONDecodeError:
pass
else:
print(f"Error: {response.status_code}")
print(f"Error: {response.status_code}")
print(response.text)
except requests.exceptions.Timeout:
print("\nRequest timed out!")
print("Request timed out!")
except Exception as e:
print(f"\nError: {e}")
print(f"Error: {e}")
print(f"\n{'-'*60}")
print(f"Total time elapsed: {time.time() - t0:.2f}s")
elapsed = time.time() - t0
print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"📊 RTF (Real-Time Factor): {elapsed / duration:.2f}x")
print(f"{'=' * 70}")
# Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
print(f"Cleaned up temp file: {temp_audio_path}")
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main():
parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with streaming output"
description="Test VibeVoice vLLM API with Customized Hotwords"
)
parser.add_argument(
"audio_path",
nargs="?",
default=None,
help="Path to audio file (wav, mp3, flac, etc.) or video file"
)
parser.add_argument(
"--url",
default="http://localhost:8000",
help="vLLM server base URL (default: http://localhost:8000)"
help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
)
args = parser.parse_args()
# Find default audio if not specified
audio_path = args.audio_path
if audio_path is None:
# Try to find a sample audio in common locations
possible_paths = [
# In VibeVoice demo folder
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
# Relative to current directory
"demo/voices/en-Carter_man.wav",
"demo/voices/zh-Anchen_man_bgm.wav",
]
for path in possible_paths:
if os.path.exists(path):
audio_path = path
break
if audio_path is None:
print("Error: No audio file specified and no default audio found.")
print("Usage: python test_api.py <audio_path>")
sys.exit(1)
if not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}")
sys.exit(1)
test_transcription(audio_path, args.url)
# Run test
test_transcription_with_hotwords(
audio_path=args.audio_path,
context_info=args.hotwords,
base_url=args.url,
)
if __name__ == "__main__":
+638
View File
@@ -0,0 +1,638 @@
#!/usr/bin/env python3
"""
Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery.
This script tests ASR transcription with automatic recovery from repetition loops.
Supports optional hotwords to improve recognition of domain-specific terms.
Features:
- Streaming output with real-time repetition detection
- Auto-recovery when model enters repetition loops
- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}")
- Video file support (auto-extracts audio)
Recovery Strategy:
1. First attempt: greedy decoding (temperature=0, top_p=1.0)
2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95
3. Max 3 retries, truncate to last complete segment boundary
Usage:
python test_api_auto_recover.py <audio_path> [output_path] [--url URL] [--hotwords "word1,word2"] [--debug]
Examples:
# Basic usage
python3 test_api_auto_recover.py audio.wav
# With hotwords
python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice"
# Save result to file
python3 test_api_auto_recover.py audio.wav result.txt
# Debug mode (show recovery info)
python3 test_api_auto_recover.py audio.wav --debug
"""
import requests
import json
import base64
import time
import sys
import os
import subprocess
import re
import argparse
from collections import Counter
def _guess_mime_type(path: str) -> str:
ext = os.path.splitext(path)[1].lower()
if ext == ".wav":
return "audio/wav"
if ext in (".mp3",):
return "audio/mpeg"
if ext in (".m4a",):
return "audio/mp4"
if ext in (".mp4", ".m4v", ".mov", ".webm"):
return "video/mp4"
if ext in (".flac",):
return "audio/flac"
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
def _get_duration_seconds_ffprobe(path: str) -> float:
cmd = [
"ffprobe", "-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
path,
]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out)
def _extract_audio_from_video(video_path: str) -> str:
"""
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
Returns the path to the extracted audio file.
"""
import tempfile
# Create temp file with .mp3 extension
fd, audio_path = tempfile.mkstemp(suffix=".mp3")
os.close(fd)
cmd = [
"ffmpeg", "-y", "-i", video_path,
"-vn", # No video
"-acodec", "libmp3lame",
"-q:a", "2", # High quality
audio_path
]
subprocess.run(cmd, check=True, capture_output=True)
return audio_path
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _find_last_segment_boundary(text: str) -> int:
"""
Find the position after the last complete segment boundary (},).
Returns -1 if no complete segment found.
"""
# Find last "}, " or "}," pattern (segment separator)
pos = text.rfind("},")
if pos != -1:
return pos + 2 # Include the },
return -1
def _find_safe_print_boundary(text: str, max_pos: int) -> int:
"""
Find the last complete segment boundary before max_pos.
Returns 0 if no complete segment found before max_pos.
"""
search_text = text[:max_pos]
pos = search_text.rfind("},")
if pos != -1:
return pos + 2 # Include the },
return 0
class RepetitionDetector:
"""Detect repetition patterns in streaming text."""
def __init__(self,
min_pattern_len: int = 10, # Minimum chars for a pattern
min_repeats: int = 3, # Minimum repetitions to trigger
window_size: int = 500): # Window to check for patterns
self.min_pattern_len = min_pattern_len
self.min_repeats = min_repeats
self.window_size = window_size
self.text = ""
def add_text(self, new_text: str):
"""Add new text and return (is_looping, good_text_end_pos)."""
self.text += new_text
return self._check_repetition()
def _check_repetition(self):
"""Check if the recent text contains repetition loops."""
if len(self.text) < self.min_pattern_len * self.min_repeats:
return False, len(self.text)
# Check the recent window
window = self.text[-self.window_size:] if len(self.text) > self.window_size else self.text
# Method 1: Check for repeated substrings
for pattern_len in range(self.min_pattern_len, len(window) // self.min_repeats + 1):
# Get the last pattern_len characters as potential pattern
pattern = window[-pattern_len:]
# Count how many times this pattern appears at the end
count = 0
pos = len(window)
while pos >= pattern_len:
if window[pos - pattern_len:pos] == pattern:
count += 1
pos -= pattern_len
else:
break
if count >= self.min_repeats:
# Found repetition! Calculate where the good text ends
repetition_start = len(self.text) - (count * pattern_len)
# Keep one instance of the pattern (or none if it's garbage)
good_end = repetition_start + pattern_len if self._is_meaningful(pattern) else repetition_start
return True, good_end
# Method 2: Check for repeated short phrases (like "you're not, you're not")
# Look for patterns like "X, X, X" or "X X X"
words = window.split()
if len(words) >= self.min_repeats * 2:
# Check last N words for repetition
for phrase_len in range(2, 6): # 2-5 word phrases
if len(words) < phrase_len * self.min_repeats:
continue
phrase = " ".join(words[-phrase_len:])
count = 0
idx = len(words)
while idx >= phrase_len:
candidate = " ".join(words[idx - phrase_len:idx])
if candidate == phrase:
count += 1
idx -= phrase_len
else:
break
if count >= self.min_repeats:
# Calculate position in original text
repeated_text = (phrase + " ") * count
good_end = len(self.text) - len(repeated_text.rstrip()) + len(phrase)
return True, max(0, good_end)
return False, len(self.text)
def _is_meaningful(self, pattern: str) -> bool:
"""Check if pattern is meaningful content (not just garbage)."""
# Filter out patterns that are just punctuation, spaces, or very repetitive
clean = pattern.strip()
if not clean:
return False
if len(set(clean)) < 3: # Too few unique characters
return False
return True
def get_good_text(self, end_pos: int) -> str:
"""Get text up to the specified position."""
return self.text[:end_pos]
def reset(self, keep_text: str = ""):
"""Reset detector, optionally keeping some text."""
self.text = keep_text
def stream_with_recovery(
url: str,
base_messages: list,
audio_data_url: str,
prompt_text: str,
max_tokens: int = 32768,
max_retries: int = 3,
timeout: int = 12000,
debug: bool = False,
):
"""
Stream transcription with automatic recovery from repetition loops.
Args:
url: API endpoint
base_messages: Base messages (system + user with audio)
audio_data_url: The audio data URL for the request
prompt_text: The text prompt
max_tokens: Maximum tokens to generate
max_retries: Maximum recovery attempts (default 3)
timeout: Request timeout
debug: If True, show recovery debug info to stderr
Recovery strategy:
- First attempt: temperature=0.0, top_p=1.0 (greedy)
- Recovery: temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
- If has complete segments: use assistant prefix
- If no complete segments: restart from scratch
- Max 3 retries, if all fail output error message
Returns:
Final transcription text
"""
import sys as _sys
def _log(msg):
"""Log to stderr only if debug."""
if debug:
print(msg, file=_sys.stderr)
detector = RepetitionDetector(
min_pattern_len=10, # At least 10 chars for a pattern
min_repeats=10, # Must repeat 10+ times
window_size=400, # Check last 400 chars (can detect 10-40 char patterns repeated 10 times)
)
accumulated_text = ""
retry_count = 0
user_safe_printed_len = 0 # Track how much we've safely shown to user (at segment boundaries)
is_recovery = False # Whether we're in recovery mode
while retry_count <= max_retries:
# Build request payload
messages = list(base_messages) # Copy base messages
# If we have accumulated text from previous attempt, add it as partial assistant response
if accumulated_text:
# Add the good content as a partial assistant message
# vLLM will continue from here
messages.append({
"role": "assistant",
"content": accumulated_text
})
# Set sampling parameters based on recovery state
if is_recovery:
# Recovery: increase temperature each retry to break loops
recovery_temp = 0.1 + 0.1 * retry_count # 0.2, 0.3, 0.4 for retry 1, 2, 3
payload = {
"model": "vibevoice",
"messages": messages,
"max_tokens": max_tokens,
"temperature": recovery_temp,
"top_p": 0.95,
"stream": True,
}
if accumulated_text:
_log(f"[RECOVERY #{retry_count}] Continuing from {len(accumulated_text)} chars with temp={recovery_temp}, top_p=0.95")
else:
_log(f"[RECOVERY #{retry_count}] Restarting from scratch with temp={recovery_temp}, top_p=0.95")
else:
# First attempt: greedy decoding
payload = {
"model": "vibevoice",
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.0,
"top_p": 1.0,
"stream": True,
}
try:
response = requests.post(url, json=payload, stream=True, timeout=timeout)
if response.status_code != 200:
_log(f"[ERROR] {response.status_code} - {response.text[:500]}")
return accumulated_text
new_text = ""
printed = "" # Track what we've already received to handle vLLM duplicates
for line in response.iter_lines():
if not line:
continue
decoded_line = line.decode('utf-8')
if not decoded_line.startswith("data: "):
continue
json_str = decoded_line[6:]
if json_str.strip() == "[DONE]":
# Successfully finished without loops
full_result = accumulated_text + new_text
# Print any remaining content that wasn't printed yet
if len(full_result) > user_safe_printed_len:
remaining = full_result[user_safe_printed_len:]
print(remaining, end='', flush=True)
print() # Final newline
return full_result
try:
data = json.loads(json_str)
delta = data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only track the newly-added part.
if content.startswith(printed):
to_add = content[len(printed):]
else:
to_add = content
if to_add:
printed += to_add
new_text += to_add
# When continuing from prefix, model may add "[" or "[{" at start
# or repeat the ending "}, " from prefix
# We need to handle these to maintain valid JSON array format
if accumulated_text and new_text:
stripped = new_text.lstrip()
# Case 1: Model added "[{" - remove the "["
if stripped.startswith("[{"):
new_text = stripped[1:]
_log("[STRIPPED leading '[' from continuation]")
# Case 2: Model added just "[" - remove it
elif stripped.startswith("["):
new_text = stripped[1:]
_log("[STRIPPED leading '[' from continuation]")
# Case 3: Model repeated "}," from prefix ending
elif stripped.startswith("},"):
new_text = stripped[2:]
_log("[STRIPPED leading '},' from continuation]")
# Case 4: Model repeated "}" from prefix ending
elif stripped.startswith("}") and not stripped.startswith("}]"):
new_text = stripped[1:]
_log("[STRIPPED leading '}' from continuation]")
# Fix malformed JSON: {"2.99,... -> {"Start":2.99,...
# This happens when model skips "Start": key
import re
malformed = re.match(r'^\{"(\d+\.?\d*),', new_text)
if malformed:
time_val = malformed.group(1)
new_text = '{"Start":' + time_val + ',' + new_text[malformed.end():]
_log(f"[FIXED malformed JSON: added Start key]")
# Check for repetition in the combined text
full_text = accumulated_text + new_text
detector.text = full_text
is_looping, good_end = detector._check_repetition()
if is_looping:
_log(f"[LOOP DETECTED at char {good_end}]")
# Use what user has already seen as prefix for retry
# user_safe_printed_len is always at a segment boundary
if user_safe_printed_len > 0:
accumulated_text = full_text[:user_safe_printed_len]
_log(f"[RETRY from user-visible content at {user_safe_printed_len}]")
else:
# No complete segment shown to user yet - restart from scratch
accumulated_text = ""
_log(f"[NO CONTENT SHOWN TO USER - restart from scratch]")
detector.reset(accumulated_text)
is_recovery = True
if debug:
print("\n[...recovering...]", end='', flush=True, file=sys.stderr)
retry_count += 1
if retry_count > max_retries:
_log(f"[MAX RETRIES REACHED]")
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
return None
# Break inner loop to retry
break
else:
# No loop detected - stream content to user
# Only print up to (full_text_len - window_size) at segment boundaries
# This ensures user never sees content that might be rolled back
safe_end = max(0, len(full_text) - detector.window_size)
safe_boundary = _find_safe_print_boundary(full_text, safe_end)
if safe_boundary > user_safe_printed_len:
# Print new safe content
to_print = full_text[user_safe_printed_len:safe_boundary]
print(to_print, end='', flush=True)
user_safe_printed_len = safe_boundary
except json.JSONDecodeError:
continue
else:
# Loop completed without break (no repetition detected)
full_result = accumulated_text + new_text
# Print any remaining content that wasn't printed yet
if len(full_result) > user_safe_printed_len:
remaining = full_result[user_safe_printed_len:]
print(remaining, end='', flush=True)
print() # Final newline
return full_result
except requests.exceptions.Timeout:
_log("[TIMEOUT]")
print()
return accumulated_text
except Exception as e:
_log(f"[ERROR: {e}]")
print()
return accumulated_text
# All retries exhausted
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
return None
def test_transcription_with_recovery(
audio_path: str,
output_path: str = None,
base_url: str = "http://localhost:8000",
hotwords: str = None,
debug: bool = False,
):
"""
Test ASR transcription with auto-recovery from repetition loops.
Args:
audio_path: Path to the audio file
output_path: Optional path to save transcription result
base_url: vLLM server URL
hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
debug: Show recovery debug info
"""
print(f"=" * 70)
print(f"Testing with Auto-Recovery")
print(f"=" * 70)
print(f"Input file: {audio_path}")
print(f"Hotwords: {hotwords or '(none)'}")
print()
# Handle video files: extract audio first
temp_audio_path = None
actual_audio_path = audio_path
if _is_video_file(audio_path):
print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path
print(f"✅ Audio extracted to: {temp_audio_path}")
# Load audio
try:
duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds")
with open(actual_audio_path, "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
print(f"Audio size: {len(audio_bytes)} bytes")
except Exception as e:
print(f"❌ Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return
url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
if hotwords and hotwords.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{hotwords}'")
else:
prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}"
# Base messages (without assistant continuation)
base_messages = [
{
"role": "system",
"content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
},
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": data_url}},
{"type": "text", "text": prompt_text}
]
}
]
print(f"\n{'=' * 70}")
print(f"Sending request to {url}")
print(f"{'=' * 70}")
t0 = time.time()
print("\n✅ Response received. Streaming content:\n")
print("-" * 50)
result = stream_with_recovery(
url=url,
base_messages=base_messages,
audio_data_url=data_url,
prompt_text=prompt_text,
max_tokens=32768,
max_retries=3,
debug=debug,
)
elapsed = time.time() - t0
print("-" * 50)
print("✅ [Finished]")
print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"{'=' * 70}")
if result is None:
print("❌ Transcription failed")
return
print(f"📄 Final output length: {len(result)} chars")
# Optionally save result
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
f.write(result)
print(f"💾 Result saved to: {output_path}")
# Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main():
parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with auto-recovery from repetition loops"
)
parser.add_argument(
"audio_path",
help="Path to audio file (wav, mp3, flac, etc.) or video file"
)
parser.add_argument(
"output_path",
nargs="?",
default=None,
help="Optional path to save transcription result"
)
parser.add_argument(
"--url",
default="http://localhost:8000",
help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
)
parser.add_argument(
"--debug",
action="store_true",
help="Show recovery debug info"
)
args = parser.parse_args()
# Run test
test_transcription_with_recovery(
audio_path=args.audio_path,
output_path=args.output_path,
base_url=args.url,
hotwords=args.hotwords,
debug=args.debug,
)
if __name__ == "__main__":
main()