diff --git a/Figures/VibeVoice_ASR_archi.png b/Figures/VibeVoice_ASR_archi.png new file mode 100644 index 0000000..0d24aff Binary files /dev/null and b/Figures/VibeVoice_ASR_archi.png differ diff --git a/README.md b/README.md index 4c4e68a..41a98de 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,9 @@ New Realtime TTS +2026-01-21: πŸ“£ We open-sourced VibeVoice-ASR, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context. -2025-12-16: πŸ“£ We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time. +2025-12-16: πŸ“£ We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time. 2025-12-09: πŸ“£ We added experimental speakers in nine languages (DE, FR, IT, JP, KR, NL, PL, PT, ES) for explorationβ€”welcome to try them out and share your feedback. @@ -123,4 +124,4 @@ We do not recommend using VibeVoice in commercial or real-world applications wit ## Star History -![Star History Chart](https://api.star-history.com/svg?repos=Microsoft/vibevoice&type=date&legend=top-left) \ No newline at end of file +![Star History Chart](https://api.star-history.com/svg?repos=Microsoft/vibevoice&type=date&legend=top-left) diff --git a/demo/vibevoice_asr_gradio_demo.py b/demo/vibevoice_asr_gradio_demo.py new file mode 100644 index 0000000..c334f95 --- /dev/null +++ b/demo/vibevoice_asr_gradio_demo.py @@ -0,0 +1,1177 @@ +#!/usr/bin/env python +""" +VibeVoice ASR Gradio Demo +""" + +import os +import sys +import torch +import numpy as np +import soundfile as sf +from pathlib import Path +import argparse +import time +import json +import gradio as gr +from typing import List, Dict, Tuple, Optional, Generator +import tempfile +import base64 +import io +import traceback +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +# Import TextIteratorStreamer for streaming generation +from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList + +try: + from liger_kernel.transformers import apply_liger_kernel_to_qwen2 + # Only apply RoPE, RMSNorm, SwiGLU patches (these affect the underlying Qwen2 layers) + apply_liger_kernel_to_qwen2( + rope=True, + rms_norm=True, + swiglu=True, + cross_entropy=False, + ) + print("βœ… Liger Kernel applied to Qwen2 components (RoPE, RMSNorm, SwiGLU)") +except Exception as e: + print(f"⚠️ Failed to apply Liger Kernel: {e}, you can install it with: pip install liger-kernel") + +# Try to import pydub for MP3 conversion +try: + from pydub import AudioSegment + HAS_PYDUB = True +except ImportError: + HAS_PYDUB = False + print("⚠️ Warning: pydub not available, falling back to WAV format") + +from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration +from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor +from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, COMMON_AUDIO_EXTS + + +class VibeVoiceASRInference: + """Simple inference wrapper for VibeVoice ASR model.""" + + def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2"): + """ + Initialize the ASR inference pipeline. + + Args: + model_path: Path to the pretrained model (HuggingFace format directory or model name) + device: Device to run inference on + dtype: Data type for model weights + attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') + """ + print(f"Loading VibeVoice ASR model from {model_path}") + + # Load processor + self.processor = VibeVoiceASRProcessor.from_pretrained(model_path) + + # Load model + print(f"Using attention implementation: {attn_implementation}") + self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( + model_path, + dtype=dtype, + device_map=device if device == "auto" else None, + attn_implementation=attn_implementation, + trust_remote_code=True + ) + + if device != "auto": + self.model = self.model.to(device) + + self.device = device if device != "auto" else next(self.model.parameters()).device + self.model.eval() + + # Print model info + total_params = sum(p.numel() for p in self.model.parameters()) + print(f"βœ… Model loaded successfully on {self.device}") + print(f"πŸ“Š Total parameters: {total_params:,} ({total_params/1e9:.2f}B)") + + def transcribe( + self, + audio_path: str = None, + audio_array: np.ndarray = None, + sample_rate: int = None, + max_new_tokens: int = 512, + temperature: float = 0.0, + top_p: float = 1.0, + do_sample: bool = False, + num_beams: int = 1, + repetition_penalty: float = 1.0, + context_info: str = None, + streamer: Optional[TextIteratorStreamer] = None, + ) -> dict: + """ + Transcribe audio to text. + + Args: + audio_path: Path to audio file + audio_array: Audio array (if not loading from file) + sample_rate: Sample rate of audio array + max_new_tokens: Maximum tokens to generate + temperature: Temperature for sampling (0 for greedy) + top_p: Top-p for nucleus sampling (1.0 for no filtering) + do_sample: Whether to use sampling + num_beams: Number of beams for beam search (1 for greedy) + repetition_penalty: Repetition penalty (1.0 for no penalty) + context_info: Optional context information (e.g., hotwords, speaker names, topics) to help transcription + streamer: Optional TextIteratorStreamer for streaming output + + Returns: + Dictionary with transcription results + """ + # Process audio + inputs = self.processor( + audio=audio_path, + sampling_rate=sample_rate, + return_tensors="pt", + add_generation_prompt=True, + context_info=context_info + ) + + # Move to device + inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in inputs.items()} + + # Generate + generation_config = { + "max_new_tokens": max_new_tokens, + "temperature": temperature if temperature > 0 else None, + "top_p": top_p if do_sample else None, + "do_sample": do_sample, + "num_beams": num_beams, + "repetition_penalty": repetition_penalty, + "pad_token_id": self.processor.pad_id, + "eos_token_id": self.processor.tokenizer.eos_token_id, + } + + # Add streamer if provided + if streamer is not None: + generation_config["streamer"] = streamer + + # Add stopping criteria for stop button support + generation_config["stopping_criteria"] = StoppingCriteriaList([StopOnFlag()]) + + # Remove None values + generation_config = {k: v for k, v in generation_config.items() if v is not None} + + start_time = time.time() + + # Calculate input token statistics before generation + input_ids = inputs['input_ids'][0] # Shape: [seq_len] + total_input_tokens = input_ids.shape[0] + + # Count padding tokens (tokens equal to pad_id) + pad_id = self.processor.pad_id + padding_mask = (input_ids == pad_id) + num_padding_tokens = padding_mask.sum().item() + + # Count speech tokens (tokens between speech_start_id and speech_end_id) + speech_start_id = self.processor.speech_start_id + speech_end_id = self.processor.speech_end_id + + # Find speech regions + input_ids_list = input_ids.tolist() + num_speech_tokens = 0 + in_speech = False + for token_id in input_ids_list: + if token_id == speech_start_id: + in_speech = True + num_speech_tokens += 1 # Count speech_start token + elif token_id == speech_end_id: + in_speech = False + num_speech_tokens += 1 # Count speech_end token + elif in_speech: + num_speech_tokens += 1 + + # Text tokens = total - speech - padding + num_text_tokens = total_input_tokens - num_speech_tokens - num_padding_tokens + + with torch.no_grad(): + output_ids = self.model.generate( + **inputs, + **generation_config + ) + + generation_time = time.time() - start_time + + # Decode output + generated_ids = output_ids[0, inputs['input_ids'].shape[1]:] + generated_text = self.processor.decode(generated_ids, skip_special_tokens=True) + + # Parse structured output + try: + transcription_segments = self.processor.post_process_transcription(generated_text) + except Exception as e: + print(f"Warning: Failed to parse structured output: {e}") + transcription_segments = [] + + return { + "raw_text": generated_text, + "segments": transcription_segments, + "generation_time": generation_time, + "input_tokens": { + "total": total_input_tokens, + "speech": num_speech_tokens, + "text": num_text_tokens, + "padding": num_padding_tokens, + }, + } + + +def clip_and_encode_audio( + audio_data: np.ndarray, + sr: int, + start_time: float, + end_time: float, + segment_idx: int, + use_mp3: bool = True, + target_sr: int = 16000, # Downsample to 16kHz for smaller size + mp3_bitrate: str = "32k" # Use low bitrate for minimal transfer +) -> Tuple[int, Optional[str], Optional[str]]: + """ + Clip audio segment and encode to base64. + + Args: + audio_data: Full audio array + sr: Sample rate + start_time: Start time in seconds + end_time: End time in seconds + segment_idx: Segment index for identification + use_mp3: Whether to use MP3 format (smaller size) + target_sr: Target sample rate for downsampling (lower = smaller) + mp3_bitrate: MP3 bitrate (lower = smaller, e.g., "24k", "32k", "48k") + + Returns: + Tuple of (segment_idx, base64_string, error_message) + """ + try: + # Convert time to sample indices + start_sample = int(start_time * sr) + end_sample = int(end_time * sr) + + # Ensure indices are within bounds + start_sample = max(0, start_sample) + end_sample = min(len(audio_data), end_sample) + + if start_sample >= end_sample: + return segment_idx, None, f"Invalid time range: [{start_time:.2f}s - {end_time:.2f}s]" + + # Extract segment + segment_data = audio_data[start_sample:end_sample] + + # Downsample if needed (reduces data size significantly) + if sr != target_sr and target_sr < sr: + # Simple downsampling using linear interpolation + duration = len(segment_data) / sr + new_length = int(duration * target_sr) + indices = np.linspace(0, len(segment_data) - 1, new_length) + segment_data = np.interp(indices, np.arange(len(segment_data)), segment_data) + sr = target_sr + + # Convert float32 audio to int16 for encoding + segment_data_int16 = (segment_data * 32768.0).astype(np.int16) + + # Convert to MP3 if pydub is available and use_mp3 is True + if use_mp3 and HAS_PYDUB: + try: + # Write to WAV in memory + wav_buffer = io.BytesIO() + sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') + wav_buffer.seek(0) + + # Convert to MP3 with low bitrate + audio_segment = AudioSegment.from_wav(wav_buffer) + # Convert to mono if stereo (halves the size) + if audio_segment.channels > 1: + audio_segment = audio_segment.set_channels(1) + mp3_buffer = io.BytesIO() + audio_segment.export(mp3_buffer, format='mp3', bitrate=mp3_bitrate) + mp3_buffer.seek(0) + + # Encode to base64 + audio_bytes = mp3_buffer.read() + audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_src = f"data:audio/mp3;base64,{audio_base64}" + + return segment_idx, audio_src, None + except Exception as e: + # Fall back to WAV on error + print(f"MP3 conversion failed for segment {segment_idx}, using WAV: {e}") + + # Fall back to WAV format (no temp file, use in-memory buffer) + wav_buffer = io.BytesIO() + sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16') + wav_buffer.seek(0) + + audio_bytes = wav_buffer.read() + audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') + audio_src = f"data:audio/wav;base64,{audio_base64}" + + return segment_idx, audio_src, None + + except Exception as e: + error_msg = f"Error clipping segment {segment_idx}: {str(e)}" + print(error_msg) + return segment_idx, None, error_msg + + +def extract_audio_segments(audio_path: str, segments: List[Dict]) -> List[Tuple[str, str, Optional[str]]]: + """ + Extract multiple segments from audio file efficiently with parallel processing. + + Args: + audio_path: Path to original audio file + segments: List of segment dictionaries with start_time, end_time, etc. + + Returns: + List of tuples (segment_label, audio_base64_src, error_msg) + """ + try: + # Read audio file once using ffmpeg for better format support + print(f"πŸ“‚ Loading audio file: {audio_path}") + audio_data, sr = load_audio_use_ffmpeg(audio_path, resample=False) + print(f"βœ… Audio loaded: {len(audio_data)} samples, {sr} Hz") + + # Prepare tasks + tasks = [] + use_mp3 = HAS_PYDUB # Use MP3 if available + + for i, seg in enumerate(segments): + start_time = seg.get('start_time') + end_time = seg.get('end_time') + + # Skip if times are not available or invalid + if (not isinstance(start_time, (int, float)) or + not isinstance(end_time, (int, float)) or + start_time >= end_time): + tasks.append((i, None, None, None, None, None)) # Will be filtered later + continue + + tasks.append((audio_data, sr, start_time, end_time, i, use_mp3)) + + # Process in parallel using ThreadPoolExecutor + results = [] + total_segments = len(tasks) + completed_count = 0 + + # Use CPU count for max workers + max_workers = os.cpu_count() or 4 + print(f"πŸš€ Starting parallel processing with {max_workers} threads...") + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {} + for task in tasks: + if task[0] is None: # Skip invalid tasks + continue + future = executor.submit(clip_and_encode_audio, *task) + futures[future] = task[4] # segment_idx + + for future in as_completed(futures): + try: + result = future.result() + results.append(result) + completed_count += 1 + # Log progress every 100 segments or at completion + if completed_count % 100 == 0 or completed_count == len(futures): + print(f"Progress: {completed_count}/{len(futures)} segments processed ({completed_count*100//len(futures)}%)") + except Exception as e: + idx = futures[future] + results.append((idx, None, f"Processing error: {str(e)}")) + completed_count += 1 + print(f"Error on segment {idx}: {e}") + + print(f"βœ… Completed processing all {len(futures)} valid segments") + + # Sort by segment index to maintain order + results.sort(key=lambda x: x[0]) + + # Build output list with labels + audio_segments = [] + for i, (idx, audio_src, error_msg) in enumerate(results): + seg = segments[idx] if idx < len(segments) else {} + start_time = seg.get('start_time', 'N/A') + end_time = seg.get('end_time', 'N/A') + speaker_id = seg.get('speaker_id', 'N/A') + + segment_label = f"Segment {idx+1}: [{start_time:.2f}s - {end_time:.2f}s] Speaker {speaker_id}" + audio_segments.append((segment_label, audio_src, error_msg)) + + return audio_segments + + except Exception as e: + print(f"Error loading audio file: {e}") + return [] + + +# Global variable to store the ASR model +asr_model = None + +# Global stop flag for generation +stop_generation_flag = False + + +class StopOnFlag(StoppingCriteria): + """Custom stopping criteria that checks a global flag.""" + def __call__(self, input_ids, scores, **kwargs): + global stop_generation_flag + return stop_generation_flag + + +def parse_time_to_seconds(val: Optional[str]) -> Optional[float]: + """Parse seconds or hh:mm:ss to float seconds.""" + if val is None: + return None + val = val.strip() + if not val: + return None + try: + return float(val) + except ValueError: + pass + if ":" in val: + parts = val.split(":") + if not all(p.strip().replace(".", "", 1).isdigit() for p in parts): + return None + parts = [float(p) for p in parts] + if len(parts) == 3: + h, m, s = parts + elif len(parts) == 2: + h = 0 + m, s = parts + else: + return None + return h * 3600 + m * 60 + s + return None + + +def slice_audio_to_temp( + audio_data: np.ndarray, + sample_rate: int, + start_sec: Optional[float], + end_sec: Optional[float] +) -> Tuple[Optional[str], Optional[str]]: + """Slice audio_data to [start_sec, end_sec) and write to a temp WAV file.""" + n_samples = len(audio_data) + full_duration = n_samples / float(sample_rate) + start = 0.0 if start_sec is None else max(0.0, start_sec) + end = full_duration if end_sec is None else min(full_duration, end_sec) + if end <= start: + return None, f"Invalid time range: start={start:.2f}s, end={end:.2f}s" + start_idx = int(start * sample_rate) + end_idx = int(end * sample_rate) + segment = audio_data[start_idx:end_idx] + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + temp_file.close() + segment_int16 = (segment * 32768.0).astype(np.int16) + sf.write(temp_file.name, segment_int16, sample_rate, subtype='PCM_16') + return temp_file.name, None + + +def initialize_model(model_path: str, device: str = "cuda", attn_implementation: str = "flash_attention_2"): + """Initialize the ASR model.""" + global asr_model + try: + dtype = torch.bfloat16 if device != "cpu" else torch.float32 + asr_model = VibeVoiceASRInference( + model_path=model_path, + device=device, + dtype=dtype, + attn_implementation=attn_implementation + ) + return f"βœ… Model loaded successfully from {model_path}" + except Exception as e: + import traceback + traceback.print_exc() + return f"❌ Error loading model: {str(e)}" + + +def transcribe_audio( + audio_input, + audio_path_input: str, + start_time_input: str, + end_time_input: str, + max_new_tokens: int, + temperature: float, + top_p: float, + do_sample: bool, + repetition_penalty: float = 1.0, + context_info: str = "" +) -> Generator[Tuple[str, str], None, None]: + """ + Transcribe audio and return results with audio segments (streaming version). + + Args: + audio_input: Audio file path or tuple (sample_rate, audio_data) + max_new_tokens: Maximum tokens to generate + temperature: Temperature for sampling (0 for greedy) + top_p: Top-p for nucleus sampling + do_sample: Whether to use sampling + context_info: Optional context information (e.g., hotwords, speaker names, topics) + + Yields: + Tuple of (raw_text, audio_segments_html) + """ + if asr_model is None: + yield "❌ Please load a model first!", "" + return + + if not audio_path_input and audio_input is None: + yield "❌ Please provide audio input!", "" + return + + try: + print("[INFO] Transcription requested") + start_sec = parse_time_to_seconds(start_time_input) + end_sec = parse_time_to_seconds(end_time_input) + print(f"[INFO] Parsed time range: start={start_sec}, end={end_sec}") + if (start_time_input and start_sec is None) or (end_time_input and end_sec is None): + yield "❌ Invalid time format. Use seconds or hh:mm:ss.", "" + return + + audio_path = None + audio_array = None + sample_rate = None + + if audio_path_input: + candidate = Path(audio_path_input.strip()) + if not candidate.exists(): + yield f"❌ Provided path does not exist: {candidate}", "" + return + audio_path = str(candidate) + print(f"[INFO] Using provided audio path: {audio_path}") + # Get audio file path (Gradio Audio component returns tuple (sample_rate, audio_data) or file path) + elif isinstance(audio_input, str): + audio_path = audio_input + print(f"[INFO] Using uploaded audio path: {audio_path}") + elif isinstance(audio_input, tuple): + # Audio from microphone: (sample_rate, audio_data) + sample_rate, audio_array = audio_input + print(f"[INFO] Received microphone audio with sample_rate={sample_rate}") + elif audio_path is None: + yield "❌ Invalid audio input format!", "" + return + + # If slicing is requested, load and slice audio + if start_sec is not None or end_sec is not None: + print("[INFO] Slicing audio per requested time range") + if audio_array is None or sample_rate is None: + try: + audio_array, sample_rate = load_audio_use_ffmpeg(audio_path, resample=False) + print("[INFO] Loaded audio for slicing via ffmpeg") + except Exception as exc: + yield f"❌ Failed to load audio for slicing: {exc}", "" + return + sliced_path, err = slice_audio_to_temp(audio_array, sample_rate, start_sec, end_sec) + if err: + yield f"❌ {err}", "" + return + audio_path = sliced_path + print(f"[INFO] Sliced audio written to temp file: {audio_path}") + elif audio_array is not None and sample_rate is not None: + # no slicing but microphone input: write to temp file + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") + audio_path = temp_file.name + temp_file.close() + audio_data_int16 = (audio_array * 32768.0).astype(np.int16) + sf.write(audio_path, audio_data_int16, sample_rate, subtype='PCM_16') + print(f"[INFO] Microphone audio saved to temp file: {audio_path}") + + # Create streamer for real-time output + streamer = TextIteratorStreamer( + asr_model.processor.tokenizer, + skip_prompt=True, + skip_special_tokens=True + ) + + # Store result in a mutable container for the thread + result_container = {"result": None, "error": None} + + def run_transcription(): + try: + result_container["result"] = asr_model.transcribe( + audio_path=audio_path, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=do_sample, + repetition_penalty=repetition_penalty, + context_info=context_info if context_info and context_info.strip() else None, + streamer=streamer + ) + except Exception as e: + result_container["error"] = str(e) + traceback.print_exc() + + # Start transcription in background thread + print("[INFO] Starting model transcription (streaming mode)") + start_time = time.time() + transcription_thread = threading.Thread(target=run_transcription) + transcription_thread.start() + + # Yield streaming output + generated_text = "" + token_count = 0 + for new_text in streamer: + generated_text += new_text + token_count += 1 + elapsed = time.time() - start_time + # Show streaming output with live stats, format for readability + formatted_text = generated_text.replace('},', '},\n') + streaming_output = f"--- πŸ”΄ LIVE Streaming Output (tokens: {token_count}, time: {elapsed:.1f}s) ---\n{formatted_text}" + yield streaming_output, "
⏳ Generating transcription... Audio segments will appear after completion.
" + + # Wait for thread to complete + transcription_thread.join() + + if result_container["error"]: + yield f"❌ Error during transcription: {result_container['error']}", "" + return + + result = result_container["result"] + generation_time = time.time() - start_time + + # Get input token statistics + input_tokens = result.get('input_tokens', {}) + speech_tokens = input_tokens.get('speech', 0) + text_tokens = input_tokens.get('text', 0) + padding_tokens = input_tokens.get('padding', 0) + total_input = input_tokens.get('total', 0) + + # Format final raw output with input/output token stats + raw_output = f"--- βœ… Raw Output ---\n" + raw_output += f"πŸ“₯ Input: {total_input} tokens (🎀 speech: {speech_tokens}, πŸ“ text: {text_tokens}, ⬜ pad: {padding_tokens})\n" + raw_output += f"πŸ“€ Output: {token_count} tokens | ⏱️ Time: {generation_time:.2f}s\n" + raw_output += f"---\n" + # Format raw text for better readability: add newline after each dict (},) + formatted_raw_text = result['raw_text'].replace('},', '},\n') + raw_output += formatted_raw_text + + # Debug: print raw output to console + print(f"[DEBUG] Raw model output:") + print(f"[DEBUG] {result['raw_text']}") + print(f"[DEBUG] Found {len(result['segments'])} segments") + + # Create audio segments with server-side encoding (low quality for minimal transfer) + # Using: 16kHz mono MP3 @ 32kbps = ~4KB per second of audio + audio_segments_html = "" + segments = result['segments'] + + if segments: + num_segments = len(segments) + print(f"[INFO] Creating per-segment audio clips ({num_segments} segments, 16kHz mono MP3 @ 32kbps)") + + # Extract all audio segments efficiently (load audio only once) + audio_segments = extract_audio_segments(audio_path, segments) + print("[INFO] Completed creating audio clips") + + # Calculate approximate total size + total_duration = sum( + (seg.get('end_time', 0) - seg.get('start_time', 0)) + for seg in segments + if isinstance(seg.get('start_time'), (int, float)) and isinstance(seg.get('end_time'), (int, float)) + ) + approx_size_kb = total_duration * 4 # ~4KB per second at 32kbps + + # Add CSS for theme-aware styling + theme_css = """ + + """ + + audio_segments_html = theme_css + audio_segments_html += f"
" + + # Add format info + format_info = "MP3 32kbps 16kHz mono" if HAS_PYDUB else "WAV 16kHz" + audio_segments_html += f"

πŸ”Š Audio Segments ({num_segments} segments)" + audio_segments_html += f"πŸ“¦ ~{approx_size_kb:.0f}KB ({format_info})

" + audio_segments_html += "

🎡 Click the play button to listen to each segment directly!

" + + for i, (label, audio_src, error_msg) in enumerate(audio_segments): + seg = segments[i] if i < len(segments) else {} + start_time = seg.get('start_time', 'N/A') + end_time = seg.get('end_time', 'N/A') + speaker_id = seg.get('speaker_id', 'N/A') + content = seg.get('text', '') + + # Format times nicely + start_str = f"{start_time:.2f}" if isinstance(start_time, (int, float)) else str(start_time) + end_str = f"{end_time:.2f}" if isinstance(end_time, (int, float)) else str(end_time) + + audio_segments_html += f""" +
+
+

