feat: add hotwords support for vLLM ASR

This commit is contained in:
YingboHAO
2026-02-04 10:33:20 +00:00
parent 0aa8cb4c64
commit bb54f78d0e
5 changed files with 253 additions and 137 deletions
+13 -3
View File
@@ -52,15 +52,25 @@ docker logs -f vibevoice-vllm
Once the vLLM server is running, test it with the provided script: Once the vLLM server is running, test it with the provided script:
```bash ```bash
# Run the test (use container path /app/...) # Basic transcription
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
# With hotwords for better recognition of specific terms
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
``` ```
```bash ```bash
# Run the recover_test (use container path /app/...) # With auto-recovery from repetition loops (for long audio)
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
# Auto-recover with hotwords
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
``` ```
> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing.
> **Note**:
> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing.
> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names.
### Environment Variables ### Environment Variables
Binary file not shown.
+101 -85
View File
@@ -1,14 +1,23 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Test VibeVoice vLLM API with Streaming (Real-time output). Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.
This script tests ASR transcription via the vLLM OpenAI-compatible API.
By default, it runs standard transcription without hotwords.
Optionally, you can provide hotwords (context_info) to improve recognition
of domain-specific content like proper nouns, technical terms, and speaker names.
Hotwords are embedded in the prompt as "with extra info: {hotwords}".
Usage: Usage:
python test_api.py [audio_path] [--url URL] python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
Examples: Examples:
python test_api.py # Use default audio # Standard transcription (no hotwords)
python test_api.py /path/to/audio.wav # Specify audio file python3 test_api.py audio.wav
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
# With hotwords for better recognition of specific terms
python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
""" """
import requests import requests
import json import json
@@ -21,38 +30,38 @@ import argparse
def _guess_mime_type(path: str) -> str: def _guess_mime_type(path: str) -> str:
"""Guess MIME type from file extension."""
ext = os.path.splitext(path)[1].lower() ext = os.path.splitext(path)[1].lower()
if ext == ".wav": mime_map = {
return "audio/wav" ".wav": "audio/wav",
if ext in (".mp3",): ".mp3": "audio/mpeg",
return "audio/mpeg" ".m4a": "audio/mp4",
if ext in (".m4a",): ".mp4": "video/mp4",
return "audio/mp4" ".flac": "audio/flac",
if ext in (".mp4", ".m4v", ".mov", ".webm"): ".ogg": "audio/ogg",
return "video/mp4" ".opus": "audio/ogg",
if ext in (".flac",): }
return "audio/flac" return mime_map.get(ext, "application/octet-stream")
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
def _get_duration_seconds_ffprobe(path: str) -> float: def _get_duration_seconds_ffprobe(path: str) -> float:
"""Get audio duration using ffprobe.""" """Get audio duration using ffprobe."""
cmd = [ cmd = [
"ffprobe", "ffprobe", "-v", "error",
"-v", "-show_entries", "format=duration",
"error", "-of", "default=noprint_wrappers=1:nokey=1",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
path, path,
] ]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out) return float(out)
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _extract_audio_from_video(video_path: str) -> str: def _extract_audio_from_video(video_path: str) -> str:
""" """
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
@@ -74,26 +83,40 @@ def _extract_audio_from_video(video_path: str) -> str:
return audio_path return audio_path
def _is_video_file(path: str) -> bool: def test_transcription_with_hotwords(
"""Check if the file is a video file that needs audio extraction.""" audio_path: str,
ext = os.path.splitext(path)[1].lower() context_info: str = None,
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") base_url: str = "http://localhost:8000",
):
"""
Test ASR transcription with customized hotwords.
Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
This helps the model recognize domain-specific terms more accurately.
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"): Args:
"""Test ASR transcription with streaming output.""" audio_path: Path to the audio file
context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
base_url: vLLM server URL
"""
print(f"Loading audio from: {audio_path}") print(f"=" * 70)
print(f"Testing Customized Hotwords Support")
print(f"=" * 70)
print(f"Input file: {audio_path}")
print(f"Hotwords: {context_info or '(none)'}")
print()
# Handle video files: extract audio first # Handle video files: extract audio first
temp_audio_path = None temp_audio_path = None
actual_audio_path = audio_path actual_audio_path = audio_path
if _is_video_file(audio_path): if _is_video_file(audio_path):
print(f"Detected video file, extracting audio...") print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path) temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}") print(f"Audio extracted to: {temp_audio_path}")
# Load audio
try: try:
duration = _get_duration_seconds_ffprobe(actual_audio_path) duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds") print(f"Audio duration: {duration:.2f} seconds")
@@ -106,16 +129,30 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
except Exception as e: except Exception as e:
print(f"Error preparing audio: {e}") print(f"Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return return
# Build the request # Build the request
url = f"{base_url}/v1/chat/completions" url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"] show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
# Hotwords are embedded as "with extra info: {hotwords}" in the prompt
if context_info and context_info.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{context_info}'")
else:
prompt_text = ( prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys) + ", ".join(show_keys)
) )
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path) mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}" data_url = f"data:{mime};base64,{audio_b64}"
@@ -139,20 +176,19 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
"temperature": 0.0, "temperature": 0.0,
"stream": True, "stream": True,
"top_p": 1.0, "top_p": 1.0,
"repetition_penalty": 1.0,
} }
print(f"\nSending request to {url} (Streaming Mode)...") print(f"\n{'=' * 70}")
print(f"Prompt: {prompt_text}") print(f"Sending request to {url}")
print("-" * 60) print(f"{'=' * 70}")
t0 = time.time() t0 = time.time()
try: try:
response = requests.post(url, json=payload, stream=True, timeout=12000) response = requests.post(url, json=payload, stream=True, timeout=12000)
if response.status_code == 200: if response.status_code == 200:
print("Response received. Streaming content:\n") print("\nResponse received. Streaming content:\n")
print("-" * 50)
printed = "" printed = ""
for line in response.iter_lines(): for line in response.iter_lines():
@@ -162,92 +198,72 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
if decoded_line.startswith("data: "): if decoded_line.startswith("data: "):
json_str = decoded_line[6:] json_str = decoded_line[6:]
if json_str.strip() == "[DONE]": if json_str.strip() == "[DONE]":
print("\n\n[Finished]") print("\n" + "-" * 50)
print("✅ [Finished]")
break break
try: try:
data = json.loads(json_str) data = json.loads(json_str)
delta = data['choices'][0]['delta'] delta = data['choices'][0]['delta']
content = delta.get('content', '') content = delta.get('content', '')
if content: if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only print the newly-added part to avoid repeats.
if content.startswith(printed): if content.startswith(printed):
to_print = content[len(printed):] to_print = content[len(printed):]
else: else:
to_print = content to_print = content
if to_print: if to_print:
print(to_print, end='', flush=True) print(to_print, end='', flush=True)
printed += to_print printed += to_print
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
else: else:
print(f"Error: {response.status_code}") print(f"Error: {response.status_code}")
print(response.text) print(response.text)
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
print("\nRequest timed out!") print("Request timed out!")
except Exception as e: except Exception as e:
print(f"\nError: {e}") print(f"Error: {e}")
print(f"\n{'-'*60}") elapsed = time.time() - t0
print(f"Total time elapsed: {time.time() - t0:.2f}s") print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"📊 RTF (Real-Time Factor): {elapsed / duration:.2f}x")
print(f"{'=' * 70}")
# Cleanup temp audio file if created # Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path): if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path) os.remove(temp_audio_path)
print(f"Cleaned up temp file: {temp_audio_path}") print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with streaming output" description="Test VibeVoice vLLM API with Customized Hotwords"
) )
parser.add_argument( parser.add_argument(
"audio_path", "audio_path",
nargs="?",
default=None,
help="Path to audio file (wav, mp3, flac, etc.) or video file" help="Path to audio file (wav, mp3, flac, etc.) or video file"
) )
parser.add_argument( parser.add_argument(
"--url", "--url",
default="http://localhost:8000", default="http://localhost:8000",
help="vLLM server base URL (default: http://localhost:8000)" help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
) )
args = parser.parse_args() args = parser.parse_args()
# Find default audio if not specified # Run test
audio_path = args.audio_path test_transcription_with_hotwords(
if audio_path is None: audio_path=args.audio_path,
# Try to find a sample audio in common locations context_info=args.hotwords,
possible_paths = [ base_url=args.url,
# In VibeVoice demo folder )
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
# Relative to current directory
"demo/voices/en-Carter_man.wav",
"demo/voices/zh-Anchen_man_bgm.wav",
]
for path in possible_paths:
if os.path.exists(path):
audio_path = path
break
if audio_path is None:
print("Error: No audio file specified and no default audio found.")
print("Usage: python test_api.py <audio_path>")
sys.exit(1)
if not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}")
sys.exit(1)
test_transcription(audio_path, args.url)
if __name__ == "__main__": if __name__ == "__main__":
+129 -39
View File
@@ -1,18 +1,36 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
VibeVoice vLLM API with Auto-Recovery from Repetition Loops. Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery.
Strategy: This script tests ASR transcription with automatic recovery from repetition loops.
1. Start with greedy decoding (temperature=0, top_p=1.0) Supports optional hotwords to improve recognition of domain-specific terms.
2. Stream and detect repetition patterns in real-time
3. Only output content up to (current_length - window_size) at segment boundaries
4. When loop detected:
- Truncate to last complete segment boundary (},)
- Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
5. Max 3 retries, if all fail output error message
User sees: clean streaming transcription output (only complete segments) Features:
Internal: automatic recovery from repetition loops (silent) - Streaming output with real-time repetition detection
- Auto-recovery when model enters repetition loops
- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}")
- Video file support (auto-extracts audio)
Recovery Strategy:
1. First attempt: greedy decoding (temperature=0, top_p=1.0)
2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95
3. Max 3 retries, truncate to last complete segment boundary
Usage:
python test_api_auto_recover.py <audio_path> [output_path] [--url URL] [--hotwords "word1,word2"] [--debug]
Examples:
# Basic usage
python3 test_api_auto_recover.py audio.wav
# With hotwords
python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice"
# Save result to file
python3 test_api_auto_recover.py audio.wav result.txt
# Debug mode (show recovery info)
python3 test_api_auto_recover.py audio.wav --debug
""" """
import requests import requests
import json import json
@@ -22,6 +40,7 @@ import sys
import os import os
import subprocess import subprocess
import re import re
import argparse
from collections import Counter from collections import Counter
@@ -441,30 +460,41 @@ def stream_with_recovery(
return None return None
def test_transcription_with_recovery(): def test_transcription_with_recovery(
"""Main test function with auto-recovery.""" audio_path: str,
output_path: str = None,
base_url: str = "http://localhost:8000",
hotwords: str = None,
debug: bool = False,
):
"""
Test ASR transcription with auto-recovery from repetition loops.
# Parse arguments Args:
debug = "--debug" in sys.argv or "-debug" in sys.argv audio_path: Path to the audio file
args = [a for a in sys.argv[1:] if not a.startswith("-")] output_path: Optional path to save transcription result
base_url: vLLM server URL
hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
debug: Show recovery debug info
"""
audio_path = ( print(f"=" * 70)
args[0] print(f"Testing with Auto-Recovery")
) print(f"=" * 70)
print(f"Input file: {audio_path}")
output_path = args[1] if len(args) > 1 else None print(f"Hotwords: {hotwords or '(none)'}")
print()
print(f"Loading audio from: {audio_path}")
# Handle video files: extract audio first # Handle video files: extract audio first
temp_audio_path = None temp_audio_path = None
actual_audio_path = audio_path actual_audio_path = audio_path
if _is_video_file(audio_path): if _is_video_file(audio_path):
print(f"Detected video file, extracting audio...") print(f"🎬 Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path) temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}") print(f"Audio extracted to: {temp_audio_path}")
# Load audio
try: try:
duration = _get_duration_seconds_ffprobe(actual_audio_path) duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds") print(f"Audio duration: {duration:.2f} seconds")
@@ -476,16 +506,29 @@ def test_transcription_with_recovery():
print(f"Audio size: {len(audio_bytes)} bytes") print(f"Audio size: {len(audio_bytes)} bytes")
except Exception as e: except Exception as e:
print(f"Error preparing audio: {e}") print(f"Error preparing audio: {e}")
# Cleanup temp file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
return return
url = "http://localhost:8000/v1/chat/completions" url = f"{base_url}/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"] show_keys = ["Start time", "End time", "Speaker ID", "Content"]
# Build prompt with optional hotwords
if hotwords and hotwords.strip():
prompt_text = (
f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n"
f"Please transcribe it with these keys: " + ", ".join(show_keys)
)
print(f"\n📝 Hotwords embedded in prompt: '{hotwords}'")
else:
prompt_text = ( prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys) + ", ".join(show_keys)
) )
print(f"\n📝 No hotwords provided")
mime = _guess_mime_type(actual_audio_path) mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}" data_url = f"data:{mime};base64,{audio_b64}"
@@ -505,12 +548,13 @@ def test_transcription_with_recovery():
} }
] ]
print(f"\nSending request to {url} (Streaming Mode)...") print(f"\n{'=' * 70}")
print(f"Prompt: {prompt_text}") print(f"Sending request to {url}")
print("-" * 60) print(f"{'=' * 70}")
print("Response received. Streaming content:\n")
t0 = time.time() t0 = time.time()
print("\n✅ Response received. Streaming content:\n")
print("-" * 50)
result = stream_with_recovery( result = stream_with_recovery(
url=url, url=url,
@@ -522,27 +566,73 @@ def test_transcription_with_recovery():
debug=debug, debug=debug,
) )
print("\n[Finished]") elapsed = time.time() - t0
print("-" * 60) print("-" * 50)
print(f"Total time elapsed: {time.time() - t0:.2f}s") print("✅ [Finished]")
print(f"\n{'=' * 70}")
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
print(f"{'=' * 70}")
if result is None: if result is None:
print("Transcription failed") print("Transcription failed")
return return
print(f"Final output length: {len(result)} chars") print(f"📄 Final output length: {len(result)} chars")
# Optionally save result # Optionally save result
if output_path: if output_path:
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
f.write(result) f.write(result)
print(f"Result saved to: {output_path}") print(f"💾 Result saved to: {output_path}")
# Cleanup temp audio file if created # Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path): if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path) os.remove(temp_audio_path)
print(f"Cleaned up temp file: {temp_audio_path}") print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
def main():
parser = argparse.ArgumentParser(
description="Test VibeVoice vLLM API with auto-recovery from repetition loops"
)
parser.add_argument(
"audio_path",
help="Path to audio file (wav, mp3, flac, etc.) or video file"
)
parser.add_argument(
"output_path",
nargs="?",
default=None,
help="Optional path to save transcription result"
)
parser.add_argument(
"--url",
default="http://localhost:8000",
help="vLLM server URL (default: http://localhost:8000)"
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
)
parser.add_argument(
"--debug",
action="store_true",
help="Show recovery debug info"
)
args = parser.parse_args()
# Run test
test_transcription_with_recovery(
audio_path=args.audio_path,
output_path=args.output_path,
base_url=args.url,
hotwords=args.hotwords,
debug=args.debug,
)
if __name__ == "__main__": if __name__ == "__main__":
test_transcription_with_recovery() main()
Binary file not shown.