Add test_api_auto_recover.py and test audio files

This commit is contained in:
YingboHAO
2026-02-02 13:49:01 +00:00
parent 1eb04f53a2
commit 0055161273
3 changed files with 548 additions and 0 deletions
Binary file not shown.
+548
View File
@@ -0,0 +1,548 @@
#!/usr/bin/env python3
"""
VibeVoice vLLM API with Auto-Recovery from Repetition Loops.
Strategy:
1. Start with greedy decoding (temperature=0, top_p=1.0)
2. Stream and detect repetition patterns in real-time
3. Only output content up to (current_length - window_size) at segment boundaries
4. When loop detected:
- Truncate to last complete segment boundary (},)
- Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
5. Max 3 retries, if all fail output error message
User sees: clean streaming transcription output (only complete segments)
Internal: automatic recovery from repetition loops (silent)
"""
import requests
import json
import base64
import time
import sys
import os
import subprocess
import re
from collections import Counter
def _guess_mime_type(path: str) -> str:
ext = os.path.splitext(path)[1].lower()
if ext == ".wav":
return "audio/wav"
if ext in (".mp3",):
return "audio/mpeg"
if ext in (".m4a",):
return "audio/mp4"
if ext in (".mp4", ".m4v", ".mov", ".webm"):
return "video/mp4"
if ext in (".flac",):
return "audio/flac"
if ext in (".ogg", ".opus"):
return "audio/ogg"
return "application/octet-stream"
def _get_duration_seconds_ffprobe(path: str) -> float:
cmd = [
"ffprobe", "-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
path,
]
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
return float(out)
def _extract_audio_from_video(video_path: str) -> str:
"""
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
Returns the path to the extracted audio file.
"""
import tempfile
# Create temp file with .mp3 extension
fd, audio_path = tempfile.mkstemp(suffix=".mp3")
os.close(fd)
cmd = [
"ffmpeg", "-y", "-i", video_path,
"-vn", # No video
"-acodec", "libmp3lame",
"-q:a", "2", # High quality
audio_path
]
subprocess.run(cmd, check=True, capture_output=True)
return audio_path
def _is_video_file(path: str) -> bool:
"""Check if the file is a video file that needs audio extraction."""
ext = os.path.splitext(path)[1].lower()
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
def _find_last_segment_boundary(text: str) -> int:
"""
Find the position after the last complete segment boundary (},).
Returns -1 if no complete segment found.
"""
# Find last "}, " or "}," pattern (segment separator)
pos = text.rfind("},")
if pos != -1:
return pos + 2 # Include the },
return -1
def _find_safe_print_boundary(text: str, max_pos: int) -> int:
"""
Find the last complete segment boundary before max_pos.
Returns 0 if no complete segment found before max_pos.
"""
search_text = text[:max_pos]
pos = search_text.rfind("},")
if pos != -1:
return pos + 2 # Include the },
return 0
class RepetitionDetector:
"""Detect repetition patterns in streaming text."""
def __init__(self,
min_pattern_len: int = 10, # Minimum chars for a pattern
min_repeats: int = 3, # Minimum repetitions to trigger
window_size: int = 500): # Window to check for patterns
self.min_pattern_len = min_pattern_len
self.min_repeats = min_repeats
self.window_size = window_size
self.text = ""
def add_text(self, new_text: str):
"""Add new text and return (is_looping, good_text_end_pos)."""
self.text += new_text
return self._check_repetition()
def _check_repetition(self):
"""Check if the recent text contains repetition loops."""
if len(self.text) < self.min_pattern_len * self.min_repeats:
return False, len(self.text)
# Check the recent window
window = self.text[-self.window_size:] if len(self.text) > self.window_size else self.text
# Method 1: Check for repeated substrings
for pattern_len in range(self.min_pattern_len, len(window) // self.min_repeats + 1):
# Get the last pattern_len characters as potential pattern
pattern = window[-pattern_len:]
# Count how many times this pattern appears at the end
count = 0
pos = len(window)
while pos >= pattern_len:
if window[pos - pattern_len:pos] == pattern:
count += 1
pos -= pattern_len
else:
break
if count >= self.min_repeats:
# Found repetition! Calculate where the good text ends
repetition_start = len(self.text) - (count * pattern_len)
# Keep one instance of the pattern (or none if it's garbage)
good_end = repetition_start + pattern_len if self._is_meaningful(pattern) else repetition_start
return True, good_end
# Method 2: Check for repeated short phrases (like "you're not, you're not")
# Look for patterns like "X, X, X" or "X X X"
words = window.split()
if len(words) >= self.min_repeats * 2:
# Check last N words for repetition
for phrase_len in range(2, 6): # 2-5 word phrases
if len(words) < phrase_len * self.min_repeats:
continue
phrase = " ".join(words[-phrase_len:])
count = 0
idx = len(words)
while idx >= phrase_len:
candidate = " ".join(words[idx - phrase_len:idx])
if candidate == phrase:
count += 1
idx -= phrase_len
else:
break
if count >= self.min_repeats:
# Calculate position in original text
repeated_text = (phrase + " ") * count
good_end = len(self.text) - len(repeated_text.rstrip()) + len(phrase)
return True, max(0, good_end)
return False, len(self.text)
def _is_meaningful(self, pattern: str) -> bool:
"""Check if pattern is meaningful content (not just garbage)."""
# Filter out patterns that are just punctuation, spaces, or very repetitive
clean = pattern.strip()
if not clean:
return False
if len(set(clean)) < 3: # Too few unique characters
return False
return True
def get_good_text(self, end_pos: int) -> str:
"""Get text up to the specified position."""
return self.text[:end_pos]
def reset(self, keep_text: str = ""):
"""Reset detector, optionally keeping some text."""
self.text = keep_text
def stream_with_recovery(
url: str,
base_messages: list,
audio_data_url: str,
prompt_text: str,
max_tokens: int = 32768,
max_retries: int = 3,
timeout: int = 12000,
debug: bool = False,
):
"""
Stream transcription with automatic recovery from repetition loops.
Args:
url: API endpoint
base_messages: Base messages (system + user with audio)
audio_data_url: The audio data URL for the request
prompt_text: The text prompt
max_tokens: Maximum tokens to generate
max_retries: Maximum recovery attempts (default 3)
timeout: Request timeout
debug: If True, show recovery debug info to stderr
Recovery strategy:
- First attempt: temperature=0.0, top_p=1.0 (greedy)
- Recovery: temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
- If has complete segments: use assistant prefix
- If no complete segments: restart from scratch
- Max 3 retries, if all fail output error message
Returns:
Final transcription text
"""
import sys as _sys
def _log(msg):
"""Log to stderr only if debug."""
if debug:
print(msg, file=_sys.stderr)
detector = RepetitionDetector(
min_pattern_len=10, # At least 10 chars for a pattern
min_repeats=10, # Must repeat 10+ times
window_size=400, # Check last 400 chars (can detect 10-40 char patterns repeated 10 times)
)
accumulated_text = ""
retry_count = 0
user_safe_printed_len = 0 # Track how much we've safely shown to user (at segment boundaries)
is_recovery = False # Whether we're in recovery mode
while retry_count <= max_retries:
# Build request payload
messages = list(base_messages) # Copy base messages
# If we have accumulated text from previous attempt, add it as partial assistant response
if accumulated_text:
# Add the good content as a partial assistant message
# vLLM will continue from here
messages.append({
"role": "assistant",
"content": accumulated_text
})
# Set sampling parameters based on recovery state
if is_recovery:
# Recovery: increase temperature each retry to break loops
recovery_temp = 0.1 + 0.1 * retry_count # 0.2, 0.3, 0.4 for retry 1, 2, 3
payload = {
"model": "vibevoice",
"messages": messages,
"max_tokens": max_tokens,
"temperature": recovery_temp,
"top_p": 0.95,
"stream": True,
}
if accumulated_text:
_log(f"[RECOVERY #{retry_count}] Continuing from {len(accumulated_text)} chars with temp={recovery_temp}, top_p=0.95")
else:
_log(f"[RECOVERY #{retry_count}] Restarting from scratch with temp={recovery_temp}, top_p=0.95")
else:
# First attempt: greedy decoding
payload = {
"model": "vibevoice",
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.0,
"top_p": 1.0,
"stream": True,
}
try:
response = requests.post(url, json=payload, stream=True, timeout=timeout)
if response.status_code != 200:
_log(f"[ERROR] {response.status_code} - {response.text[:500]}")
return accumulated_text
new_text = ""
printed = "" # Track what we've already received to handle vLLM duplicates
for line in response.iter_lines():
if not line:
continue
decoded_line = line.decode('utf-8')
if not decoded_line.startswith("data: "):
continue
json_str = decoded_line[6:]
if json_str.strip() == "[DONE]":
# Successfully finished without loops
full_result = accumulated_text + new_text
# Print any remaining content that wasn't printed yet
if len(full_result) > user_safe_printed_len:
remaining = full_result[user_safe_printed_len:]
print(remaining, end='', flush=True)
print() # Final newline
return full_result
try:
data = json.loads(json_str)
delta = data['choices'][0].get('delta', {})
content = delta.get('content', '')
if content:
# vLLM/OpenAI-compatible streams may emit either
# incremental deltas OR the full accumulated text.
# Only track the newly-added part.
if content.startswith(printed):
to_add = content[len(printed):]
else:
to_add = content
if to_add:
printed += to_add
new_text += to_add
# When continuing from prefix, model may add "[" or "[{" at start
# or repeat the ending "}, " from prefix
# We need to handle these to maintain valid JSON array format
if accumulated_text and new_text:
stripped = new_text.lstrip()
# Case 1: Model added "[{" - remove the "["
if stripped.startswith("[{"):
new_text = stripped[1:]
_log("[STRIPPED leading '[' from continuation]")
# Case 2: Model added just "[" - remove it
elif stripped.startswith("["):
new_text = stripped[1:]
_log("[STRIPPED leading '[' from continuation]")
# Case 3: Model repeated "}," from prefix ending
elif stripped.startswith("},"):
new_text = stripped[2:]
_log("[STRIPPED leading '},' from continuation]")
# Case 4: Model repeated "}" from prefix ending
elif stripped.startswith("}") and not stripped.startswith("}]"):
new_text = stripped[1:]
_log("[STRIPPED leading '}' from continuation]")
# Fix malformed JSON: {"2.99,... -> {"Start":2.99,...
# This happens when model skips "Start": key
import re
malformed = re.match(r'^\{"(\d+\.?\d*),', new_text)
if malformed:
time_val = malformed.group(1)
new_text = '{"Start":' + time_val + ',' + new_text[malformed.end():]
_log(f"[FIXED malformed JSON: added Start key]")
# Check for repetition in the combined text
full_text = accumulated_text + new_text
detector.text = full_text
is_looping, good_end = detector._check_repetition()
if is_looping:
_log(f"[LOOP DETECTED at char {good_end}]")
# Use what user has already seen as prefix for retry
# user_safe_printed_len is always at a segment boundary
if user_safe_printed_len > 0:
accumulated_text = full_text[:user_safe_printed_len]
_log(f"[RETRY from user-visible content at {user_safe_printed_len}]")
else:
# No complete segment shown to user yet - restart from scratch
accumulated_text = ""
_log(f"[NO CONTENT SHOWN TO USER - restart from scratch]")
detector.reset(accumulated_text)
is_recovery = True
if debug:
print("\n[...recovering...]", end='', flush=True, file=sys.stderr)
retry_count += 1
if retry_count > max_retries:
_log(f"[MAX RETRIES REACHED]")
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
return None
# Break inner loop to retry
break
else:
# No loop detected - stream content to user
# Only print up to (full_text_len - window_size) at segment boundaries
# This ensures user never sees content that might be rolled back
safe_end = max(0, len(full_text) - detector.window_size)
safe_boundary = _find_safe_print_boundary(full_text, safe_end)
if safe_boundary > user_safe_printed_len:
# Print new safe content
to_print = full_text[user_safe_printed_len:safe_boundary]
print(to_print, end='', flush=True)
user_safe_printed_len = safe_boundary
except json.JSONDecodeError:
continue
else:
# Loop completed without break (no repetition detected)
full_result = accumulated_text + new_text
# Print any remaining content that wasn't printed yet
if len(full_result) > user_safe_printed_len:
remaining = full_result[user_safe_printed_len:]
print(remaining, end='', flush=True)
print() # Final newline
return full_result
except requests.exceptions.Timeout:
_log("[TIMEOUT]")
print()
return accumulated_text
except Exception as e:
_log(f"[ERROR: {e}]")
print()
return accumulated_text
# All retries exhausted
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
return None
def test_transcription_with_recovery():
"""Main test function with auto-recovery."""
# Parse arguments
debug = "--debug" in sys.argv or "-debug" in sys.argv
args = [a for a in sys.argv[1:] if not a.startswith("-")]
audio_path = (
args[0]
)
output_path = args[1] if len(args) > 1 else None
print(f"Loading audio from: {audio_path}")
# Handle video files: extract audio first
temp_audio_path = None
actual_audio_path = audio_path
if _is_video_file(audio_path):
print(f"Detected video file, extracting audio...")
temp_audio_path = _extract_audio_from_video(audio_path)
actual_audio_path = temp_audio_path
print(f"Audio extracted to: {temp_audio_path}")
try:
duration = _get_duration_seconds_ffprobe(actual_audio_path)
print(f"Audio duration: {duration:.2f} seconds")
with open(actual_audio_path, "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
print(f"Audio size: {len(audio_bytes)} bytes")
except Exception as e:
print(f"Error preparing audio: {e}")
return
url = "http://localhost:8000/v1/chat/completions"
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
prompt_text = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
mime = _guess_mime_type(actual_audio_path)
data_url = f"data:{mime};base64,{audio_b64}"
# Base messages (without assistant continuation)
base_messages = [
{
"role": "system",
"content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
},
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": data_url}},
{"type": "text", "text": prompt_text}
]
}
]
print(f"\nSending request to {url} (Streaming Mode)...")
print(f"Prompt: {prompt_text}")
print("-" * 60)
print("Response received. Streaming content:\n")
t0 = time.time()
result = stream_with_recovery(
url=url,
base_messages=base_messages,
audio_data_url=data_url,
prompt_text=prompt_text,
max_tokens=32768,
max_retries=3,
debug=debug,
)
print("\n[Finished]")
print("-" * 60)
print(f"Total time elapsed: {time.time() - t0:.2f}s")
if result is None:
print("Transcription failed")
return
print(f"Final output length: {len(result)} chars")
# Optionally save result
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
f.write(result)
print(f"Result saved to: {output_path}")
# Cleanup temp audio file if created
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
print(f"Cleaned up temp file: {temp_audio_path}")
if __name__ == "__main__":
test_transcription_with_recovery()
Binary file not shown.