Segment {i+1}

+
+ Time: [{start_str}s - {end_str}s] | + Speaker: {speaker_id} +
+
+ +
+ {content} +
+ """ + + if audio_src: + # Detect format from data URI + audio_type = 'audio/mp3' if 'audio/mp3' in audio_src else 'audio/wav' + audio_segments_html += f""" + + """ + elif error_msg: + audio_segments_html += f""" +
+ ❌ {error_msg} +
+ """ + else: + audio_segments_html += """ +
+ Audio playback unavailable for this segment +
+ """ + + audio_segments_html += "
" + + audio_segments_html += "
" + else: + audio_segments_html = """ + +
+

❌ No audio segments available.

+

This could happen if the model output doesn't contain valid time stamps.

+
+ """ + + # Final yield with complete results + yield raw_output, audio_segments_html + + except Exception as e: + print(f"Error during transcription: {e}") + print(traceback.format_exc()) + yield f"❌ Error during transcription: {str(e)}", "" + + +def create_gradio_interface(model_path: str, default_max_tokens: int = 8192, attn_implementation: str = "flash_attention_2"): + """Create and launch Gradio interface. + + Args: + model_path: Path to the model (HuggingFace format directory or model name) + default_max_tokens: Default value for max_new_tokens slider + attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') + """ + + # Initialize model at startup + device = "cuda" if torch.cuda.is_available() else "cpu" + model_status = initialize_model(model_path, device, attn_implementation) + print(model_status) + + # Exit if model loading failed + if model_status.startswith("❌"): + print("\n" + "="*80) + print("πŸ’₯ FATAL ERROR: Model loading failed!") + print("="*80) + print("Cannot start demo without a valid model. Please check:") + print(" 1. Model path is correct") + print(" 2. Model files are not corrupted") + print(" 3. You have enough GPU memory") + print(" 4. CUDA is properly installed (if using GPU)") + print("="*80) + sys.exit(1) + + # Custom CSS for Stop button styling + custom_css = """ + #stop-btn { + background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important; + border: none !important; + color: white !important; + } + #stop-btn:hover { + background: linear-gradient(135deg, #dc2626 0%, #b91c1c 100%) !important; + } + """ + + # Gradio 6.0+ moved theme/css to launch() + with gr.Blocks(title="VibeVoice ASR Demo") as demo: + gr.Markdown("# πŸŽ™οΈ VibeVoice ASR Demo") + gr.Markdown("Upload audio files or record from microphone to get speech-to-text transcription with speaker diarization.") + gr.Markdown(f"**Model loaded from:** `{model_path}`") + + with gr.Row(): + with gr.Column(scale=1): + # Generation parameters + gr.Markdown("## βš™οΈ Generation Parameters") + max_tokens_slider = gr.Slider( + minimum=4096, + maximum=65536, + value=default_max_tokens, + step=4096, + label="Max New Tokens" + ) + + # Sampling parameters + gr.Markdown("### 🎲 Sampling") + do_sample_checkbox = gr.Checkbox( + value=False, + label="Enable Sampling", + info="Enable random sampling instead of deterministic decoding" + ) + + with gr.Column(visible=False) as sampling_params: + temperature_slider = gr.Slider( + minimum=0.0, + maximum=2.0, + value=0.0, + step=0.1, + label="Temperature", + info="0 = greedy, higher = more random" + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.05, + label="Top-p (Nucleus Sampling)", + info="1.0 = no filtering" + ) + + # Repetition penalty (works with both greedy and sampling) + repetition_penalty_slider = gr.Slider( + minimum=1.0, + maximum=1.2, + value=1.0, + step=0.01, + label="Repetition Penalty", + info="1.0 = no penalty, higher = less repetition (works with greedy & sampling)" + ) + + # Context information section + gr.Markdown("## πŸ“‹ Context Info (Optional)") + context_info_input = gr.Textbox( + label="Context Information", + placeholder="Enter hotwords, speaker names, topics, or other context to help transcription...\nExample:\nJohn Smith\nMachine Learning\nOpenAI", + lines=4, + max_lines=8, + interactive=True, + info="Provide context like proper nouns, technical terms, or speaker names to improve accuracy" + ) + + with gr.Column(scale=2): + # Audio input section + gr.Markdown("## 🎡 Audio Input") + audio_input = gr.Audio( + label="Upload Audio File or Record from Microphone", + sources=["upload", "microphone"], + type="filepath", + interactive=True, + buttons=["download"] + ) + + with gr.Accordion("πŸ“‚ Advanced: Remote Path & Time Slicing", open=False): + audio_path_input = gr.Textbox( + label="Audio path (optional)", + placeholder="Enter remote audio file path", + lines=1 + ) + with gr.Row(): + start_time_input = gr.Textbox( + label="Start time", + placeholder="e.g., 0 or 00:00:00", + lines=1, + info="Leave empty to start from the beginning" + ) + end_time_input = gr.Textbox( + label="End time", + placeholder="e.g., 30.5 or 00:00:30.5", + lines=1, + info="Leave empty to use full length" + ) + + with gr.Row(): + transcribe_button = gr.Button("🎯 Transcribe", variant="primary", size="lg", scale=3) + stop_button = gr.Button("⏹️ Stop", variant="secondary", size="lg", scale=1, elem_id="stop-btn") + + # Results section + gr.Markdown("## πŸ“ Results") + + with gr.Tabs(): + with gr.TabItem("Raw Output"): + raw_output = gr.Textbox( + label="Raw Transcription Output", + lines=8, + max_lines=20, + interactive=False + ) + + with gr.TabItem("Audio Segments"): + audio_segments_output = gr.HTML( + label="Play individual segments to verify accuracy" + ) + + # Event handlers + do_sample_checkbox.change( + fn=lambda x: gr.update(visible=x), + inputs=[do_sample_checkbox], + outputs=[sampling_params] + ) + + def reset_stop_flag(): + """Reset stop flag before starting transcription.""" + global stop_generation_flag + stop_generation_flag = False + + def set_stop_flag(): + """Set stop flag to interrupt generation.""" + global stop_generation_flag + stop_generation_flag = True + return "⏹️ Stop requested..." + + transcribe_button.click( + fn=reset_stop_flag, + inputs=[], + outputs=[], + queue=False + ).then( + fn=transcribe_audio, + inputs=[ + audio_input, + audio_path_input, + start_time_input, + end_time_input, + max_tokens_slider, + temperature_slider, + top_p_slider, + do_sample_checkbox, + repetition_penalty_slider, + context_info_input + ], + outputs=[raw_output, audio_segments_output] + ) + + stop_button.click( + fn=set_stop_flag, + inputs=[], + outputs=[raw_output], + queue=False + ) + + # Add examples + gr.Markdown("## πŸ“‹ Instructions") + gr.Markdown(f""" + 1. **Upload Audio**: Use the audio component to upload a file or record from microphone + - **Supported formats**: {', '.join(sorted(set([ext.lower() for ext in COMMON_AUDIO_EXTS])))} + - Optionally set **Start/End time** (seconds or hh:mm:ss) to clip before transcription + 2. **Context Info (Optional)**: Provide context to improve transcription accuracy + - Add hotwords, proper nouns, speaker names, or technical terms + - One item per line or comma-separated + - Examples: "John Smith", "OpenAI", "machine learning" + 3. **Adjust Parameters**: Configure generation parameters as needed + 4. **Transcribe**: Click "Transcribe" to get results + 5. **Review Results**: + - **Raw Output**: View the model's original output + - **Audio Segments**: Play individual segments directly to verify accuracy + + **Audio Segments**: Each segment shows the time range, speaker ID, transcribed content, and an embedded audio player for immediate verification. + """) + + return demo, custom_css + + +def main(): + parser = argparse.ArgumentParser(description="VibeVoice ASR Gradio Demo") + parser.add_argument( + "--model_path", + type=str, + default="microsoft/VibeVoice-ASR", + help="Path to the model (HuggingFace format directory or model name)" + ) + parser.add_argument( + "--attn_implementation", + type=str, + default="flash_attention_2", + help="Attention implementation to use (default: flash_attention_2)" + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=32768, + help="Default max new tokens for generation (default: 32768)" + ) + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to bind the server to" + ) + parser.add_argument( + "--port", + type=int, + default=7860, + help="Port to bind the server to" + ) + parser.add_argument( + "--share", + action="store_true", + help="Create a public link" + ) + + args = parser.parse_args() + + # Create and launch interface + demo, custom_css = create_gradio_interface( + model_path=args.model_path, + default_max_tokens=args.max_new_tokens, + attn_implementation=args.attn_implementation + ) + + print(f"πŸš€ Starting VibeVoice ASR Demo...") + print(f"πŸ“ Server will be available at: http://{args.host}:{args.port}") + + # Gradio 6.0+ moved theme/css to launch() + launch_kwargs = { + "server_name": args.host, + "server_port": args.port, + "share": args.share, + "show_error": True, + "theme": gr.themes.Soft(), + "css": custom_css, + } + + # Enable queue for concurrent request handling + demo.queue(default_concurrency_limit=3) + demo.launch(**launch_kwargs) + + +if __name__ == "__main__": + main() diff --git a/demo/vibevoice_asr_inference_from_file.py b/demo/vibevoice_asr_inference_from_file.py new file mode 100644 index 0000000..e7d95d4 --- /dev/null +++ b/demo/vibevoice_asr_inference_from_file.py @@ -0,0 +1,554 @@ +#!/usr/bin/env python +""" +VibeVoice ASR Batch Inference Demo Script + +This script supports batch inference for ASR model and compares results +between batch processing and single-sample processing. +""" + +import os +import sys +import torch +import numpy as np +from pathlib import Path +import argparse +import time +import json +import re +from typing import List, Dict, Any, Optional +from functools import wraps + +from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration +from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor + + +class VibeVoiceASRBatchInference: + """Batch inference wrapper for VibeVoice ASR model.""" + + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + attn_implementation: str = "flash_attention_2" + ): + """ + Initialize the ASR batch inference pipeline. + + Args: + model_path: Path to the pretrained model + device: Device to run inference on + dtype: Data type for model weights + attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') + """ + print(f"Loading VibeVoice ASR model from {model_path}") + + # Load processor + self.processor = VibeVoiceASRProcessor.from_pretrained( + model_path, + language_model_pretrained_name="Qwen/Qwen2.5-7B" + ) + + # Load model with specified attention implementation + print(f"Using attention implementation: {attn_implementation}") + self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( + model_path, + dtype=dtype, + device_map=device if device == "auto" else None, + attn_implementation=attn_implementation, + trust_remote_code=True + ) + + if device != "auto": + self.model = self.model.to(device) + + self.device = device if device != "auto" else next(self.model.parameters()).device + self.dtype = dtype + self.model.eval() + + print(f"Model loaded successfully on {self.device}") + + def _prepare_generation_config( + self, + max_new_tokens: int = 512, + temperature: float = 0.0, + top_p: float = 0.9, + do_sample: bool = True, + num_beams: int = 1, + ) -> dict: + """Prepare generation configuration.""" + config = { + "max_new_tokens": max_new_tokens, + "pad_token_id": self.processor.pad_id, + "eos_token_id": self.processor.tokenizer.eos_token_id, + } + + # Beam search vs sampling + if num_beams > 1: + config["num_beams"] = num_beams + config["do_sample"] = False # Beam search doesn't use sampling + else: + config["do_sample"] = do_sample + # Only set temperature and top_p when sampling is enabled + if do_sample: + config["temperature"] = temperature + config["top_p"] = top_p + + return config + + def transcribe_batch( + self, + audio_inputs: List, + max_new_tokens: int = 512, + temperature: float = 0.0, + top_p: float = 1.0, + do_sample: bool = True, + num_beams: int = 1, + ) -> List[Dict[str, Any]]: + """ + Transcribe multiple audio files/arrays in a single batch. + + Args: + audio_inputs: List of audio file paths or (array, sampling_rate) tuples + max_new_tokens: Maximum tokens to generate + temperature: Temperature for sampling + top_p: Top-p for nucleus sampling + do_sample: Whether to use sampling + + Returns: + List of transcription results + """ + if len(audio_inputs) == 0: + return [] + + batch_size = len(audio_inputs) + print(f"\nProcessing batch of {batch_size} audio(s)...") + + # Process all audio together + inputs = self.processor( + audio=audio_inputs, + sampling_rate=None, + return_tensors="pt", + padding=True, + add_generation_prompt=True + ) + + # Move to device + inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v + for k, v in inputs.items()} + + # Print batch info + print(f" Input IDs shape: {inputs['input_ids'].shape}") + print(f" Speech tensors shape: {inputs['speech_tensors'].shape}") + print(f" Attention mask shape: {inputs['attention_mask'].shape}") + + # Generate + generation_config = self._prepare_generation_config( + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=do_sample, + num_beams=num_beams, + ) + + start_time = time.time() + + with torch.no_grad(): + output_ids = self.model.generate( + **inputs, + **generation_config + ) + + generation_time = time.time() - start_time + + # Decode outputs for each sample in the batch + results = [] + input_length = inputs['input_ids'].shape[1] + + for i, audio_input in enumerate(audio_inputs): + # Get generated tokens for this sample (excluding input tokens) + generated_ids = output_ids[i, input_length:] + + # Remove padding tokens from the end + # Find the first eos_token or pad_token + eos_positions = (generated_ids == self.processor.tokenizer.eos_token_id).nonzero(as_tuple=True)[0] + if len(eos_positions) > 0: + generated_ids = generated_ids[:eos_positions[0] + 1] + + generated_text = self.processor.decode(generated_ids, skip_special_tokens=True) + + # Parse structured output + try: + transcription_segments = self.processor.post_process_transcription(generated_text) + except Exception as e: + print(f"Warning: Failed to parse structured output: {e}") + transcription_segments = [] + + # Get file name based on input type + if isinstance(audio_input, str): + file_name = audio_input + elif isinstance(audio_input, dict) and 'id' in audio_input: + file_name = audio_input['id'] + else: + file_name = f"audio_{i}" + + results.append({ + "file": file_name, + "raw_text": generated_text, + "segments": transcription_segments, + "generation_time": generation_time / batch_size, + }) + + print(f" Total generation time: {generation_time:.2f}s") + print(f" Average time per sample: {generation_time/batch_size:.2f}s") + + return results + + def transcribe_with_batching( + self, + audio_inputs: List, + batch_size: int = 4, + max_new_tokens: int = 512, + temperature: float = 0.0, + top_p: float = 1.0, + do_sample: bool = True, + num_beams: int = 1, + ) -> List[Dict[str, Any]]: + """ + Transcribe multiple audio files/arrays with automatic batching. + + Args: + audio_inputs: List of audio file paths or (array, sampling_rate) tuples + batch_size: Number of samples per batch + max_new_tokens: Maximum tokens to generate + temperature: Temperature for sampling + top_p: Top-p for nucleus sampling + do_sample: Whether to use sampling + + Returns: + List of transcription results + """ + all_results = [] + + # Process in batches + for i in range(0, len(audio_inputs), batch_size): + batch_inputs = audio_inputs[i:i + batch_size] + print(f"\n{'='*60}") + print(f"Processing batch {i//batch_size + 1}/{(len(audio_inputs) + batch_size - 1)//batch_size}") + + batch_results = self.transcribe_batch( + batch_inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=do_sample, + num_beams=num_beams, + ) + all_results.extend(batch_results) + + return all_results + + +def print_result(result: Dict[str, Any]): + """Pretty print a single transcription result.""" + print(f"\nFile: {result['file']}") + print(f"Generation Time: {result['generation_time']:.2f}s") + print(f"\n--- Raw Output ---") + print(result['raw_text'][:500] + "..." if len(result['raw_text']) > 500 else result['raw_text']) + + if result['segments']: + print(f"\n--- Structured Output ({len(result['segments'])} segments) ---") + for seg in result['segments'][:50]: # Show first 50 segments + print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] " + f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')}...") + if len(result['segments']) > 50: + print(f" ... and {len(result['segments']) - 50} more segments") + + +def load_dataset_and_concatenate( + dataset_name: str, + split: str, + max_duration: float, + num_audios: int, + target_sr: int = 24000 +) -> Optional[List[np.ndarray]]: + """ + Load a HuggingFace dataset and concatenate audio samples into long audio chunks. + (Note, just for demo purpose, not for benchmark evaluation) + + Args: + dataset_name: HuggingFace dataset name (e.g., 'openslr/librispeech_asr') + split: Dataset split to use (e.g., 'test', 'test.other') + max_duration: Maximum duration in seconds for each concatenated audio + num_audios: Number of concatenated audios to create + target_sr: Target sample rate (default: 24000) + + Returns: + List of concatenated audio arrays, or None if loading fails + """ + try: + from datasets import load_dataset + import torchcodec # just for decode audio in datasets + except ImportError: + print("Please install it with: pip install datasets torchcodec") + return None + + print(f"\nLoading dataset: {dataset_name} (split: {split})") + print(f"Will create {num_audios} concatenated audio(s), each up to {max_duration:.1f}s ({max_duration/3600:.2f} hours)") + + try: + # Use streaming to avoid downloading the entire dataset + dataset = load_dataset(dataset_name, split=split, streaming=True) + print(f"Dataset loaded in streaming mode") + + concatenated_audios = [] # List of concatenated audio metadata + + # Create multiple concatenated audios based on num_audios + current_chunks = [] + current_duration = 0.0 + current_samples_used = 0 + sample_idx = 0 + + for sample in dataset: + if len(concatenated_audios) >= num_audios: + break + + if 'audio' not in sample: + continue + + audio_data = sample['audio'] + audio_array = audio_data['array'] + sr = audio_data['sampling_rate'] + + # Resample if needed + if sr != target_sr: + duration = len(audio_array) / sr + new_length = int(duration * target_sr) + audio_array = np.interp( + np.linspace(0, len(audio_array) - 1, new_length), + np.arange(len(audio_array)), + audio_array + ) + + chunk_duration = len(audio_array) / target_sr + + # Check if adding this chunk exceeds max_duration + if current_duration + chunk_duration > max_duration: + remaining_duration = max_duration - current_duration + if remaining_duration > 0.5: # Only add if > 0.5s remaining + samples_to_take = int(remaining_duration * target_sr) + current_chunks.append(audio_array[:samples_to_take]) + current_duration += remaining_duration + current_samples_used += 1 + + # Save current concatenated audio and start a new one + if current_chunks: + concatenated_audios.append({ + 'array': np.concatenate(current_chunks), + 'duration': current_duration, + 'samples_used': current_samples_used, + }) + print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples") + + # Reset for next concatenated audio + current_chunks = [] + current_duration = 0.0 + current_samples_used = 0 + + if len(concatenated_audios) >= num_audios: + break + + current_chunks.append(audio_array) + current_duration += chunk_duration + current_samples_used += 1 + + sample_idx += 1 + if sample_idx % 100 == 0: + print(f" Processed {sample_idx} samples...") + + # Don't forget the last batch if it has content + if current_chunks and len(concatenated_audios) < num_audios: + concatenated_audios.append({ + 'array': np.concatenate(current_chunks), + 'duration': current_duration, + 'samples_used': current_samples_used, + }) + print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples") + + if not concatenated_audios: + print("Warning: No audio samples found in dataset") + return None + + # Extract arrays and print summary + result = [a['array'] for a in concatenated_audios] + total_duration = sum(a['duration'] for a in concatenated_audios) + total_samples = sum(a['samples_used'] for a in concatenated_audios) + print(f"\nCreated {len(result)} concatenated audio(s), total {total_duration:.1f}s ({total_duration/60:.1f} min) from {total_samples} samples") + + return result + + except Exception as e: + print(f"Error loading dataset: {e}") + import traceback + traceback.print_exc() + return None + + +def main(): + parser = argparse.ArgumentParser(description="VibeVoice ASR Batch Inference Demo") + parser.add_argument( + "--model_path", + type=str, + default="", + help="Path to the model checkpoint" + ) + parser.add_argument( + "--audio_files", + type=str, + nargs='+', + required=False, + help="Paths to audio files for transcription" + ) + parser.add_argument( + "--audio_dir", + type=str, + required=False, + help="Directory containing audio files for batch transcription" + ) + parser.add_argument( + "--dataset", + type=str, + required=False, + help="HuggingFace dataset name (e.g., 'openslr/librispeech_asr')" + ) + parser.add_argument( + "--split", + type=str, + default="test", + help="Dataset split to use (e.g., 'test', 'test.other', 'test.clean')" + ) + parser.add_argument( + "--max_duration", + type=float, + default=3600.0, + help="Maximum duration in seconds for concatenated dataset audio (default: 3600 = 1 hour)" + ) + parser.add_argument( + "--batch_size", + type=int, + default=2, + help="Batch size for processing multiple files" + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + choices=["cuda", "cpu", "auto"], + help="Device to run inference on" + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=32768, + help="Maximum number of tokens to generate" + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Temperature for sampling (0 = greedy decoding)" + ) + parser.add_argument( + "--top_p", + type=float, + default=1.0, + help="Top-p for nucleus sampling" + ) + parser.add_argument( + "--num_beams", + type=int, + default=1, + help="Number of beams for beam search. Use 1 for greedy/sampling" + ) + parser.add_argument( + "--attn_implementation", + type=str, + default="flash_attention_2", + help="Attention implementation to use (default: flash_attention_2)" + ) + + args = parser.parse_args() + + # Collect audio files + audio_files = [] + concatenated_audio = None # For storing concatenated dataset audio + + if args.audio_files: + audio_files.extend(args.audio_files) + + if args.audio_dir: + import glob + for ext in ["*.wav", "*.mp3", "*.flac", "*.mp4", "*.m4a", "*.webm"]: + audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext))) + + if args.dataset: + concatenated_audio = load_dataset_and_concatenate( + dataset_name=args.dataset, + split=args.split, + max_duration=args.max_duration, + num_audios=args.batch_size, + ) + if concatenated_audio is None: + return + + if len(audio_files) == 0 and concatenated_audio is None: + print("No audio files provided. Please specify --audio_files, --audio_dir, or --dataset.") + return + + if audio_files: + print(f"\nAudio files to process ({len(audio_files)}):") + for f in audio_files: + print(f" - {f}") + + if concatenated_audio: + print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)") + + # Initialize model + asr = VibeVoiceASRBatchInference( + model_path=args.model_path, + device=args.device, + dtype=torch.bfloat16 if args.device != "cpu" else torch.float32, + attn_implementation=args.attn_implementation + ) + + # If temperature is 0, use greedy decoding (no sampling) + do_sample = args.temperature > 0 + + # Combine all audio inputs + all_audio_inputs = audio_files + (concatenated_audio or []) + + print("\n" + "="*80) + print(f"Processing {len(all_audio_inputs)} audio(s)") + print("="*80) + + all_results = asr.transcribe_with_batching( + all_audio_inputs, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_p=args.top_p, + do_sample=do_sample, + num_beams=args.num_beams, + ) + + # Print results + print("\n" + "="*80) + print("Results") + print("="*80) + for result in all_results: + print("\n" + "-"*60) + print_result(result) + + +if __name__ == "__main__": + main() diff --git a/docs/vibevoice-asr.md b/docs/vibevoice-asr.md new file mode 100644 index 0000000..5b2b130 --- /dev/null +++ b/docs/vibevoice-asr.md @@ -0,0 +1,62 @@ +# VibeVoice-ASR: Long-Form Rich Transcription with User Prompts + +**VibeVoice-ASR** is the latest addition to the **VibeVoice** family. While the original VibeVoice / VibeVoice-Realtime focused on expressive TTS, **VibeVoice-ASR** focuses on understanding long-form speech with high precision and rich metadata. + +It is a unified speech-to-text model designed to handle **1-hour long-form audio** in a single pass, generating structured transcriptions containing **Who (Speaker), When (Timestamps), and What (Content)**, with support for **User-Customized Context**. + +## πŸ”₯ Key Features + +- **πŸ•’ 60-min Single-Pass Processing**: + Unlike conventional ASR models that slice audio into short chunks (often losing global context), VibeVoice ASR accepts up to **60 minutes** of continuous audio input within 64K length. This ensures consistent speaker tracking and semantic coherence across the entire hour. + +- **πŸ‘€ Optional Context Injection**: + Users can provide customized context (e.g., specific names, technical terms, or background info) to guide the recognition process, significantly improving accuracy on domain-specific content. + +- **πŸ“ Rich Transcription (Who, When, What)**: + The model performs ASR, Diarization, and Timestamping simultaneously. The output is a structured sequence indicating *who* said *what* at *which time*. + +## πŸ—οΈ Model Architecture + +

