549 lines
22 KiB
Python
549 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
VibeVoice vLLM API with Auto-Recovery from Repetition Loops.
|
|
|
|
Strategy:
|
|
1. Start with greedy decoding (temperature=0, top_p=1.0)
|
|
2. Stream and detect repetition patterns in real-time
|
|
3. Only output content up to (current_length - window_size) at segment boundaries
|
|
4. When loop detected:
|
|
- Truncate to last complete segment boundary (},)
|
|
- Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
|
|
5. Max 3 retries, if all fail output error message
|
|
|
|
User sees: clean streaming transcription output (only complete segments)
|
|
Internal: automatic recovery from repetition loops (silent)
|
|
"""
|
|
import requests
|
|
import json
|
|
import base64
|
|
import time
|
|
import sys
|
|
import os
|
|
import subprocess
|
|
import re
|
|
from collections import Counter
|
|
|
|
|
|
def _guess_mime_type(path: str) -> str:
|
|
ext = os.path.splitext(path)[1].lower()
|
|
if ext == ".wav":
|
|
return "audio/wav"
|
|
if ext in (".mp3",):
|
|
return "audio/mpeg"
|
|
if ext in (".m4a",):
|
|
return "audio/mp4"
|
|
if ext in (".mp4", ".m4v", ".mov", ".webm"):
|
|
return "video/mp4"
|
|
if ext in (".flac",):
|
|
return "audio/flac"
|
|
if ext in (".ogg", ".opus"):
|
|
return "audio/ogg"
|
|
return "application/octet-stream"
|
|
|
|
|
|
def _get_duration_seconds_ffprobe(path: str) -> float:
|
|
cmd = [
|
|
"ffprobe", "-v", "error",
|
|
"-show_entries", "format=duration",
|
|
"-of", "default=noprint_wrappers=1:nokey=1",
|
|
path,
|
|
]
|
|
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
|
|
return float(out)
|
|
|
|
|
|
def _extract_audio_from_video(video_path: str) -> str:
|
|
"""
|
|
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
|
|
Returns the path to the extracted audio file.
|
|
"""
|
|
import tempfile
|
|
# Create temp file with .mp3 extension
|
|
fd, audio_path = tempfile.mkstemp(suffix=".mp3")
|
|
os.close(fd)
|
|
|
|
cmd = [
|
|
"ffmpeg", "-y", "-i", video_path,
|
|
"-vn", # No video
|
|
"-acodec", "libmp3lame",
|
|
"-q:a", "2", # High quality
|
|
audio_path
|
|
]
|
|
subprocess.run(cmd, check=True, capture_output=True)
|
|
return audio_path
|
|
|
|
|
|
def _is_video_file(path: str) -> bool:
|
|
"""Check if the file is a video file that needs audio extraction."""
|
|
ext = os.path.splitext(path)[1].lower()
|
|
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
|
|
|
|
|
def _find_last_segment_boundary(text: str) -> int:
|
|
"""
|
|
Find the position after the last complete segment boundary (},).
|
|
Returns -1 if no complete segment found.
|
|
"""
|
|
# Find last "}, " or "}," pattern (segment separator)
|
|
pos = text.rfind("},")
|
|
if pos != -1:
|
|
return pos + 2 # Include the },
|
|
return -1
|
|
|
|
|
|
def _find_safe_print_boundary(text: str, max_pos: int) -> int:
|
|
"""
|
|
Find the last complete segment boundary before max_pos.
|
|
Returns 0 if no complete segment found before max_pos.
|
|
"""
|
|
search_text = text[:max_pos]
|
|
pos = search_text.rfind("},")
|
|
if pos != -1:
|
|
return pos + 2 # Include the },
|
|
return 0
|
|
|
|
|
|
class RepetitionDetector:
|
|
"""Detect repetition patterns in streaming text."""
|
|
|
|
def __init__(self,
|
|
min_pattern_len: int = 10, # Minimum chars for a pattern
|
|
min_repeats: int = 3, # Minimum repetitions to trigger
|
|
window_size: int = 500): # Window to check for patterns
|
|
self.min_pattern_len = min_pattern_len
|
|
self.min_repeats = min_repeats
|
|
self.window_size = window_size
|
|
self.text = ""
|
|
|
|
def add_text(self, new_text: str):
|
|
"""Add new text and return (is_looping, good_text_end_pos)."""
|
|
self.text += new_text
|
|
return self._check_repetition()
|
|
|
|
def _check_repetition(self):
|
|
"""Check if the recent text contains repetition loops."""
|
|
if len(self.text) < self.min_pattern_len * self.min_repeats:
|
|
return False, len(self.text)
|
|
|
|
# Check the recent window
|
|
window = self.text[-self.window_size:] if len(self.text) > self.window_size else self.text
|
|
|
|
# Method 1: Check for repeated substrings
|
|
for pattern_len in range(self.min_pattern_len, len(window) // self.min_repeats + 1):
|
|
# Get the last pattern_len characters as potential pattern
|
|
pattern = window[-pattern_len:]
|
|
|
|
# Count how many times this pattern appears at the end
|
|
count = 0
|
|
pos = len(window)
|
|
while pos >= pattern_len:
|
|
if window[pos - pattern_len:pos] == pattern:
|
|
count += 1
|
|
pos -= pattern_len
|
|
else:
|
|
break
|
|
|
|
if count >= self.min_repeats:
|
|
# Found repetition! Calculate where the good text ends
|
|
repetition_start = len(self.text) - (count * pattern_len)
|
|
# Keep one instance of the pattern (or none if it's garbage)
|
|
good_end = repetition_start + pattern_len if self._is_meaningful(pattern) else repetition_start
|
|
return True, good_end
|
|
|
|
# Method 2: Check for repeated short phrases (like "you're not, you're not")
|
|
# Look for patterns like "X, X, X" or "X X X"
|
|
words = window.split()
|
|
if len(words) >= self.min_repeats * 2:
|
|
# Check last N words for repetition
|
|
for phrase_len in range(2, 6): # 2-5 word phrases
|
|
if len(words) < phrase_len * self.min_repeats:
|
|
continue
|
|
|
|
phrase = " ".join(words[-phrase_len:])
|
|
count = 0
|
|
idx = len(words)
|
|
while idx >= phrase_len:
|
|
candidate = " ".join(words[idx - phrase_len:idx])
|
|
if candidate == phrase:
|
|
count += 1
|
|
idx -= phrase_len
|
|
else:
|
|
break
|
|
|
|
if count >= self.min_repeats:
|
|
# Calculate position in original text
|
|
repeated_text = (phrase + " ") * count
|
|
good_end = len(self.text) - len(repeated_text.rstrip()) + len(phrase)
|
|
return True, max(0, good_end)
|
|
|
|
return False, len(self.text)
|
|
|
|
def _is_meaningful(self, pattern: str) -> bool:
|
|
"""Check if pattern is meaningful content (not just garbage)."""
|
|
# Filter out patterns that are just punctuation, spaces, or very repetitive
|
|
clean = pattern.strip()
|
|
if not clean:
|
|
return False
|
|
if len(set(clean)) < 3: # Too few unique characters
|
|
return False
|
|
return True
|
|
|
|
def get_good_text(self, end_pos: int) -> str:
|
|
"""Get text up to the specified position."""
|
|
return self.text[:end_pos]
|
|
|
|
def reset(self, keep_text: str = ""):
|
|
"""Reset detector, optionally keeping some text."""
|
|
self.text = keep_text
|
|
|
|
|
|
def stream_with_recovery(
|
|
url: str,
|
|
base_messages: list,
|
|
audio_data_url: str,
|
|
prompt_text: str,
|
|
max_tokens: int = 32768,
|
|
max_retries: int = 3,
|
|
timeout: int = 12000,
|
|
debug: bool = False,
|
|
):
|
|
"""
|
|
Stream transcription with automatic recovery from repetition loops.
|
|
|
|
Args:
|
|
url: API endpoint
|
|
base_messages: Base messages (system + user with audio)
|
|
audio_data_url: The audio data URL for the request
|
|
prompt_text: The text prompt
|
|
max_tokens: Maximum tokens to generate
|
|
max_retries: Maximum recovery attempts (default 3)
|
|
timeout: Request timeout
|
|
debug: If True, show recovery debug info to stderr
|
|
|
|
Recovery strategy:
|
|
- First attempt: temperature=0.0, top_p=1.0 (greedy)
|
|
- Recovery: temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
|
|
- If has complete segments: use assistant prefix
|
|
- If no complete segments: restart from scratch
|
|
- Max 3 retries, if all fail output error message
|
|
|
|
Returns:
|
|
Final transcription text
|
|
"""
|
|
import sys as _sys
|
|
|
|
def _log(msg):
|
|
"""Log to stderr only if debug."""
|
|
if debug:
|
|
print(msg, file=_sys.stderr)
|
|
|
|
detector = RepetitionDetector(
|
|
min_pattern_len=10, # At least 10 chars for a pattern
|
|
min_repeats=10, # Must repeat 10+ times
|
|
window_size=400, # Check last 400 chars (can detect 10-40 char patterns repeated 10 times)
|
|
)
|
|
|
|
accumulated_text = ""
|
|
retry_count = 0
|
|
user_safe_printed_len = 0 # Track how much we've safely shown to user (at segment boundaries)
|
|
is_recovery = False # Whether we're in recovery mode
|
|
|
|
while retry_count <= max_retries:
|
|
# Build request payload
|
|
messages = list(base_messages) # Copy base messages
|
|
|
|
# If we have accumulated text from previous attempt, add it as partial assistant response
|
|
if accumulated_text:
|
|
# Add the good content as a partial assistant message
|
|
# vLLM will continue from here
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": accumulated_text
|
|
})
|
|
|
|
# Set sampling parameters based on recovery state
|
|
if is_recovery:
|
|
# Recovery: increase temperature each retry to break loops
|
|
recovery_temp = 0.1 + 0.1 * retry_count # 0.2, 0.3, 0.4 for retry 1, 2, 3
|
|
payload = {
|
|
"model": "vibevoice",
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": recovery_temp,
|
|
"top_p": 0.95,
|
|
"stream": True,
|
|
}
|
|
if accumulated_text:
|
|
_log(f"[RECOVERY #{retry_count}] Continuing from {len(accumulated_text)} chars with temp={recovery_temp}, top_p=0.95")
|
|
else:
|
|
_log(f"[RECOVERY #{retry_count}] Restarting from scratch with temp={recovery_temp}, top_p=0.95")
|
|
else:
|
|
# First attempt: greedy decoding
|
|
payload = {
|
|
"model": "vibevoice",
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": 0.0,
|
|
"top_p": 1.0,
|
|
"stream": True,
|
|
}
|
|
|
|
try:
|
|
response = requests.post(url, json=payload, stream=True, timeout=timeout)
|
|
|
|
if response.status_code != 200:
|
|
_log(f"[ERROR] {response.status_code} - {response.text[:500]}")
|
|
return accumulated_text
|
|
|
|
new_text = ""
|
|
printed = "" # Track what we've already received to handle vLLM duplicates
|
|
|
|
for line in response.iter_lines():
|
|
if not line:
|
|
continue
|
|
|
|
decoded_line = line.decode('utf-8')
|
|
if not decoded_line.startswith("data: "):
|
|
continue
|
|
|
|
json_str = decoded_line[6:]
|
|
if json_str.strip() == "[DONE]":
|
|
# Successfully finished without loops
|
|
full_result = accumulated_text + new_text
|
|
# Print any remaining content that wasn't printed yet
|
|
if len(full_result) > user_safe_printed_len:
|
|
remaining = full_result[user_safe_printed_len:]
|
|
print(remaining, end='', flush=True)
|
|
print() # Final newline
|
|
return full_result
|
|
|
|
try:
|
|
data = json.loads(json_str)
|
|
delta = data['choices'][0].get('delta', {})
|
|
content = delta.get('content', '')
|
|
|
|
if content:
|
|
# vLLM/OpenAI-compatible streams may emit either
|
|
# incremental deltas OR the full accumulated text.
|
|
# Only track the newly-added part.
|
|
if content.startswith(printed):
|
|
to_add = content[len(printed):]
|
|
else:
|
|
to_add = content
|
|
|
|
if to_add:
|
|
printed += to_add
|
|
new_text += to_add
|
|
|
|
# When continuing from prefix, model may add "[" or "[{" at start
|
|
# or repeat the ending "}, " from prefix
|
|
# We need to handle these to maintain valid JSON array format
|
|
if accumulated_text and new_text:
|
|
stripped = new_text.lstrip()
|
|
# Case 1: Model added "[{" - remove the "["
|
|
if stripped.startswith("[{"):
|
|
new_text = stripped[1:]
|
|
_log("[STRIPPED leading '[' from continuation]")
|
|
# Case 2: Model added just "[" - remove it
|
|
elif stripped.startswith("["):
|
|
new_text = stripped[1:]
|
|
_log("[STRIPPED leading '[' from continuation]")
|
|
# Case 3: Model repeated "}," from prefix ending
|
|
elif stripped.startswith("},"):
|
|
new_text = stripped[2:]
|
|
_log("[STRIPPED leading '},' from continuation]")
|
|
# Case 4: Model repeated "}" from prefix ending
|
|
elif stripped.startswith("}") and not stripped.startswith("}]"):
|
|
new_text = stripped[1:]
|
|
_log("[STRIPPED leading '}' from continuation]")
|
|
|
|
# Fix malformed JSON: {"2.99,... -> {"Start":2.99,...
|
|
# This happens when model skips "Start": key
|
|
import re
|
|
malformed = re.match(r'^\{"(\d+\.?\d*),', new_text)
|
|
if malformed:
|
|
time_val = malformed.group(1)
|
|
new_text = '{"Start":' + time_val + ',' + new_text[malformed.end():]
|
|
_log(f"[FIXED malformed JSON: added Start key]")
|
|
|
|
# Check for repetition in the combined text
|
|
full_text = accumulated_text + new_text
|
|
detector.text = full_text
|
|
is_looping, good_end = detector._check_repetition()
|
|
|
|
if is_looping:
|
|
_log(f"[LOOP DETECTED at char {good_end}]")
|
|
|
|
# Use what user has already seen as prefix for retry
|
|
# user_safe_printed_len is always at a segment boundary
|
|
if user_safe_printed_len > 0:
|
|
accumulated_text = full_text[:user_safe_printed_len]
|
|
_log(f"[RETRY from user-visible content at {user_safe_printed_len}]")
|
|
else:
|
|
# No complete segment shown to user yet - restart from scratch
|
|
accumulated_text = ""
|
|
_log(f"[NO CONTENT SHOWN TO USER - restart from scratch]")
|
|
|
|
detector.reset(accumulated_text)
|
|
is_recovery = True
|
|
|
|
if debug:
|
|
print("\n[...recovering...]", end='', flush=True, file=sys.stderr)
|
|
|
|
retry_count += 1
|
|
|
|
if retry_count > max_retries:
|
|
_log(f"[MAX RETRIES REACHED]")
|
|
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
|
|
return None
|
|
|
|
# Break inner loop to retry
|
|
break
|
|
else:
|
|
# No loop detected - stream content to user
|
|
# Only print up to (full_text_len - window_size) at segment boundaries
|
|
# This ensures user never sees content that might be rolled back
|
|
safe_end = max(0, len(full_text) - detector.window_size)
|
|
safe_boundary = _find_safe_print_boundary(full_text, safe_end)
|
|
|
|
if safe_boundary > user_safe_printed_len:
|
|
# Print new safe content
|
|
to_print = full_text[user_safe_printed_len:safe_boundary]
|
|
print(to_print, end='', flush=True)
|
|
user_safe_printed_len = safe_boundary
|
|
|
|
except json.JSONDecodeError:
|
|
continue
|
|
else:
|
|
# Loop completed without break (no repetition detected)
|
|
full_result = accumulated_text + new_text
|
|
|
|
# Print any remaining content that wasn't printed yet
|
|
if len(full_result) > user_safe_printed_len:
|
|
remaining = full_result[user_safe_printed_len:]
|
|
print(remaining, end='', flush=True)
|
|
|
|
print() # Final newline
|
|
return full_result
|
|
|
|
except requests.exceptions.Timeout:
|
|
_log("[TIMEOUT]")
|
|
print()
|
|
return accumulated_text
|
|
except Exception as e:
|
|
_log(f"[ERROR: {e}]")
|
|
print()
|
|
return accumulated_text
|
|
|
|
# All retries exhausted
|
|
print("\n[Error] Transcription failed due to model output anomaly. Please try another audio or contact support.", flush=True)
|
|
return None
|
|
|
|
|
|
def test_transcription_with_recovery():
|
|
"""Main test function with auto-recovery."""
|
|
|
|
# Parse arguments
|
|
debug = "--debug" in sys.argv or "-debug" in sys.argv
|
|
args = [a for a in sys.argv[1:] if not a.startswith("-")]
|
|
|
|
audio_path = (
|
|
args[0]
|
|
)
|
|
|
|
output_path = args[1] if len(args) > 1 else None
|
|
|
|
print(f"Loading audio from: {audio_path}")
|
|
|
|
# Handle video files: extract audio first
|
|
temp_audio_path = None
|
|
actual_audio_path = audio_path
|
|
if _is_video_file(audio_path):
|
|
print(f"Detected video file, extracting audio...")
|
|
temp_audio_path = _extract_audio_from_video(audio_path)
|
|
actual_audio_path = temp_audio_path
|
|
print(f"Audio extracted to: {temp_audio_path}")
|
|
|
|
try:
|
|
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
|
print(f"Audio duration: {duration:.2f} seconds")
|
|
|
|
with open(actual_audio_path, "rb") as f:
|
|
audio_bytes = f.read()
|
|
|
|
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
|
print(f"Audio size: {len(audio_bytes)} bytes")
|
|
|
|
except Exception as e:
|
|
print(f"Error preparing audio: {e}")
|
|
return
|
|
|
|
url = "http://localhost:8000/v1/chat/completions"
|
|
|
|
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
|
prompt_text = (
|
|
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
|
+ ", ".join(show_keys)
|
|
)
|
|
|
|
mime = _guess_mime_type(actual_audio_path)
|
|
data_url = f"data:{mime};base64,{audio_b64}"
|
|
|
|
# Base messages (without assistant continuation)
|
|
base_messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "audio_url", "audio_url": {"url": data_url}},
|
|
{"type": "text", "text": prompt_text}
|
|
]
|
|
}
|
|
]
|
|
|
|
print(f"\nSending request to {url} (Streaming Mode)...")
|
|
print(f"Prompt: {prompt_text}")
|
|
print("-" * 60)
|
|
print("Response received. Streaming content:\n")
|
|
|
|
t0 = time.time()
|
|
|
|
result = stream_with_recovery(
|
|
url=url,
|
|
base_messages=base_messages,
|
|
audio_data_url=data_url,
|
|
prompt_text=prompt_text,
|
|
max_tokens=32768,
|
|
max_retries=3,
|
|
debug=debug,
|
|
)
|
|
|
|
print("\n[Finished]")
|
|
print("-" * 60)
|
|
print(f"Total time elapsed: {time.time() - t0:.2f}s")
|
|
|
|
if result is None:
|
|
print("Transcription failed")
|
|
return
|
|
|
|
print(f"Final output length: {len(result)} chars")
|
|
|
|
# Optionally save result
|
|
if output_path:
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
|
f.write(result)
|
|
print(f"Result saved to: {output_path}")
|
|
|
|
# Cleanup temp audio file if created
|
|
if temp_audio_path and os.path.exists(temp_audio_path):
|
|
os.remove(temp_audio_path)
|
|
print(f"Cleaned up temp file: {temp_audio_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_transcription_with_recovery()
|