diff --git a/docs/vibevoice-vllm-asr.md b/docs/vibevoice-vllm-asr.md index 16cb2ca..f712d3b 100644 --- a/docs/vibevoice-vllm-asr.md +++ b/docs/vibevoice-vllm-asr.md @@ -55,6 +55,11 @@ Once the vLLM server is running, test it with the provided script: # Run the test (use container path /app/...) docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav ``` + +```bash +# Run the recover_test (use container path /app/...) +docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav +``` > **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing. ### Environment Variables diff --git a/vllm_plugin/tests/52min.mp3 b/vllm_plugin/tests/52min.mp3 new file mode 100644 index 0000000..0ed89ef Binary files /dev/null and b/vllm_plugin/tests/52min.mp3 differ diff --git a/vllm_plugin/tests/test_api_auto_recover.py b/vllm_plugin/tests/test_api_auto_recover.py new file mode 100644 index 0000000..4fa053e --- /dev/null +++ b/vllm_plugin/tests/test_api_auto_recover.py @@ -0,0 +1,548 @@ +#!/usr/bin/env python3 +""" +VibeVoice vLLM API with Auto-Recovery from Repetition Loops. + +Strategy: +1. Start with greedy decoding (temperature=0, top_p=1.0) +2. Stream and detect repetition patterns in real-time +3. Only output content up to (current_length - window_size) at segment boundaries +4. When loop detected: + - Truncate to last complete segment boundary (},) + - Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95 +5. Max 3 retries, if all fail output error message + +User sees: clean streaming transcription output (only complete segments) +Internal: automatic recovery from repetition loops (silent) +""" +import requests +import json +import base64 +import time +import sys +import os +import subprocess +import re +from collections import Counter + + +def _guess_mime_type(path: str) -> str: + ext = os.path.splitext(path)[1].lower() + if ext == ".wav": + return "audio/wav" + if ext in (".mp3",): + return "audio/mpeg" + if ext in (".m4a",): + return "audio/mp4" + if ext in (".mp4", ".m4v", ".mov", ".webm"): + return "video/mp4" + if ext in (".flac",): + return "audio/flac" + if ext in (".ogg", ".opus"): + return "audio/ogg" + return "application/octet-stream" + + +def _get_duration_seconds_ffprobe(path: str) -> float: + cmd = [ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "default=noprint_wrappers=1:nokey=1", + path, + ] + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() + return float(out) + + +def _extract_audio_from_video(video_path: str) -> str: + """ + Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. + Returns the path to the extracted audio file. + """ + import tempfile + # Create temp file with .mp3 extension + fd, audio_path = tempfile.mkstemp(suffix=".mp3") + os.close(fd) + + cmd = [ + "ffmpeg", "-y", "-i", video_path, + "-vn", # No video + "-acodec", "libmp3lame", + "-q:a", "2", # High quality + audio_path + ] + subprocess.run(cmd, check=True, capture_output=True) + return audio_path + + +def _is_video_file(path: str) -> bool: + """Check if the file is a video file that needs audio extraction.""" + ext = os.path.splitext(path)[1].lower() + return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") + + +def _find_last_segment_boundary(text: str) -> int: + """ + Find the position after the last complete segment boundary (},). + Returns -1 if no complete segment found. + """ + # Find last "}, " or "}," pattern (segment separator) + pos = text.rfind("},") + if pos != -1: + return pos + 2 # Include the }, + return -1 + + +def _find_safe_print_boundary(text: str, max_pos: int) -> int: + """ + Find the last complete segment boundary before max_pos. + Returns 0 if no complete segment found before max_pos. + """ + search_text = text[:max_pos] + pos = search_text.rfind("},") + if pos != -1: + return pos + 2 # Include the }, + return 0 + + +class RepetitionDetector: + """Detect repetition patterns in streaming text.""" + + def __init__(self, + min_pattern_len: int = 10, # Minimum chars for a pattern + min_repeats: int = 3, # Minimum repetitions to trigger + window_size: int = 500): # Window to check for patterns + self.min_pattern_len = min_pattern_len + self.min_repeats = min_repeats + self.window_size = window_size + self.text = "" + + def add_text(self, new_text: str): + """Add new text and return (is_looping, good_text_end_pos).""" + self.text += new_text + return self._check_repetition() + + def _check_repetition(self): + """Check if the recent text contains repetition loops.""" + if len(self.text) < self.min_pattern_len * self.min_repeats: + return False, len(self.text) + + # Check the recent window + window = self.text[-self.window_size:] if len(self.text) > self.window_size else self.text + + # Method 1: Check for repeated substrings + for pattern_len in range(self.min_pattern_len, len(window) // self.min_repeats + 1): + # Get the last pattern_len characters as potential pattern + pattern = window[-pattern_len:] + + # Count how many times this pattern appears at the end + count = 0 + pos = len(window) + while pos >= pattern_len: + if window[pos - pattern_len:pos] == pattern: + count += 1 + pos -= pattern_len + else: + break + + if count >= self.min_repeats: + # Found repetition! Calculate where the good text ends + repetition_start = len(self.text) - (count * pattern_len) + # Keep one instance of the pattern (or none if it's garbage) + good_end = repetition_start + pattern_len if self._is_meaningful(pattern) else repetition_start + return True, good_end + + # Method 2: Check for repeated short phrases (like "you're not, you're not") + # Look for patterns like "X, X, X" or "X X X" + words = window.split() + if len(words) >= self.min_repeats * 2: + # Check last N words for repetition + for phrase_len in range(2, 6): # 2-5 word phrases + if len(words) < phrase_len * self.min_repeats: + continue + + phrase = " ".join(words[-phrase_len:]) + count = 0 + idx = len(words) + while idx >= phrase_len: + candidate = " ".join(words[idx - phrase_len:idx]) + if candidate == phrase: + count += 1 + idx -= phrase_len + else: + break + + if count >= self.min_repeats: + # Calculate position in original text + repeated_text = (phrase + " ") * count + good_end = len(self.text) - len(repeated_text.rstrip()) + len(phrase) + return True, max(0, good_end) + + return False, len(self.text) + + def _is_meaningful(self, pattern: str) -> bool: + """Check if pattern is meaningful content (not just garbage).""" + # Filter out patterns that are just punctuation, spaces, or very repetitive + clean = pattern.strip() + if not clean: + return False + if len(set(clean)) < 3: # Too few unique characters + return False + return True + + def get_good_text(self, end_pos: int) -> str: + """Get text up to the specified position.""" + return self.text[:end_pos] + + def reset(self, keep_text: str = ""): + """Reset detector, optionally keeping some text.""" + self.text = keep_text + + +def stream_with_recovery( + url: str, + base_messages: list, + audio_data_url: str, + prompt_text: str, + max_tokens: int = 32768, + max_retries: int = 3, + timeout: int = 12000, + debug: bool = False, +): + """ + Stream transcription with automatic recovery from repetition loops. + + Args: + url: API endpoint + base_messages: Base messages (system + user with audio) + audio_data_url: The audio data URL for the request + prompt_text: The text prompt + max_tokens: Maximum tokens to generate + max_retries: Maximum recovery attempts (default 3) + timeout: Request timeout + debug: If True, show recovery debug info to stderr + + Recovery strategy: + - First attempt: temperature=0.0, top_p=1.0 (greedy) + - Recovery: temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95 + - If has complete segments: use assistant prefix + - If no complete segments: restart from scratch + - Max 3 retries, if all fail output error message + + Returns: + Final transcription text + """ + import sys as _sys + + def _log(msg): + """Log to stderr only if debug.""" + if debug: + print(msg, file=_sys.stderr) + + detector = RepetitionDetector( + min_pattern_len=10, # At least 10 chars for a pattern + min_repeats=10, # Must repeat 10+ times + window_size=400, # Check last 400 chars (can detect 10-40 char patterns repeated 10 times) + ) + + accumulated_text = "" + retry_count = 0 + user_safe_printed_len = 0 # Track how much we've safely shown to user (at segment boundaries) + is_recovery = False # Whether we're in recovery mode + + while retry_count <= max_retries: + # Build request payload + messages = list(base_messages) # Copy base messages + + # If we have accumulated text from previous attempt, add it as partial assistant response + if accumulated_text: + # Add the good content as a partial assistant message + # vLLM will continue from here + messages.append({ + "role": "assistant", + "content": accumulated_text + }) + + # Set sampling parameters based on recovery state + if is_recovery: + # Recovery: increase temperature each retry to break loops + recovery_temp = 0.1 + 0.1 * retry_count # 0.2, 0.3, 0.4 for retry 1, 2, 3 + payload = { + "model": "vibevoice", + "messages": messages, + "max_tokens": max_tokens, + "temperature": recovery_temp, + "top_p": 0.95, + "stream": True, + } + if accumulated_text: + _log(f"[RECOVERY #{retry_count}] Continuing from {len(accumulated_text)} chars with temp={recovery_temp}, top_p=0.95") + else: + _log(f"[RECOVERY #{retry_count}] Restarting from scratch with temp={recovery_temp}, top_p=0.95") + else: + # First attempt: greedy decoding + payload = { + "model": "vibevoice", + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.0, + "top_p": 1.0, + "stream": True, + } + + try: + response = requests.post(url, json=payload, stream=True, timeout=timeout) + + if response.status_code != 200: + _log(f"[ERROR] {response.status_code} - {response.text[:500]}") + return accumulated_text + + new_text = "" + printed = "" # Track what we've already received to handle vLLM duplicates + + for line in response.iter_lines(): + if not line: + continue + + decoded_line = line.decode('utf-8') + if not decoded_line.startswith("data: "): + continue + + json_str = decoded_line[6:] + if json_str.strip() == "[DONE]": + # Successfully finished without loops + full_result = accumulated_text + new_text + # Print any remaining content that wasn't printed yet + if len(full_result) > user_safe_printed_len: + remaining = full_result[user_safe_printed_len:] + print(remaining, end='', flush=True) + print() # Final newline + return full_result + + try: + data = json.loads(json_str) + delta = data['choices'][0].get('delta', {}) + content = delta.get('content', '') + + if content: + # vLLM/OpenAI-compatible streams may emit either + # incremental deltas OR the full accumulated text. + # Only track the newly-added part. + if content.startswith(printed): + to_add = content[len(printed):] + else: + to_add = content + + if to_add: + printed += to_add + new_text += to_add + + # When continuing from prefix, model may add "[" or "[{" at start + # or repeat the ending "}, " from prefix + # We need to handle these to maintain valid JSON array format + if accumulated_text and new_text: + stripped = new_text.lstrip() + # Case 1: Model added "[{" - remove the "[" + if stripped.startswith("[{"): + new_text = stripped[1:] + _log("[STRIPPED leading '[' from continuation]") + # Case 2: Model added just "[" - remove it + elif stripped.startswith("["): + new_text = stripped[1:] + _log("[STRIPPED leading '[' from continuation]") + # Case 3: Model repeated "}," from prefix ending + elif stripped.startswith("},"): + new_text = stripped[2:] + _log("[STRIPPED leading '},' from continuation]") + # Case 4: Model repeated "}" from prefix ending + elif stripped.startswith("}") and not stripped.startswith("}]"): + new_text = stripped[1:] + _log("[STRIPPED leading '}' from continuation]") + + # Fix malformed JSON: {"2.99,... -> {"Start":2.99,... + # This happens when model skips "Start": key + import re + malformed = re.match(r'^\{"(\d+\.?\d*),', new_text) + if malformed: + time_val = malformed.group(1) + new_text = '{"Start":' + time_val + ',' + new_text[malformed.end():] + _log(f"[FIXED malformed JSON: added Start key]") + + # Check for repetition in the combined text + full_text = accumulated_text + new_text + detector.text = full_text + is_looping, good_end = detector._check_repetition() + + if is_looping: + _log(f"[LOOP DETECTED at char {good_end}]") + + # Use what user has already seen as prefix for retry + # user_safe_printed_len is always at a segment boundary + if user_safe_printed_len > 0: + accumulated_text = full_text[:user_safe_printed_len] + _log(f"[RETRY from user-visible content at {user_safe_printed_len}]") + else: + # No complete segment shown to user yet - restart from scratch + accumulated_text = "" + _log(f"[NO CONTENT SHOWN TO USER - restart from scratch]") + + detector.reset(accumulated_text) + is_recovery = True + + if debug: + print("\n[...recovering...]", end='', flush=True, file=sys.stderr) + + retry_count += 1 + + if retry_count > max_retries: + _log(f"[MAX RETRIES REACHED]") + print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True) + return None + + # Break inner loop to retry + break + else: + # No loop detected - stream content to user + # Only print up to (full_text_len - window_size) at segment boundaries + # This ensures user never sees content that might be rolled back + safe_end = max(0, len(full_text) - detector.window_size) + safe_boundary = _find_safe_print_boundary(full_text, safe_end) + + if safe_boundary > user_safe_printed_len: + # Print new safe content + to_print = full_text[user_safe_printed_len:safe_boundary] + print(to_print, end='', flush=True) + user_safe_printed_len = safe_boundary + + except json.JSONDecodeError: + continue + else: + # Loop completed without break (no repetition detected) + full_result = accumulated_text + new_text + + # Print any remaining content that wasn't printed yet + if len(full_result) > user_safe_printed_len: + remaining = full_result[user_safe_printed_len:] + print(remaining, end='', flush=True) + + print() # Final newline + return full_result + + except requests.exceptions.Timeout: + _log("[TIMEOUT]") + print() + return accumulated_text + except Exception as e: + _log(f"[ERROR: {e}]") + print() + return accumulated_text + + # All retries exhausted + print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True) + return None + + +def test_transcription_with_recovery(): + """Main test function with auto-recovery.""" + + # Parse arguments + debug = "--debug" in sys.argv or "-debug" in sys.argv + args = [a for a in sys.argv[1:] if not a.startswith("-")] + + audio_path = ( + args[0] + ) + + output_path = args[1] if len(args) > 1 else None + + print(f"Loading audio from: {audio_path}") + + # Handle video files: extract audio first + temp_audio_path = None + actual_audio_path = audio_path + if _is_video_file(audio_path): + print(f"Detected video file, extracting audio...") + temp_audio_path = _extract_audio_from_video(audio_path) + actual_audio_path = temp_audio_path + print(f"Audio extracted to: {temp_audio_path}") + + try: + duration = _get_duration_seconds_ffprobe(actual_audio_path) + print(f"Audio duration: {duration:.2f} seconds") + + with open(actual_audio_path, "rb") as f: + audio_bytes = f.read() + + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + print(f"Audio size: {len(audio_bytes)} bytes") + + except Exception as e: + print(f"Error preparing audio: {e}") + return + + url = "http://localhost:8000/v1/chat/completions" + + show_keys = ["Start time", "End time", "Speaker ID", "Content"] + prompt_text = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + + mime = _guess_mime_type(actual_audio_path) + data_url = f"data:{mime};base64,{audio_b64}" + + # Base messages (without assistant continuation) + base_messages = [ + { + "role": "system", + "content": "You are a helpful assistant that transcribes audio input into text output in JSON format." + }, + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": data_url}}, + {"type": "text", "text": prompt_text} + ] + } + ] + + print(f"\nSending request to {url} (Streaming Mode)...") + print(f"Prompt: {prompt_text}") + print("-" * 60) + print("Response received. Streaming content:\n") + + t0 = time.time() + + result = stream_with_recovery( + url=url, + base_messages=base_messages, + audio_data_url=data_url, + prompt_text=prompt_text, + max_tokens=32768, + max_retries=3, + debug=debug, + ) + + print("\n[Finished]") + print("-" * 60) + print(f"Total time elapsed: {time.time() - t0:.2f}s") + + if result is None: + print("Transcription failed") + return + + print(f"Final output length: {len(result)} chars") + + # Optionally save result + if output_path: + with open(output_path, "w", encoding="utf-8") as f: + f.write(result) + print(f"Result saved to: {output_path}") + + # Cleanup temp audio file if created + if temp_audio_path and os.path.exists(temp_audio_path): + os.remove(temp_audio_path) + print(f"Cleaned up temp file: {temp_audio_path}") + + +if __name__ == "__main__": + test_transcription_with_recovery() diff --git a/vllm_plugin/tests/zeo.mp3 b/vllm_plugin/tests/zeo.mp3 new file mode 100644 index 0000000..e149f94 Binary files /dev/null and b/vllm_plugin/tests/zeo.mp3 differ