Merge pull request #233 from Damon-Salvetore/main

Add hot words support
This commit is contained in:
Jianwei Yu
2026-02-07 12:32:03 +08:00
committed by GitHub
7 changed files with 291 additions and 323 deletions
+13 -3
View File
@@ -52,15 +52,25 @@ docker logs -f vibevoice-vllm
Once the vLLM server is running, test it with the provided script: Once the vLLM server is running, test it with the provided script:
```bash ```bash
# Run the test (use container path /app/...) # Basic transcription
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
# With hotwords for better recognition of specific terms
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
``` ```
```bash ```bash
# Run the recover_test (use container path /app/...) # With auto-recovery from repetition loops (for long audio)
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
# Auto-recover with hotwords
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
``` ```
> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing.
> **Note**:
> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing.
> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names.
### Environment Variables ### Environment Variables
-1
View File
@@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
from .model import VibeVoiceForCausalLM from .model import VibeVoiceForCausalLM
from .inputs import vibevoice_audio_input_mapper
def register_vibevoice(): def register_vibevoice():
+13 -160
View File
@@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
integration for speech-to-text inference. integration for speech-to-text inference.
""" """
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
import json
import math
import os import os
import sys
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from io import BytesIO
import tempfile
import base64 import base64
@@ -29,32 +23,12 @@ import base64
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
def _suffix_from_media_type(media_type: str | None) -> str: def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
if not media_type: """Load audio bytes using FFmpeg via stdin-pipe decoding.
return ".bin"
mt = media_type.lower().strip()
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
return ".wav"
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
return ".mp3"
if mt in ("audio/flac",):
return ".flac"
if mt in ("audio/ogg", "audio/opus"):
return ".ogg"
if mt in ("audio/mp4", "audio/m4a"):
return ".m4a"
if mt in ("video/mp4",):
return ".mp4"
return ".bin"
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg.
Returns: Returns:
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000. Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
""" """
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000) audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
normalizer = AudioNormalizer() normalizer = AudioNormalizer()
audio = normalizer(audio) audio = normalizer(audio)
@@ -79,10 +53,10 @@ class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
"""AudioMediaIO implementation using FFmpeg for audio decoding.""" """AudioMediaIO implementation using FFmpeg for audio decoding."""
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]: def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data, media_type=None) return _ffmpeg_load_bytes(data)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]: def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type) return _ffmpeg_load_bytes(base64.b64decode(data))
def load_file(self, filepath) -> tuple[np.ndarray, int]: def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath) return _ffmpeg_load_file(filepath)
@@ -96,17 +70,13 @@ _vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
# ============================================================================ # ============================================================================
from transformers import Qwen2Config, BatchFeature from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.config import VllmConfig, ModelConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataParser
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
from vllm.inputs import PromptType
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
@@ -558,7 +528,7 @@ class VibeVoiceProcessingInfo(BaseProcessingInfo):
return tokens return tokens
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None} return {"audio": 1}
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]): class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
@@ -873,17 +843,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with a causal language model for text generation. with a causal language model for text generation.
""" """
# SupportsTranscription interface
supports_transcription: ClassVar[Literal[True]] = True
supports_transcription_only: ClassVar[bool] = False
supports_segment_timestamp: ClassVar[bool] = False
# Supported languages (Chinese as primary target)
supported_languages: ClassVar[Mapping[str, str]] = {
"zh": "Chinese",
"en": "English",
}
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""Return the placeholder string format for a given modality. """Return the placeholder string format for a given modality.
@@ -897,113 +856,11 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return "<|AUDIO|>" return "<|AUDIO|>"
raise ValueError(f"Unsupported modality: {modality}") raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the prompt for the ASR model.
Generates a chat-formatted prompt for speech-to-text transcription
with JSON output format.
"""
# If user provides custom prompt, use it
if request_prompt:
return request_prompt
# Calculate audio duration for the prompt
# Audio should be at 24kHz, so duration = len(audio) / 24000
duration = len(audio) / 24000 if audio is not None else 10.0
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
user_suffix = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
user_content = "<|AUDIO|>\n" + user_suffix
prompt = (
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{user_content}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return prompt
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model."""
return SpeechToTextConfig(
language=None, # Auto-detect or use request language
task_type=task_type,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
"""Estimate number of audio tokens from duration.
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
Note: _get_prompt_updates actually generates:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
So total prompt tokens = N + 3, but this returns N (the embedding count).
"""
sampling_rate = 24000
compress_ratio = 3200
samples = int(audio_duration_s * sampling_rate)
num_tokens = int(np.ceil(samples / compress_ratio))
return num_tokens
@classmethod
def get_other_languages(cls) -> Mapping[str, str]:
"""Get languages from Whisper map not natively supported."""
# Import LANGUAGES from vllm
try:
from vllm.transformers_utils.tokenizer import LANGUAGES
except ImportError:
# Fallback to empty dict if import fails
return {}
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
@classmethod
def validate_language(cls, language: str | None) -> str | None:
"""Validate the language code."""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
print(f"Warning: Language {language!r} is not natively supported")
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. "
f"Supported: {list(cls.supported_languages.keys())}"
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
self.config = config self.config = config
# Keep a copy of the resolved model path for any custom weight-loading
# logic (e.g., loading audio encoder weights in fp32 directly from
# safetensors shards).
self._model_path = vllm_config.model_config.model
self.audio_encoder = VibeVoiceAudioEncoder(config) self.audio_encoder = VibeVoiceAudioEncoder(config)
# Pass decoder_config to the language model initialization # Pass decoder_config to the language model initialization
@@ -1100,14 +957,14 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Process each audio through the VibeVoice encoder # Process each audio through the VibeVoice encoder
embeddings = [] embeddings = []
# Get model device and dtype for alignment # Get model device for tensor placement.
# dtype is NOT set here — audio_encoder.forward() handles it internally:
# input: converted to fp32 (self._audio_encoder_dtype)
# output: converted to bfloat16 (self._lm_dtype)
try: try:
device = next(self.audio_encoder.parameters()).device device = next(self.audio_encoder.parameters()).device
dtype = next(self.audio_encoder.parameters()).dtype
except StopIteration: except StopIteration:
# Fallback if no parameters (shouldn't happen)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Handle both stacked tensor and list of tensors # Handle both stacked tensor and list of tensors
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
@@ -1138,7 +995,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(audio_tensor, torch.Tensor): if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor) audio_tensor = torch.tensor(audio_tensor)
# Let vLLM handle dtype (bfloat16 by default) # Only place on correct device; audio_encoder.forward() handles dtype
audio_tensor = audio_tensor.to(device=device) audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length # Get actual length if available, otherwise use full length
@@ -1294,7 +1151,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Intermediate tensors for pipeline parallelism. intermediate_tensors: Intermediate tensors for pipeline parallelism.
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode). inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
""" """
try:
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode) # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available # Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None: if inputs_embeds is None and input_ids is not None:
@@ -1321,9 +1177,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
) )
return hidden_states return hidden_states
except Exception as e:
raise
# Alias for training checkpoint compatibility # Alias for training checkpoint compatibility
VibeVoiceForASRTraining = VibeVoiceForCausalLM VibeVoiceForASRTraining = VibeVoiceForCausalLM
Binary file not shown.
+101 -85
View File
@@ -1,14 +1,23 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Test VibeVoice vLLM API with Streaming (Real-time output). Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.
This script tests ASR transcription via the vLLM OpenAI-compatible API.
By default, it runs standard transcription without hotwords.
Optionally, you can provide hotwords (context_info) to improve recognition
of domain-specific content like proper nouns, technical terms, and speaker names.
Hotwords are embedded in the prompt as "with extra info: {hotwords}".
Usage: Usage:
python test_api.py [audio_path] [--url URL] python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
Examples: Examples:
python test_api.py # Use default audio # Standard transcription (no hotwords)
python test_api.py /path/to/audio.wav # Specify audio file python3 test_api.py audio.wav
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
# With hotwords for better recognition of specific terms
python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
""" """
import requests import requests
import json import json
@@ -21,38 +30,38 @@ import argparse
def _guess_mime_type(path: str) -> str: def _guess_mime_type(path: str) -> str:
"""Guess MIME type from file extension."""
ext = os.path.splitext(path)[1].lower() ext = os.path.splitext(path)[1].lower()
if ext == ".wav": mime_map = {
return "audio/wav" ".wav": "audio/wav",
if ext in (".mp3",): ".mp3": "audio/mpeg",
return "audio/mpeg" ".m4a": "audio/mp4",
if ext in (".m4a",): ".mp4": "video/mp4",
return "audio/mp4" ".flac": "audio/flac",
if ext in (".mp4", ".m4v", ".mov", ".webm"): ".ogg": "audio/ogg",
return "video/mp4" ".opus": "audio/ogg",
if ext in (".flac",): }
return "audio/flac" return mime_map.get(ext, "application/octet-stream")
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
def _get_duration_seconds_ffprobe(path: str) -> float: def _get_duration_seconds_ffprobe(path: str) -> float:
"""Get audio duration using ffprobe.""" """Get audio duration using ffprobe."""
cmd = [ cmd = [
"ffprobe", "ffprobe", "-v", "error",
"-v", "-show_entries", "format=duration",
"error", "-of", "default=noprint_wrappers=1:nokey=1",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
path, path,
] ]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out) return float(out)
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _extract_audio_from_video(video_path: str) -> str: def _extract_audio_from_video(video_path: str) -> str:
""" """
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
@@ -74,26 +83,40 @@ def _extract_audio_from_video(video_path: str) -> str:
return audio_path return audio_path
def _is_video_file(path: str) -> bool: def test_transcription_with_hotwords(
"""Check if the file is a video file that needs audio extraction.""" audio_path: str,
ext = os.path.splitext(path)[1].lower() context_info: str = None,
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") base_url: str = "http://localhost:8000",
):
"""
Test ASR transcription with customized hotwords.
Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
This helps the model recognize domain-specific terms more accurately.
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"): Args:
"""Test ASR transcription with streaming output.""" audio_path: Path to the audio file
context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
base_url: vLLM server URL
"""
print(f"Loading audio from: {audio_path}") print(f"=" * 70)
print(f"Testing Customized Hotwords Support")
print(f"=" * 70)
print(f"Input file: {audio_path}")
print(f"Hotwords: {context_info or '(none)'}")
print()
# Handle video files: extract audio first # Handle video files: extract audio first
temp_audio_path = None temp_audio_path = None
actual_audio_path = audio_path actual_audio_path = audio_path
if _is_video_file(audio_path): if _is_video_file(audio_path):
print(f"Detected video file, extracting audio...") print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path) temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}") print(f"Audio extracted to: {temp_audio_path}")
# Load audio
try: try:
duration = _get_duration_seconds_ffprobe(actual_audio_path) duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds") print(f"Audio duration: {duration:.2f} seconds")
@@ -106,16 +129,30 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
except Exception as e: except Exception as e:
print(f"Error preparing audio: {e}") print(f"Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return return
# Build the request # Build the request
url = f"{base_url}/v1/chat/completions" url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"] show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
# Hotwords are embedded as "with extra info: {hotwords}" in the prompt
if context_info and context_info.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{context_info}'")
else:
prompt_text = ( prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys) + ", ".join(show_keys)
) )
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path) mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}" data_url = f"data:{mime};base64,{audio_b64}"
@@ -139,20 +176,19 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
"temperature": 0.0, "temperature": 0.0,
"stream": True, "stream": True,
"top_p": 1.0, "top_p": 1.0,
"repetition_penalty": 1.0,
} }
print(f"\nSending request to {url} (Streaming Mode)...") print(f"\n{'=' * 70}")
print(f"Prompt: {prompt_text}") print(f"Sending request to {url}")
print("-" * 60) print(f"{'=' * 70}")
t0 = time.time() t0 = time.time()
try: try:
response = requests.post(url, json=payload, stream=True, timeout=12000) response = requests.post(url, json=payload, stream=True, timeout=12000)
if response.status_code == 200: if response.status_code == 200:
print("Response received. Streaming content:\n") print("\nResponse received. Streaming content:\n")
print("-" * 50)
printed = "" printed = ""
for line in response.iter_lines(): for line in response.iter_lines():
@@ -162,92 +198,72 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
if decoded_line.startswith("data: "): if decoded_line.startswith("data: "):
json_str = decoded_line[6:] json_str = decoded_line[6:]
if json_str.strip() == "[DONE]": if json_str.strip() == "[DONE]":
print("\n\n[Finished]") print("\n" + "-" * 50)
print("✅ [Finished]")
break break
try: try:
data = json.loads(json_str) data = json.loads(json_str)
delta = data['choices'][0]['delta'] delta = data['choices'][0]['delta']
content = delta.get('content', '') content = delta.get('content', '')
if content: if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only print the newly-added part to avoid repeats.
if content.startswith(printed): if content.startswith(printed):
to_print = content[len(printed):] to_print = content[len(printed):]
else: else:
to_print = content to_print = content
if to_print: if to_print:
print(to_print, end='', flush=True) print(to_print, end='', flush=True)
printed += to_print printed += to_print
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
else: else:
print(f"Error: {response.status_code}") print(f"Error: {response.status_code}")
print(response.text) print(response.text)
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
print("\nRequest timed out!") print("Request timed out!")
except Exception as e: except Exception as e:
print(f"\nError: {e}") print(f"Error: {e}")
print(f"\n{'-'*60}") elapsed = time.time() - t0
print(f"Total time elapsed: {time.time() - t0:.2f}s") print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"📊 RTF (Real-Time Factor): {elapsed / duration:.2f}x")
print(f"{'=' * 70}")
# Cleanup temp audio file if created # Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path): if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path) os.remove(temp_audio_path)
print(f"Cleaned up temp file: {temp_audio_path}") print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with streaming output" description="Test VibeVoice vLLM API with Customized Hotwords"
) )
parser.add_argument( parser.add_argument(
"audio_path", "audio_path",
nargs="?",
default=None,
help="Path to audio file (wav, mp3, flac, etc.) or video file" help="Path to audio file (wav, mp3, flac, etc.) or video file"
) )
parser.add_argument( parser.add_argument(
"--url", "--url",
default="http://localhost:8000", default="http://localhost:8000",
help="vLLM server base URL (default: http://localhost:8000)" help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
) )
args = parser.parse_args() args = parser.parse_args()
# Find default audio if not specified # Run test
audio_path = args.audio_path test_transcription_with_hotwords(
if audio_path is None: audio_path=args.audio_path,
# Try to find a sample audio in common locations context_info=args.hotwords,
possible_paths = [ base_url=args.url,
# In VibeVoice demo folder )
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
# Relative to current directory
"demo/voices/en-Carter_man.wav",
"demo/voices/zh-Anchen_man_bgm.wav",
]
for path in possible_paths:
if os.path.exists(path):
audio_path = path
break
if audio_path is None:
print("Error: No audio file specified and no default audio found.")
print("Usage: python test_api.py <audio_path>")
sys.exit(1)
if not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}")
sys.exit(1)
test_transcription(audio_path, args.url)
if __name__ == "__main__": if __name__ == "__main__":
+129 -39
View File
@@ -1,18 +1,36 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
VibeVoice vLLM API with Auto-Recovery from Repetition Loops. Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery.
Strategy: This script tests ASR transcription with automatic recovery from repetition loops.
1. Start with greedy decoding (temperature=0, top_p=1.0) Supports optional hotwords to improve recognition of domain-specific terms.
2. Stream and detect repetition patterns in real-time
3. Only output content up to (current_length - window_size) at segment boundaries
4. When loop detected:
- Truncate to last complete segment boundary (},)
- Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
5. Max 3 retries, if all fail output error message
User sees: clean streaming transcription output (only complete segments) Features:
Internal: automatic recovery from repetition loops (silent) - Streaming output with real-time repetition detection
- Auto-recovery when model enters repetition loops
- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}")
- Video file support (auto-extracts audio)
Recovery Strategy:
1. First attempt: greedy decoding (temperature=0, top_p=1.0)
2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95
3. Max 3 retries, truncate to last complete segment boundary
Usage:
python test_api_auto_recover.py <audio_path> [output_path] [--url URL] [--hotwords "word1,word2"] [--debug]
Examples:
# Basic usage
python3 test_api_auto_recover.py audio.wav
# With hotwords
python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice"
# Save result to file
python3 test_api_auto_recover.py audio.wav result.txt
# Debug mode (show recovery info)
python3 test_api_auto_recover.py audio.wav --debug
""" """
import requests import requests
import json import json
@@ -22,6 +40,7 @@ import sys
import os import os
import subprocess import subprocess
import re import re
import argparse
from collections import Counter from collections import Counter
@@ -441,30 +460,41 @@ def stream_with_recovery(
return None return None
def test_transcription_with_recovery(): def test_transcription_with_recovery(
"""Main test function with auto-recovery.""" audio_path: str,
output_path: str = None,
base_url: str = "http://localhost:8000",
hotwords: str = None,
debug: bool = False,
):
"""
Test ASR transcription with auto-recovery from repetition loops.
# Parse arguments Args:
debug = "--debug" in sys.argv or "-debug" in sys.argv audio_path: Path to the audio file
args = [a for a in sys.argv[1:] if not a.startswith("-")] output_path: Optional path to save transcription result
base_url: vLLM server URL
hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
debug: Show recovery debug info
"""
audio_path = ( print(f"=" * 70)
args[0] print(f"Testing with Auto-Recovery")
) print(f"=" * 70)
print(f"Input file: {audio_path}")
output_path = args[1] if len(args) > 1 else None print(f"Hotwords: {hotwords or '(none)'}")
print()
print(f"Loading audio from: {audio_path}")
# Handle video files: extract audio first # Handle video files: extract audio first
temp_audio_path = None temp_audio_path = None
actual_audio_path = audio_path actual_audio_path = audio_path
if _is_video_file(audio_path): if _is_video_file(audio_path):
print(f"Detected video file, extracting audio...") print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path) temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}") print(f"Audio extracted to: {temp_audio_path}")
# Load audio
try: try:
duration = _get_duration_seconds_ffprobe(actual_audio_path) duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds") print(f"Audio duration: {duration:.2f} seconds")
@@ -476,16 +506,29 @@ def test_transcription_with_recovery():
print(f"Audio size: {len(audio_bytes)} bytes") print(f"Audio size: {len(audio_bytes)} bytes")
except Exception as e: except Exception as e:
print(f"Error preparing audio: {e}") print(f"Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return return
url = "http://localhost:8000/v1/chat/completions" url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"] show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
if hotwords and hotwords.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{hotwords}'")
else:
prompt_text = ( prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys) + ", ".join(show_keys)
) )
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path) mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}" data_url = f"data:{mime};base64,{audio_b64}"
@@ -505,12 +548,13 @@ def test_transcription_with_recovery():
} }
] ]
print(f"\nSending request to {url} (Streaming Mode)...") print(f"\n{'=' * 70}")
print(f"Prompt: {prompt_text}") print(f"Sending request to {url}")
print("-" * 60) print(f"{'=' * 70}")
print("Response received. Streaming content:\n")
t0 = time.time() t0 = time.time()
print("\n✅ Response received. Streaming content:\n")
print("-" * 50)
result = stream_with_recovery( result = stream_with_recovery(
url=url, url=url,
@@ -522,27 +566,73 @@ def test_transcription_with_recovery():
debug=debug, debug=debug,
) )
print("\n[Finished]") elapsed = time.time() - t0
print("-" * 60) print("-" * 50)
print(f"Total time elapsed: {time.time() - t0:.2f}s") print("✅ [Finished]")
print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"{'=' * 70}")
if result is None: if result is None:
print("Transcription failed") print("Transcription failed")
return return
print(f"Final output length: {len(result)} chars") print(f"📄 Final output length: {len(result)} chars")
# Optionally save result # Optionally save result
if output_path: if output_path:
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
f.write(result) f.write(result)
print(f"Result saved to: {output_path}") print(f"💾 Result saved to: {output_path}")
# Cleanup temp audio file if created # Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path): if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path) os.remove(temp_audio_path)
print(f"Cleaned up temp file: {temp_audio_path}") print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main():
parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with auto-recovery from repetition loops"
)
parser.add_argument(
"audio_path",
help="Path to audio file (wav, mp3, flac, etc.) or video file"
)
parser.add_argument(
"output_path",
nargs="?",
default=None,
help="Optional path to save transcription result"
)
parser.add_argument(
"--url",
default="http://localhost:8000",
help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
)
parser.add_argument(
"--debug",
action="store_true",
help="Show recovery debug info"
)
args = parser.parse_args()
# Run test
test_transcription_with_recovery(
audio_path=args.audio_path,
output_path=args.output_path,
base_url=args.url,
hotwords=args.hotwords,
debug=args.debug,
)
if __name__ == "__main__": if __name__ == "__main__":
test_transcription_with_recovery() main()
Binary file not shown.