diff --git a/docs/vibevoice-vllm-asr.md b/docs/vibevoice-vllm-asr.md index f712d3b..489389e 100644 --- a/docs/vibevoice-vllm-asr.md +++ b/docs/vibevoice-vllm-asr.md @@ -52,15 +52,25 @@ docker logs -f vibevoice-vllm Once the vLLM server is running, test it with the provided script: ```bash -# Run the test (use container path /app/...) +# Basic transcription docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav + +# With hotwords for better recognition of specific terms +docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice" + ``` ```bash -# Run the recover_test (use container path /app/...) +# With auto-recovery from repetition loops (for long audio) docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav + +# Auto-recover with hotwords +docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice" ``` -> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing. + +> **Note**: +> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing. +> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names. ### Environment Variables diff --git a/vllm_plugin/tests/52min.mp3 b/vllm_plugin/tests/52min.mp3 deleted file mode 100644 index 0ed89ef..0000000 Binary files a/vllm_plugin/tests/52min.mp3 and /dev/null differ diff --git a/vllm_plugin/tests/test_api.py b/vllm_plugin/tests/test_api.py index af128bc..4076c20 100644 --- a/vllm_plugin/tests/test_api.py +++ b/vllm_plugin/tests/test_api.py @@ -1,14 +1,23 @@ #!/usr/bin/env python3 """ -Test VibeVoice vLLM API with Streaming (Real-time output). +Test VibeVoice vLLM API with Streaming and Optional Hotwords Support. + +This script tests ASR transcription via the vLLM OpenAI-compatible API. +By default, it runs standard transcription without hotwords. + +Optionally, you can provide hotwords (context_info) to improve recognition +of domain-specific content like proper nouns, technical terms, and speaker names. +Hotwords are embedded in the prompt as "with extra info: {hotwords}". Usage: - python test_api.py [audio_path] [--url URL] + python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"] Examples: - python test_api.py # Use default audio - python test_api.py /path/to/audio.wav # Specify audio file - python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL + # Standard transcription (no hotwords) + python3 test_api.py audio.wav + + # With hotwords for better recognition of specific terms + python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice" """ import requests import json @@ -21,38 +30,38 @@ import argparse def _guess_mime_type(path: str) -> str: + """Guess MIME type from file extension.""" ext = os.path.splitext(path)[1].lower() - if ext == ".wav": - return "audio/wav" - if ext in (".mp3",): - return "audio/mpeg" - if ext in (".m4a",): - return "audio/mp4" - if ext in (".mp4", ".m4v", ".mov", ".webm"): - return "video/mp4" - if ext in (".flac",): - return "audio/flac" - if ext in (".ogg", ".opus"): - return "audio/ogg" - return "application/octet-stream" + mime_map = { + ".wav": "audio/wav", + ".mp3": "audio/mpeg", + ".m4a": "audio/mp4", + ".mp4": "video/mp4", + ".flac": "audio/flac", + ".ogg": "audio/ogg", + ".opus": "audio/ogg", + } + return mime_map.get(ext, "application/octet-stream") def _get_duration_seconds_ffprobe(path: str) -> float: """Get audio duration using ffprobe.""" cmd = [ - "ffprobe", - "-v", - "error", - "-show_entries", - "format=duration", - "-of", - "default=noprint_wrappers=1:nokey=1", + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "default=noprint_wrappers=1:nokey=1", path, ] out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() return float(out) +def _is_video_file(path: str) -> bool: + """Check if the file is a video file that needs audio extraction.""" + ext = os.path.splitext(path)[1].lower() + return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") + + def _extract_audio_from_video(video_path: str) -> str: """ Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. @@ -74,26 +83,40 @@ def _extract_audio_from_video(video_path: str) -> str: return audio_path -def _is_video_file(path: str) -> bool: - """Check if the file is a video file that needs audio extraction.""" - ext = os.path.splitext(path)[1].lower() - return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") - - -def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"): - """Test ASR transcription with streaming output.""" +def test_transcription_with_hotwords( + audio_path: str, + context_info: str = None, + base_url: str = "http://localhost:8000", +): + """ + Test ASR transcription with customized hotwords. - print(f"Loading audio from: {audio_path}") + Hotwords are embedded in the prompt text as "with extra info: {hotwords}". + This helps the model recognize domain-specific terms more accurately. + + Args: + audio_path: Path to the audio file + context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice") + base_url: vLLM server URL + """ + + print(f"=" * 70) + print(f"Testing Customized Hotwords Support") + print(f"=" * 70) + print(f"Input file: {audio_path}") + print(f"Hotwords: {context_info or '(none)'}") + print() # Handle video files: extract audio first temp_audio_path = None actual_audio_path = audio_path if _is_video_file(audio_path): - print(f"Detected video file, extracting audio...") + print(f"šŸŽ¬ Detected video file, extracting audio...") temp_audio_path = _extract_audio_from_video(audio_path) actual_audio_path = temp_audio_path - print(f"Audio extracted to: {temp_audio_path}") + print(f"āœ… Audio extracted to: {temp_audio_path}") + # Load audio try: duration = _get_duration_seconds_ffprobe(actual_audio_path) print(f"Audio duration: {duration:.2f} seconds") @@ -106,16 +129,30 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000") except Exception as e: print(f"Error preparing audio: {e}") + # Cleanup temp file if created + if temp_audio_path and os.path.exists(temp_audio_path): + os.remove(temp_audio_path) return # Build the request url = f"{base_url}/v1/chat/completions" show_keys = ["Start time", "End time", "Speaker ID", "Content"] - prompt_text = ( - f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " - + ", ".join(show_keys) - ) + + # Build prompt with optional hotwords + # Hotwords are embedded as "with extra info: {hotwords}" in the prompt + if context_info and context_info.strip(): + prompt_text = ( + f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n" + f"Please transcribe it with these keys: " + ", ".join(show_keys) + ) + print(f"\nšŸ“ Hotwords embedded in prompt: '{context_info}'") + else: + prompt_text = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + print(f"\nšŸ“ No hotwords provided") mime = _guess_mime_type(actual_audio_path) data_url = f"data:{mime};base64,{audio_b64}" @@ -139,20 +176,19 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000") "temperature": 0.0, "stream": True, "top_p": 1.0, - "repetition_penalty": 1.0, } - print(f"\nSending request to {url} (Streaming Mode)...") - print(f"Prompt: {prompt_text}") - print("-" * 60) + print(f"\n{'=' * 70}") + print(f"Sending request to {url}") + print(f"{'=' * 70}") t0 = time.time() try: - response = requests.post(url, json=payload, stream=True, timeout=12000) if response.status_code == 200: - print("Response received. Streaming content:\n") + print("\nāœ… Response received. Streaming content:\n") + print("-" * 50) printed = "" for line in response.iter_lines(): @@ -162,92 +198,72 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000") if decoded_line.startswith("data: "): json_str = decoded_line[6:] if json_str.strip() == "[DONE]": - print("\n\n[Finished]") + print("\n" + "-" * 50) + print("āœ… [Finished]") break try: data = json.loads(json_str) - delta = data['choices'][0]['delta'] content = delta.get('content', '') if content: - - # vLLM/OpenAI-compatible streams may emit either - # incremental deltas OR the full accumulated text. - # Only print the newly-added part to avoid repeats. if content.startswith(printed): to_print = content[len(printed):] else: to_print = content - if to_print: print(to_print, end='', flush=True) printed += to_print except json.JSONDecodeError: pass else: - print(f"Error: {response.status_code}") + print(f"āŒ Error: {response.status_code}") print(response.text) except requests.exceptions.Timeout: - print("\nRequest timed out!") + print("āŒ Request timed out!") except Exception as e: - print(f"\nError: {e}") + print(f"āŒ Error: {e}") - print(f"\n{'-'*60}") - print(f"Total time elapsed: {time.time() - t0:.2f}s") + elapsed = time.time() - t0 + print(f"\n{'=' * 70}") + print(f"ā±ļø Total time elapsed: {elapsed:.2f}s") + print(f"šŸ“Š RTF (Real-Time Factor): {elapsed / duration:.2f}x") + print(f"{'=' * 70}") # Cleanup temp audio file if created if temp_audio_path and os.path.exists(temp_audio_path): os.remove(temp_audio_path) - print(f"Cleaned up temp file: {temp_audio_path}") + print(f"šŸ—‘ļø Cleaned up temp file: {temp_audio_path}") def main(): parser = argparse.ArgumentParser( - description="Test VibeVoice vLLM API with streaming output" + description="Test VibeVoice vLLM API with Customized Hotwords" ) parser.add_argument( "audio_path", - nargs="?", - default=None, help="Path to audio file (wav, mp3, flac, etc.) or video file" ) parser.add_argument( "--url", default="http://localhost:8000", - help="vLLM server base URL (default: http://localhost:8000)" + help="vLLM server URL (default: http://localhost:8000)" + ) + parser.add_argument( + "--hotwords", + type=str, + default=None, + help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')" ) args = parser.parse_args() - # Find default audio if not specified - audio_path = args.audio_path - if audio_path is None: - # Try to find a sample audio in common locations - possible_paths = [ - # In VibeVoice demo folder - os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"), - os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"), - # Relative to current directory - "demo/voices/en-Carter_man.wav", - "demo/voices/zh-Anchen_man_bgm.wav", - ] - - for path in possible_paths: - if os.path.exists(path): - audio_path = path - break - - if audio_path is None: - print("Error: No audio file specified and no default audio found.") - print("Usage: python test_api.py ") - sys.exit(1) - - if not os.path.exists(audio_path): - print(f"Error: Audio file not found: {audio_path}") - sys.exit(1) - - test_transcription(audio_path, args.url) + # Run test + test_transcription_with_hotwords( + audio_path=args.audio_path, + context_info=args.hotwords, + base_url=args.url, + ) if __name__ == "__main__": diff --git a/vllm_plugin/tests/test_api_auto_recover.py b/vllm_plugin/tests/test_api_auto_recover.py index 4fa053e..482e258 100644 --- a/vllm_plugin/tests/test_api_auto_recover.py +++ b/vllm_plugin/tests/test_api_auto_recover.py @@ -1,18 +1,36 @@ #!/usr/bin/env python3 """ -VibeVoice vLLM API with Auto-Recovery from Repetition Loops. +Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery. -Strategy: -1. Start with greedy decoding (temperature=0, top_p=1.0) -2. Stream and detect repetition patterns in real-time -3. Only output content up to (current_length - window_size) at segment boundaries -4. When loop detected: - - Truncate to last complete segment boundary (},) - - Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95 -5. Max 3 retries, if all fail output error message +This script tests ASR transcription with automatic recovery from repetition loops. +Supports optional hotwords to improve recognition of domain-specific terms. -User sees: clean streaming transcription output (only complete segments) -Internal: automatic recovery from repetition loops (silent) +Features: +- Streaming output with real-time repetition detection +- Auto-recovery when model enters repetition loops +- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}") +- Video file support (auto-extracts audio) + +Recovery Strategy: +1. First attempt: greedy decoding (temperature=0, top_p=1.0) +2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95 +3. Max 3 retries, truncate to last complete segment boundary + +Usage: + python test_api_auto_recover.py [output_path] [--url URL] [--hotwords "word1,word2"] [--debug] + +Examples: + # Basic usage + python3 test_api_auto_recover.py audio.wav + + # With hotwords + python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice" + + # Save result to file + python3 test_api_auto_recover.py audio.wav result.txt + + # Debug mode (show recovery info) + python3 test_api_auto_recover.py audio.wav --debug """ import requests import json @@ -22,6 +40,7 @@ import sys import os import subprocess import re +import argparse from collections import Counter @@ -441,30 +460,41 @@ def stream_with_recovery( return None -def test_transcription_with_recovery(): - """Main test function with auto-recovery.""" +def test_transcription_with_recovery( + audio_path: str, + output_path: str = None, + base_url: str = "http://localhost:8000", + hotwords: str = None, + debug: bool = False, +): + """ + Test ASR transcription with auto-recovery from repetition loops. - # Parse arguments - debug = "--debug" in sys.argv or "-debug" in sys.argv - args = [a for a in sys.argv[1:] if not a.startswith("-")] + Args: + audio_path: Path to the audio file + output_path: Optional path to save transcription result + base_url: vLLM server URL + hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice") + debug: Show recovery debug info + """ - audio_path = ( - args[0] - ) - - output_path = args[1] if len(args) > 1 else None - - print(f"Loading audio from: {audio_path}") + print(f"=" * 70) + print(f"Testing with Auto-Recovery") + print(f"=" * 70) + print(f"Input file: {audio_path}") + print(f"Hotwords: {hotwords or '(none)'}") + print() # Handle video files: extract audio first temp_audio_path = None actual_audio_path = audio_path if _is_video_file(audio_path): - print(f"Detected video file, extracting audio...") + print(f"šŸŽ¬ Detected video file, extracting audio...") temp_audio_path = _extract_audio_from_video(audio_path) actual_audio_path = temp_audio_path - print(f"Audio extracted to: {temp_audio_path}") + print(f"āœ… Audio extracted to: {temp_audio_path}") + # Load audio try: duration = _get_duration_seconds_ffprobe(actual_audio_path) print(f"Audio duration: {duration:.2f} seconds") @@ -476,16 +506,29 @@ def test_transcription_with_recovery(): print(f"Audio size: {len(audio_bytes)} bytes") except Exception as e: - print(f"Error preparing audio: {e}") + print(f"āŒ Error preparing audio: {e}") + # Cleanup temp file if created + if temp_audio_path and os.path.exists(temp_audio_path): + os.remove(temp_audio_path) return - url = "http://localhost:8000/v1/chat/completions" + url = f"{base_url}/v1/chat/completions" show_keys = ["Start time", "End time", "Speaker ID", "Content"] - prompt_text = ( - f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " - + ", ".join(show_keys) - ) + + # Build prompt with optional hotwords + if hotwords and hotwords.strip(): + prompt_text = ( + f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n" + f"Please transcribe it with these keys: " + ", ".join(show_keys) + ) + print(f"\nšŸ“ Hotwords embedded in prompt: '{hotwords}'") + else: + prompt_text = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + print(f"\nšŸ“ No hotwords provided") mime = _guess_mime_type(actual_audio_path) data_url = f"data:{mime};base64,{audio_b64}" @@ -505,12 +548,13 @@ def test_transcription_with_recovery(): } ] - print(f"\nSending request to {url} (Streaming Mode)...") - print(f"Prompt: {prompt_text}") - print("-" * 60) - print("Response received. Streaming content:\n") + print(f"\n{'=' * 70}") + print(f"Sending request to {url}") + print(f"{'=' * 70}") t0 = time.time() + print("\nāœ… Response received. Streaming content:\n") + print("-" * 50) result = stream_with_recovery( url=url, @@ -522,27 +566,73 @@ def test_transcription_with_recovery(): debug=debug, ) - print("\n[Finished]") - print("-" * 60) - print(f"Total time elapsed: {time.time() - t0:.2f}s") + elapsed = time.time() - t0 + print("-" * 50) + print("āœ… [Finished]") + print(f"\n{'=' * 70}") + print(f"ā±ļø Total time elapsed: {elapsed:.2f}s") + print(f"{'=' * 70}") if result is None: - print("Transcription failed") + print("āŒ Transcription failed") return - print(f"Final output length: {len(result)} chars") + print(f"šŸ“„ Final output length: {len(result)} chars") # Optionally save result if output_path: with open(output_path, "w", encoding="utf-8") as f: f.write(result) - print(f"Result saved to: {output_path}") + print(f"šŸ’¾ Result saved to: {output_path}") # Cleanup temp audio file if created if temp_audio_path and os.path.exists(temp_audio_path): os.remove(temp_audio_path) - print(f"Cleaned up temp file: {temp_audio_path}") + print(f"šŸ—‘ļø Cleaned up temp file: {temp_audio_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Test VibeVoice vLLM API with auto-recovery from repetition loops" + ) + parser.add_argument( + "audio_path", + help="Path to audio file (wav, mp3, flac, etc.) or video file" + ) + parser.add_argument( + "output_path", + nargs="?", + default=None, + help="Optional path to save transcription result" + ) + parser.add_argument( + "--url", + default="http://localhost:8000", + help="vLLM server URL (default: http://localhost:8000)" + ) + parser.add_argument( + "--hotwords", + type=str, + default=None, + help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')" + ) + parser.add_argument( + "--debug", + action="store_true", + help="Show recovery debug info" + ) + + args = parser.parse_args() + + # Run test + test_transcription_with_recovery( + audio_path=args.audio_path, + output_path=args.output_path, + base_url=args.url, + hotwords=args.hotwords, + debug=args.debug, + ) if __name__ == "__main__": - test_transcription_with_recovery() + main() diff --git a/vllm_plugin/tests/zeo.mp3 b/vllm_plugin/tests/zeo.mp3 deleted file mode 100644 index e149f94..0000000 Binary files a/vllm_plugin/tests/zeo.mp3 and /dev/null differ