Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 61ecb098d6 | |||
| b4cd7c479f | |||
| 11dd7420ec |
@@ -3,13 +3,11 @@
|
||||
## 🎙️ VibeVoice: Open-Source Frontier Voice AI
|
||||
[](https://microsoft.github.io/VibeVoice)
|
||||
[](https://huggingface.co/collections/microsoft/vibevoice-68a2ef24a875c44be47b034f)
|
||||
[](https://openreview.net/pdf?id=FihSkzyxdv)
|
||||
[](https://arxiv.org/pdf/2508.19205)
|
||||
[](https://arxiv.org/pdf/2601.18184)
|
||||
[](https://colab.research.google.com/github/microsoft/VibeVoice/blob/main/demo/VibeVoice_colab.ipynb)
|
||||
[](https://aka.ms/vibevoice-asr)
|
||||
|
||||
[](https://trendshift.io/repositories/15465)
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -24,13 +22,7 @@
|
||||
|
||||
<h3>📰 News</h3>
|
||||
|
||||
<strong>2026-03-29: 🎉 VibeVoice-ASR is being adopted by the open-source community! <a href="https://vibingjustspeakit.github.io/Vibing/">Vibing</a>, a voice-powered input method, is now built on top of VibeVoice-ASR. Download: [macOS](https://github.com/VibingJustSpeakIt/Vibing/releases/download/v0.1.0/Vibing-v0.1.0-mac.dmg) | [Windows](https://github.com/VibingJustSpeakIt/Vibing/releases/download/v0.1.0/Vibing-v0.1.0-windows.exe)</strong>
|
||||
|
||||
https://github.com/user-attachments/assets/db0bb23f-ae06-4135-a66a-1ff1669f4f84
|
||||
|
||||
<strong>2026-03-06: 🚀 VibeVoice ASR is now part of a <a href="https://github.com/huggingface/transformers/releases/tag/v5.3.0">Transformers release</a>! You can now use our speech recognition model directly through the Hugging Face Transformers library for seamless integration into your projects.</strong>
|
||||
|
||||
<strong>2026-01-21:</strong> 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. Try it in [Playground](https://aka.ms/vibevoice-asr).
|
||||
<strong>2026-01-21: 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. Try it in [Playground](https://aka.ms/vibevoice-asr)</strong>.
|
||||
- ⭐️ VibeVoice-ASR is natively multilingual, supporting over 50 languages — check the [supported languages](docs/vibevoice-asr.md#language-distribution) for details.
|
||||
- 🔥 The VibeVoice-ASR [finetuning code](finetuning-asr/README.md) is now available!
|
||||
- ⚡️ **vLLM inference** is now supported for faster inference; see [vllm-asr](docs/vibevoice-vllm-asr.md) for more details.
|
||||
@@ -44,7 +36,7 @@ https://github.com/user-attachments/assets/db0bb23f-ae06-4135-a66a-1ff1669f4f84
|
||||
2025-09-05: VibeVoice is an open-source research framework intended to advance collaboration in the speech synthesis community. After release, we discovered instances where the tool was used in ways inconsistent with the stated intent. Since responsible use of AI is one of Microsoft’s guiding principles, we have removed the VibeVoice-TTS code from this repository.
|
||||
|
||||
|
||||
2025-08-25: 📣 We open-sourced <a href="docs/vibevoice-tts.md"><strong>VibeVoice-TTS</strong></a>, a long-form multi-speaker text-to-speech model that can synthesize speech up to 90 minutes long with up to 4 distinct speakers. — accepted as an [Oral](https://openreview.net/forum?id=FihSkzyxdv) at ICLR 2026! 🔥
|
||||
2025-08-25: 📣 We open-sourced <a href="docs/vibevoice-tts.md"><strong>VibeVoice-TTS</strong></a>, a long-form multi-speaker text-to-speech model that can synthesize speech up to 90 minutes long with up to 4 distinct speakers.
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -66,7 +66,7 @@
|
||||
"print(\"✅ Cloned VibeVoice repository\")\n",
|
||||
"\n",
|
||||
"# Install project dependencies\n",
|
||||
"!uv pip --quiet install --system -e /content/VibeVoice[streamingtts]\n",
|
||||
"!uv pip --quiet install --system -e /content/VibeVoice[tts]\n",
|
||||
"!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared && chmod +x cloudflared\n",
|
||||
"print(\"✅ Installed dependencies\")\n",
|
||||
"\n",
|
||||
|
||||
+1
-1
@@ -142,7 +142,7 @@ class StreamingTTSService:
|
||||
if name and name in self.voice_presets:
|
||||
return name
|
||||
|
||||
default_key = "en-Carter_man"
|
||||
default_key = "en-WHTest_man"
|
||||
if default_key in self.voice_presets:
|
||||
return default_key
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulim
|
||||
git clone https://github.com/microsoft/VibeVoice.git
|
||||
cd VibeVoice/
|
||||
|
||||
pip install -e .[streamingtts]
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ Deploy VibeVoice ASR model as a high-performance API service using [vLLM](https:
|
||||
- **📡 OpenAI-Compatible API**: Standard `/v1/chat/completions` endpoint with streaming support
|
||||
- **🎵 Long Audio Support**: Process up to 60+ minutes of audio in a single request
|
||||
- **🔌 Plugin Architecture**: No vLLM source code modification required - just install and run
|
||||
- **⚡ Data Parallel (DP)**: Run independent model replicas across multiple GPUs with automatic load balancing behind a single port
|
||||
|
||||
## 🛠️ Installation
|
||||
|
||||
@@ -32,87 +31,10 @@ docker run -d --gpus all --name vibevoice-vllm \
|
||||
-v $(pwd):/app \
|
||||
-w /app \
|
||||
--entrypoint bash \
|
||||
vllm/vllm-openai:v0.14.1 \
|
||||
vllm/vllm-openai:latest \
|
||||
-c "python3 /app/vllm_plugin/scripts/start_server.py"
|
||||
```
|
||||
|
||||
## ⚡ Multi-GPU Deployment
|
||||
|
||||
The launcher supports two types of GPU parallelism via `--tp` and `--dp` flags:
|
||||
|
||||
| Flag | Name | What it does |
|
||||
|------|------|-------------|
|
||||
| `--tp N` | Tensor Parallel | Splits **one model** across N GPUs (for models too large for a single GPU) |
|
||||
| `--dp N` | Data Parallel | Runs **N independent replicas**, one per GPU, with automatic load balancing behind a single port |
|
||||
|
||||
### Data Parallel (Recommended for scaling throughput)
|
||||
|
||||
Run N independent replicas on N GPUs with automatic load balancing behind a single port.
|
||||
When `--dp N` is specified (N > 1), the launcher automatically starts N independent vLLM
|
||||
processes behind an nginx reverse proxy (2×N workers) for optimal throughput:
|
||||
|
||||
```bash
|
||||
docker run -d --gpus '"device=0,1,2,3"' --name vibevoice-vllm \
|
||||
--ipc=host \
|
||||
-p 8000:8000 \
|
||||
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
|
||||
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
|
||||
-v $(pwd):/app \
|
||||
-w /app \
|
||||
--entrypoint bash \
|
||||
vllm/vllm-openai:v0.14.1 \
|
||||
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 4"
|
||||
```
|
||||
|
||||
Run on all 8 GPUs:
|
||||
|
||||
```bash
|
||||
docker run -d --gpus all --name vibevoice-vllm \
|
||||
--ipc=host \
|
||||
-p 8000:8000 \
|
||||
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
|
||||
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
|
||||
-v $(pwd):/app \
|
||||
-w /app \
|
||||
--entrypoint bash \
|
||||
vllm/vllm-openai:v0.14.1 \
|
||||
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 8"
|
||||
```
|
||||
|
||||
### Tensor Parallel
|
||||
|
||||
Split a single model across 2 GPUs (useful if GPU memory is limited):
|
||||
|
||||
```bash
|
||||
docker run -d --gpus '"device=0,1"' --name vibevoice-vllm \
|
||||
--ipc=host \
|
||||
-p 8000:8000 \
|
||||
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
|
||||
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
|
||||
-v $(pwd):/app \
|
||||
-w /app \
|
||||
--entrypoint bash \
|
||||
vllm/vllm-openai:v0.14.1 \
|
||||
-c "python3 /app/vllm_plugin/scripts/start_server.py --tp 2"
|
||||
```
|
||||
|
||||
### Hybrid (DP × TP)
|
||||
|
||||
Combine both — e.g., 2 replicas, each split across 2 GPUs (4 GPUs total):
|
||||
|
||||
```bash
|
||||
docker run -d --gpus '"device=0,1,2,3"' --name vibevoice-vllm \
|
||||
--ipc=host \
|
||||
-p 8000:8000 \
|
||||
-v $(pwd):/app \
|
||||
-w /app \
|
||||
--entrypoint bash \
|
||||
vllm/vllm-openai:v0.14.1 \
|
||||
-c "python3 /app/vllm_plugin/scripts/start_server.py --dp 2 --tp 2"
|
||||
```
|
||||
|
||||
> **Note**: Total GPUs required = `dp × tp`. Make sure to expose enough GPU devices in the Docker `--gpus` flag.
|
||||
|
||||
3. View logs
|
||||
```bash
|
||||
docker logs -f vibevoice-vllm
|
||||
@@ -130,25 +52,10 @@ docker logs -f vibevoice-vllm
|
||||
Once the vLLM server is running, test it with the provided script:
|
||||
|
||||
```bash
|
||||
# Basic transcription
|
||||
# Run the test (use container path /app/...)
|
||||
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
|
||||
|
||||
# With hotwords for better recognition of specific terms
|
||||
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
|
||||
|
||||
```
|
||||
|
||||
```bash
|
||||
# With auto-recovery from repetition loops (for long audio)
|
||||
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
|
||||
|
||||
# Auto-recover with hotwords
|
||||
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
|
||||
```
|
||||
|
||||
> **Note**:
|
||||
> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing.
|
||||
> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names.
|
||||
> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
|
||||
@@ -38,12 +38,6 @@ dependencies = [
|
||||
"requests",
|
||||
]
|
||||
|
||||
|
||||
[project.optional-dependencies]
|
||||
streamingtts = [
|
||||
"transformers==4.51.3",
|
||||
]
|
||||
|
||||
[project.entry-points."vllm.general_plugins"]
|
||||
vibevoice = "vllm_plugin:register_vibevoice"
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
||||
|
||||
from .model import VibeVoiceForCausalLM
|
||||
from .inputs import vibevoice_audio_input_mapper
|
||||
|
||||
|
||||
def register_vibevoice():
|
||||
|
||||
+284
-156
@@ -5,11 +5,17 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
|
||||
integration for speech-to-text inference.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
|
||||
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
import tempfile
|
||||
import base64
|
||||
|
||||
|
||||
@@ -23,12 +29,32 @@ import base64
|
||||
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
|
||||
|
||||
|
||||
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
|
||||
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
|
||||
def _suffix_from_media_type(media_type: str | None) -> str:
|
||||
if not media_type:
|
||||
return ".bin"
|
||||
mt = media_type.lower().strip()
|
||||
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
|
||||
return ".wav"
|
||||
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
|
||||
return ".mp3"
|
||||
if mt in ("audio/flac",):
|
||||
return ".flac"
|
||||
if mt in ("audio/ogg", "audio/opus"):
|
||||
return ".ogg"
|
||||
if mt in ("audio/mp4", "audio/m4a"):
|
||||
return ".m4a"
|
||||
if mt in ("video/mp4",):
|
||||
return ".mp4"
|
||||
return ".bin"
|
||||
|
||||
|
||||
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
|
||||
"""Load audio bytes using FFmpeg.
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
|
||||
"""
|
||||
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
|
||||
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
|
||||
normalizer = AudioNormalizer()
|
||||
audio = normalizer(audio)
|
||||
@@ -46,53 +72,91 @@ def _ffmpeg_load_file(filepath) -> tuple[np.ndarray, int]:
|
||||
return audio, sr
|
||||
|
||||
# Register FFmpeg-based audio loader
|
||||
import vllm.multimodal.audio as _vllm_audio_module
|
||||
import warnings
|
||||
|
||||
# Handle both old and new vLLM versions
|
||||
# In newer versions, AudioMediaIO may not exist or may have been moved
|
||||
try:
|
||||
# Try new location (vLLM >= 0.6.x)
|
||||
from vllm.multimodal.media.audio import AudioMediaIO as _OriginalAudioMediaIO
|
||||
except ImportError:
|
||||
# Fall back to old location (vLLM < 0.6.x)
|
||||
import vllm.multimodal.audio as _vllm_audio_module
|
||||
_OriginalAudioMediaIO = _vllm_audio_module.AudioMediaIO
|
||||
|
||||
class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
|
||||
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_bytes(data)
|
||||
class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
|
||||
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Call parent constructor if it exists
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_bytes(data, media_type=None)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
|
||||
|
||||
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_file(filepath)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_file(filepath)
|
||||
|
||||
# Replace globally
|
||||
try:
|
||||
# For new vLLM versions
|
||||
import vllm.multimodal.media.audio as _vllm_audio_module
|
||||
# Replace globally
|
||||
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
|
||||
except ImportError:
|
||||
# For old vLLM versions
|
||||
import vllm.multimodal.audio as _vllm_audio_module
|
||||
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
|
||||
|
||||
# Also patch in utils module where it's imported
|
||||
try:
|
||||
import vllm.multimodal.utils as _vllm_utils_module
|
||||
_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
|
||||
except (ImportError, AttributeError):
|
||||
# AudioMediaIO might not be imported in utils in newer versions
|
||||
pass
|
||||
|
||||
# Also patch in utils module where it's imported
|
||||
try:
|
||||
import vllm.multimodal.utils as _vllm_utils_module
|
||||
_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
|
||||
except (ImportError, AttributeError) as e:
|
||||
warnings.warn(f"Could not patch AudioMediaIO in vllm.multimodal.utils: {e}", UserWarning)
|
||||
|
||||
except (AttributeError, ImportError) as e:
|
||||
# AudioMediaIO doesn't exist in this vLLM version
|
||||
# Define our own standalone implementation
|
||||
warnings.warn(
|
||||
f"AudioMediaIO not found in vllm.multimodal.audio ({e}). "
|
||||
"Using standalone FFmpeg-based implementation for compatibility.",
|
||||
UserWarning
|
||||
)
|
||||
|
||||
class _PatchedAudioMediaIO:
|
||||
"""Standalone AudioMediaIO implementation using FFmpeg for audio decoding.
|
||||
|
||||
This is used when vLLM doesn't provide AudioMediaIO or it's been moved/removed.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_bytes(data, media_type=None)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
|
||||
|
||||
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
||||
return _ffmpeg_load_file(filepath)
|
||||
|
||||
# Try to register it in the module if possible
|
||||
try:
|
||||
_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
|
||||
except (AttributeError, TypeError) as e:
|
||||
warnings.warn(
|
||||
f"Could not register AudioMediaIO in vllm.multimodal.audio: {e}. "
|
||||
"Audio loading will use FFmpeg functions directly.",
|
||||
UserWarning
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
|
||||
from transformers import BatchFeature
|
||||
from transformers import Qwen2Config, BatchFeature
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.config import VllmConfig, ModelConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.model_executor.models.utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
@@ -107,17 +171,7 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
try:
|
||||
# Try new location (vLLM >= 0.6.x)
|
||||
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
|
||||
except ImportError:
|
||||
# Fall back to old location (vLLM < 0.6.x)
|
||||
try:
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
except ImportError:
|
||||
# If neither location works, try individual imports
|
||||
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
|
||||
from vllm.multimodal.processing.inputs import ProcessorInputs
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
# Import VibeVoice components
|
||||
from vibevoice.modular.modular_vibevoice_tokenizer import (
|
||||
@@ -554,88 +608,30 @@ class VibeVoiceProcessingInfo(BaseProcessingInfo):
|
||||
return tokens
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
"""Return the maximum number of audio tokens per item.
|
||||
|
||||
This tells vLLM's scheduler the upper bound so that
|
||||
``encoder_compute_budget`` is large enough for any audio length
|
||||
the model can handle, preventing the silent scheduling deadlock
|
||||
described in docs/max_num_batched_tokens_issue.md.
|
||||
|
||||
Formula: audio_tokens = ceil(audio_samples / compress_ratio) + 3
|
||||
where +3 accounts for speech_start, speech_end, and newline tokens.
|
||||
The max audio samples is bounded by seq_len (the model's context
|
||||
window cannot hold more tokens than that).
|
||||
"""
|
||||
hf_config = self.get_hf_config()
|
||||
|
||||
def _cfg(key: str, default):
|
||||
if isinstance(hf_config, dict):
|
||||
return hf_config.get(key, default)
|
||||
return getattr(hf_config, key, default)
|
||||
|
||||
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
|
||||
sample_rate = int(_cfg("target_sample_rate", 24000))
|
||||
|
||||
# Upper bound: 61-minute audio at 24 kHz
|
||||
max_audio_samples = 61 * 60 * sample_rate # 88,464,000
|
||||
max_audio_tokens = int(np.ceil(max_audio_samples / compress_ratio)) + 3
|
||||
|
||||
# Cannot exceed the model's context window
|
||||
max_audio_tokens = min(max_audio_tokens, seq_len)
|
||||
|
||||
return {"audio": max_audio_tokens}
|
||||
return {"audio": None}
|
||||
|
||||
|
||||
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
|
||||
"""
|
||||
Build dummy inputs for multimodal profiling.
|
||||
|
||||
vLLM uses dummy inputs to:
|
||||
1. Measure peak GPU activation memory → determine KV cache capacity
|
||||
2. Warm up CUDA graphs
|
||||
|
||||
The dummy audio length must be consistent with ``get_mm_max_tokens_per_item``
|
||||
so that the memory estimate covers the worst-case (longest audio) scenario.
|
||||
Dummy text uses the raw <|AUDIO|> token(s). vLLM's processing pipeline will
|
||||
expand each <|AUDIO|> via `VibeVoiceMultiModalProcessor._get_prompt_updates`
|
||||
into the full ASR format:
|
||||
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
|
||||
where N is derived from audio length / compress_ratio.
|
||||
"""
|
||||
|
||||
def _get_max_audio_samples(self, seq_len: int) -> int:
|
||||
"""Compute maximum audio samples consistent with ``get_mm_max_tokens_per_item``.
|
||||
|
||||
Uses the same formula: max_tokens = min(ceil(61min * sr / ratio) + 3, seq_len),
|
||||
then converts back to samples.
|
||||
"""
|
||||
hf_config = self.info.get_hf_config()
|
||||
|
||||
def _cfg(key: str, default):
|
||||
if isinstance(hf_config, dict):
|
||||
return hf_config.get(key, default)
|
||||
return getattr(hf_config, key, default)
|
||||
|
||||
compress_ratio = int(_cfg("speech_tok_compress_ratio", 3200))
|
||||
sample_rate = int(_cfg("target_sample_rate", 24000))
|
||||
|
||||
# Upper bound: 61-minute audio at 24 kHz
|
||||
max_hour_samples = 61 * 60 * sample_rate # 88,464,000
|
||||
max_tokens_from_audio = int(np.ceil(max_hour_samples / compress_ratio)) + 3
|
||||
# Cannot exceed model context window
|
||||
max_tokens = min(max_tokens_from_audio, seq_len)
|
||||
# Convert tokens back to samples
|
||||
return max_tokens * compress_ratio
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
if num_audios <= 0:
|
||||
return ""
|
||||
|
||||
# Get the audio token from our token info helper
|
||||
token_info = self.info.get_audio_token_info()
|
||||
audio_token = token_info["audio_token"]
|
||||
|
||||
# Return ONLY the audio tokens - the HF processor adds bos/eos
|
||||
return audio_token * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
@@ -644,23 +640,16 @@ class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate dummy audio data for profiling.
|
||||
"""Generate dummy audio data for profiling."""
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
The audio length is derived from ``seq_len`` so that profiling
|
||||
accurately measures memory for the longest audio the model can handle.
|
||||
Supports ``AudioDummyOptions.length`` override for faster startup.
|
||||
"""
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
max_audio_len = self._get_max_audio_samples(seq_len)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
|
||||
# Generate dummy audio as numpy arrays (what the HF processor expects)
|
||||
return {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=max_audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
"audio": [np.zeros(audio_len, dtype=np.float32) for _ in range(num_audios)]
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
@@ -934,6 +923,17 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
with a causal language model for text generation.
|
||||
"""
|
||||
|
||||
# SupportsTranscription interface
|
||||
supports_transcription: ClassVar[Literal[True]] = True
|
||||
supports_transcription_only: ClassVar[bool] = False
|
||||
supports_segment_timestamp: ClassVar[bool] = False
|
||||
|
||||
# Supported languages (Chinese as primary target)
|
||||
supported_languages: ClassVar[Mapping[str, str]] = {
|
||||
"zh": "Chinese",
|
||||
"en": "English",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
"""Return the placeholder string format for a given modality.
|
||||
@@ -947,10 +947,112 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return "<|AUDIO|>"
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
"""Get the prompt for the ASR model.
|
||||
|
||||
Generates a chat-formatted prompt for speech-to-text transcription
|
||||
with JSON output format.
|
||||
"""
|
||||
# If user provides custom prompt, use it
|
||||
if request_prompt:
|
||||
return request_prompt
|
||||
|
||||
# Calculate audio duration for the prompt
|
||||
# Audio should be at 24kHz, so duration = len(audio) / 24000
|
||||
duration = len(audio) / 24000 if audio is not None else 10.0
|
||||
|
||||
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
||||
user_suffix = (
|
||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||
+ ", ".join(show_keys)
|
||||
)
|
||||
|
||||
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
|
||||
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
|
||||
user_content = "<|AUDIO|>\n" + user_suffix
|
||||
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{user_content}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
|
||||
) -> SpeechToTextConfig:
|
||||
"""Get the speech to text config for the ASR model."""
|
||||
return SpeechToTextConfig(
|
||||
language=None, # Auto-detect or use request language
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> int | None:
|
||||
"""Estimate number of audio tokens from duration.
|
||||
|
||||
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
|
||||
Note: _get_prompt_updates actually generates:
|
||||
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
|
||||
So total prompt tokens = N + 3, but this returns N (the embedding count).
|
||||
"""
|
||||
sampling_rate = 24000
|
||||
compress_ratio = 3200
|
||||
samples = int(audio_duration_s * sampling_rate)
|
||||
num_tokens = int(np.ceil(samples / compress_ratio))
|
||||
return num_tokens
|
||||
|
||||
@classmethod
|
||||
def get_other_languages(cls) -> Mapping[str, str]:
|
||||
"""Get languages from Whisper map not natively supported."""
|
||||
# Import LANGUAGES from vllm
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import LANGUAGES
|
||||
except ImportError:
|
||||
# Fallback to empty dict if import fails
|
||||
return {}
|
||||
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str | None) -> str | None:
|
||||
"""Validate the language code."""
|
||||
if language is None or language in cls.supported_languages:
|
||||
return language
|
||||
elif language in cls.get_other_languages():
|
||||
print(f"Warning: Language {language!r} is not natively supported")
|
||||
return language
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported language: {language!r}. "
|
||||
f"Supported: {list(cls.supported_languages.keys())}"
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
|
||||
# Keep a copy of the resolved model path for any custom weight-loading
|
||||
# logic (e.g., loading audio encoder weights in fp32 directly from
|
||||
# safetensors shards).
|
||||
self._model_path = vllm_config.model_config.model
|
||||
|
||||
self.audio_encoder = VibeVoiceAudioEncoder(config)
|
||||
|
||||
@@ -1048,54 +1150,76 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# Process each audio through the VibeVoice encoder
|
||||
embeddings = []
|
||||
|
||||
# Get model device for tensor placement.
|
||||
# Get model device and dtype for alignment
|
||||
try:
|
||||
device = next(self.audio_encoder.parameters()).device
|
||||
dtype = next(self.audio_encoder.parameters()).dtype
|
||||
except StopIteration:
|
||||
# Fallback if no parameters (shouldn't happen)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Handle both stacked tensor and list of tensors
|
||||
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
||||
if isinstance(raw_audio, torch.Tensor):
|
||||
if raw_audio.dim() == 3:
|
||||
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
|
||||
num_audios = raw_audio.shape[0]
|
||||
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
|
||||
elif raw_audio.dim() == 2:
|
||||
# Shape: [batch_size, seq_len]
|
||||
num_audios = raw_audio.shape[0]
|
||||
audio_list = [raw_audio[i] for i in range(num_audios)]
|
||||
else:
|
||||
# Single 1D tensor
|
||||
audio_list = [raw_audio]
|
||||
elif isinstance(raw_audio, (list, tuple)):
|
||||
audio_list = list(raw_audio)
|
||||
else:
|
||||
# Single tensor
|
||||
audio_list = [raw_audio]
|
||||
|
||||
for i, audio_tensor in enumerate(audio_list):
|
||||
try:
|
||||
if isinstance(audio_tensor, list):
|
||||
audio_tensor = torch.stack(audio_tensor)
|
||||
|
||||
# Ensure tensor
|
||||
if not isinstance(audio_tensor, torch.Tensor):
|
||||
audio_tensor = torch.tensor(audio_tensor)
|
||||
|
||||
# Let vLLM handle dtype (bfloat16 by default)
|
||||
audio_tensor = audio_tensor.to(device=device)
|
||||
|
||||
# Get actual length if available, otherwise use full length
|
||||
if raw_audio_lengths and i < len(raw_audio_lengths):
|
||||
actual_len = int(raw_audio_lengths[i])
|
||||
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
|
||||
# Truncate from the last dimension (sequence length)
|
||||
audio_tensor = audio_tensor[..., :actual_len]
|
||||
if audio_tensor.numel() < 160:
|
||||
|
||||
# Skip if audio is too short (< 1 frame)
|
||||
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
|
||||
continue
|
||||
|
||||
# Encode audio through VibeVoice encoder
|
||||
audio_embeds = self.audio_encoder(
|
||||
audio_tensor,
|
||||
use_streaming=use_streaming_flag,
|
||||
segment_duration_s=streaming_segment_duration,
|
||||
)
|
||||
|
||||
# audio_embeds shape: [1, seq_len, hidden_size]
|
||||
# We need to return it as a single embedding tensor per audio
|
||||
final_embed = audio_embeds.squeeze(0)
|
||||
embeddings.append(final_embed)
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't crash - this helps debug profiling issues
|
||||
print(f"[VibeVoice] Error encoding audio {i}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Return empty embedding to avoid crash
|
||||
continue
|
||||
|
||||
return tuple(embeddings)
|
||||
@@ -1220,31 +1344,35 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
intermediate_tensors: Intermediate tensors for pipeline parallelism.
|
||||
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
|
||||
"""
|
||||
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
|
||||
# Only compute from input_ids if inputs_embeds is not available
|
||||
if inputs_embeds is None and input_ids is not None:
|
||||
# Compute embeddings from input_ids
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# Get the inner model - handle both wrapped and direct language models
|
||||
language_model = self.language_model
|
||||
if hasattr(language_model, "language_model"):
|
||||
language_model = language_model.language_model
|
||||
|
||||
# Call the language model's model (Qwen2Model)
|
||||
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
|
||||
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
|
||||
hidden_states = language_model.model(
|
||||
input_ids=None, # Always None when we have inputs_embeds
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
try:
|
||||
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
|
||||
# Only compute from input_ids if inputs_embeds is not available
|
||||
if inputs_embeds is None and input_ids is not None:
|
||||
# Compute embeddings from input_ids
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# Get the inner model - handle both wrapped and direct language models
|
||||
language_model = self.language_model
|
||||
if hasattr(language_model, "language_model"):
|
||||
language_model = language_model.language_model
|
||||
|
||||
# Call the language model's model (Qwen2Model)
|
||||
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
|
||||
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
|
||||
hidden_states = language_model.model(
|
||||
input_ids=None, # Always None when we have inputs_embeds
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
|
||||
# Alias for training checkpoint compatibility
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,21 +9,14 @@ One-click deployment script that handles:
|
||||
4. Generating tokenizer files
|
||||
5. Starting vLLM server
|
||||
|
||||
For DP > 1, launches N independent vLLM processes behind an nginx
|
||||
reverse proxy for optimal throughput (avoids single-process HTTP
|
||||
bottleneck of vLLM's built-in DP coordinator).
|
||||
|
||||
Usage:
|
||||
python3 start_server.py [--model MODEL_ID] [--port PORT]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
|
||||
|
||||
def run_command(cmd: list[str], description: str, shell: bool = False) -> None:
|
||||
@@ -84,268 +77,45 @@ def generate_tokenizer(model_path: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _build_vllm_cmd(model_path: str, port: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
data_parallel_size: int = 1,
|
||||
max_num_seqs: int = 64,
|
||||
max_model_len: int = 65536,
|
||||
gpu_memory_utilization: float = 0.8) -> list[str]:
|
||||
"""Build the vllm serve command."""
|
||||
return [
|
||||
def start_vllm_server(model_path: str, port: int) -> None:
|
||||
"""Start vLLM server (replaces current process)."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Starting vLLM server on port {port}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
vllm_cmd = [
|
||||
"vllm", "serve", model_path,
|
||||
"--served-model-name", "vibevoice",
|
||||
"--trust-remote-code",
|
||||
"--dtype", "bfloat16",
|
||||
"--max-num-seqs", str(max_num_seqs),
|
||||
"--max-model-len", str(max_model_len),
|
||||
"--gpu-memory-utilization", str(gpu_memory_utilization),
|
||||
"--max-num-seqs", "64",
|
||||
"--max-model-len", "65536",
|
||||
"--max-num-batched-tokens", "32768",
|
||||
"--gpu-memory-utilization", "0.8",
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--enable-chunked-prefill",
|
||||
"--chat-template-content-format", "openai",
|
||||
"--tensor-parallel-size", str(tensor_parallel_size),
|
||||
"--data-parallel-size", str(data_parallel_size),
|
||||
"--tensor-parallel-size", "1",
|
||||
"--allowed-local-media-path", "/app",
|
||||
"--port", str(port),
|
||||
]
|
||||
|
||||
|
||||
def start_vllm_server(model_path: str, port: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
data_parallel_size: int = 1,
|
||||
max_num_seqs: int = 64,
|
||||
max_model_len: int = 65536,
|
||||
gpu_memory_utilization: float = 0.8) -> None:
|
||||
"""Start a single vLLM server (replaces current process)."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Starting vLLM server on port {port}")
|
||||
print(f" Tensor Parallel (TP): {tensor_parallel_size}")
|
||||
print(f" Data Parallel (DP): {data_parallel_size}")
|
||||
print(f" Max Num Seqs: {max_num_seqs}")
|
||||
print(f" Max Model Len: {max_model_len}")
|
||||
print(f" GPU Mem Utilization: {gpu_memory_utilization}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
vllm_cmd = _build_vllm_cmd(
|
||||
model_path, port,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
)
|
||||
os.execvp("vllm", vllm_cmd)
|
||||
|
||||
|
||||
def _install_nginx() -> None:
|
||||
"""Install nginx if not already available."""
|
||||
if subprocess.run(["which", "nginx"], capture_output=True).returncode != 0:
|
||||
run_command(["apt-get", "update"], "Updating package list for nginx")
|
||||
run_command(
|
||||
["apt-get", "install", "-y", "nginx"],
|
||||
"Installing nginx for load balancing"
|
||||
)
|
||||
|
||||
|
||||
def _write_nginx_config(frontend_port: int, backend_ports: list[int],
|
||||
num_workers: int = 0) -> str:
|
||||
"""Write nginx config for round-robin load balancing.
|
||||
|
||||
Args:
|
||||
num_workers: Number of nginx worker processes. 0 = auto (2 × num backends).
|
||||
"""
|
||||
if num_workers <= 0:
|
||||
num_workers = len(backend_ports) * 2
|
||||
backends = "\n".join(f" server 127.0.0.1:{p};" for p in backend_ports)
|
||||
config = textwrap.dedent(f"""\
|
||||
worker_processes {num_workers};
|
||||
worker_rlimit_nofile 65536;
|
||||
error_log /dev/stderr warn;
|
||||
pid /tmp/nginx.pid;
|
||||
|
||||
events {{
|
||||
worker_connections 8192;
|
||||
}}
|
||||
|
||||
http {{
|
||||
access_log off;
|
||||
|
||||
upstream vllm_backends {{
|
||||
least_conn;
|
||||
{backends}
|
||||
}}
|
||||
|
||||
server {{
|
||||
listen {frontend_port};
|
||||
client_max_body_size 200m;
|
||||
client_body_buffer_size 10m;
|
||||
proxy_buffering on;
|
||||
proxy_buffer_size 64k;
|
||||
proxy_buffers 16 64k;
|
||||
|
||||
location / {{
|
||||
proxy_pass http://vllm_backends;
|
||||
proxy_read_timeout 600s;
|
||||
proxy_connect_timeout 10s;
|
||||
proxy_send_timeout 600s;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
""")
|
||||
config_path = "/tmp/nginx_vllm.conf"
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config)
|
||||
return config_path
|
||||
|
||||
|
||||
def start_dp_server(model_path: str, frontend_port: int,
|
||||
data_parallel_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
max_num_seqs: int = 64,
|
||||
max_model_len: int = 65536,
|
||||
gpu_memory_utilization: float = 0.8) -> None:
|
||||
"""Start multiple vLLM workers behind nginx for data parallelism.
|
||||
|
||||
Launches N independent vLLM processes (one per GPU group) on internal
|
||||
ports, with an nginx reverse proxy on the frontend port for load
|
||||
balancing. This avoids the single-process HTTP bottleneck of vLLM's
|
||||
built-in DP coordinator when handling large audio payloads.
|
||||
"""
|
||||
import torch
|
||||
num_gpus = torch.cuda.device_count()
|
||||
gpus_per_replica = tensor_parallel_size
|
||||
total_gpus_needed = data_parallel_size * gpus_per_replica
|
||||
assert num_gpus >= total_gpus_needed, (
|
||||
f"Need {total_gpus_needed} GPUs (dp={data_parallel_size} × tp={tensor_parallel_size}) "
|
||||
f"but only {num_gpus} available"
|
||||
)
|
||||
|
||||
# Auto-tune per-worker env vars based on dp size
|
||||
ffmpeg_concurrency = max(
|
||||
64, int(os.environ.get("VIBEVOICE_FFMPEG_MAX_CONCURRENCY", "64"))
|
||||
)
|
||||
media_threads = max(
|
||||
8, int(os.environ.get("VLLM_MEDIA_LOADING_THREAD_COUNT", "8"))
|
||||
)
|
||||
|
||||
_install_nginx()
|
||||
|
||||
# Assign internal ports: frontend_port + 100, +101, ...
|
||||
backend_ports = [frontend_port + 100 + i for i in range(data_parallel_size)]
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Starting DP server with nginx load balancing")
|
||||
print(f" Frontend port: {frontend_port} (nginx)")
|
||||
print(f" Backend ports: {backend_ports}")
|
||||
print(f" Data Parallel: {data_parallel_size}")
|
||||
print(f" Tensor Parallel: {tensor_parallel_size}")
|
||||
print(f" GPUs per replica: {gpus_per_replica}")
|
||||
print(f" Max Num Seqs: {max_num_seqs}")
|
||||
print(f" Max Model Len: {max_model_len}")
|
||||
print(f" FFmpeg concurrency (per worker): {ffmpeg_concurrency}")
|
||||
print(f" Media loading threads (per worker): {media_threads}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Write nginx config
|
||||
nginx_conf = _write_nginx_config(frontend_port, backend_ports)
|
||||
|
||||
# Launch vLLM workers
|
||||
workers: list[subprocess.Popen] = []
|
||||
for rank in range(data_parallel_size):
|
||||
gpu_start = rank * gpus_per_replica
|
||||
gpu_ids = ",".join(str(gpu_start + j) for j in range(gpus_per_replica))
|
||||
port = backend_ports[rank]
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = gpu_ids
|
||||
env["VIBEVOICE_FFMPEG_MAX_CONCURRENCY"] = str(ffmpeg_concurrency)
|
||||
env["VLLM_MEDIA_LOADING_THREAD_COUNT"] = str(media_threads)
|
||||
|
||||
vllm_cmd = _build_vllm_cmd(
|
||||
model_path, port,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=1, # Each worker is dp=1
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
)
|
||||
|
||||
print(f" Launching worker rank={rank} on GPU(s) {gpu_ids}, port {port}")
|
||||
proc = subprocess.Popen(vllm_cmd, env=env)
|
||||
workers.append(proc)
|
||||
|
||||
# Start nginx
|
||||
print(f"\n Starting nginx on port {frontend_port} ...")
|
||||
nginx_proc = subprocess.Popen(
|
||||
["nginx", "-c", nginx_conf, "-g", "daemon off;"]
|
||||
)
|
||||
|
||||
# Wait for all backends to be ready
|
||||
print(" Waiting for all backends to be ready ...")
|
||||
import urllib.request
|
||||
for port in backend_ports:
|
||||
url = f"http://127.0.0.1:{port}/v1/models"
|
||||
for attempt in range(600): # up to 10 minutes
|
||||
try:
|
||||
urllib.request.urlopen(url, timeout=2)
|
||||
print(f" ✅ Backend on port {port} is ready")
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
else:
|
||||
print(f" ❌ Backend on port {port} failed to start")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" ✅ VibeVoice DP server ready on port {frontend_port}")
|
||||
print(f" {data_parallel_size} replicas behind nginx load balancer")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Handle shutdown: forward signals to all children
|
||||
def _shutdown(signum, frame):
|
||||
print("\nShutting down ...")
|
||||
nginx_proc.terminate()
|
||||
for w in workers:
|
||||
w.terminate()
|
||||
for w in workers:
|
||||
w.wait(timeout=10)
|
||||
nginx_proc.wait(timeout=5)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _shutdown)
|
||||
signal.signal(signal.SIGINT, _shutdown)
|
||||
|
||||
# Wait for any child to exit (indicates a failure)
|
||||
while True:
|
||||
for i, w in enumerate(workers):
|
||||
ret = w.poll()
|
||||
if ret is not None:
|
||||
print(f" ❌ Worker {i} exited with code {ret}")
|
||||
_shutdown(None, None)
|
||||
if nginx_proc.poll() is not None:
|
||||
print(f" ❌ nginx exited with code {nginx_proc.returncode}")
|
||||
_shutdown(None, None)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VibeVoice vLLM ASR Server - One-Click Deployment",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Start with default settings (single GPU)
|
||||
# Start with default settings
|
||||
python3 start_server.py
|
||||
|
||||
# Use custom port
|
||||
python3 start_server.py --port 8080
|
||||
|
||||
# Data parallel: 4 replicas on 4 GPUs (nginx load balancing)
|
||||
python3 start_server.py --dp 4
|
||||
|
||||
# Tensor parallel: split model across 2 GPUs
|
||||
python3 start_server.py --tp 2
|
||||
|
||||
# Skip dependency installation (if already installed)
|
||||
python3 start_server.py --skip-deps
|
||||
"""
|
||||
@@ -371,41 +141,6 @@ Examples:
|
||||
action="store_true",
|
||||
help="Skip generating tokenizer files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp", "--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="tensor_parallel_size",
|
||||
help="Tensor parallel size: split one model across N GPUs (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp", "--data-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="data_parallel_size",
|
||||
help="Data parallel size: run N independent model replicas for load balancing (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-num-seqs",
|
||||
type=int,
|
||||
default=64,
|
||||
dest="max_num_seqs",
|
||||
help="Maximum number of sequences per batch (default: 64)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=65536,
|
||||
dest="max_model_len",
|
||||
help="Maximum model context length (default: 65536)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.8,
|
||||
dest="gpu_memory_utilization",
|
||||
help="GPU memory utilization fraction (default: 0.8)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("\n" + "="*60)
|
||||
@@ -426,25 +161,8 @@ Examples:
|
||||
if not args.skip_tokenizer:
|
||||
generate_tokenizer(model_path)
|
||||
|
||||
# Step 5: Start server
|
||||
if args.data_parallel_size > 1:
|
||||
start_dp_server(
|
||||
model_path, args.port,
|
||||
data_parallel_size=args.data_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
)
|
||||
else:
|
||||
start_vllm_server(
|
||||
model_path, args.port,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
data_parallel_size=1,
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
)
|
||||
# Step 5: Start vLLM server
|
||||
start_vllm_server(model_path, args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+91
-107
@@ -1,23 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.
|
||||
|
||||
This script tests ASR transcription via the vLLM OpenAI-compatible API.
|
||||
By default, it runs standard transcription without hotwords.
|
||||
|
||||
Optionally, you can provide hotwords (context_info) to improve recognition
|
||||
of domain-specific content like proper nouns, technical terms, and speaker names.
|
||||
Hotwords are embedded in the prompt as "with extra info: {hotwords}".
|
||||
Test VibeVoice vLLM API with Streaming (Real-time output).
|
||||
|
||||
Usage:
|
||||
python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
|
||||
python test_api.py [audio_path] [--url URL]
|
||||
|
||||
Examples:
|
||||
# Standard transcription (no hotwords)
|
||||
python3 test_api.py audio.wav
|
||||
|
||||
# With hotwords for better recognition of specific terms
|
||||
python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
|
||||
python test_api.py # Use default audio
|
||||
python test_api.py /path/to/audio.wav # Specify audio file
|
||||
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
@@ -30,38 +21,38 @@ import argparse
|
||||
|
||||
|
||||
def _guess_mime_type(path: str) -> str:
|
||||
"""Guess MIME type from file extension."""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
mime_map = {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".m4a": "audio/mp4",
|
||||
".mp4": "video/mp4",
|
||||
".flac": "audio/flac",
|
||||
".ogg": "audio/ogg",
|
||||
".opus": "audio/ogg",
|
||||
}
|
||||
return mime_map.get(ext, "application/octet-stream")
|
||||
if ext == ".wav":
|
||||
return "audio/wav"
|
||||
if ext in (".mp3",):
|
||||
return "audio/mpeg"
|
||||
if ext in (".m4a",):
|
||||
return "audio/mp4"
|
||||
if ext in (".mp4", ".m4v", ".mov", ".webm"):
|
||||
return "video/mp4"
|
||||
if ext in (".flac",):
|
||||
return "audio/flac"
|
||||
if ext in (".ogg", ".opus"):
|
||||
return "audio/ogg"
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def _get_duration_seconds_ffprobe(path: str) -> float:
|
||||
"""Get audio duration using ffprobe."""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
path,
|
||||
]
|
||||
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
|
||||
return float(out)
|
||||
|
||||
|
||||
def _is_video_file(path: str) -> bool:
|
||||
"""Check if the file is a video file that needs audio extraction."""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
||||
|
||||
|
||||
def _extract_audio_from_video(video_path: str) -> str:
|
||||
"""
|
||||
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
|
||||
@@ -83,40 +74,26 @@ def _extract_audio_from_video(video_path: str) -> str:
|
||||
return audio_path
|
||||
|
||||
|
||||
def test_transcription_with_hotwords(
|
||||
audio_path: str,
|
||||
context_info: str = None,
|
||||
base_url: str = "http://localhost:8000",
|
||||
):
|
||||
"""
|
||||
Test ASR transcription with customized hotwords.
|
||||
def _is_video_file(path: str) -> bool:
|
||||
"""Check if the file is a video file that needs audio extraction."""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
||||
|
||||
|
||||
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"):
|
||||
"""Test ASR transcription with streaming output."""
|
||||
|
||||
Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
|
||||
This helps the model recognize domain-specific terms more accurately.
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
|
||||
base_url: vLLM server URL
|
||||
"""
|
||||
|
||||
print(f"=" * 70)
|
||||
print(f"Testing Customized Hotwords Support")
|
||||
print(f"=" * 70)
|
||||
print(f"Input file: {audio_path}")
|
||||
print(f"Hotwords: {context_info or '(none)'}")
|
||||
print()
|
||||
print(f"Loading audio from: {audio_path}")
|
||||
|
||||
# Handle video files: extract audio first
|
||||
temp_audio_path = None
|
||||
actual_audio_path = audio_path
|
||||
if _is_video_file(audio_path):
|
||||
print(f"🎬 Detected video file, extracting audio...")
|
||||
print(f"Detected video file, extracting audio...")
|
||||
temp_audio_path = _extract_audio_from_video(audio_path)
|
||||
actual_audio_path = temp_audio_path
|
||||
print(f"✅ Audio extracted to: {temp_audio_path}")
|
||||
print(f"Audio extracted to: {temp_audio_path}")
|
||||
|
||||
# Load audio
|
||||
try:
|
||||
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
||||
print(f"Audio duration: {duration:.2f} seconds")
|
||||
@@ -129,30 +106,16 @@ def test_transcription_with_hotwords(
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error preparing audio: {e}")
|
||||
# Cleanup temp file if created
|
||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
return
|
||||
|
||||
# Build the request
|
||||
url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
||||
|
||||
# Build prompt with optional hotwords
|
||||
# Hotwords are embedded as "with extra info: {hotwords}" in the prompt
|
||||
if context_info and context_info.strip():
|
||||
prompt_text = (
|
||||
f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
|
||||
f"Please transcribe it with these keys: " + ", ".join(show_keys)
|
||||
)
|
||||
print(f"\n📝 Hotwords embedded in prompt: '{context_info}'")
|
||||
else:
|
||||
prompt_text = (
|
||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||
+ ", ".join(show_keys)
|
||||
)
|
||||
print(f"\n📝 No hotwords provided")
|
||||
prompt_text = (
|
||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||
+ ", ".join(show_keys)
|
||||
)
|
||||
|
||||
mime = _guess_mime_type(actual_audio_path)
|
||||
data_url = f"data:{mime};base64,{audio_b64}"
|
||||
@@ -176,19 +139,20 @@ def test_transcription_with_hotwords(
|
||||
"temperature": 0.0,
|
||||
"stream": True,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"Sending request to {url}")
|
||||
print(f"{'=' * 70}")
|
||||
print(f"\nSending request to {url} (Streaming Mode)...")
|
||||
print(f"Prompt: {prompt_text}")
|
||||
print("-" * 60)
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
|
||||
response = requests.post(url, json=payload, stream=True, timeout=12000)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("\n✅ Response received. Streaming content:\n")
|
||||
print("-" * 50)
|
||||
print("Response received. Streaming content:\n")
|
||||
|
||||
printed = ""
|
||||
for line in response.iter_lines():
|
||||
@@ -198,72 +162,92 @@ def test_transcription_with_hotwords(
|
||||
if decoded_line.startswith("data: "):
|
||||
json_str = decoded_line[6:]
|
||||
if json_str.strip() == "[DONE]":
|
||||
print("\n" + "-" * 50)
|
||||
print("✅ [Finished]")
|
||||
print("\n\n[Finished]")
|
||||
break
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
|
||||
delta = data['choices'][0]['delta']
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
|
||||
# vLLM/OpenAI-compatible streams may emit either
|
||||
# incremental deltas OR the full accumulated text.
|
||||
# Only print the newly-added part to avoid repeats.
|
||||
if content.startswith(printed):
|
||||
to_print = content[len(printed):]
|
||||
else:
|
||||
to_print = content
|
||||
|
||||
if to_print:
|
||||
print(to_print, end='', flush=True)
|
||||
printed += to_print
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
else:
|
||||
print(f"❌ Error: {response.status_code}")
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("❌ Request timed out!")
|
||||
print("\nRequest timed out!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
print(f"\nError: {e}")
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||
print(f"📊 RTF (Real-Time Factor): {elapsed / duration:.2f}x")
|
||||
print(f"{'=' * 70}")
|
||||
print(f"\n{'-'*60}")
|
||||
print(f"Total time elapsed: {time.time() - t0:.2f}s")
|
||||
|
||||
# Cleanup temp audio file if created
|
||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
|
||||
print(f"Cleaned up temp file: {temp_audio_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test VibeVoice vLLM API with Customized Hotwords"
|
||||
description="Test VibeVoice vLLM API with streaming output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Path to audio file (wav, mp3, flac, etc.) or video file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="http://localhost:8000",
|
||||
help="vLLM server URL (default: http://localhost:8000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hotwords",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
|
||||
help="vLLM server base URL (default: http://localhost:8000)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run test
|
||||
test_transcription_with_hotwords(
|
||||
audio_path=args.audio_path,
|
||||
context_info=args.hotwords,
|
||||
base_url=args.url,
|
||||
)
|
||||
# Find default audio if not specified
|
||||
audio_path = args.audio_path
|
||||
if audio_path is None:
|
||||
# Try to find a sample audio in common locations
|
||||
possible_paths = [
|
||||
# In VibeVoice demo folder
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
|
||||
# Relative to current directory
|
||||
"demo/voices/en-Carter_man.wav",
|
||||
"demo/voices/zh-Anchen_man_bgm.wav",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
audio_path = path
|
||||
break
|
||||
|
||||
if audio_path is None:
|
||||
print("Error: No audio file specified and no default audio found.")
|
||||
print("Usage: python test_api.py <audio_path>")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
print(f"Error: Audio file not found: {audio_path}")
|
||||
sys.exit(1)
|
||||
|
||||
test_transcription(audio_path, args.url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,638 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery.
|
||||
|
||||
This script tests ASR transcription with automatic recovery from repetition loops.
|
||||
Supports optional hotwords to improve recognition of domain-specific terms.
|
||||
|
||||
Features:
|
||||
- Streaming output with real-time repetition detection
|
||||
- Auto-recovery when model enters repetition loops
|
||||
- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}")
|
||||
- Video file support (auto-extracts audio)
|
||||
|
||||
Recovery Strategy:
|
||||
1. First attempt: greedy decoding (temperature=0, top_p=1.0)
|
||||
2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95
|
||||
3. Max 3 retries, truncate to last complete segment boundary
|
||||
|
||||
Usage:
|
||||
python test_api_auto_recover.py <audio_path> [output_path] [--url URL] [--hotwords "word1,word2"] [--debug]
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
python3 test_api_auto_recover.py audio.wav
|
||||
|
||||
# With hotwords
|
||||
python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice"
|
||||
|
||||
# Save result to file
|
||||
python3 test_api_auto_recover.py audio.wav result.txt
|
||||
|
||||
# Debug mode (show recovery info)
|
||||
python3 test_api_auto_recover.py audio.wav --debug
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import argparse
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def _guess_mime_type(path: str) -> str:
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext == ".wav":
|
||||
return "audio/wav"
|
||||
if ext in (".mp3",):
|
||||
return "audio/mpeg"
|
||||
if ext in (".m4a",):
|
||||
return "audio/mp4"
|
||||
if ext in (".mp4", ".m4v", ".mov", ".webm"):
|
||||
return "video/mp4"
|
||||
if ext in (".flac",):
|
||||
return "audio/flac"
|
||||
if ext in (".ogg", ".opus"):
|
||||
return "audio/ogg"
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def _get_duration_seconds_ffprobe(path: str) -> float:
|
||||
cmd = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
path,
|
||||
]
|
||||
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
|
||||
return float(out)
|
||||
|
||||
|
||||
def _extract_audio_from_video(video_path: str) -> str:
|
||||
"""
|
||||
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
|
||||
Returns the path to the extracted audio file.
|
||||
"""
|
||||
import tempfile
|
||||
# Create temp file with .mp3 extension
|
||||
fd, audio_path = tempfile.mkstemp(suffix=".mp3")
|
||||
os.close(fd)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", video_path,
|
||||
"-vn", # No video
|
||||
"-acodec", "libmp3lame",
|
||||
"-q:a", "2", # High quality
|
||||
audio_path
|
||||
]
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
return audio_path
|
||||
|
||||
|
||||
def _is_video_file(path: str) -> bool:
|
||||
"""Check if the file is a video file that needs audio extraction."""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
||||
|
||||
|
||||
def _find_last_segment_boundary(text: str) -> int:
|
||||
"""
|
||||
Find the position after the last complete segment boundary (},).
|
||||
Returns -1 if no complete segment found.
|
||||
"""
|
||||
# Find last "}, " or "}," pattern (segment separator)
|
||||
pos = text.rfind("},")
|
||||
if pos != -1:
|
||||
return pos + 2 # Include the },
|
||||
return -1
|
||||
|
||||
|
||||
def _find_safe_print_boundary(text: str, max_pos: int) -> int:
|
||||
"""
|
||||
Find the last complete segment boundary before max_pos.
|
||||
Returns 0 if no complete segment found before max_pos.
|
||||
"""
|
||||
search_text = text[:max_pos]
|
||||
pos = search_text.rfind("},")
|
||||
if pos != -1:
|
||||
return pos + 2 # Include the },
|
||||
return 0
|
||||
|
||||
|
||||
class RepetitionDetector:
|
||||
"""Detect repetition patterns in streaming text."""
|
||||
|
||||
def __init__(self,
|
||||
min_pattern_len: int = 10, # Minimum chars for a pattern
|
||||
min_repeats: int = 3, # Minimum repetitions to trigger
|
||||
window_size: int = 500): # Window to check for patterns
|
||||
self.min_pattern_len = min_pattern_len
|
||||
self.min_repeats = min_repeats
|
||||
self.window_size = window_size
|
||||
self.text = ""
|
||||
|
||||
def add_text(self, new_text: str):
|
||||
"""Add new text and return (is_looping, good_text_end_pos)."""
|
||||
self.text += new_text
|
||||
return self._check_repetition()
|
||||
|
||||
def _check_repetition(self):
|
||||
"""Check if the recent text contains repetition loops."""
|
||||
if len(self.text) < self.min_pattern_len * self.min_repeats:
|
||||
return False, len(self.text)
|
||||
|
||||
# Check the recent window
|
||||
window = self.text[-self.window_size:] if len(self.text) > self.window_size else self.text
|
||||
|
||||
# Method 1: Check for repeated substrings
|
||||
for pattern_len in range(self.min_pattern_len, len(window) // self.min_repeats + 1):
|
||||
# Get the last pattern_len characters as potential pattern
|
||||
pattern = window[-pattern_len:]
|
||||
|
||||
# Count how many times this pattern appears at the end
|
||||
count = 0
|
||||
pos = len(window)
|
||||
while pos >= pattern_len:
|
||||
if window[pos - pattern_len:pos] == pattern:
|
||||
count += 1
|
||||
pos -= pattern_len
|
||||
else:
|
||||
break
|
||||
|
||||
if count >= self.min_repeats:
|
||||
# Found repetition! Calculate where the good text ends
|
||||
repetition_start = len(self.text) - (count * pattern_len)
|
||||
# Keep one instance of the pattern (or none if it's garbage)
|
||||
good_end = repetition_start + pattern_len if self._is_meaningful(pattern) else repetition_start
|
||||
return True, good_end
|
||||
|
||||
# Method 2: Check for repeated short phrases (like "you're not, you're not")
|
||||
# Look for patterns like "X, X, X" or "X X X"
|
||||
words = window.split()
|
||||
if len(words) >= self.min_repeats * 2:
|
||||
# Check last N words for repetition
|
||||
for phrase_len in range(2, 6): # 2-5 word phrases
|
||||
if len(words) < phrase_len * self.min_repeats:
|
||||
continue
|
||||
|
||||
phrase = " ".join(words[-phrase_len:])
|
||||
count = 0
|
||||
idx = len(words)
|
||||
while idx >= phrase_len:
|
||||
candidate = " ".join(words[idx - phrase_len:idx])
|
||||
if candidate == phrase:
|
||||
count += 1
|
||||
idx -= phrase_len
|
||||
else:
|
||||
break
|
||||
|
||||
if count >= self.min_repeats:
|
||||
# Calculate position in original text
|
||||
repeated_text = (phrase + " ") * count
|
||||
good_end = len(self.text) - len(repeated_text.rstrip()) + len(phrase)
|
||||
return True, max(0, good_end)
|
||||
|
||||
return False, len(self.text)
|
||||
|
||||
def _is_meaningful(self, pattern: str) -> bool:
|
||||
"""Check if pattern is meaningful content (not just garbage)."""
|
||||
# Filter out patterns that are just punctuation, spaces, or very repetitive
|
||||
clean = pattern.strip()
|
||||
if not clean:
|
||||
return False
|
||||
if len(set(clean)) < 3: # Too few unique characters
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_good_text(self, end_pos: int) -> str:
|
||||
"""Get text up to the specified position."""
|
||||
return self.text[:end_pos]
|
||||
|
||||
def reset(self, keep_text: str = ""):
|
||||
"""Reset detector, optionally keeping some text."""
|
||||
self.text = keep_text
|
||||
|
||||
|
||||
def stream_with_recovery(
|
||||
url: str,
|
||||
base_messages: list,
|
||||
audio_data_url: str,
|
||||
prompt_text: str,
|
||||
max_tokens: int = 32768,
|
||||
max_retries: int = 3,
|
||||
timeout: int = 12000,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
Stream transcription with automatic recovery from repetition loops.
|
||||
|
||||
Args:
|
||||
url: API endpoint
|
||||
base_messages: Base messages (system + user with audio)
|
||||
audio_data_url: The audio data URL for the request
|
||||
prompt_text: The text prompt
|
||||
max_tokens: Maximum tokens to generate
|
||||
max_retries: Maximum recovery attempts (default 3)
|
||||
timeout: Request timeout
|
||||
debug: If True, show recovery debug info to stderr
|
||||
|
||||
Recovery strategy:
|
||||
- First attempt: temperature=0.0, top_p=1.0 (greedy)
|
||||
- Recovery: temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
|
||||
- If has complete segments: use assistant prefix
|
||||
- If no complete segments: restart from scratch
|
||||
- Max 3 retries, if all fail output error message
|
||||
|
||||
Returns:
|
||||
Final transcription text
|
||||
"""
|
||||
import sys as _sys
|
||||
|
||||
def _log(msg):
|
||||
"""Log to stderr only if debug."""
|
||||
if debug:
|
||||
print(msg, file=_sys.stderr)
|
||||
|
||||
detector = RepetitionDetector(
|
||||
min_pattern_len=10, # At least 10 chars for a pattern
|
||||
min_repeats=10, # Must repeat 10+ times
|
||||
window_size=400, # Check last 400 chars (can detect 10-40 char patterns repeated 10 times)
|
||||
)
|
||||
|
||||
accumulated_text = ""
|
||||
retry_count = 0
|
||||
user_safe_printed_len = 0 # Track how much we've safely shown to user (at segment boundaries)
|
||||
is_recovery = False # Whether we're in recovery mode
|
||||
|
||||
while retry_count <= max_retries:
|
||||
# Build request payload
|
||||
messages = list(base_messages) # Copy base messages
|
||||
|
||||
# If we have accumulated text from previous attempt, add it as partial assistant response
|
||||
if accumulated_text:
|
||||
# Add the good content as a partial assistant message
|
||||
# vLLM will continue from here
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": accumulated_text
|
||||
})
|
||||
|
||||
# Set sampling parameters based on recovery state
|
||||
if is_recovery:
|
||||
# Recovery: increase temperature each retry to break loops
|
||||
recovery_temp = 0.1 + 0.1 * retry_count # 0.2, 0.3, 0.4 for retry 1, 2, 3
|
||||
payload = {
|
||||
"model": "vibevoice",
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": recovery_temp,
|
||||
"top_p": 0.95,
|
||||
"stream": True,
|
||||
}
|
||||
if accumulated_text:
|
||||
_log(f"[RECOVERY #{retry_count}] Continuing from {len(accumulated_text)} chars with temp={recovery_temp}, top_p=0.95")
|
||||
else:
|
||||
_log(f"[RECOVERY #{retry_count}] Restarting from scratch with temp={recovery_temp}, top_p=0.95")
|
||||
else:
|
||||
# First attempt: greedy decoding
|
||||
payload = {
|
||||
"model": "vibevoice",
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, stream=True, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
_log(f"[ERROR] {response.status_code} - {response.text[:500]}")
|
||||
return accumulated_text
|
||||
|
||||
new_text = ""
|
||||
printed = "" # Track what we've already received to handle vLLM duplicates
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
decoded_line = line.decode('utf-8')
|
||||
if not decoded_line.startswith("data: "):
|
||||
continue
|
||||
|
||||
json_str = decoded_line[6:]
|
||||
if json_str.strip() == "[DONE]":
|
||||
# Successfully finished without loops
|
||||
full_result = accumulated_text + new_text
|
||||
# Print any remaining content that wasn't printed yet
|
||||
if len(full_result) > user_safe_printed_len:
|
||||
remaining = full_result[user_safe_printed_len:]
|
||||
print(remaining, end='', flush=True)
|
||||
print() # Final newline
|
||||
return full_result
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
delta = data['choices'][0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
# vLLM/OpenAI-compatible streams may emit either
|
||||
# incremental deltas OR the full accumulated text.
|
||||
# Only track the newly-added part.
|
||||
if content.startswith(printed):
|
||||
to_add = content[len(printed):]
|
||||
else:
|
||||
to_add = content
|
||||
|
||||
if to_add:
|
||||
printed += to_add
|
||||
new_text += to_add
|
||||
|
||||
# When continuing from prefix, model may add "[" or "[{" at start
|
||||
# or repeat the ending "}, " from prefix
|
||||
# We need to handle these to maintain valid JSON array format
|
||||
if accumulated_text and new_text:
|
||||
stripped = new_text.lstrip()
|
||||
# Case 1: Model added "[{" - remove the "["
|
||||
if stripped.startswith("[{"):
|
||||
new_text = stripped[1:]
|
||||
_log("[STRIPPED leading '[' from continuation]")
|
||||
# Case 2: Model added just "[" - remove it
|
||||
elif stripped.startswith("["):
|
||||
new_text = stripped[1:]
|
||||
_log("[STRIPPED leading '[' from continuation]")
|
||||
# Case 3: Model repeated "}," from prefix ending
|
||||
elif stripped.startswith("},"):
|
||||
new_text = stripped[2:]
|
||||
_log("[STRIPPED leading '},' from continuation]")
|
||||
# Case 4: Model repeated "}" from prefix ending
|
||||
elif stripped.startswith("}") and not stripped.startswith("}]"):
|
||||
new_text = stripped[1:]
|
||||
_log("[STRIPPED leading '}' from continuation]")
|
||||
|
||||
# Fix malformed JSON: {"2.99,... -> {"Start":2.99,...
|
||||
# This happens when model skips "Start": key
|
||||
import re
|
||||
malformed = re.match(r'^\{"(\d+\.?\d*),', new_text)
|
||||
if malformed:
|
||||
time_val = malformed.group(1)
|
||||
new_text = '{"Start":' + time_val + ',' + new_text[malformed.end():]
|
||||
_log(f"[FIXED malformed JSON: added Start key]")
|
||||
|
||||
# Check for repetition in the combined text
|
||||
full_text = accumulated_text + new_text
|
||||
detector.text = full_text
|
||||
is_looping, good_end = detector._check_repetition()
|
||||
|
||||
if is_looping:
|
||||
_log(f"[LOOP DETECTED at char {good_end}]")
|
||||
|
||||
# Use what user has already seen as prefix for retry
|
||||
# user_safe_printed_len is always at a segment boundary
|
||||
if user_safe_printed_len > 0:
|
||||
accumulated_text = full_text[:user_safe_printed_len]
|
||||
_log(f"[RETRY from user-visible content at {user_safe_printed_len}]")
|
||||
else:
|
||||
# No complete segment shown to user yet - restart from scratch
|
||||
accumulated_text = ""
|
||||
_log(f"[NO CONTENT SHOWN TO USER - restart from scratch]")
|
||||
|
||||
detector.reset(accumulated_text)
|
||||
is_recovery = True
|
||||
|
||||
if debug:
|
||||
print("\n[...recovering...]", end='', flush=True, file=sys.stderr)
|
||||
|
||||
retry_count += 1
|
||||
|
||||
if retry_count > max_retries:
|
||||
_log(f"[MAX RETRIES REACHED]")
|
||||
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
|
||||
return None
|
||||
|
||||
# Break inner loop to retry
|
||||
break
|
||||
else:
|
||||
# No loop detected - stream content to user
|
||||
# Only print up to (full_text_len - window_size) at segment boundaries
|
||||
# This ensures user never sees content that might be rolled back
|
||||
safe_end = max(0, len(full_text) - detector.window_size)
|
||||
safe_boundary = _find_safe_print_boundary(full_text, safe_end)
|
||||
|
||||
if safe_boundary > user_safe_printed_len:
|
||||
# Print new safe content
|
||||
to_print = full_text[user_safe_printed_len:safe_boundary]
|
||||
print(to_print, end='', flush=True)
|
||||
user_safe_printed_len = safe_boundary
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
else:
|
||||
# Loop completed without break (no repetition detected)
|
||||
full_result = accumulated_text + new_text
|
||||
|
||||
# Print any remaining content that wasn't printed yet
|
||||
if len(full_result) > user_safe_printed_len:
|
||||
remaining = full_result[user_safe_printed_len:]
|
||||
print(remaining, end='', flush=True)
|
||||
|
||||
print() # Final newline
|
||||
return full_result
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
_log("[TIMEOUT]")
|
||||
print()
|
||||
return accumulated_text
|
||||
except Exception as e:
|
||||
_log(f"[ERROR: {e}]")
|
||||
print()
|
||||
return accumulated_text
|
||||
|
||||
# All retries exhausted
|
||||
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
def test_transcription_with_recovery(
|
||||
audio_path: str,
|
||||
output_path: str = None,
|
||||
base_url: str = "http://localhost:8000",
|
||||
hotwords: str = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
Test ASR transcription with auto-recovery from repetition loops.
|
||||
|
||||
Args:
|
||||
audio_path: Path to the audio file
|
||||
output_path: Optional path to save transcription result
|
||||
base_url: vLLM server URL
|
||||
hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
|
||||
debug: Show recovery debug info
|
||||
"""
|
||||
|
||||
print(f"=" * 70)
|
||||
print(f"Testing with Auto-Recovery")
|
||||
print(f"=" * 70)
|
||||
print(f"Input file: {audio_path}")
|
||||
print(f"Hotwords: {hotwords or '(none)'}")
|
||||
print()
|
||||
|
||||
# Handle video files: extract audio first
|
||||
temp_audio_path = None
|
||||
actual_audio_path = audio_path
|
||||
if _is_video_file(audio_path):
|
||||
print(f"🎬 Detected video file, extracting audio...")
|
||||
temp_audio_path = _extract_audio_from_video(audio_path)
|
||||
actual_audio_path = temp_audio_path
|
||||
print(f"✅ Audio extracted to: {temp_audio_path}")
|
||||
|
||||
# Load audio
|
||||
try:
|
||||
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
||||
print(f"Audio duration: {duration:.2f} seconds")
|
||||
|
||||
with open(actual_audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
print(f"Audio size: {len(audio_bytes)} bytes")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error preparing audio: {e}")
|
||||
# Cleanup temp file if created
|
||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
return
|
||||
|
||||
url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
||||
|
||||
# Build prompt with optional hotwords
|
||||
if hotwords and hotwords.strip():
|
||||
prompt_text = (
|
||||
f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n"
|
||||
f"Please transcribe it with these keys: " + ", ".join(show_keys)
|
||||
)
|
||||
print(f"\n📝 Hotwords embedded in prompt: '{hotwords}'")
|
||||
else:
|
||||
prompt_text = (
|
||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||
+ ", ".join(show_keys)
|
||||
)
|
||||
print(f"\n📝 No hotwords provided")
|
||||
|
||||
mime = _guess_mime_type(actual_audio_path)
|
||||
data_url = f"data:{mime};base64,{audio_b64}"
|
||||
|
||||
# Base messages (without assistant continuation)
|
||||
base_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt_text}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"Sending request to {url}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
t0 = time.time()
|
||||
print("\n✅ Response received. Streaming content:\n")
|
||||
print("-" * 50)
|
||||
|
||||
result = stream_with_recovery(
|
||||
url=url,
|
||||
base_messages=base_messages,
|
||||
audio_data_url=data_url,
|
||||
prompt_text=prompt_text,
|
||||
max_tokens=32768,
|
||||
max_retries=3,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
print("-" * 50)
|
||||
print("✅ [Finished]")
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
if result is None:
|
||||
print("❌ Transcription failed")
|
||||
return
|
||||
|
||||
print(f"📄 Final output length: {len(result)} chars")
|
||||
|
||||
# Optionally save result
|
||||
if output_path:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(result)
|
||||
print(f"💾 Result saved to: {output_path}")
|
||||
|
||||
# Cleanup temp audio file if created
|
||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test VibeVoice vLLM API with auto-recovery from repetition loops"
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
help="Path to audio file (wav, mp3, flac, etc.) or video file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Optional path to save transcription result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="http://localhost:8000",
|
||||
help="vLLM server URL (default: http://localhost:8000)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hotwords",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Show recovery debug info"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run test
|
||||
test_transcription_with_recovery(
|
||||
audio_path=args.audio_path,
|
||||
output_path=args.output_path,
|
||||
base_url=args.url,
|
||||
hotwords=args.hotwords,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user