+ VibeVoice ASR Architecture +

+ +## Installation +We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment. + +1. Launch docker +```bash +# NVIDIA PyTorch Container 24.07 ~ 25.12 verified. +# Previous versions are also compatible. +sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulimit stack=-1:-1 --gpus all --rm -it nvcr.io/nvidia/pytorch:25.12-py3 + +## If flash attention is not included in your docker environment, you need to install it manually +## Refer to https://github.com/Dao-AILab/flash-attention for installation instructions +# pip install flash-attn --no-build-isolation +``` + +2. Install from github +```bash +git clone https://github.com/microsoft/VibeVoice.git +cd VibeVoice +pip install -e .[asr] +``` + +## Usages + +### Usage 1: Launch Gradio demo +```bash +apt update && apt install ffmpeg -y # for demo + +python demo/vibevoice_asr_gradio_demo.py --model_path microsoft/VibeVoice-ASR --share +``` + +### Usage 2: Inference from files directly +```bash +python demo/vibevoice_asr_inference_from_file.py --model_path microsoft/VibeVoice-ASR --audio_files [add a audio path here] +``` + + +## πŸ“„ License + +This project is licensed under the [MIT License](../LICENSE). diff --git a/pyproject.toml b/pyproject.toml index a36e9cd..6dc69b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta" [project] name = "vibevoice" -version = "0.0.1" +version = "1.0.0" authors = [ - { name="vibevoice team", email="vibepod@microsoft.com" }, + { name="vibevoice team", email="VibeVoice@microsoft.com" }, ] -description = "A model for speech generation with an AR + diffusion architecture." +description = "Open-Source Frontier Voice AI." readme = "README.md" requires-python = ">=3.9" classifiers = [ @@ -18,8 +18,7 @@ classifiers = [ ] dependencies = [ "torch", - "accelerate==1.6.0", - "transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible + "accelerate", "llvmlite>=0.40.0", "numba>=0.57.0", "diffusers", @@ -30,12 +29,21 @@ dependencies = [ "ml-collections", "absl-py", "gradio", - "av", - "aiortc", - "uvicorn[standard]", - "fastapi" ] +[project.optional-dependencies] +tts = [ + "transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible + "av", + "aiortc", + "uvicorn[standard]", + "fastapi" +] + +asr = [ + "transformers>=4.51.3", # the versions after 4.51.3 are all support + "pydub" # for visualization +] [project.urls] "Homepage" = "https://github.com/microsoft/VibeVoice" diff --git a/vibevoice/modular/configuration_vibevoice.py b/vibevoice/modular/configuration_vibevoice.py index fcffcb9..02b2751 100644 --- a/vibevoice/modular/configuration_vibevoice.py +++ b/vibevoice/modular/configuration_vibevoice.py @@ -240,9 +240,112 @@ class VibeVoiceConfig(PretrainedConfig): super().__init__(**kwargs) +class VibeVoiceASRConfig(PretrainedConfig): + model_type = "vibevoice" + is_composition = True + sub_configs = { + "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig, + "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig, + "decoder_config": Qwen2Config, + } + # keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + acoustic_tokenizer_config=None, + semantic_tokenizer_config=None, + decoder_config=None, + **kwargs + ): + + # kwargs["_attn_implementation"] = "flash_attention_2" + kwargs["_attn_implementation_autoset"] = False + + if acoustic_tokenizer_config is None: + self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]() + elif isinstance(acoustic_tokenizer_config, dict): + acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer" + self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config) + elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig): + # If an instance of the config class is provided + self.acoustic_tokenizer_config = acoustic_tokenizer_config + + if semantic_tokenizer_config is None: + self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]() + elif isinstance(semantic_tokenizer_config, dict): + semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer" + self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config) + elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig): + # If an instance of the config class is provided + self.semantic_tokenizer_config = semantic_tokenizer_config + + if decoder_config is None: + self.decoder_config = self.sub_configs["decoder_config"]() + elif isinstance(decoder_config, dict): + # If a dictionary is provided, instantiate the config class with it + # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config) + if decoder_config.get("model_type", '') == "qwen2": + self.decoder_config = Qwen2Config(**decoder_config) + else: + raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}") + elif isinstance(decoder_config, Qwen2Config): + # If an instance of the config class is provided + self.decoder_config = decoder_config + + # other parameters + self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64) + self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128) + + super().__init__(**kwargs) + + def get_text_config(self, decoder: bool = False): + """Return the text (decoder) config for generation.""" + return self.decoder_config + + @property + def vocab_size(self): + """Return vocab_size from decoder config for generation compatibility.""" + return self.decoder_config.vocab_size + + @property + def num_attention_heads(self): + """Return num_attention_heads from decoder config for Ulysses SP compatibility.""" + return self.decoder_config.num_attention_heads + + @property + def num_key_value_heads(self): + """Return num_key_value_heads from decoder config for Ulysses SP compatibility.""" + return self.decoder_config.num_key_value_heads + + @property + def hidden_size(self): + """Return hidden_size from decoder config for model compatibility.""" + return self.decoder_config.hidden_size + + @property + def num_hidden_layers(self): + """Return num_hidden_layers from decoder config for Ulysses SP compatibility.""" + return self.decoder_config.num_hidden_layers + + @property + def head_dim(self): + """Return head_dim from decoder config for Ulysses SP compatibility.""" + return getattr(self.decoder_config, 'head_dim', self.hidden_size // self.num_attention_heads) + __all__ = [ "VibeVoiceAcousticTokenizerConfig", "VibeVoiceSemanticTokenizerConfig", "VibeVoiceDiffusionHeadConfig", - "VibeVoiceConfig" + "VibeVoiceConfig", + "VibeVoiceASRConfig" ] \ No newline at end of file diff --git a/vibevoice/modular/modeling_vibevoice.py b/vibevoice/modular/modeling_vibevoice.py new file mode 100644 index 0000000..a4ecbab --- /dev/null +++ b/vibevoice/modular/modeling_vibevoice.py @@ -0,0 +1,496 @@ +# copied from https://github.com/vibevoice-community/VibeVoice/blob/main/vibevoice/modular/modeling_vibevoice.py +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union, Callable +from tqdm import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from transformers.models.auto import AutoModel, AutoModelForCausalLM + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.utils import logging + + +from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel +from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead +from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler + +from .configuration_vibevoice import VibeVoiceConfig + + +logger = logging.get_logger(__name__) + +if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] + +@dataclass +class VibeVoiceCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + diffusion_loss: Optional[torch.FloatTensor] = None + speech_token_num: Optional[int] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class VibeVoiceGenerationOutput(ModelOutput): + """ + Output type for VibeVoice generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. + speech_outputs (`List[torch.FloatTensor]`, *optional*): + List of generated speech waveforms or latents for each speech segment. + """ + sequences: torch.LongTensor = None + speech_outputs: Optional[List[torch.FloatTensor]] = None + + +class SpeechConnector(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.fc1 = nn.Linear(input_dim, output_dim) + self.norm = LlamaRMSNorm(output_dim, eps=1e-6) + self.fc2 = nn.Linear(output_dim, output_dim) + + def forward(self, features, **kwargs): + x = self.fc1(features) + x = self.norm(x) + x = self.fc2(x) + return x + + +# @auto_docstring +class VibeVoicePreTrainedModel(PreTrainedModel): + config_class = VibeVoiceConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + if isinstance(module, VibeVoiceDiffusionHead): + module.initialize_weights() + return + + # Use the language model's initializer_range if available + if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'): + std = self.config.language_model_config.initializer_range + elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'): + std = self.config.decoder_config.initializer_range + else: + std = 0.02 # Default value + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + +# @auto_docstring +class VibeVoiceModel(VibeVoicePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: + if isinstance(config.torch_dtype, str): + dtype = getattr(torch, config.torch_dtype) + else: + dtype = config.torch_dtype + else: + dtype = torch.float32 + + # Initialize Qwen2 model for language modeling + lm_config = config.decoder_config + self.language_model = AutoModel.from_config(lm_config) + + # Initialize speech components if needed + self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype) + self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype) + + self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype) + self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype) + + # Register scaling factors as buffers - use 1D tensors for FSDP compatibility + self.register_buffer('speech_scaling_factor', torch.tensor(float('nan'))) + self.register_buffer('speech_bias_factor', torch.tensor(float('nan'))) + + # Initialize prediction head for speech generation + self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype) + + # Initialize noise scheduler + self.noise_scheduler = DPMSolverMultistepScheduler( + num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, + beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, + prediction_type=config.diffusion_head_config.prediction_type + ) + + def get_input_embeddings(self): + if hasattr(self.language_model, 'embed_tokens'): + # If the language model has an embed_tokens attribute, return it + return self.language_model.embed_tokens + + for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed + if attr.orig_name == 'embed_tokens.weight': + return getattr(self.language_model, name) + assert False, 'should not arrive here' + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): + """Set the speech tokenizers used for encoding and decoding speech.""" + self.acoustic_tokenizer = acoustic_tokenizer + self.semantic_tokenizer = semantic_tokenizer + + # Reset the encoder to evaluation mode + if self.acoustic_tokenizer is not None: + self.acoustic_tokenizer.eval() + + if self.semantic_tokenizer is not None: + self.semantic_tokenizer.eval() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Forward through language model + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + if not return_dict: + return outputs + + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = VibeVoiceModel(config) + self.vocab_size = config.decoder_config.vocab_size + self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_decoder(self, decoder): + self.model.language_model = decoder + + def get_decoder(self): + return self.model.language_model + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + """ + if getattr(self.config.decoder_config, 'tie_word_embeddings', False): + # The standard PreTrainedModel method will handle the tying. + # It typically does a simple parameter object assignment, which is + # CORRECT to do BEFORE FSDP wraps the model. + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + if hasattr(input_embeddings, 'weight'): + output_embeddings.weight = input_embeddings.weight + else: + # maybe returned input_embeddings a tensor directly + output_embeddings.weight = input_embeddings + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = nn.functional.pad( + output_embeddings.bias.data, + (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]), + "constant", + 0, + ) + print("Tied input and output embeddings using standard assignment.") + else: + print("tie_word_embeddings is False, not tying weights.") + + # Also, ensure set_output_embeddings is safe, though your implementation looks okay. + # The key is to avoid calling it after accelerator.prepare(). + def set_output_embeddings(self, new_embeddings): + # Your current implementation using data.copy_ is good practice, + # but the best way is to not call this after prepare(). + self.lm_head = new_embeddings + + def forward_speech_features( + self, + speech_tensors=None, + speech_masks=None, + speech_type="audio", + return_unmask=False + ): + if speech_tensors is None: + # Use config to get vae_dim instead of non-existent self.args + vae_dim = self.config.acoustic_tokenizer_config.vae_dim + audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight) + connect_features = self.model.acoustic_connector(audio_features) + return audio_features, connect_features + else: + with torch.no_grad(): + if speech_type == "audio": + with torch.no_grad(): + frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0] + audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0] + + elif speech_type == "vae": + # Use config to get vae_dim instead of non-existent self.args + vae_dim = self.config.acoustic_tokenizer_config.vae_dim + speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim) + + # gaussian sample from the speech_mode + batch_size = speech_mode.size(0) + value = self.model.acoustic_tokenizer.fix_std / 0.8 + std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value + std = std.view(-1, *[1] * (speech_mode.dim() - 1)) + audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode) + else: + raise NotImplementedError(f"Speech type {speech_type} not implemented") + + if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor): + scaling_factor = 1. / audio_tokens[speech_masks].flatten().std() + bias_factor = -audio_tokens[speech_masks].flatten().mean() + + # Only use distributed operations if the process group is initialized + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM) + dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + self.model.speech_scaling_factor.copy_(scaling_factor / world_size) + self.model.speech_bias_factor.copy_(bias_factor / world_size) + print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True) + else: + # Single process case + self.model.speech_scaling_factor.copy_(scaling_factor) + self.model.speech_bias_factor.copy_(bias_factor) + print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True) + + audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor + + connect_features = self.model.acoustic_connector(audio_features) + if return_unmask: + return audio_features, connect_features + return audio_features[speech_masks], connect_features[speech_masks] + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # New arguments for speech processing and loss calculation + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speeches_loss_input: Optional[torch.FloatTensor] = None, + speech_semantic_tensors: Optional[torch.FloatTensor] = None, + acoustic_input_mask: Optional[torch.BoolTensor] = None, + acoustic_loss_mask: Optional[torch.BoolTensor] = None, + ddpm_batch_mul: int = 1, + **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]], + ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + x = self.get_input_embeddings()(input_ids) + + semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors) + if speeches_loss_input is not None: + # only part audio need diffuse + speech_all_features, speech_all_connect_features = self.forward_speech_features( + speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None, + speech_masks=speech_masks, + speech_type=kwargs.get("speech_type", "audio"), + return_unmask=True + ) + if speech_tensors is not None: + if semantic_speech_all_connect_features is not None: + x[acoustic_input_mask] = ( + speech_all_connect_features[speech_masks] + + semantic_speech_all_connect_features[speech_masks] + ) + else: + x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + + # Select only the target segments' latents for diffusion loss. + # Both masks are [num_segments, max_latent_len]; using 2D mask on [B,T,D] selects [N_true, D]. + target_latent_mask = speeches_loss_input & speech_masks + speech_features = speech_all_features[target_latent_mask] + speech_connect_features = speech_all_connect_features[target_latent_mask] + else: + speech_features, speech_connect_features = self.forward_speech_features( + speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None, + speech_masks=speech_masks, + speech_type=kwargs.get("speech_type", "audio"), + ) + if speech_tensors is not None: + x[acoustic_input_mask] = speech_connect_features + + outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=x, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=False, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + # logits = logits.float() + + loss = None + if labels is not None: + # The custom CE loss with masking is calculated in the training script. + # We leave the standard loss calculation here as None. + pass + + # --- Diffusion Loss Calculation --- + diffusion_loss = None + # This block is executed only if we are in a context that involves speech. + if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: + condition_features = hidden_states[acoustic_loss_mask] + + speech_len, latent_size = speech_features.shape + + noise = torch.randn( + (speech_len * ddpm_batch_mul, latent_size), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + + timesteps = torch.multinomial( + torch.ones(self.config.diffusion_head_config.ddpm_num_steps), + speech_len * ddpm_batch_mul, + replacement=True, + ).to(hidden_states.device) + + speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0) + condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0) + + noisy_speech_features = self.model.noise_scheduler.add_noise( + speech_features_repeated, noise, timesteps + ) + + model_output = self.model.prediction_head( + noisy_speech_features, + timesteps.type_as(x), + condition_features_repeated + ) + + prediction_type = self.config.diffusion_head_config.prediction_type + if prediction_type == "epsilon": + target_for_loss = noise + elif prediction_type == "v_prediction": + target_for_loss = self.model.noise_scheduler.get_velocity( + speech_features_repeated, noise, timesteps + ) + else: + raise NotImplementedError(f"Prediction type {prediction_type} not implemented") + + diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum') + if latent_size > 0 and ddpm_batch_mul > 0: + diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul + else: + diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) + + else: + # Dummy loss for DDP to work when there are no speech samples in a batch, + # but we are in a speech context. + diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0 + diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0 + diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0 + # --- End Diffusion Loss Calculation --- + + if not return_dict: + output = (logits, speech_len) + outputs.to_tuple()[1:] + return (loss, diffusion_loss) + output + + return VibeVoiceCausalLMOutputWithPast( + loss=loss, + diffusion_loss=diffusion_loss, + speech_token_num=speech_len if speech_tensors is not None else 0, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +AutoModel.register(VibeVoiceConfig, VibeVoiceModel) +AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration) + +__all__ = [ + "VibeVoiceModel", + "VibeVoicePreTrainedModel", + "VibeVoiceForConditionalGeneration", + "VibeVoiceCausalLMOutputWithPast", + "VibeVoiceGenerationOutput", +] \ No newline at end of file diff --git a/vibevoice/modular/modeling_vibevoice_asr.py b/vibevoice/modular/modeling_vibevoice_asr.py new file mode 100644 index 0000000..706bf00 --- /dev/null +++ b/vibevoice/modular/modeling_vibevoice_asr.py @@ -0,0 +1,520 @@ +from typing import List, Optional, Tuple, Union +import torch +import torch.nn as nn + +from transformers.models.auto import AutoModel, AutoModelForCausalLM + +from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation import GenerationMixin + +from .modular_vibevoice_tokenizer import ( + VibeVoiceTokenizerStreamingCache, + VibeVoiceTokenizerEncoderOutput +) + +from .configuration_vibevoice import VibeVoiceASRConfig +from .modeling_vibevoice import ( + VibeVoiceCausalLMOutputWithPast, + SpeechConnector +) + +logger = logging.get_logger(__name__) + +if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] + +# @auto_docstring +class VibeVoiceASRPreTrainedModel(PreTrainedModel): + config_class = VibeVoiceASRConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + + # Use the language model's initializer_range if available + if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'): + std = self.config.language_model_config.initializer_range + elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'): + std = self.config.decoder_config.initializer_range + else: + std = 0.02 # Default value + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + +# @auto_docstring +class VibeVoiceASRModel(VibeVoiceASRPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: + if isinstance(config.torch_dtype, str): + dtype = getattr(torch, config.torch_dtype) + else: + dtype = config.torch_dtype + else: + dtype = torch.float32 + + # Initialize Qwen2 model for language modeling + lm_config = config.decoder_config + self.language_model = AutoModel.from_config(lm_config) + + # Initialize speech components if needed + self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype) + self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype) + + self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype) + self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype) + + def get_input_embeddings(self): + if hasattr(self.language_model, 'embed_tokens'): + # If the language model has an embed_tokens attribute, return it + return self.language_model.embed_tokens + + for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed + if attr.orig_name == 'embed_tokens.weight': + return getattr(self.language_model, name) + assert False, 'should not arrive here' + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): + """Set the speech tokenizers used for encoding and decoding speech.""" + self.acoustic_tokenizer = acoustic_tokenizer + self.semantic_tokenizer = semantic_tokenizer + + # Reset the encoder to evaluation mode + if self.acoustic_tokenizer is not None: + self.acoustic_tokenizer.eval() + + if self.semantic_tokenizer is not None: + self.semantic_tokenizer.eval() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Forward through language model + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + if not return_dict: + return outputs + + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, GenerationMixin): + """ + VibeVoice model for Automatic Speech Recognition (ASR) with language modeling head for conditional generation. + This class is designed for inference and generation tasks. + """ + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = VibeVoiceASRModel(config) + self.vocab_size = config.decoder_config.vocab_size + + # Determine the dtype to use + if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: + if isinstance(config.torch_dtype, str): + dtype = getattr(torch, config.torch_dtype) + else: + dtype = config.torch_dtype + else: + dtype = torch.float32 + + # Initialize lm_head with the correct dtype + self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False).to(dtype) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.language_model = decoder + + def get_decoder(self): + return self.model.language_model + + def tie_weights(self): + """Tie the weights between the input embeddings and the output embeddings.""" + if getattr(self.config.decoder_config, 'tie_word_embeddings', False): + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + if hasattr(input_embeddings, 'weight'): + output_embeddings.weight = input_embeddings.weight + else: + output_embeddings.weight = input_embeddings + + def encode_speech( + self, + speech_tensors: torch.FloatTensor, + speech_masks: Optional[torch.BoolTensor] = None, + speech_semantic_tensors: Optional[torch.FloatTensor] = None, + streaming_segment_duration: float = 60.0, # seconds + ): + """ + Encode speech input into features that can be used by the language model. + This method is called once before generation to process the speech input. + + For long audio (>600s by default), uses streaming processing to avoid conv overflow (>2^32). + Segments are processed independently, then concatenated before final sampling. + + Args: + speech_tensors: Input audio tensor [batch_size, samples] + speech_masks: Optional mask for speech features + speech_semantic_tensors: Optional pre-computed semantic tokens + streaming_segment_duration: Segment duration in seconds for streaming processing (default: 60s) + """ + if hasattr(self.config, 'torch_dtype') and self.config.torch_dtype is not None: + if isinstance(self.config.torch_dtype, str): + dtype = getattr(torch, self.config.torch_dtype) + else: + dtype = self.config.torch_dtype + else: + dtype = torch.float32 + + speech_tensors = speech_tensors.to(dtype) + + # Ensure proper shape: (batch, samples) + if speech_tensors.ndim == 1: + speech_tensors = speech_tensors.unsqueeze(0) + + batch_size, total_samples = speech_tensors.shape + sample_rate = 24000 # fix 24kHz sample rate + + # Calculate segment size in samples + segment_samples = int(streaming_segment_duration * sample_rate) + + # Decide whether to use streaming based on audio length + use_streaming = total_samples > segment_samples + + with torch.no_grad(): + if not use_streaming: + # Short audio: direct processing (original behavior) + encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1)) + audio_tokens = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0] + acoustic_features = self.model.acoustic_connector(audio_tokens) + + # Encode semantic features + if speech_semantic_tensors is not None: + semantic_features = self.model.semantic_connector(speech_semantic_tensors) + else: + semantic_tokens = self.model.semantic_tokenizer.encode(speech_tensors.unsqueeze(1)).mean + semantic_features = self.model.semantic_connector(semantic_tokens) + else: + # Long audio: streaming processing + # print(f"Using streaming processing for long audio: {total_samples/sample_rate:.1f}s " + # f"(segment size: {streaming_segment_duration}s)") + + # Initialize caches for both tokenizers + acoustic_encoder_cache = VibeVoiceTokenizerStreamingCache() + semantic_encoder_cache = VibeVoiceTokenizerStreamingCache() + acoustic_mean_segments = [] + semantic_mean_segments = [] + sample_indices = torch.arange(batch_size, device=speech_tensors.device) + + # Helper function from batch_asr_sft_cache.py + def _iter_segments(total_length: int, segment_length: int): + """Iterate over audio segments with a given segment length.""" + if segment_length <= 0: + raise ValueError("segment_length must be positive") + for start in range(0, total_length, segment_length): + end = min(start + segment_length, total_length) + if end > start: + yield start, end + + # Process each segment for both acoustic and semantic tokenizers + segments = list(_iter_segments(total_samples, segment_samples)) + num_segments = len(segments) + for seg_idx, (start, end) in enumerate(segments): + chunk = speech_tensors[:, start:end].contiguous() + if chunk.numel() == 0: + continue + + # Check if this is the final segment + is_final = (seg_idx == num_segments - 1) + + # Encode chunk for acoustic tokenizer (don't sample yet) + acoustic_encoder_output = self.model.acoustic_tokenizer.encode( + chunk.unsqueeze(1), + cache=acoustic_encoder_cache, + sample_indices=sample_indices, + use_cache=True, + is_final_chunk=is_final, + ) + acoustic_mean_segments.append(acoustic_encoder_output.mean) + + # Encode chunk for semantic tokenizer (take mean directly) + semantic_encoder_output = self.model.semantic_tokenizer.encode( + chunk.unsqueeze(1), + cache=semantic_encoder_cache, + sample_indices=sample_indices, + use_cache=True, + is_final_chunk=is_final, + ) + semantic_mean_segments.append(semantic_encoder_output.mean) + + # print(f"Processed {len(acoustic_mean_segments)} segments.") + # Concatenate all acoustic means and sample once + acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous() + acoustic_encoder_output = VibeVoiceTokenizerEncoderOutput( + mean=acoustic_mean_full, + std=self.model.acoustic_tokenizer.fix_std + ) + audio_tokens = acoustic_encoder_output.sample( + dist_type=self.model.acoustic_tokenizer.std_dist_type + )[0] + acoustic_features = self.model.acoustic_connector(audio_tokens) + + # Concatenate all semantic means + semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous() + semantic_features = self.model.semantic_connector(semantic_tokens) + + # Combine acoustic and semantic features + if speech_masks is not None: + combined_features = acoustic_features[speech_masks] + semantic_features[speech_masks] + else: + combined_features = acoustic_features + semantic_features + + return combined_features + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # Speech-specific arguments + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speech_semantic_tensors: Optional[torch.FloatTensor] = None, + acoustic_input_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutput]: + """ + Forward pass for the model. Handles both training and generation scenarios. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # Process inputs + if inputs_embeds is None and input_ids is not None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # If we have speech input and acoustic_input_mask, encode and insert speech features + if speech_tensors is not None and acoustic_input_mask is not None: + speech_features = self.encode_speech( + speech_tensors=speech_tensors, + speech_masks=speech_masks, + speech_semantic_tensors=speech_semantic_tensors, + ) + inputs_embeds[acoustic_input_mask] = speech_features + + # Forward through the model + outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return VibeVoiceCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + speech_tensors=None, + speech_masks=None, + speech_semantic_tensors=None, + acoustic_input_mask=None, + **kwargs, + ): + """ + Prepare inputs for generation step. This method is called by generate() + for each token generation step. + + Following Qwen2-VL's approach: speech inputs are only forwarded on the first pass + (when cache_position[0] == 0), and are excluded in subsequent generation steps. + """ + # If we have past key values, we only need to process the new tokens + if past_key_values is not None: + if isinstance(past_key_values, tuple): + past_length = past_key_values[0][0].shape[2] + else: + past_length = past_key_values.get_seq_length() + + # Keep only the new tokens + if input_ids is not None and input_ids.shape[1] > past_length: + input_ids = input_ids[:, past_length:] + + # Prepare position ids + if position_ids is None and attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None and input_ids is not None: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # Prepare cache position + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]), + device=input_ids.device if input_ids is not None else inputs_embeds.device + ) + + # Prepare model inputs + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + + # Following Qwen2-VL pattern: only include speech inputs on the first forward pass + # (when cache_position[0] == 0), exclude them in subsequent generation steps + if cache_position is not None and len(cache_position) > 0 and cache_position[0] == 0: + # First forward pass - include speech inputs if provided + model_inputs.update({ + "speech_tensors": speech_tensors, + "speech_masks": speech_masks, + "speech_semantic_tensors": speech_semantic_tensors, + "acoustic_input_mask": acoustic_input_mask, + }) + else: + # Subsequent generation steps - exclude speech inputs + model_inputs.update({ + "speech_tensors": None, + "speech_masks": None, + "speech_semantic_tensors": None, + "acoustic_input_mask": None, + }) + + # Include any remaining kwargs that might be needed + model_inputs.update(kwargs) + + return model_inputs + +AutoModel.register(VibeVoiceASRConfig, VibeVoiceASRModel) +AutoModelForCausalLM.register(VibeVoiceASRConfig, VibeVoiceASRForConditionalGeneration) + +__all__ = [ + "VibeVoiceASRPreTrainedModel", + "VibeVoiceASRModel", + "VibeVoiceASRForConditionalGeneration", +] \ No newline at end of file diff --git a/vibevoice/modular/modular_vibevoice_text_tokenizer.py b/vibevoice/modular/modular_vibevoice_text_tokenizer.py index bfa7bdd..da5669f 100644 --- a/vibevoice/modular/modular_vibevoice_text_tokenizer.py +++ b/vibevoice/modular/modular_vibevoice_text_tokenizer.py @@ -207,8 +207,107 @@ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast): """Id used for padding (returns -100 for loss masking).""" return self._pad_id +class VibeVoiceASRTextTokenizerFast(Qwen2TokenizerFast): + """ + Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library). + Based on the Qwen2 tokenizer with additional special tokens for speech. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + merges_file (`str`, *optional*): + Path to the merges file. + tokenizer_file (`str`, *optional*): + Path to [tokenizers](https://github.com/huggingface/tokenizers) file. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. + bos_token (`str`, *optional*): + The beginning of sequence token. Not used for vibevoice. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding. + """ + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + add_prefix_space=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tokenizer_file=tokenizer_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + # Add VibeVoice-specific special tokens + self._add_vibevoice_special_tokens() + + # https://github.com/QwenLM/Qwen2.5-VL/blob/d2240f11656bfe404b9ba56db4e51cd09f522ff1/qwen-vl-finetune/qwenvl/data/data_qwen_packed.py#L57C5-L57C222 + self.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + + def _add_vibevoice_special_tokens(self): + """Add VibeVoice-specific special tokens.""" + special_tokens = { + "additional_special_tokens": [ + "<|object_ref_start|>", # Speech start (reusing vision tokens) + "<|object_ref_end|>", # Speech end + "<|box_start|>", # Speech diffusion pad + ] + } + num_added = self.add_special_tokens(special_tokens) + + # Cache special token IDs + self._speech_start_id = self.convert_tokens_to_ids("<|object_ref_start|>") + self._speech_end_id = self.convert_tokens_to_ids("<|object_ref_end|>") + self._speech_pad_id = self.convert_tokens_to_ids("<|box_start|>") + + self._eos_id = self.eos_token_id # qwen2 / qwen3 + self._pad_id = self.convert_tokens_to_ids('<|image_pad|>') + + return num_added + + @property + def eos_id(self) -> int: + """Id of the end of sequence token.""" + return self._eos_id + + @property + def speech_start_id(self) -> int: + """Id of the speech start token.""" + return self._speech_start_id + + @property + def speech_end_id(self) -> int: + """Id of the speech end token.""" + return self._speech_end_id + + @property + def speech_pad_id(self) -> int: + """Id of the speech diffusion token.""" + return self._speech_pad_id + + @property + def pad_id(self) -> int: + return self._pad_id + __all__ = [ "VibeVoiceTextTokenizer", "VibeVoiceTextTokenizerFast", + "VibeVoiceASRTextTokenizerFast", ] \ No newline at end of file diff --git a/vibevoice/modular/modular_vibevoice_tokenizer.py b/vibevoice/modular/modular_vibevoice_tokenizer.py index 0031b26..454f9c1 100644 --- a/vibevoice/modular/modular_vibevoice_tokenizer.py +++ b/vibevoice/modular/modular_vibevoice_tokenizer.py @@ -17,7 +17,7 @@ from transformers.utils import logging from transformers.modeling_utils import PreTrainedModel from transformers.activations import ACT2FN -from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig +from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig logger = logging.get_logger(__name__) @@ -26,14 +26,13 @@ import os try: from apex.normalization.fused_layer_norm import fused_rms_norm_affine APEX_AVAILABLE = True - logger.info("APEX FusedRMSNorm is available and will be used for optimization") + # logger.info("APEX FusedRMSNorm is available and will be used for optimization") if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: APEX_AVAILABLE = False - logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0") + # logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0") except ImportError: APEX_AVAILABLE = False - logger.warning("APEX FusedRMSNorm not available, using native implementation") -# APEX_AVAILABLE=False + # logger.warning("APEX FusedRMSNorm not available, using native implementation") # Normalization modules class ConvLayerNorm(nn.LayerNorm): @@ -297,7 +296,8 @@ class SConv1d(nn.Module): cache: Optional[VibeVoiceTokenizerStreamingCache] = None, sample_indices: Optional[torch.Tensor] = None, use_cache: bool = False, - debug: bool = False) -> torch.Tensor: + debug: bool = False, + is_final_chunk: bool = False) -> torch.Tensor: """ Forward pass with optional streaming support via cache. @@ -307,6 +307,7 @@ class SConv1d(nn.Module): sample_indices: Indices identifying each sample for cache management use_cache: Whether to use cached states for streaming debug: Whether to print debug information + is_final_chunk: Whether this is the final chunk (adds extra padding for alignment) Returns: Output tensor @@ -322,12 +323,13 @@ class SConv1d(nn.Module): assert sample_indices is not None, "sample_indices must be provided for streaming mode" assert len(sample_indices) == B, "sample_indices must match batch size" - return self._forward_streaming(x, cache, sample_indices, debug) + return self._forward_streaming(x, cache, sample_indices, debug, is_final_chunk) def _forward_streaming(self, x: torch.Tensor, cache: VibeVoiceTokenizerStreamingCache, sample_indices: torch.Tensor, - debug: bool = False) -> torch.Tensor: + debug: bool = False, + is_final_chunk: bool = False) -> torch.Tensor: """Streaming forward pass with cache operations kept separate from compiled code""" B, C, T = x.shape @@ -350,6 +352,16 @@ class SConv1d(nn.Module): input_with_context = torch.cat([cached_states, x], dim=2) else: input_with_context = x + + # For final chunk, add extra padding to ensure ceil behavior (same as non-streaming) + if is_final_chunk: + extra_padding = get_extra_padding_for_conv1d( + input_with_context, self.kernel_size, self.stride, self.padding_total + ) + if extra_padding > 0: + input_with_context = pad1d(input_with_context, (0, extra_padding), mode=self.pad_mode) + if debug: + print(f"[DEBUG] Final chunk: added extra_padding={extra_padding}") if debug: print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}") @@ -684,6 +696,135 @@ class Block1D(nn.Module): return x +class TokenizerEncoder(nn.Module): + """ + Encoder component for the VibeVoice tokenizer that converts audio to latent representations. + + Args: + config: Configuration object with model parameters + """ + def __init__(self, config): + super().__init__() + + # Extract parameters from config + self.channels = config.channels + self.dimension = config.dimension + self.n_filters = config.n_filters + self.ratios = list(reversed(config.ratios)) + self.depths = config.depths + self.n_residual_layers = getattr(config, "n_residual_layers", 1) + self.hop_length = np.prod(self.ratios) + self.causal = config.causal + + # Additional config parameters with defaults + kernel_size = getattr(config, "kernel_size", 7) + last_kernel_size = getattr(config, "last_kernel_size", 7) + norm = getattr(config, "norm", "none") + norm_params = getattr(config, "norm_params", {}) + pad_mode = getattr(config, "pad_mode", "reflect") + bias = getattr(config, "bias", True) + layernorm = getattr(config, "layernorm", "LN") + layernorm_eps = getattr(config, "layernorm_eps", 1e-6) + layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) + drop_path_rate = getattr(config, "drop_path_rate", 0.0) + mixer_layer = getattr(config, "mixer_layer", "conv") + layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) + disable_last_norm = getattr(config, "disable_last_norm", False) + + # determine the norm type based on layernorm + if layernorm == 'LN': + norm_type = ConvLayerNorm + elif layernorm == 'RMSNorm': + norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) + else: + raise ValueError(f"Unsupported norm type: {layernorm}") + + # stem and intermediate downsampling conv layers + stem = nn.Sequential( + SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), + ) + + self.downsample_layers = nn.ModuleList() + self.downsample_layers.append(stem) + for i in range(len(self.ratios)): + in_ch = self.n_filters * (2 ** i) + out_ch = self.n_filters * (2 ** (i + 1)) + downsample_layer = nn.Sequential( + SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) + ) + self.downsample_layers.append(downsample_layer) + + # configure the transformer blocks + layer_type = partial( + Block1D, + mixer_layer=mixer_layer, + layernorm=layernorm, + eps=layernorm_eps, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + layer_scale_init_value=layer_scale_init_value, + ) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + + for i in range(len(self.depths)): + in_ch = self.n_filters * (2 ** i) + stage = nn.Sequential( + *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] + ) + self.stages.append(stage) + cur += self.depths[i] + + if not disable_last_norm: + self.norm = norm_type(in_ch, eps=layernorm_eps) + else: + self.norm = nn.Identity() + self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) + + def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False): + for i in range(len(self.depths)): + # Apply downsampling + for layer in self.downsample_layers[i]: + if isinstance(layer, SConv1d): + x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk) + else: + x = layer(x) + + # Apply stage (Block1D contains Convlayer which contains SConv1d) + for block in self.stages[i]: + if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): + # Block1D forward with cache support + residual = x + x = block.norm(x) + x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk) + if block.gamma is not None: + x = x * block.gamma.unsqueeze(-1) + x = residual + x + + # FFN part + residual = x + x = block.ffn_norm(x) + x = x.permute(0, 2, 1) + x = block.ffn(x) + x = x.permute(0, 2, 1) + if block.ffn_gamma is not None: + x = x * block.ffn_gamma.unsqueeze(-1) + x = residual + x + else: + x = block(x) + + return self.norm(x) + + def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False): + x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk) + x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk) + return x + + class TokenizerDecoder(nn.Module): """ Decoder component for the VibeVoice tokenizer that converts latent representations back to audio. @@ -821,15 +962,63 @@ class TokenizerDecoder(nn.Module): x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) return x + +@dataclass +class VibeVoiceTokenizerEncoderOutput: + """ + Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance. + + Args: + mean (`torch.FloatTensor`): The mean parameters of the distribution. + std (`float` or `torch.FloatTensor`): Fixed standard deviation value. + """ + mean: torch.Tensor + std: Optional[Union[float, torch.Tensor]] = None + + def sample(self, dist_type='fix'): + """ + Sample from the distribution. + + Args: + dist_type (`str`): Sampling method, either 'fix' or 'gaussian'. + + Returns: + `torch.FloatTensor`: Sampled values. + `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian'). + """ + if dist_type == 'fix': + x = self.mean + self.std * torch.randn_like(self.mean) + return x, self.std + elif dist_type == 'gaussian': + batch_size = self.mean.size(0) + value = self.std / 0.8 + std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value + + while std.dim() < self.mean.dim(): + std = std.unsqueeze(-1) + + x = self.mean + std * torch.randn_like(self.mean) + return x, std + else: + return self.mean, self.std + + def kl(self): + """Compute KL divergence between this distribution and a standard normal.""" + target = torch.zeros_like(self.mean) + return F.mse_loss(self.mean, target, reduction='none') + + def mode(self): + """Return the distribution mode (which is the mean for Gaussian).""" + return self.mean class VibeVoiceAcousticTokenizerModel(PreTrainedModel): - """VibeVoice speech tokenizer model (only decoder) for acoustic tokens""" + """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens""" config_class = VibeVoiceAcousticTokenizerConfig base_model_prefix = "vibevoice_acoustic_tokenizer" _supports_flash_attn_2 = True _supports_sdpa = True - _no_split_modules = ["TokenizerDecoder"] + _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] def __init__(self, config): super().__init__(config) @@ -850,6 +1039,21 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel): # Default: use reversed encoder depths if decoder_depths is None decoder_depths = list(reversed(encoder_depths)) + # Create encoder config + encoder_config = copy.deepcopy(config) + encoder_config.dimension = config.vae_dim + encoder_config.n_filters = config.encoder_n_filters + encoder_config.ratios = config.encoder_ratios + encoder_config.depths = encoder_depths + encoder_config.norm = config.conv_norm + encoder_config.pad_mode = config.pad_mode + encoder_config.bias = config.conv_bias + encoder_config.layernorm_eps = config.layernorm_eps + encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine + encoder_config.mixer_layer = config.mixer_layer + encoder_config.layer_scale_init_value = config.layer_scale_init_value + encoder_config.disable_last_norm = config.disable_last_norm + # Create decoder config decoder_config = copy.deepcopy(config) decoder_config.dimension = config.vae_dim @@ -865,6 +1069,8 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel): decoder_config.layer_scale_init_value = config.layer_scale_init_value decoder_config.disable_last_norm = config.disable_last_norm + # Initialize encoder and decoder + self.encoder = TokenizerEncoder(encoder_config) self.decoder = TokenizerDecoder(decoder_config) # Initialize weights @@ -884,6 +1090,24 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel): if module.bias is not None: nn.init.zeros_(module.bias) + @torch.no_grad() + def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False): + """Convert audio to latent representations""" + latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk) + return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std) + + @torch.no_grad() + def sampling(self, encoder_output, dist_type=None): + """Sample from the encoder output distribution""" + dist_type = dist_type or self.std_dist_type + + if dist_type == 'fix': + return encoder_output.sample(dist_type='fix') + elif dist_type == 'gaussian': + return encoder_output.sample(dist_type='gaussian') + else: + raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'") + @torch.no_grad() def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False): """Convert latent representations back to audio""" @@ -895,10 +1119,89 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel): audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) return audio + def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): + """Full forward pass: encode audio to latents, then decode back to audio""" + encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + sampled_latents, _ = self.sampling(encoder_output) + reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return reconstructed, sampled_latents + + +class VibeVoiceSemanticTokenizerModel(PreTrainedModel): + """VibeVoice speech tokenizer model with only encoder for semantic tokens""" + + config_class = VibeVoiceSemanticTokenizerConfig + base_model_prefix = "vibevoice_semantic_tokenizer" + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["TokenizerEncoder"] + + def __init__(self, config): + super().__init__(config) + + # Parse encoder depths + if isinstance(config.encoder_depths, str): + encoder_depths = [int(d) for d in config.encoder_depths.split('-')] + else: + encoder_depths = config.encoder_depths + + # Create encoder config + encoder_config = copy.deepcopy(config) + encoder_config.dimension = config.vae_dim + encoder_config.n_filters = config.encoder_n_filters + encoder_config.ratios = config.encoder_ratios + encoder_config.depths = encoder_depths + encoder_config.norm = config.conv_norm + encoder_config.pad_mode = config.pad_mode + encoder_config.bias = config.conv_bias + encoder_config.layernorm_eps = config.layernorm_eps + encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine + encoder_config.mixer_layer = config.mixer_layer + encoder_config.layer_scale_init_value = config.layer_scale_init_value + encoder_config.disable_last_norm = config.disable_last_norm + + # Initialize encoder and decoder + self.encoder = TokenizerEncoder(encoder_config) + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module): + """Initialize weights for the model""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv1d): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.no_grad() + def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False): + """Convert audio to latent representations""" + latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk) + return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) + + @torch.no_grad() + def sampling(self, encoder_output, dist_type=None): + """Sample from the encoder output distribution""" + return encoder_output.sample(dist_type='none') + + def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): + """Full forward pass: encode audio to latents, then decode back to audio""" + encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + sampled_latents, _ = self.sampling(encoder_output, dist_type='none') + return None, sampled_latents AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel) +AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel) __all__ = [ "VibeVoiceTokenizerStreamingCache", "VibeVoiceAcousticTokenizerModel", + "VibeVoiceSemanticTokenizerModel", ] \ No newline at end of file diff --git a/vibevoice/processor/audio_utils.py b/vibevoice/processor/audio_utils.py new file mode 100644 index 0000000..ad4d10d --- /dev/null +++ b/vibevoice/processor/audio_utils.py @@ -0,0 +1,143 @@ +import numpy as np +from subprocess import run +from typing import List, Optional, Union, Dict, Any + +COMMON_AUDIO_EXTS = [ + '.mp3', '.MP3', '.Mp3', # All case variations of mp3 + '.m4a', + '.mp4', '.MP4', + '.wav', '.WAV', + '.m4v', + '.aac', + '.ogg', + '.mov', '.MOV', + '.opus', + '.m4b', + '.flac', + '.wma', '.WMA', + '.rm', '.3gp', '.mpeg', '.flv', '.webm', '.mp2', '.aif', '.aiff', '.oga', '.ogv', '.mpga', '.m3u8', '.amr' +] + +def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24000): + """ + Open an audio file and read as mono waveform, optionally resampling. + Returns both the audio data and the original sample rate. + + Parameters + ---------- + file: str + The audio file to open + resample: bool + Whether to resample the audio + target_sr: int + The target sample rate if resampling is requested + + Returns + ------- + A tuple containing: + - A NumPy array with the audio waveform in float32 dtype + - The original sample rate of the audio file + """ + if not resample: + # First, get the original sample rate + cmd_probe = [ + "ffprobe", + "-v", "quiet", + "-show_entries", "stream=sample_rate", + "-of", "default=noprint_wrappers=1:nokey=1", + file + ] + + original_sr = int(run(cmd_probe, capture_output=True, check=True).stdout.decode().strip()) + else: + original_sr = None + + # Now load the audio + sr_to_use = target_sr if resample else original_sr + + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(sr_to_use), + "-" + ] + + out = run(cmd, capture_output=True, check=True).stdout + audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + return audio_data, sr_to_use + +class AudioNormalizer: + """ + Audio normalization class for VibeVoice tokenizer. + + This class provides audio normalization to ensure consistent input levels + for the VibeVoice tokenizer while maintaining audio quality. + """ + + def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6): + """ + Initialize the audio normalizer. + + Args: + target_dB_FS (float): Target dB FS level for the audio. Default: -25 + eps (float): Small value to avoid division by zero. Default: 1e-6 + """ + self.target_dB_FS = target_dB_FS + self.eps = eps + + def tailor_dB_FS(self, audio: np.ndarray) -> tuple: + """ + Adjust the audio to the target dB FS level. + + Args: + audio (np.ndarray): Input audio signal + + Returns: + tuple: (normalized_audio, rms, scalar) + """ + rms = np.sqrt(np.mean(audio**2)) + scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) + normalized_audio = audio * scalar + return normalized_audio, rms, scalar + + def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple: + """ + Avoid clipping by scaling down if necessary. + + Args: + audio (np.ndarray): Input audio signal + scalar (float, optional): Explicit scaling factor + + Returns: + tuple: (normalized_audio, scalar) + """ + if scalar is None: + max_val = np.max(np.abs(audio)) + if max_val > 1.0: + scalar = max_val + self.eps + else: + scalar = 1.0 + + return audio / scalar, scalar + + def __call__(self, audio: np.ndarray) -> np.ndarray: + """ + Normalize the audio by adjusting to target dB FS and avoiding clipping. + + Args: + audio (np.ndarray): Input audio signal + + Returns: + np.ndarray: Normalized audio signal + """ + # First adjust to target dB FS + audio, _, _ = self.tailor_dB_FS(audio) + # Then avoid clipping + audio, _ = self.avoid_clipping(audio) + return audio \ No newline at end of file diff --git a/vibevoice/processor/vibevoice_asr_processor.py b/vibevoice/processor/vibevoice_asr_processor.py new file mode 100644 index 0000000..007dc39 --- /dev/null +++ b/vibevoice/processor/vibevoice_asr_processor.py @@ -0,0 +1,572 @@ +""" +Processor class for VibeVoice ASR models. +""" + +import os +import json +import math +import warnings +from typing import List, Optional, Union, Dict, Any, Tuple + +import numpy as np +import torch + +from transformers.tokenization_utils_base import BatchEncoding +from transformers.utils import TensorType, logging +from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor, AudioNormalizer + +try: + from .audio_utils import load_audio_use_ffmpeg + HAS_FFMPEG_UTILS = True +except ImportError: + HAS_FFMPEG_UTILS = False + warnings.warn("audio_utils not available, will fall back to soundfile for audio loading") + +logger = logging.get_logger(__name__) + +SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format." + + +class VibeVoiceASRProcessor: + """ + Processor for VibeVoice ASR (Automatic Speech Recognition) models. + + This processor handles audio preprocessing and tokenization for ASR tasks, + following the exact format used in training with proper chat templates. + + Args: + tokenizer: The text tokenizer for processing text + audio_processor: The audio processor for processing speech + speech_tok_compress_ratio (int): Compression ratio for speech tokenization + target_sample_rate (int): Target sample rate for audio + normalize_audio (bool): Whether to normalize audio input + """ + + def __init__( + self, + tokenizer=None, + audio_processor=None, + speech_tok_compress_ratio=320, + target_sample_rate=24000, + normalize_audio=True, + **kwargs + ): + self.tokenizer = tokenizer + self.audio_processor = audio_processor or VibeVoiceTokenizerProcessor( + sampling_rate=target_sample_rate, + normalize_audio=normalize_audio + ) + self.speech_tok_compress_ratio = speech_tok_compress_ratio + self.target_sample_rate = target_sample_rate + self.normalize_audio = normalize_audio + + if normalize_audio: + self.audio_normalizer = AudioNormalizer() + else: + self.audio_normalizer = None + + # Cache special token IDs + self._cache_special_tokens() + + def _cache_special_tokens(self): + """Cache special token IDs for efficiency.""" + # Add safety checks for special tokens + if hasattr(self.tokenizer, 'speech_start_id'): + self.speech_start_id = self.tokenizer.speech_start_id + else: + self.speech_start_id = self.tokenizer.convert_tokens_to_ids("<|speech_start|>") + + if hasattr(self.tokenizer, 'speech_end_id'): + self.speech_end_id = self.tokenizer.speech_end_id + else: + self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>") + + if hasattr(self.tokenizer, 'speech_pad_id'): + self.speech_pad_id = self.tokenizer.speech_pad_id + else: + self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>") + + if hasattr(self.tokenizer, 'pad_id'): + self.pad_id = self.tokenizer.pad_id + elif hasattr(self.tokenizer, 'pad_token_id'): + self.pad_id = self.tokenizer.pad_token_id + else: + self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Load processor from a pretrained model path. + + Args: + pretrained_model_name_or_path: Path to the pretrained model + **kwargs: Additional keyword arguments + + Returns: + VibeVoiceASRProcessor: The loaded processor + """ + import json + from transformers.utils import cached_file + from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast + + # Try to load configuration + config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") + config = {} + + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = json.load(f) + else: + try: + config_file = cached_file( + pretrained_model_name_or_path, + "preprocessor_config.json", + **kwargs + ) + with open(config_file, 'r') as f: + config = json.load(f) + except Exception as e: + logger.warning(f"Could not load preprocessor_config.json: {e}") + logger.warning("Using default configuration") + + # Extract parameters + speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) + target_sample_rate = config.get("target_sample_rate", 24000) + normalize_audio = config.get("normalize_audio", True) + + # Load tokenizer + language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B") + logger.info(f"Loading tokenizer from {language_model_pretrained_name}") + + if 'qwen' in language_model_pretrained_name.lower(): + tokenizer = VibeVoiceASRTextTokenizerFast.from_pretrained( + language_model_pretrained_name, + **kwargs + ) + else: + raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}") + + # Load audio processor + audio_processor = VibeVoiceTokenizerProcessor( + sampling_rate=target_sample_rate, + normalize_audio=normalize_audio, + target_dB_FS=config.get("target_dB_FS", -25), + eps=config.get("eps", 1e-6), + ) + + return cls( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=speech_tok_compress_ratio, + target_sample_rate=target_sample_rate, + normalize_audio=normalize_audio, + ) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + """ + Save processor configuration to a directory. + + Args: + save_directory: Directory to save the configuration + **kwargs: Additional keyword arguments + """ + import json + + os.makedirs(save_directory, exist_ok=True) + + # Save processor configuration + processor_config = { + "processor_class": "VibeVoiceASRProcessor", + "speech_tok_compress_ratio": self.speech_tok_compress_ratio, + "target_sample_rate": self.target_sample_rate, + "normalize_audio": self.normalize_audio, + "target_dB_FS": -25, + "eps": 1e-6, + } + + config_path = os.path.join(save_directory, "preprocessor_config.json") + with open(config_path, 'w') as f: + json.dump(processor_config, f, indent=2) + + logger.info(f"Processor configuration saved in {config_path}") + + def __call__( + self, + audio: Optional[Union[str, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, torch.Tensor]]]] = None, + sampling_rate: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + padding: bool = True, + max_length: Optional[int] = None, + truncation: bool = False, + add_generation_prompt: bool = True, + use_streaming: bool = True, + context_info: Optional[str] = None, + **kwargs + ) -> BatchEncoding: + """ + Process audio input for ASR model. + + Args: + audio: Audio input(s). Can be: + - str: Path to audio file + - np.ndarray: Audio array + - torch.Tensor: Audio tensor + - List of the above for batch processing + sampling_rate: Sampling rate of input audio + return_tensors: Output format ('pt' for PyTorch, 'np' for NumPy) + padding: Whether to pad batch inputs + max_length: Maximum sequence length + truncation: Whether to truncate long sequences + add_generation_prompt: Whether to add generation prompt for inference + use_streaming: Whether to use streaming mode (True by default, auto False if <60s) + context_info: Optional context information (e.g., hotwords, metadata) to help transcription + + Returns: + BatchEncoding with: + - input_ids: Token IDs for the model + - attention_mask: Attention mask + - acoustic_input_mask: Mask indicating speech token positions + - speech_tensors: Processed speech features + - speech_masks: Valid speech masks + - vae_tok_seqlens: Length of each speech segment in tokens + """ + if audio is None: + raise ValueError("Audio input is required for ASR processing") + + # Handle single vs batch input + if isinstance(audio, list): + is_batched = True + audio_list = audio + else: + is_batched = False + audio_list = [audio] + + # Process each audio input + all_encodings = [] + for audio_input in audio_list: + encoding = self._process_single_audio( + audio_input, + sampling_rate=sampling_rate, + add_generation_prompt=add_generation_prompt, + use_streaming=use_streaming, + context_info=context_info, + ) + all_encodings.append(encoding) + + # Combine into batch + batch_encoding = self._batch_encode( + all_encodings, + padding=padding, + max_length=max_length, + truncation=truncation, + return_tensors=return_tensors, + ) + + return batch_encoding + + def _process_single_audio( + self, + audio: Union[str, np.ndarray, torch.Tensor], + sampling_rate: Optional[int] = None, + add_generation_prompt: bool = True, + use_streaming: bool = True, + context_info: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Process a single audio input. + + Args: + audio: Single audio input + sampling_rate: Audio sampling rate + add_generation_prompt: Whether to add generation prompt + context_info: Optional context information (e.g., hotwords, metadata) to help transcription + + Returns: + Dictionary with processed tokens and audio features + """ + # Process audio through audio processor + if isinstance(audio, str): + # Load from file using ffmpeg for better format support + if HAS_FFMPEG_UTILS: + try: + audio_array, file_sr = load_audio_use_ffmpeg(audio, resample=False) + except Exception as e: + # Fall back to soundfile if ffmpeg fails + warnings.warn(f"ffmpeg loading failed, falling back to soundfile: {e}") + import soundfile as sf + audio_array, file_sr = sf.read(audio) + if audio_array.ndim > 1: + audio_array = audio_array.mean(axis=1) # Convert to mono + else: + import soundfile as sf + audio_array, file_sr = sf.read(audio) + if audio_array.ndim > 1: + audio_array = audio_array.mean(axis=1) # Convert to mono + + # Resample if needed + if file_sr != self.target_sample_rate: + import librosa + audio_array = librosa.resample( + audio_array, + orig_sr=file_sr, + target_sr=self.target_sample_rate + ) + elif isinstance(audio, torch.Tensor): + audio_array = audio.cpu().numpy() + if audio_array.ndim > 1: + audio_array = audio_array.squeeze() + else: + audio_array = np.array(audio, dtype=np.float32) + if audio_array.ndim > 1: + audio_array = audio_array.squeeze() + + # Ensure float32 + audio_array = audio_array.astype(np.float32) + + # Normalize if needed + if self.normalize_audio and self.audio_normalizer: + audio_array = self.audio_normalizer(audio_array) + + # Calculate audio duration + audio_duration = len(audio_array) / self.target_sample_rate + + # Auto-disable streaming for short audio (<60s) + if use_streaming and audio_duration < 60.0: + use_streaming = False + + # Calculate token length based on streaming mode + # Non-streaming: uses ceil (encoder adds extra_padding for stride alignment) + # Streaming: uses floor (segments processed independently, no global alignment) + # if use_streaming: + # vae_tok_len = len(audio_array) // self.speech_tok_compress_ratio + # else: + vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio) + + # Build token sequence following training format + # 1. System prompt - use apply_chat_template then encode like in training + system_prompt_text = self.tokenizer.apply_chat_template( + [{"role": "system", "content": SYSTEM_PROMPT}], + tokenize=False + ) + system_tokens = self.tokenizer.encode(system_prompt_text) + + # 2. User input with speech tokens + # Build speech placeholder string + sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id) + sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id) + sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id) + + # User suffix with audio duration info + show_keys = ['Start time', 'End time', 'Speaker ID', 'Content'] + if context_info and context_info.strip(): + user_suffix = f"This is a {audio_duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\nPlease transcribe it with these keys: " + ", ".join(show_keys) + else: + user_suffix = f"This is a {audio_duration:.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys) + + user_input_string = ''.join( + [sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token] + ) + '\n' + user_suffix + + user_tokens = self.tokenizer.apply_chat_template( + [{"role": "user", "content": user_input_string}], + tokenize=True + ) + + # Combine tokens + full_tokens = system_tokens + user_tokens + + # Create acoustic input mask + acoustic_input_mask = [1 if token == self.speech_pad_id else 0 for token in full_tokens] + + return { + "input_ids": full_tokens, + "acoustic_input_mask": acoustic_input_mask, + "speech": audio_array, + "vae_tok_len": vae_tok_len, + } + + def _batch_encode( + self, + encodings: List[Dict[str, Any]], + padding: bool = True, + max_length: Optional[int] = None, + truncation: bool = False, + return_tensors: Optional[str] = None, + ) -> BatchEncoding: + """ + Combine multiple encodings into a batch. + + Args: + encodings: List of encoded samples + padding: Whether to pad sequences + max_length: Maximum sequence length + truncation: Whether to truncate + return_tensors: Output format + + Returns: + BatchEncoding with batched data + """ + # Extract components + input_ids_list = [enc["input_ids"] for enc in encodings] + acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings] + speech_list = [enc["speech"] for enc in encodings] + vae_tok_lens = [enc["vae_tok_len"] for enc in encodings] + + # Determine max length for padding + if padding: + if max_length is not None: + target_length = max_length + else: + target_length = max(len(ids) for ids in input_ids_list) + + # Pad sequences + padded_input_ids = [] + padded_acoustic_masks = [] + attention_masks = [] + + for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list): + # Truncate if needed + if truncation and len(input_ids) > target_length: + input_ids = input_ids[:target_length] + acoustic_mask = acoustic_mask[:target_length] + + # Pad sequences to left (for autoregressive generation) + padding_length = target_length - len(input_ids) + padded_ids = [self.pad_id] * padding_length + input_ids + padded_acoustic = [0] * padding_length + acoustic_mask + attention_mask = [0] * padding_length + [1] * len(input_ids) + + padded_input_ids.append(padded_ids) + padded_acoustic_masks.append(padded_acoustic) + attention_masks.append(attention_mask) + + input_ids_list = padded_input_ids + acoustic_masks_list = padded_acoustic_masks + else: + attention_masks = [[1] * len(ids) for ids in input_ids_list] + + # Process speech tensors - raw audio is 1D, so we keep it as is + max_speech_length = max(len(s) for s in speech_list) + padded_speeches = np.zeros((len(speech_list), max_speech_length), dtype=np.float32) + speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool) + + for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)): + padded_speeches[i, :len(speech)] = speech + speech_masks[i, :vae_len] = True + + # Create batch encoding + batch_encoding = BatchEncoding() + + if return_tensors == "pt": + batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) + batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long) + batch_encoding["acoustic_input_mask"] = torch.tensor(acoustic_masks_list, dtype=torch.bool) + batch_encoding["speech_tensors"] = torch.tensor(padded_speeches, dtype=torch.float32) + batch_encoding["speech_masks"] = torch.tensor(speech_masks, dtype=torch.bool) + # Note: vae_tok_seqlens and speech_type are not included as they are not model inputs + else: + batch_encoding["input_ids"] = input_ids_list if len(input_ids_list) > 1 else input_ids_list[0] + batch_encoding["attention_mask"] = attention_masks if len(attention_masks) > 1 else attention_masks[0] + batch_encoding["acoustic_input_mask"] = acoustic_masks_list if len(acoustic_masks_list) > 1 else acoustic_masks_list[0] + batch_encoding["speech_tensors"] = padded_speeches if len(padded_speeches) > 1 else padded_speeches[0] + batch_encoding["speech_masks"] = speech_masks if len(speech_masks) > 1 else speech_masks[0] + + return batch_encoding + + def batch_decode(self, *args, **kwargs): + """ + Decode batch of token IDs to text. + Forwards to tokenizer's batch_decode method. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + Decode token IDs to text. + Forwards to tokenizer's decode method. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_transcription(self, text: str) -> List[Dict[str, Any]]: + """ + Post-process the generated transcription text to extract structured data. + + Args: + text: Generated text from the model + + Returns: + List of dictionaries with transcription segments + """ + try: + # Try to parse as JSON + if "```json" in text: + # Extract JSON from markdown code block + json_start = text.find("```json") + 7 + json_end = text.find("```", json_start) + json_str = text[json_start:json_end].strip() + else: + # Try to find JSON array or object + json_start = text.find("[") + if json_start == -1: + json_start = text.find("{") + if json_start != -1: + # Find matching closing bracket + bracket_count = 0 + json_end = json_start + for i in range(json_start, len(text)): + if text[i] in "[{": + bracket_count += 1 + elif text[i] in "]}": + bracket_count -= 1 + if bracket_count == 0: + json_end = i + 1 + break + json_str = text[json_start:json_end] + else: + json_str = text + + # Parse JSON + result = json.loads(json_str) + + # Ensure it's a list + if isinstance(result, dict): + result = [result] + + # Validate and clean up the result + cleaned_result = [] + for item in result: + if isinstance(item, dict): + cleaned_item = {} + # Map keys to expected format + key_mapping = { + "Start time": "start_time", + "Start": "start_time", + "End time": "end_time", + "End": "end_time", + "Speaker ID": "speaker_id", + "Speaker": "speaker_id", + "Content": "text", + } + for key, mapped_key in key_mapping.items(): + if key in item: + cleaned_item[mapped_key] = item[key] + + if cleaned_item: + cleaned_result.append(cleaned_item) + + return cleaned_result + + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse JSON from transcription: {e}") + logger.debug(f"Raw text: {text}") + return [] + except Exception as e: + logger.warning(f"Error post-processing transcription: {e}") + return [] + + @property + def model_input_names(self): + """Return the list of inputs accepted by the model.""" + return ["input_ids", "attention_mask", "acoustic_input_mask", "speech_tensors", "speech_masks"] + +__all__ = ["VibeVoiceASRProcessor"] diff --git a/vibevoice/processor/vibevoice_tokenizer_processor.py b/vibevoice/processor/vibevoice_tokenizer_processor.py index 19b5795..67f61a6 100644 --- a/vibevoice/processor/vibevoice_tokenizer_processor.py +++ b/vibevoice/processor/vibevoice_tokenizer_processor.py @@ -13,80 +13,10 @@ import torch from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.utils import logging +from .audio_utils import AudioNormalizer + logger = logging.get_logger(__name__) - -class AudioNormalizer: - """ - Audio normalization class for VibeVoice tokenizer. - - This class provides audio normalization to ensure consistent input levels - for the VibeVoice tokenizer while maintaining audio quality. - """ - - def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6): - """ - Initialize the audio normalizer. - - Args: - target_dB_FS (float): Target dB FS level for the audio. Default: -25 - eps (float): Small value to avoid division by zero. Default: 1e-6 - """ - self.target_dB_FS = target_dB_FS - self.eps = eps - - def tailor_dB_FS(self, audio: np.ndarray) -> tuple: - """ - Adjust the audio to the target dB FS level. - - Args: - audio (np.ndarray): Input audio signal - - Returns: - tuple: (normalized_audio, rms, scalar) - """ - rms = np.sqrt(np.mean(audio**2)) - scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) - normalized_audio = audio * scalar - return normalized_audio, rms, scalar - - def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple: - """ - Avoid clipping by scaling down if necessary. - - Args: - audio (np.ndarray): Input audio signal - scalar (float, optional): Explicit scaling factor - - Returns: - tuple: (normalized_audio, scalar) - """ - if scalar is None: - max_val = np.max(np.abs(audio)) - if max_val > 1.0: - scalar = max_val + self.eps - else: - scalar = 1.0 - - return audio / scalar, scalar - - def __call__(self, audio: np.ndarray) -> np.ndarray: - """ - Normalize the audio by adjusting to target dB FS and avoiding clipping. - - Args: - audio (np.ndarray): Input audio signal - - Returns: - np.ndarray: Normalized audio signal - """ - # First adjust to target dB FS - audio, _, _ = self.tailor_dB_FS(audio) - # Then avoid clipping - audio, _ = self.avoid_clipping(audio) - return audio - - # Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components class VibeVoiceTokenizerProcessor(FeatureExtractionMixin): """