#!/usr/bin/env python """ VibeVoice ASR Gradio Demo This demo uses the vLLM API server instead of loading the model directly. Supports concurrent requests (non-blocking) and streaming output. Usage: python gradio_asr_demo_api.py --api_url http://localhost:8000 """ import os import sys import io import json import time import base64 import asyncio import tempfile import argparse import threading import subprocess import traceback import shutil import uuid from datetime import datetime from pathlib import Path from typing import List, Dict, Tuple, Optional, Generator from concurrent.futures import ThreadPoolExecutor, as_completed import httpx import numpy as np import soundfile as sf import gradio as gr from typing import AsyncGenerator # Try to import pydub for MP3 conversion try: from pydub import AudioSegment HAS_PYDUB = True except ImportError: HAS_PYDUB = False print("â ī¸ Warning: pydub not available, falling back to WAV format") # Common audio extensions supported COMMON_AUDIO_EXTS = { '.wav', '.mp3', '.flac', '.ogg', '.opus', '.m4a', '.aac', '.wma', '.aiff', '.aif' } # Common video extensions supported COMMON_VIDEO_EXTS = { '.mp4', '.webm', '.mov', '.avi', '.mkv', '.flv', '.wmv', '.m4v', '.mpeg', '.mpg', '.3gp', '.ts' } # Default max video size in MB DEFAULT_MAX_VIDEO_SIZE_MB = 50 # Default directory to save uploaded files DEFAULT_UPLOAD_SAVE_DIR = "local/from_custom" # Custom temporary directory CUSTOM_TEMP_DIR = os.environ.get("VIBEVOICE_TEMP_DIR", "/tmp/vibevoice_demo") os.makedirs(CUSTOM_TEMP_DIR, exist_ok=True) # ============================================================================ # Cloudflared Tunnel Support # ============================================================================ CLOUDFLARED_PATH = os.path.expanduser("~/.local/bin/cloudflared") def download_cloudflared(): """Download cloudflared binary if not exists""" if os.path.exists(CLOUDFLARED_PATH): return True print("đĨ Downloading cloudflared...") os.makedirs(os.path.dirname(CLOUDFLARED_PATH), exist_ok=True) download_url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64" try: subprocess.run( ["wget", "-q", download_url, "-O", CLOUDFLARED_PATH], check=True, timeout=120 ) os.chmod(CLOUDFLARED_PATH, 0o755) print("â cloudflared downloaded successfully") return True except Exception as e: print(f"â Failed to download cloudflared: {e}") return False def start_cloudflared_tunnel(port: int): """Start cloudflared tunnel and return the process""" if not download_cloudflared(): print("â Cannot start cloudflared tunnel") return None print(f"đ Starting cloudflared tunnel for port {port}...") process = subprocess.Popen( [CLOUDFLARED_PATH, "tunnel", "--url", f"http://localhost:{port}"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True ) # Read output in background to find the URL def read_output(): for line in process.stdout: print(f"[cloudflared] {line.strip()}") thread = threading.Thread(target=read_output, daemon=True) thread.start() # Give it a moment to start time.sleep(3) return process # ============================================================================ # Audio Utilities # ============================================================================ def _guess_mime_type(path: str) -> str: """Guess MIME type from file extension.""" ext = os.path.splitext(path)[1].lower() mime_map = { ".wav": "audio/wav", ".mp3": "audio/mpeg", ".m4a": "audio/mp4", ".mp4": "video/mp4", ".m4v": "video/mp4", ".mov": "video/mp4", ".webm": "video/webm", ".flac": "audio/flac", ".ogg": "audio/ogg", ".opus": "audio/ogg", ".aac": "audio/aac", } return mime_map.get(ext, "application/octet-stream") def _get_duration_seconds_ffprobe(path: str) -> float: """Get audio duration using ffprobe.""" cmd = [ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", path, ] try: out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() return float(out) except Exception: # Fallback: try soundfile try: info = sf.info(path) return info.duration except Exception: return 0.0 def load_audio_ffmpeg(path: str, target_sr: int = None) -> Tuple[np.ndarray, int]: """Load audio file using ffmpeg for better format support.""" try: # Use soundfile first (faster for supported formats) audio_data, sr = sf.read(path, dtype='float32') # Debug: log audio info print(f"[DEBUG] soundfile loaded: shape={audio_data.shape}, sr={sr}, dtype={audio_data.dtype}") print(f"[DEBUG] audio range: min={audio_data.min():.6f}, max={audio_data.max():.6f}") # Convert to mono if multi-channel if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) # Convert to mono print(f"[DEBUG] Converted to mono: shape={audio_data.shape}") # Ensure data is in [-1, 1] range (soundfile should do this, but verify) max_val = max(abs(audio_data.max()), abs(audio_data.min())) if max_val > 10.0: # Likely int16 or other integer format, normalize it audio_data = audio_data / max_val print(f"[DEBUG] Normalized audio (int format detected), original max_val={max_val}") elif max_val > 1.0: # Float format with slight overflow, just clip it audio_data = np.clip(audio_data, -1.0, 1.0) print(f"[DEBUG] Clipped audio (slight overflow), original max_val={max_val}") # Check for silent audio if audio_data.max() == 0 and audio_data.min() == 0: print(f"[WARNING] Audio appears to be completely silent!") return audio_data, sr except Exception as e: print(f"[DEBUG] soundfile failed: {e}, trying ffmpeg...") # Fallback to ffmpeg try: target_sr = target_sr or 16000 cmd = [ "ffmpeg", "-i", path, "-f", "f32le", "-acodec", "pcm_f32le", "-ac", "1", "-ar", str(target_sr), "-" ] process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL ) audio_bytes, _ = process.communicate() audio_data = np.frombuffer(audio_bytes, dtype=np.float32) print(f"[DEBUG] ffmpeg loaded: shape={audio_data.shape}, sr={target_sr}") print(f"[DEBUG] audio range: min={audio_data.min():.6f}, max={audio_data.max():.6f}") # Check for silent audio if len(audio_data) == 0: raise RuntimeError("ffmpeg returned empty audio data") if audio_data.max() == 0 and audio_data.min() == 0: print(f"[WARNING] Audio appears to be completely silent!") return audio_data, target_sr except Exception as e: raise RuntimeError(f"Failed to load audio: {e}") def get_file_size_mb(file_path: str) -> float: """Get file size in MB.""" try: return os.path.getsize(file_path) / (1024 * 1024) except Exception: return 0.0 def is_video_file(file_path: str) -> bool: """Check if the file is a video file based on extension.""" ext = os.path.splitext(file_path)[1].lower() return ext in COMMON_VIDEO_EXTS def format_srt_time(seconds: float) -> str: """Convert seconds to SRT time format (HH:MM:SS,mmm).""" hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) secs = int(seconds % 60) millis = int((seconds - int(seconds)) * 1000) return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" def segments_to_srt(segments: List[Dict]) -> str: """ Convert ASR segments to SRT subtitle format. Args: segments: List of segment dictionaries with Start, End, Content keys Returns: SRT formatted string """ srt_lines = [] for i, seg in enumerate(segments, 1): start = seg.get('Start', seg.get('start', seg.get('Start time', 0))) end = seg.get('End', seg.get('end', seg.get('End time', 0))) content = seg.get('Content', seg.get('content', seg.get('text', ''))) speaker = seg.get('Speaker', seg.get('speaker', seg.get('Speaker ID', None))) if start is None or end is None: continue start_time = format_srt_time(float(start)) end_time = format_srt_time(float(end)) # Add speaker prefix if available text = f"[Speaker {speaker}] {content}" if speaker is not None else content srt_lines.append(f"{i}") srt_lines.append(f"{start_time} --> {end_time}") srt_lines.append(text) srt_lines.append("") # Empty line between entries return "\n".join(srt_lines) def segments_to_vtt(segments: List[Dict]) -> str: """ Convert ASR segments to WebVTT subtitle format (for HTML5 video). Args: segments: List of segment dictionaries with Start, End, Content keys Returns: WebVTT formatted string """ vtt_lines = ["WEBVTT", ""] for i, seg in enumerate(segments, 1): start = seg.get('Start', seg.get('start', seg.get('Start time', 0))) end = seg.get('End', seg.get('end', seg.get('End time', 0))) content = seg.get('Content', seg.get('content', seg.get('text', ''))) speaker = seg.get('Speaker', seg.get('speaker', seg.get('Speaker ID', None))) if start is None or end is None: continue # WebVTT uses HH:MM:SS.mmm format (dot instead of comma) start_time = format_srt_time(float(start)).replace(',', '.') end_time = format_srt_time(float(end)).replace(',', '.') # Add speaker prefix if available text = f"[Speaker {speaker}] {content}" if speaker is not None else content vtt_lines.append(f"{i}") vtt_lines.append(f"{start_time} --> {end_time}") vtt_lines.append(text) vtt_lines.append("") # Empty line between entries return "\n".join(vtt_lines) # Audio formats that need conversion (browsers and some APIs may not support them directly) AUDIO_FORMATS_NEED_CONVERSION = {'.opus', '.ogg', '.flac', '.aiff', '.aif', '.wma'} # Audio formats that can be used directly AUDIO_FORMATS_DIRECT = {'.wav', '.mp3', '.m4a', '.aac'} def convert_audio_to_mp3( audio_path: str, output_path: Optional[str] = None, sample_rate: int = 16000, bitrate: str = "128k" ) -> Tuple[Optional[str], Optional[str]]: """ Convert audio file to MP3 format using ffmpeg. This is useful for converting formats like opus, ogg, flac that may not be well-supported by browsers or some APIs. Args: audio_path: Path to the input audio file output_path: Optional output path. If None, creates a temp file. sample_rate: Target sample rate bitrate: Audio bitrate (e.g., '128k') Returns: Tuple of (mp3_path, error_message) - If successful: (mp3_path, None) - If failed: (None, error_message) """ try: if output_path is None: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3", dir=CUSTOM_TEMP_DIR) temp_file.close() output_path = temp_file.name cmd = [ "ffmpeg", "-y", # Overwrite output file "-i", audio_path, "-acodec", "libmp3lame", # MP3 codec "-ar", str(sample_rate), # Sample rate "-ac", "1", # Mono "-b:a", bitrate, # Audio bitrate output_path ] result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=300 # 5 minutes timeout ) if result.returncode != 0: error_msg = result.stderr.decode('utf-8', errors='ignore') if os.path.exists(output_path): os.unlink(output_path) return None, f"ffmpeg error: {error_msg[:500]}" if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: return None, "Failed to convert audio: output file is empty" return output_path, None except subprocess.TimeoutExpired: return None, "Audio conversion timed out (>5 minutes)" except Exception as e: return None, f"Error converting audio: {str(e)}" def extract_audio_from_video( video_path: str, output_path: Optional[str] = None, sample_rate: int = 16000, output_format: str = "mp3", bitrate: str = "128k" ) -> Tuple[Optional[str], Optional[str]]: """ Extract audio from video file using ffmpeg. Args: video_path: Path to the video file output_path: Optional output path for extracted audio. If None, creates a temp file. sample_rate: Target sample rate for extracted audio output_format: Output audio format ('mp3' or 'wav') bitrate: Audio bitrate for mp3 (e.g., '128k') Returns: Tuple of (audio_path, error_message) - If successful: (audio_path, None) - If failed: (None, error_message) """ try: if output_path is None: suffix = f".{output_format}" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=CUSTOM_TEMP_DIR) temp_file.close() output_path = temp_file.name # Use ffmpeg to extract audio if output_format == "mp3": cmd = [ "ffmpeg", "-y", # Overwrite output file "-i", video_path, "-vn", # No video "-acodec", "libmp3lame", # MP3 codec "-ar", str(sample_rate), # Sample rate "-ac", "1", # Mono "-b:a", bitrate, # Audio bitrate output_path ] else: cmd = [ "ffmpeg", "-y", # Overwrite output file "-i", video_path, "-vn", # No video "-acodec", "pcm_s16le", # PCM 16-bit "-ar", str(sample_rate), # Sample rate "-ac", "1", # Mono output_path ] result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=300 # 5 minutes timeout ) if result.returncode != 0: error_msg = result.stderr.decode('utf-8', errors='ignore') # Clean up temp file on error if os.path.exists(output_path): os.unlink(output_path) return None, f"ffmpeg error: {error_msg[:500]}" # Verify the output file exists and has content if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: return None, "Failed to extract audio: output file is empty" return output_path, None except subprocess.TimeoutExpired: return None, "Audio extraction timed out (>5 minutes)" except Exception as e: return None, f"Error extracting audio: {str(e)}" def convert_video_to_mp4( video_path: str, output_path: Optional[str] = None, height: int = 480, crf: int = 28, fps: int = 30 ) -> Tuple[Optional[str], Optional[str]]: """ Convert video to MP4 format with compression (480p by default). Args: video_path: Path to the input video file (e.g., WebM) output_path: Optional output path. If None, creates a temp file. height: Target video height (width auto-scaled to maintain aspect ratio) crf: Constant Rate Factor for compression (18-28 recommended, higher = smaller file) fps: Target frame rate (default 30fps) Returns: Tuple of (mp4_path, error_message) - If successful: (mp4_path, None) - If failed: (None, error_message) """ try: if output_path is None: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", dir=CUSTOM_TEMP_DIR) temp_file.close() output_path = temp_file.name # Use ffmpeg to convert to MP4 with H.264 codec # Scale to target height while maintaining aspect ratio (-2 ensures even dimensions) cmd = [ "ffmpeg", "-y", # Overwrite output file "-i", video_path, "-vf", f"scale=-2:{height},fps={fps}", # Scale to 480p height + set fps "-c:v", "libx264", # H.264 video codec "-preset", "fast", # Encoding speed/compression tradeoff "-crf", str(crf), # Quality (lower = better, 18-28 typical) "-c:a", "aac", # AAC audio codec "-b:a", "128k", # Audio bitrate "-movflags", "+faststart", # Enable streaming output_path ] result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=600 # 10 minutes timeout ) if result.returncode != 0: error_msg = result.stderr.decode('utf-8', errors='ignore') # Clean up temp file on error if os.path.exists(output_path): os.unlink(output_path) return None, f"ffmpeg error: {error_msg[:500]}" # Verify the output file exists and has content if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: return None, "Failed to convert video: output file is empty" return output_path, None except subprocess.TimeoutExpired: return None, "Video conversion timed out (>10 minutes)" except Exception as e: return None, f"Error converting video: {str(e)}" def parse_time_to_seconds(val: Optional[str]) -> Optional[float]: """Parse seconds or hh:mm:ss to float seconds.""" if val is None: return None val = val.strip() if not val: return None try: return float(val) except ValueError: pass if ":" in val: parts = val.split(":") if not all(p.strip().replace(".", "", 1).isdigit() for p in parts): return None parts = [float(p) for p in parts] if len(parts) == 3: h, m, s = parts elif len(parts) == 2: h = 0 m, s = parts else: return None return h * 3600 + m * 60 + s return None def slice_audio_to_temp( audio_path: str, start_sec: Optional[float], end_sec: Optional[float] ) -> Tuple[Optional[str], Optional[str]]: """Slice audio to [start_sec, end_sec) and write to a temp WAV file.""" try: audio_data, sample_rate = load_audio_ffmpeg(audio_path) n_samples = len(audio_data) full_duration = n_samples / float(sample_rate) start = 0.0 if start_sec is None else max(0.0, start_sec) end = full_duration if end_sec is None else min(full_duration, end_sec) if end <= start: return None, f"Invalid time range: start={start}, end={end}" start_idx = int(start * sample_rate) end_idx = int(end * sample_rate) segment = audio_data[start_idx:end_idx] temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir=CUSTOM_TEMP_DIR) temp_file.close() # Use 32767.0 instead of 32768.0 to avoid potential overflow segment_int16 = (segment * 32767.0).astype(np.int16) sf.write(temp_file.name, segment_int16, sample_rate, subtype='PCM_16') return temp_file.name, None except Exception as e: return None, f"Error slicing audio: {e}" def clip_and_encode_audio( audio_data: np.ndarray, sr: int, start_time: float, end_time: float, segment_idx: int, use_mp3: bool = True, target_sr: int = 16000, mp3_bitrate: str = "32k" ) -> Tuple[int, Optional[str], Optional[str]]: """Clip audio segment and encode to base64.""" try: start_sample = int(start_time * sr) end_sample = int(end_time * sr) start_sample = max(0, start_sample) end_sample = min(len(audio_data), end_sample) if start_sample >= end_sample: return segment_idx, None, f"Invalid segment range: {start_time}-{end_time}" segment_data = audio_data[start_sample:end_sample] # Resample if needed if sr != target_sr and target_sr < sr: import scipy.signal num_samples = int(len(segment_data) * target_sr / sr) segment_data = scipy.signal.resample(segment_data, num_samples) sr = target_sr # Use 32767.0 instead of 32768.0 to avoid potential overflow segment_data_int16 = (segment_data * 32767.0).astype(np.int16) # Convert to MP3 if pydub is available if use_mp3 and HAS_PYDUB: try: wav_buffer = io.BytesIO() sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') wav_buffer.seek(0) audio_segment = AudioSegment.from_wav(wav_buffer) mp3_buffer = io.BytesIO() audio_segment.export(mp3_buffer, format='mp3', bitrate=mp3_bitrate) mp3_buffer.seek(0) audio_bytes = mp3_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') return segment_idx, f"data:audio/mpeg;base64,{audio_base64}", None except Exception: pass # Fallback to WAV wav_buffer = io.BytesIO() sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') wav_buffer.seek(0) audio_bytes = wav_buffer.read() audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') return segment_idx, f"data:audio/wav;base64,{audio_base64}", None except Exception as e: return segment_idx, None, f"Error: {str(e)}" def extract_audio_segments(audio_path: str, segments: List[Dict]) -> List[Tuple[str, str, Optional[str]]]: """Extract multiple segments from audio file with parallel processing.""" try: audio_data, sr = load_audio_ffmpeg(audio_path) tasks = [] use_mp3 = HAS_PYDUB for i, seg in enumerate(segments): start_time = seg.get('Start', seg.get('start', seg.get('Start time', 0))) end_time = seg.get('End', seg.get('end', seg.get('End time', 0))) if start_time is not None and end_time is not None: tasks.append((audio_data, sr, float(start_time), float(end_time), i, use_mp3)) results = [] max_workers = os.cpu_count() or 4 with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(clip_and_encode_audio, *task): task[4] for task in tasks } for future in as_completed(futures): results.append(future.result()) results.sort(key=lambda x: x[0]) audio_segments = [] for i, (idx, audio_src, error_msg) in enumerate(results): if idx < len(segments): seg = segments[idx] label = f"Segment {idx + 1}" audio_segments.append((label, audio_src, error_msg)) return audio_segments except Exception as e: print(f"Error loading audio file: {e}") return [] # ============================================================================ # API Client # ============================================================================ class VibeVoiceAPIClient: """Client for VibeVoice vLLM API.""" def __init__(self, api_url: str = "http://localhost:8000", model_name: str = None): self.api_url = api_url.rstrip("/") self._model_name = model_name # User-specified model name (can be None for auto-detect) self._available_models: List[str] = [] # Cached available models self.endpoint = f"{self.api_url}/v1/chat/completions" @property def model_name(self) -> str: """Get the model name (auto-detected if not specified).""" if self._model_name: return self._model_name if self._available_models: return self._available_models[0] return "vibevoice" # Fallback default @model_name.setter def model_name(self, value: str): """Set the model name.""" self._model_name = value def get_available_models_sync(self) -> List[str]: """Fetch available models from vLLM API (synchronous).""" try: response = httpx.get(f"{self.api_url}/v1/models", timeout=10) if response.status_code == 200: data = response.json() models = [m.get('id') for m in data.get('data', []) if m.get('id')] self._available_models = models print(f"đ Available models: {models}") return models return [] except Exception as e: print(f"â ī¸ Failed to fetch models: {e}") return [] async def get_available_models(self) -> List[str]: """Fetch available models from vLLM API (async).""" try: async with httpx.AsyncClient() as client: response = await client.get(f"{self.api_url}/v1/models", timeout=10) if response.status_code == 200: data = response.json() models = [m.get('id') for m in data.get('data', []) if m.get('id')] self._available_models = models return models return [] except Exception as e: print(f"â ī¸ Failed to fetch models: {e}") return [] async def check_health(self) -> Tuple[bool, str]: """Check if the API server is healthy.""" try: async with httpx.AsyncClient() as client: response = await client.get(f"{self.api_url}/health", timeout=5) if response.status_code == 200: return True, "API server is healthy" return False, f"API returned status {response.status_code}" except httpx.ConnectError: return False, "Cannot connect to API server" except Exception as e: return False, f"Health check failed: {e}" def check_health_sync(self) -> Tuple[bool, str]: """Synchronous health check for startup.""" try: response = httpx.get(f"{self.api_url}/health", timeout=5) if response.status_code == 200: return True, "API server is healthy" return False, f"API returned status {response.status_code}" except httpx.ConnectError: return False, "Cannot connect to API server" except Exception as e: return False, f"Health check failed: {e}" async def transcribe_streaming( self, audio_path: str, max_tokens: int = 4096, temperature: float = 0.0, top_p: float = 1.0, context_info: str = None, timeout: int = 1200, ) -> AsyncGenerator[Tuple[str, Optional[Dict]], None]: """ Transcribe audio using streaming API (async version). Yields: Tuple of (accumulated_text, final_result_or_none) """ # Get audio duration duration = _get_duration_seconds_ffprobe(audio_path) # Read and encode audio with open(audio_path, "rb") as f: audio_bytes = f.read() audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") # Build prompt show_keys = ["Start", "End", "Speaker", "Content"] prompt_text = ( f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys) ) # Add context info if provided if context_info and context_info.strip(): prompt_text += f"\n\nContext information (hotwords, speaker names, etc.):\n{context_info.strip()}" # Build request payload mime = _guess_mime_type(audio_path) data_url = f"data:{mime};base64,{audio_b64}" payload = { "model": self.model_name, "messages": [ { "role": "system", "content": "You are a helpful assistant that transcribes audio input into text output in JSON format." }, { "role": "user", "content": [ {"type": "audio_url", "audio_url": {"url": data_url}}, {"type": "text", "text": prompt_text} ] } ], "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "stream": True, "stream_options": {"include_usage": True}, # Enable token statistics } # Send request with streaming using async httpx try: async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client: async with client.stream( "POST", self.endpoint, json=payload, ) as response: if response.status_code != 200: error_text = await response.aread() error_msg = f"API error: {response.status_code} - {error_text.decode()}" yield error_msg, {"error": error_msg} return accumulated_text = "" usage_info = None stopped = False # Declare global at function level for proper access global stop_generation_flag # Process streaming lines with periodic stop flag checking async for line in response.aiter_lines(): # Check stop flag after each line if stop_generation_flag: stopped = True print("[INFO] Stop flag detected, breaking out of stream...") break if line: if line.startswith("data: "): json_str = line[6:] if json_str.strip() == "[DONE]": break try: data = json.loads(json_str) # Check for usage info (sent in final chunk) if 'usage' in data and data['usage']: usage_info = data['usage'] if 'choices' in data and data['choices']: delta = data['choices'][0].get('delta', {}) content = delta.get('content', '') if content: # Handle incremental or full text if content.startswith(accumulated_text): accumulated_text = content else: accumulated_text += content yield accumulated_text, None except json.JSONDecodeError: pass # If stopped, try to close the response to stop receiving more data if stopped: try: await response.aclose() print("[INFO] Response closed after stop") except Exception: pass # Parse final result with partial parsing support segments, parse_warning = self._parse_segments(accumulated_text) if segments is None: segments = [] final_result = { "raw_text": accumulated_text, "segments": segments, "duration": duration, "usage": usage_info, # Include token statistics "stopped": stopped, # Whether generation was stopped by user "parse_warning": parse_warning, # Warning if partial parse } yield accumulated_text, final_result except httpx.TimeoutException: yield "Request timed out", {"error": "timeout"} except Exception as e: yield f"Error: {str(e)}", {"error": str(e)} def _parse_segments(self, raw_text: str) -> Tuple[Optional[List[Dict]], Optional[str]]: """ Parse segments from raw API response. Handles truncated responses by extracting complete segments. Returns: Tuple of (segments_list, warning_message) - If fully successful: (segments, None) - If partially successful: (segments, warning_message) - If failed: (None, error_message) """ if not raw_text: return None, "Empty response" # Try to find JSON array in the response text = raw_text.strip() # Try direct parse first try: result = json.loads(text) if isinstance(result, list): return result, None elif isinstance(result, dict) and "segments" in result: return result["segments"], None elif isinstance(result, dict): # Single segment return [result], None except json.JSONDecodeError: pass # Try to extract JSON array from text import re # Try to find array pattern array_match = re.search(r'\[[\s\S]*\]', text) if array_match: try: result = json.loads(array_match.group()) if isinstance(result, list): return result, None except json.JSONDecodeError: pass # Try to find object with segments obj_match = re.search(r'\{[\s\S]*"segments"[\s\S]*\}', text) if obj_match: try: result = json.loads(obj_match.group()) if "segments" in result: return result["segments"], None except json.JSONDecodeError: pass # ===== Handle truncated response ===== # Try to parse individual complete segments from truncated array segments = self._parse_truncated_segments(text) if segments: return segments, f"â ī¸ Partial parse: {len(segments)} segments recovered from truncated response" return None, "Cannot parse JSON from response" def _parse_truncated_segments(self, text: str) -> Optional[List[Dict]]: """ Parse complete segments from a truncated JSON array response. This handles cases where the response is cut off mid-segment. Strategy: 1. Find all complete JSON objects {...} that look like segments 2. Validate each has expected keys (Start, End, Content or Speaker) 3. Try to recover incomplete last segment (e.g., repetition truncation) """ # Check if text starts with array text = text.strip() if not text.startswith('['): # Try to find array start array_start = text.find('[') if array_start == -1: return None text = text[array_start:] # Find all complete JSON objects # Pattern: {...} that are properly closed segments = [] depth = 0 obj_start = -1 in_string = False escape_next = False for i, char in enumerate(text): if escape_next: escape_next = False continue if char == '\\': escape_next = True continue if char == '"' and not escape_next: in_string = not in_string continue if in_string: continue if char == '{': if depth == 0: obj_start = i depth += 1 elif char == '}': depth -= 1 if depth == 0 and obj_start != -1: # Found a complete object obj_str = text[obj_start:i+1] try: obj = json.loads(obj_str) # Validate it looks like a segment if self._is_valid_segment(obj): segments.append(obj) except json.JSONDecodeError: pass obj_start = -1 # Try to recover incomplete last segment (truncated due to repetition) if obj_start != -1: incomplete_text = text[obj_start:] recovered = self._recover_incomplete_segment(incomplete_text) if recovered and self._is_valid_segment(recovered): segments.append(recovered) return segments if segments else None def _recover_incomplete_segment(self, incomplete_text: str) -> Optional[Dict]: """ Try to recover an incomplete segment that was truncated. Handles cases like repetition where Content is cut off mid-string. Example input: {"Start":198.36,"End":206.86,"Speaker":0,"Content":"I'm not gonna do it, I'm not gonna do it, I'm not... """ import re # Try to extract available fields segment = {} # Extract Start start_match = re.search(r'"Start"\s*:\s*([0-9.]+)', incomplete_text) if start_match: segment['Start'] = float(start_match.group(1)) # Extract End end_match = re.search(r'"End"\s*:\s*([0-9.]+)', incomplete_text) if end_match: segment['End'] = float(end_match.group(1)) # Extract Speaker speaker_match = re.search(r'"Speaker"\s*:\s*([0-9]+)', incomplete_text) if speaker_match: segment['Speaker'] = int(speaker_match.group(1)) # Extract Content - handle truncated string content_match = re.search(r'"Content"\s*:\s*"', incomplete_text) if content_match: content_start = content_match.end() # Find the content, may be truncated content_text = incomplete_text[content_start:] # Check for repetition pattern and clean it cleaned_content = self._clean_repetition(content_text) if cleaned_content: segment['Content'] = cleaned_content segment['_truncated'] = True # Mark as recovered from truncation # Must have at least Start, End to be valid if 'Start' in segment and 'End' in segment: return segment return None def _clean_repetition(self, content: str) -> Optional[str]: """ Clean content from truncated string. For repetition cases, keep first 500 characters. """ # Remove trailing incomplete quote if any content = content.rstrip('"') if not content: return None # Keep first 500 characters for repetition cases if len(content) > 500: return content[:500] + "..." return content def _is_valid_segment(self, obj: Dict) -> bool: """ Check if a dict looks like a valid ASR segment. Must have Start, End, and either Content or Speaker. """ if not isinstance(obj, dict): return False # Check for time boundaries (various possible key names) has_start = any(k in obj for k in ['Start', 'start', 'Start time']) has_end = any(k in obj for k in ['End', 'end', 'End time']) if not (has_start and has_end): return False # Should have content or speaker has_content = any(k in obj for k in ['Content', 'content', 'text']) has_speaker = any(k in obj for k in ['Speaker', 'speaker', 'Speaker ID']) return has_content or has_speaker # ============================================================================ # Global State # ============================================================================ api_client: Optional[VibeVoiceAPIClient] = None # Global stop flag for generation stop_generation_flag = False # Event to signal stop for async operations stop_event: Optional[asyncio.Event] = None # ============================================================================ # Gradio Interface Functions # ============================================================================ async def transcribe_audio( media_input, max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, context_info: str = "", max_video_size_mb: float = DEFAULT_MAX_VIDEO_SIZE_MB ) -> AsyncGenerator[Tuple[str, str, Optional[str], Optional[str], Optional[str]], None]: """ Transcribe audio/video using API and return results (async streaming version). Args: media_input: Audio/Video file path or tuple (sample_rate, audio_data) for microphone max_new_tokens: Maximum tokens to generate temperature: Temperature for sampling top_p: Top-p for nucleus sampling do_sample: Whether to use sampling (affects temperature) context_info: Optional context information max_video_size_mb: Maximum video file size in MB Yields: Tuple of (raw_text, audio_segments_html, srt_content, video_path, vtt_content) """ global api_client, stop_generation_flag # Reset stop flag at the start of each transcription stop_generation_flag = False print("[INFO] Stop flag reset at transcribe_audio start") if api_client is None: yield "â API client not initialized. Please check API URL.", "", None, None, None return # Check API health (async) healthy, msg = await api_client.check_health() if not healthy: yield f"â API server not available: {msg}", "", None, None, None return if media_input is None: yield "â Please provide an audio or video input.", "", None, None, None return try: print("[INFO] Transcription requested via API") # Determine audio path and track if input is video audio_path = None original_video_path = None # Track original video for playback with subtitles temp_file_to_cleanup = None extracted_audio_to_cleanup = None # Track extracted audio from video is_video_input = False # Handle media input if isinstance(media_input, str): # Check if uploaded file is a video if is_video_file(media_input): is_video_input = True original_video_path = media_input # Keep video path for subtitle playback video_size = get_file_size_mb(media_input) print(f"[INFO] Uploaded video file size: {video_size:.2f} MB (limit: {max_video_size_mb} MB)") if video_size > max_video_size_mb: yield f"â Video file too large: {video_size:.2f} MB. Maximum allowed: {max_video_size_mb} MB", "", None, None, None return yield f"đŦ Extracting audio from video ({video_size:.2f} MB)...", "", None, None, None extracted_path, extract_error = extract_audio_from_video(media_input) if extract_error: yield f"â Failed to extract audio from video: {extract_error}", "", None, None, None return audio_path = extracted_path extracted_audio_to_cleanup = extracted_path print(f"[INFO] Extracted audio from video: {audio_path}") else: # Audio file audio_path = media_input print(f"[INFO] Using uploaded audio file: {audio_path}") elif isinstance(media_input, tuple): # Gradio microphone input: (sample_rate, audio_array) sample_rate, audio_array = media_input audio_array = np.array(audio_array, dtype=np.float32) # Debug: log input audio info print(f"[DEBUG] Microphone input: shape={audio_array.shape}, sr={sample_rate}") print(f"[DEBUG] Microphone audio range: min={audio_array.min():.6f}, max={audio_array.max():.6f}") # Normalize to [-1, 1] range properly max_val = max(abs(audio_array.max()), abs(audio_array.min())) if max_val > 10.0: # Data is likely int16 or similar, normalize it audio_array = audio_array / max_val print(f"[DEBUG] Normalized microphone audio (int format detected), original max_val={max_val}") elif max_val > 1.0: # Float format with slight overflow, just clip it audio_array = np.clip(audio_array, -1.0, 1.0) print(f"[DEBUG] Clipped microphone audio (slight overflow), original max_val={max_val}") # Check for silent audio if max_val == 0: print(f"[WARNING] Microphone audio appears to be completely silent!") temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir=CUSTOM_TEMP_DIR) temp_file.close() audio_int16 = (audio_array * 32767.0).astype(np.int16) sf.write(temp_file.name, audio_int16, sample_rate, subtype='PCM_16') audio_path = temp_file.name temp_file_to_cleanup = temp_file.name print(f"[INFO] Saved microphone input to: {audio_path}") # Final check - if we still don't have audio_path, something went wrong if audio_path is None: yield "â Invalid audio input format.", "", None, None, None return # Set temperature based on sampling mode actual_temp = temperature if do_sample else 0.0 # Start streaming transcription print("[INFO] Starting API transcription (streaming mode)") start_time = time.time() final_result = None token_count = 0 accumulated_text = "" # Track accumulated text for stop case async for text, result in api_client.transcribe_streaming( audio_path=audio_path, max_tokens=max_new_tokens, temperature=actual_temp, top_p=top_p, context_info=context_info, ): # Track accumulated text if text: accumulated_text = text # Check stop flag at higher level too (already declared global at function start) if stop_generation_flag: print("[INFO] Stop flag detected in transcribe_audio, breaking...") # Create a stopped result - parse whatever we have so far stopped_segments, stopped_warning = api_client._parse_segments(accumulated_text) if accumulated_text else ([], None) final_result = { "raw_text": accumulated_text or "", "segments": stopped_segments or [], "duration": 0, "usage": None, "stopped": True, "parse_warning": stopped_warning, } break if result is not None: final_result = result else: # Streaming update - format for readability token_count = len(text.split()) # Rough estimate formatted_text = text.replace('},', '},\n') yield f"đ Transcribing... ({token_count} tokens)\n---\n{formatted_text}", "", None, None, None generation_time = time.time() - start_time if final_result is None or "error" in final_result: error_msg = final_result.get("error", "Unknown error") if final_result else "No response" yield f"â Transcription failed: {error_msg}", "", None, None, None return # Check if stopped by user was_stopped = final_result.get('stopped', False) # Format final output with token statistics if was_stopped: raw_output = f"--- âšī¸ Transcription Stopped ---\n" else: raw_output = f"--- â Transcription Complete ---\n" raw_output += f"âąī¸ Time: {generation_time:.2f}s | đĩ Audio: {final_result.get('duration', 0):.2f}s\n" # Add token statistics if available usage = final_result.get('usage') if usage: prompt_tokens = usage.get('prompt_tokens', 0) completion_tokens = usage.get('completion_tokens', 0) total_tokens = usage.get('total_tokens', 0) tokens_per_sec = completion_tokens / generation_time if generation_time > 0 else 0 raw_output += f"đ Tokens: {prompt_tokens} (prompt) + {completion_tokens} (completion) = {total_tokens} (total)\n" raw_output += f"⥠Speed: {tokens_per_sec:.1f} tokens/s\n" # Add parse warning if partial parsing was used parse_warning = final_result.get('parse_warning') if parse_warning: raw_output += f"{parse_warning}\n" raw_output += f"---\n" formatted_raw_text = final_result['raw_text'].replace('},', '},\n') raw_output += formatted_raw_text # Generate audio segments HTML segments = final_result.get('segments', []) audio_segments_html = "" num_segments = len(segments) if segments: # Extract audio clips for each segment audio_clips = extract_audio_segments(audio_path, segments) # Calculate approximate total size total_duration = sum( (float(seg.get('End', seg.get('end', 0))) - float(seg.get('Start', seg.get('start', 0)))) for seg in segments if seg.get('End') is not None and seg.get('Start') is not None ) approx_size_kb = total_duration * 4 # ~4KB per second at 32kbps # Add CSS for theme-aware styling (matching original demo) theme_css = """ """ # Build HTML format_info = "MP3 32kbps 16kHz mono" if HAS_PYDUB else "WAV 16kHz" audio_segments_html = theme_css audio_segments_html += "
đĩ Click the play button to listen to each segment directly!
" for i, seg in enumerate(segments): start = seg.get('Start', seg.get('start', seg.get('Start time', None))) end = seg.get('End', seg.get('end', seg.get('End time', None))) speaker = seg.get('Speaker', seg.get('speaker', seg.get('Speaker ID', None))) content = seg.get('Content', seg.get('content', seg.get('text', ''))) # Format times nicely start_str = f"{float(start):.2f}" if start is not None else "N/A" end_str = f"{float(end):.2f}" if end is not None else "N/A" speaker_str = str(speaker) if speaker is not None else "N/A" # Get audio clip audio_html = "" error_html = "" if i < len(audio_clips): _, audio_src, error_msg = audio_clips[i] if audio_src: audio_type = 'audio/mp3' if 'audio/mp3' in audio_src or 'audio/mpeg' in audio_src else 'audio/wav' audio_html = f""" """ elif error_msg: error_html = f"""â No audio segments available.
This could happen if the model output doesn't contain valid time stamps.
đŦ No video input detected.
Upload a video file to see playback with subtitles.