diff --git a/Figures/VibeVoice_ASR_archi.png b/Figures/VibeVoice_ASR_archi.png
new file mode 100644
index 0000000..0d24aff
Binary files /dev/null and b/Figures/VibeVoice_ASR_archi.png differ
diff --git a/README.md b/README.md
index 4c4e68a..41a98de 100644
--- a/README.md
+++ b/README.md
@@ -23,8 +23,9 @@
+2026-01-21: π£ We open-sourced VibeVoice-ASR, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context.
-2025-12-16: π£ We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.
+2025-12-16: π£ We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.
2025-12-09: π£ We added experimental speakers in nine languages (DE, FR, IT, JP, KR, NL, PL, PT, ES) for explorationβwelcome to try them out and share your feedback.
@@ -123,4 +124,4 @@ We do not recommend using VibeVoice in commercial or real-world applications wit
## Star History
-
\ No newline at end of file
+
diff --git a/demo/vibevoice_asr_gradio_demo.py b/demo/vibevoice_asr_gradio_demo.py
new file mode 100644
index 0000000..c334f95
--- /dev/null
+++ b/demo/vibevoice_asr_gradio_demo.py
@@ -0,0 +1,1177 @@
+#!/usr/bin/env python
+"""
+VibeVoice ASR Gradio Demo
+"""
+
+import os
+import sys
+import torch
+import numpy as np
+import soundfile as sf
+from pathlib import Path
+import argparse
+import time
+import json
+import gradio as gr
+from typing import List, Dict, Tuple, Optional, Generator
+import tempfile
+import base64
+import io
+import traceback
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+# Import TextIteratorStreamer for streaming generation
+from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
+
+try:
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen2
+ # Only apply RoPE, RMSNorm, SwiGLU patches (these affect the underlying Qwen2 layers)
+ apply_liger_kernel_to_qwen2(
+ rope=True,
+ rms_norm=True,
+ swiglu=True,
+ cross_entropy=False,
+ )
+ print("β
Liger Kernel applied to Qwen2 components (RoPE, RMSNorm, SwiGLU)")
+except Exception as e:
+ print(f"β οΈ Failed to apply Liger Kernel: {e}, you can install it with: pip install liger-kernel")
+
+# Try to import pydub for MP3 conversion
+try:
+ from pydub import AudioSegment
+ HAS_PYDUB = True
+except ImportError:
+ HAS_PYDUB = False
+ print("β οΈ Warning: pydub not available, falling back to WAV format")
+
+from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
+from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
+from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, COMMON_AUDIO_EXTS
+
+
+class VibeVoiceASRInference:
+ """Simple inference wrapper for VibeVoice ASR model."""
+
+ def __init__(self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, attn_implementation: str = "flash_attention_2"):
+ """
+ Initialize the ASR inference pipeline.
+
+ Args:
+ model_path: Path to the pretrained model (HuggingFace format directory or model name)
+ device: Device to run inference on
+ dtype: Data type for model weights
+ attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
+ """
+ print(f"Loading VibeVoice ASR model from {model_path}")
+
+ # Load processor
+ self.processor = VibeVoiceASRProcessor.from_pretrained(model_path)
+
+ # Load model
+ print(f"Using attention implementation: {attn_implementation}")
+ self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device if device == "auto" else None,
+ attn_implementation=attn_implementation,
+ trust_remote_code=True
+ )
+
+ if device != "auto":
+ self.model = self.model.to(device)
+
+ self.device = device if device != "auto" else next(self.model.parameters()).device
+ self.model.eval()
+
+ # Print model info
+ total_params = sum(p.numel() for p in self.model.parameters())
+ print(f"β
Model loaded successfully on {self.device}")
+ print(f"π Total parameters: {total_params:,} ({total_params/1e9:.2f}B)")
+
+ def transcribe(
+ self,
+ audio_path: str = None,
+ audio_array: np.ndarray = None,
+ sample_rate: int = None,
+ max_new_tokens: int = 512,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+ do_sample: bool = False,
+ num_beams: int = 1,
+ repetition_penalty: float = 1.0,
+ context_info: str = None,
+ streamer: Optional[TextIteratorStreamer] = None,
+ ) -> dict:
+ """
+ Transcribe audio to text.
+
+ Args:
+ audio_path: Path to audio file
+ audio_array: Audio array (if not loading from file)
+ sample_rate: Sample rate of audio array
+ max_new_tokens: Maximum tokens to generate
+ temperature: Temperature for sampling (0 for greedy)
+ top_p: Top-p for nucleus sampling (1.0 for no filtering)
+ do_sample: Whether to use sampling
+ num_beams: Number of beams for beam search (1 for greedy)
+ repetition_penalty: Repetition penalty (1.0 for no penalty)
+ context_info: Optional context information (e.g., hotwords, speaker names, topics) to help transcription
+ streamer: Optional TextIteratorStreamer for streaming output
+
+ Returns:
+ Dictionary with transcription results
+ """
+ # Process audio
+ inputs = self.processor(
+ audio=audio_path,
+ sampling_rate=sample_rate,
+ return_tensors="pt",
+ add_generation_prompt=True,
+ context_info=context_info
+ )
+
+ # Move to device
+ inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
+ for k, v in inputs.items()}
+
+ # Generate
+ generation_config = {
+ "max_new_tokens": max_new_tokens,
+ "temperature": temperature if temperature > 0 else None,
+ "top_p": top_p if do_sample else None,
+ "do_sample": do_sample,
+ "num_beams": num_beams,
+ "repetition_penalty": repetition_penalty,
+ "pad_token_id": self.processor.pad_id,
+ "eos_token_id": self.processor.tokenizer.eos_token_id,
+ }
+
+ # Add streamer if provided
+ if streamer is not None:
+ generation_config["streamer"] = streamer
+
+ # Add stopping criteria for stop button support
+ generation_config["stopping_criteria"] = StoppingCriteriaList([StopOnFlag()])
+
+ # Remove None values
+ generation_config = {k: v for k, v in generation_config.items() if v is not None}
+
+ start_time = time.time()
+
+ # Calculate input token statistics before generation
+ input_ids = inputs['input_ids'][0] # Shape: [seq_len]
+ total_input_tokens = input_ids.shape[0]
+
+ # Count padding tokens (tokens equal to pad_id)
+ pad_id = self.processor.pad_id
+ padding_mask = (input_ids == pad_id)
+ num_padding_tokens = padding_mask.sum().item()
+
+ # Count speech tokens (tokens between speech_start_id and speech_end_id)
+ speech_start_id = self.processor.speech_start_id
+ speech_end_id = self.processor.speech_end_id
+
+ # Find speech regions
+ input_ids_list = input_ids.tolist()
+ num_speech_tokens = 0
+ in_speech = False
+ for token_id in input_ids_list:
+ if token_id == speech_start_id:
+ in_speech = True
+ num_speech_tokens += 1 # Count speech_start token
+ elif token_id == speech_end_id:
+ in_speech = False
+ num_speech_tokens += 1 # Count speech_end token
+ elif in_speech:
+ num_speech_tokens += 1
+
+ # Text tokens = total - speech - padding
+ num_text_tokens = total_input_tokens - num_speech_tokens - num_padding_tokens
+
+ with torch.no_grad():
+ output_ids = self.model.generate(
+ **inputs,
+ **generation_config
+ )
+
+ generation_time = time.time() - start_time
+
+ # Decode output
+ generated_ids = output_ids[0, inputs['input_ids'].shape[1]:]
+ generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
+
+ # Parse structured output
+ try:
+ transcription_segments = self.processor.post_process_transcription(generated_text)
+ except Exception as e:
+ print(f"Warning: Failed to parse structured output: {e}")
+ transcription_segments = []
+
+ return {
+ "raw_text": generated_text,
+ "segments": transcription_segments,
+ "generation_time": generation_time,
+ "input_tokens": {
+ "total": total_input_tokens,
+ "speech": num_speech_tokens,
+ "text": num_text_tokens,
+ "padding": num_padding_tokens,
+ },
+ }
+
+
+def clip_and_encode_audio(
+ audio_data: np.ndarray,
+ sr: int,
+ start_time: float,
+ end_time: float,
+ segment_idx: int,
+ use_mp3: bool = True,
+ target_sr: int = 16000, # Downsample to 16kHz for smaller size
+ mp3_bitrate: str = "32k" # Use low bitrate for minimal transfer
+) -> Tuple[int, Optional[str], Optional[str]]:
+ """
+ Clip audio segment and encode to base64.
+
+ Args:
+ audio_data: Full audio array
+ sr: Sample rate
+ start_time: Start time in seconds
+ end_time: End time in seconds
+ segment_idx: Segment index for identification
+ use_mp3: Whether to use MP3 format (smaller size)
+ target_sr: Target sample rate for downsampling (lower = smaller)
+ mp3_bitrate: MP3 bitrate (lower = smaller, e.g., "24k", "32k", "48k")
+
+ Returns:
+ Tuple of (segment_idx, base64_string, error_message)
+ """
+ try:
+ # Convert time to sample indices
+ start_sample = int(start_time * sr)
+ end_sample = int(end_time * sr)
+
+ # Ensure indices are within bounds
+ start_sample = max(0, start_sample)
+ end_sample = min(len(audio_data), end_sample)
+
+ if start_sample >= end_sample:
+ return segment_idx, None, f"Invalid time range: [{start_time:.2f}s - {end_time:.2f}s]"
+
+ # Extract segment
+ segment_data = audio_data[start_sample:end_sample]
+
+ # Downsample if needed (reduces data size significantly)
+ if sr != target_sr and target_sr < sr:
+ # Simple downsampling using linear interpolation
+ duration = len(segment_data) / sr
+ new_length = int(duration * target_sr)
+ indices = np.linspace(0, len(segment_data) - 1, new_length)
+ segment_data = np.interp(indices, np.arange(len(segment_data)), segment_data)
+ sr = target_sr
+
+ # Convert float32 audio to int16 for encoding
+ segment_data_int16 = (segment_data * 32768.0).astype(np.int16)
+
+ # Convert to MP3 if pydub is available and use_mp3 is True
+ if use_mp3 and HAS_PYDUB:
+ try:
+ # Write to WAV in memory
+ wav_buffer = io.BytesIO()
+ sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16')
+ wav_buffer.seek(0)
+
+ # Convert to MP3 with low bitrate
+ audio_segment = AudioSegment.from_wav(wav_buffer)
+ # Convert to mono if stereo (halves the size)
+ if audio_segment.channels > 1:
+ audio_segment = audio_segment.set_channels(1)
+ mp3_buffer = io.BytesIO()
+ audio_segment.export(mp3_buffer, format='mp3', bitrate=mp3_bitrate)
+ mp3_buffer.seek(0)
+
+ # Encode to base64
+ audio_bytes = mp3_buffer.read()
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
+ audio_src = f"data:audio/mp3;base64,{audio_base64}"
+
+ return segment_idx, audio_src, None
+ except Exception as e:
+ # Fall back to WAV on error
+ print(f"MP3 conversion failed for segment {segment_idx}, using WAV: {e}")
+
+ # Fall back to WAV format (no temp file, use in-memory buffer)
+ wav_buffer = io.BytesIO()
+ sf.write(wav_buffer, segment_data_int16, sr, format='WAV', subtype='PCM_16')
+ wav_buffer.seek(0)
+
+ audio_bytes = wav_buffer.read()
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
+ audio_src = f"data:audio/wav;base64,{audio_base64}"
+
+ return segment_idx, audio_src, None
+
+ except Exception as e:
+ error_msg = f"Error clipping segment {segment_idx}: {str(e)}"
+ print(error_msg)
+ return segment_idx, None, error_msg
+
+
+def extract_audio_segments(audio_path: str, segments: List[Dict]) -> List[Tuple[str, str, Optional[str]]]:
+ """
+ Extract multiple segments from audio file efficiently with parallel processing.
+
+ Args:
+ audio_path: Path to original audio file
+ segments: List of segment dictionaries with start_time, end_time, etc.
+
+ Returns:
+ List of tuples (segment_label, audio_base64_src, error_msg)
+ """
+ try:
+ # Read audio file once using ffmpeg for better format support
+ print(f"π Loading audio file: {audio_path}")
+ audio_data, sr = load_audio_use_ffmpeg(audio_path, resample=False)
+ print(f"β
Audio loaded: {len(audio_data)} samples, {sr} Hz")
+
+ # Prepare tasks
+ tasks = []
+ use_mp3 = HAS_PYDUB # Use MP3 if available
+
+ for i, seg in enumerate(segments):
+ start_time = seg.get('start_time')
+ end_time = seg.get('end_time')
+
+ # Skip if times are not available or invalid
+ if (not isinstance(start_time, (int, float)) or
+ not isinstance(end_time, (int, float)) or
+ start_time >= end_time):
+ tasks.append((i, None, None, None, None, None)) # Will be filtered later
+ continue
+
+ tasks.append((audio_data, sr, start_time, end_time, i, use_mp3))
+
+ # Process in parallel using ThreadPoolExecutor
+ results = []
+ total_segments = len(tasks)
+ completed_count = 0
+
+ # Use CPU count for max workers
+ max_workers = os.cpu_count() or 4
+ print(f"π Starting parallel processing with {max_workers} threads...")
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {}
+ for task in tasks:
+ if task[0] is None: # Skip invalid tasks
+ continue
+ future = executor.submit(clip_and_encode_audio, *task)
+ futures[future] = task[4] # segment_idx
+
+ for future in as_completed(futures):
+ try:
+ result = future.result()
+ results.append(result)
+ completed_count += 1
+ # Log progress every 100 segments or at completion
+ if completed_count % 100 == 0 or completed_count == len(futures):
+ print(f"Progress: {completed_count}/{len(futures)} segments processed ({completed_count*100//len(futures)}%)")
+ except Exception as e:
+ idx = futures[future]
+ results.append((idx, None, f"Processing error: {str(e)}"))
+ completed_count += 1
+ print(f"Error on segment {idx}: {e}")
+
+ print(f"β
Completed processing all {len(futures)} valid segments")
+
+ # Sort by segment index to maintain order
+ results.sort(key=lambda x: x[0])
+
+ # Build output list with labels
+ audio_segments = []
+ for i, (idx, audio_src, error_msg) in enumerate(results):
+ seg = segments[idx] if idx < len(segments) else {}
+ start_time = seg.get('start_time', 'N/A')
+ end_time = seg.get('end_time', 'N/A')
+ speaker_id = seg.get('speaker_id', 'N/A')
+
+ segment_label = f"Segment {idx+1}: [{start_time:.2f}s - {end_time:.2f}s] Speaker {speaker_id}"
+ audio_segments.append((segment_label, audio_src, error_msg))
+
+ return audio_segments
+
+ except Exception as e:
+ print(f"Error loading audio file: {e}")
+ return []
+
+
+# Global variable to store the ASR model
+asr_model = None
+
+# Global stop flag for generation
+stop_generation_flag = False
+
+
+class StopOnFlag(StoppingCriteria):
+ """Custom stopping criteria that checks a global flag."""
+ def __call__(self, input_ids, scores, **kwargs):
+ global stop_generation_flag
+ return stop_generation_flag
+
+
+def parse_time_to_seconds(val: Optional[str]) -> Optional[float]:
+ """Parse seconds or hh:mm:ss to float seconds."""
+ if val is None:
+ return None
+ val = val.strip()
+ if not val:
+ return None
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if ":" in val:
+ parts = val.split(":")
+ if not all(p.strip().replace(".", "", 1).isdigit() for p in parts):
+ return None
+ parts = [float(p) for p in parts]
+ if len(parts) == 3:
+ h, m, s = parts
+ elif len(parts) == 2:
+ h = 0
+ m, s = parts
+ else:
+ return None
+ return h * 3600 + m * 60 + s
+ return None
+
+
+def slice_audio_to_temp(
+ audio_data: np.ndarray,
+ sample_rate: int,
+ start_sec: Optional[float],
+ end_sec: Optional[float]
+) -> Tuple[Optional[str], Optional[str]]:
+ """Slice audio_data to [start_sec, end_sec) and write to a temp WAV file."""
+ n_samples = len(audio_data)
+ full_duration = n_samples / float(sample_rate)
+ start = 0.0 if start_sec is None else max(0.0, start_sec)
+ end = full_duration if end_sec is None else min(full_duration, end_sec)
+ if end <= start:
+ return None, f"Invalid time range: start={start:.2f}s, end={end:.2f}s"
+ start_idx = int(start * sample_rate)
+ end_idx = int(end * sample_rate)
+ segment = audio_data[start_idx:end_idx]
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
+ temp_file.close()
+ segment_int16 = (segment * 32768.0).astype(np.int16)
+ sf.write(temp_file.name, segment_int16, sample_rate, subtype='PCM_16')
+ return temp_file.name, None
+
+
+def initialize_model(model_path: str, device: str = "cuda", attn_implementation: str = "flash_attention_2"):
+ """Initialize the ASR model."""
+ global asr_model
+ try:
+ dtype = torch.bfloat16 if device != "cpu" else torch.float32
+ asr_model = VibeVoiceASRInference(
+ model_path=model_path,
+ device=device,
+ dtype=dtype,
+ attn_implementation=attn_implementation
+ )
+ return f"β
Model loaded successfully from {model_path}"
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ return f"β Error loading model: {str(e)}"
+
+
+def transcribe_audio(
+ audio_input,
+ audio_path_input: str,
+ start_time_input: str,
+ end_time_input: str,
+ max_new_tokens: int,
+ temperature: float,
+ top_p: float,
+ do_sample: bool,
+ repetition_penalty: float = 1.0,
+ context_info: str = ""
+) -> Generator[Tuple[str, str], None, None]:
+ """
+ Transcribe audio and return results with audio segments (streaming version).
+
+ Args:
+ audio_input: Audio file path or tuple (sample_rate, audio_data)
+ max_new_tokens: Maximum tokens to generate
+ temperature: Temperature for sampling (0 for greedy)
+ top_p: Top-p for nucleus sampling
+ do_sample: Whether to use sampling
+ context_info: Optional context information (e.g., hotwords, speaker names, topics)
+
+ Yields:
+ Tuple of (raw_text, audio_segments_html)
+ """
+ if asr_model is None:
+ yield "β Please load a model first!", ""
+ return
+
+ if not audio_path_input and audio_input is None:
+ yield "β Please provide audio input!", ""
+ return
+
+ try:
+ print("[INFO] Transcription requested")
+ start_sec = parse_time_to_seconds(start_time_input)
+ end_sec = parse_time_to_seconds(end_time_input)
+ print(f"[INFO] Parsed time range: start={start_sec}, end={end_sec}")
+ if (start_time_input and start_sec is None) or (end_time_input and end_sec is None):
+ yield "β Invalid time format. Use seconds or hh:mm:ss.", ""
+ return
+
+ audio_path = None
+ audio_array = None
+ sample_rate = None
+
+ if audio_path_input:
+ candidate = Path(audio_path_input.strip())
+ if not candidate.exists():
+ yield f"β Provided path does not exist: {candidate}", ""
+ return
+ audio_path = str(candidate)
+ print(f"[INFO] Using provided audio path: {audio_path}")
+ # Get audio file path (Gradio Audio component returns tuple (sample_rate, audio_data) or file path)
+ elif isinstance(audio_input, str):
+ audio_path = audio_input
+ print(f"[INFO] Using uploaded audio path: {audio_path}")
+ elif isinstance(audio_input, tuple):
+ # Audio from microphone: (sample_rate, audio_data)
+ sample_rate, audio_array = audio_input
+ print(f"[INFO] Received microphone audio with sample_rate={sample_rate}")
+ elif audio_path is None:
+ yield "β Invalid audio input format!", ""
+ return
+
+ # If slicing is requested, load and slice audio
+ if start_sec is not None or end_sec is not None:
+ print("[INFO] Slicing audio per requested time range")
+ if audio_array is None or sample_rate is None:
+ try:
+ audio_array, sample_rate = load_audio_use_ffmpeg(audio_path, resample=False)
+ print("[INFO] Loaded audio for slicing via ffmpeg")
+ except Exception as exc:
+ yield f"β Failed to load audio for slicing: {exc}", ""
+ return
+ sliced_path, err = slice_audio_to_temp(audio_array, sample_rate, start_sec, end_sec)
+ if err:
+ yield f"β {err}", ""
+ return
+ audio_path = sliced_path
+ print(f"[INFO] Sliced audio written to temp file: {audio_path}")
+ elif audio_array is not None and sample_rate is not None:
+ # no slicing but microphone input: write to temp file
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
+ audio_path = temp_file.name
+ temp_file.close()
+ audio_data_int16 = (audio_array * 32768.0).astype(np.int16)
+ sf.write(audio_path, audio_data_int16, sample_rate, subtype='PCM_16')
+ print(f"[INFO] Microphone audio saved to temp file: {audio_path}")
+
+ # Create streamer for real-time output
+ streamer = TextIteratorStreamer(
+ asr_model.processor.tokenizer,
+ skip_prompt=True,
+ skip_special_tokens=True
+ )
+
+ # Store result in a mutable container for the thread
+ result_container = {"result": None, "error": None}
+
+ def run_transcription():
+ try:
+ result_container["result"] = asr_model.transcribe(
+ audio_path=audio_path,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ do_sample=do_sample,
+ repetition_penalty=repetition_penalty,
+ context_info=context_info if context_info and context_info.strip() else None,
+ streamer=streamer
+ )
+ except Exception as e:
+ result_container["error"] = str(e)
+ traceback.print_exc()
+
+ # Start transcription in background thread
+ print("[INFO] Starting model transcription (streaming mode)")
+ start_time = time.time()
+ transcription_thread = threading.Thread(target=run_transcription)
+ transcription_thread.start()
+
+ # Yield streaming output
+ generated_text = ""
+ token_count = 0
+ for new_text in streamer:
+ generated_text += new_text
+ token_count += 1
+ elapsed = time.time() - start_time
+ # Show streaming output with live stats, format for readability
+ formatted_text = generated_text.replace('},', '},\n')
+ streaming_output = f"--- π΄ LIVE Streaming Output (tokens: {token_count}, time: {elapsed:.1f}s) ---\n{formatted_text}"
+ yield streaming_output, "
β³ Generating transcription... Audio segments will appear after completion.
"
+
+ # Wait for thread to complete
+ transcription_thread.join()
+
+ if result_container["error"]:
+ yield f"β Error during transcription: {result_container['error']}", ""
+ return
+
+ result = result_container["result"]
+ generation_time = time.time() - start_time
+
+ # Get input token statistics
+ input_tokens = result.get('input_tokens', {})
+ speech_tokens = input_tokens.get('speech', 0)
+ text_tokens = input_tokens.get('text', 0)
+ padding_tokens = input_tokens.get('padding', 0)
+ total_input = input_tokens.get('total', 0)
+
+ # Format final raw output with input/output token stats
+ raw_output = f"--- β
Raw Output ---\n"
+ raw_output += f"π₯ Input: {total_input} tokens (π€ speech: {speech_tokens}, π text: {text_tokens}, β¬ pad: {padding_tokens})\n"
+ raw_output += f"π€ Output: {token_count} tokens | β±οΈ Time: {generation_time:.2f}s\n"
+ raw_output += f"---\n"
+ # Format raw text for better readability: add newline after each dict (},)
+ formatted_raw_text = result['raw_text'].replace('},', '},\n')
+ raw_output += formatted_raw_text
+
+ # Debug: print raw output to console
+ print(f"[DEBUG] Raw model output:")
+ print(f"[DEBUG] {result['raw_text']}")
+ print(f"[DEBUG] Found {len(result['segments'])} segments")
+
+ # Create audio segments with server-side encoding (low quality for minimal transfer)
+ # Using: 16kHz mono MP3 @ 32kbps = ~4KB per second of audio
+ audio_segments_html = ""
+ segments = result['segments']
+
+ if segments:
+ num_segments = len(segments)
+ print(f"[INFO] Creating per-segment audio clips ({num_segments} segments, 16kHz mono MP3 @ 32kbps)")
+
+ # Extract all audio segments efficiently (load audio only once)
+ audio_segments = extract_audio_segments(audio_path, segments)
+ print("[INFO] Completed creating audio clips")
+
+ # Calculate approximate total size
+ total_duration = sum(
+ (seg.get('end_time', 0) - seg.get('start_time', 0))
+ for seg in segments
+ if isinstance(seg.get('start_time'), (int, float)) and isinstance(seg.get('end_time'), (int, float))
+ )
+ approx_size_kb = total_duration * 4 # ~4KB per second at 32kbps
+
+ # Add CSS for theme-aware styling
+ theme_css = """
+
+ """
+
+ audio_segments_html = theme_css
+ audio_segments_html += f""
+
+ # Add format info
+ format_info = "MP3 32kbps 16kHz mono" if HAS_PYDUB else "WAV 16kHz"
+ audio_segments_html += f"
π Audio Segments ({num_segments} segments)"
+ audio_segments_html += f"π¦ ~{approx_size_kb:.0f}KB ({format_info})
"
+ audio_segments_html += "
π΅ Click the play button to listen to each segment directly!
"
+
+ for i, (label, audio_src, error_msg) in enumerate(audio_segments):
+ seg = segments[i] if i < len(segments) else {}
+ start_time = seg.get('start_time', 'N/A')
+ end_time = seg.get('end_time', 'N/A')
+ speaker_id = seg.get('speaker_id', 'N/A')
+ content = seg.get('text', '')
+
+ # Format times nicely
+ start_str = f"{start_time:.2f}" if isinstance(start_time, (int, float)) else str(start_time)
+ end_str = f"{end_time:.2f}" if isinstance(end_time, (int, float)) else str(end_time)
+
+ audio_segments_html += f"""
+
+
+
+
+ {content}
+
+ """
+
+ if audio_src:
+ # Detect format from data URI
+ audio_type = 'audio/mp3' if 'audio/mp3' in audio_src else 'audio/wav'
+ audio_segments_html += f"""
+
+ """
+ elif error_msg:
+ audio_segments_html += f"""
+
+ β {error_msg}
+
+ """
+ else:
+ audio_segments_html += """
+
+ Audio playback unavailable for this segment
+
+ """
+
+ audio_segments_html += "
"
+
+ audio_segments_html += "
"
+ else:
+ audio_segments_html = """
+
+
+
β No audio segments available.
+
This could happen if the model output doesn't contain valid time stamps.
+
+ """
+
+ # Final yield with complete results
+ yield raw_output, audio_segments_html
+
+ except Exception as e:
+ print(f"Error during transcription: {e}")
+ print(traceback.format_exc())
+ yield f"β Error during transcription: {str(e)}", ""
+
+
+def create_gradio_interface(model_path: str, default_max_tokens: int = 8192, attn_implementation: str = "flash_attention_2"):
+ """Create and launch Gradio interface.
+
+ Args:
+ model_path: Path to the model (HuggingFace format directory or model name)
+ default_max_tokens: Default value for max_new_tokens slider
+ attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
+ """
+
+ # Initialize model at startup
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ model_status = initialize_model(model_path, device, attn_implementation)
+ print(model_status)
+
+ # Exit if model loading failed
+ if model_status.startswith("β"):
+ print("\n" + "="*80)
+ print("π₯ FATAL ERROR: Model loading failed!")
+ print("="*80)
+ print("Cannot start demo without a valid model. Please check:")
+ print(" 1. Model path is correct")
+ print(" 2. Model files are not corrupted")
+ print(" 3. You have enough GPU memory")
+ print(" 4. CUDA is properly installed (if using GPU)")
+ print("="*80)
+ sys.exit(1)
+
+ # Custom CSS for Stop button styling
+ custom_css = """
+ #stop-btn {
+ background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important;
+ border: none !important;
+ color: white !important;
+ }
+ #stop-btn:hover {
+ background: linear-gradient(135deg, #dc2626 0%, #b91c1c 100%) !important;
+ }
+ """
+
+ # Gradio 6.0+ moved theme/css to launch()
+ with gr.Blocks(title="VibeVoice ASR Demo") as demo:
+ gr.Markdown("# ποΈ VibeVoice ASR Demo")
+ gr.Markdown("Upload audio files or record from microphone to get speech-to-text transcription with speaker diarization.")
+ gr.Markdown(f"**Model loaded from:** `{model_path}`")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ # Generation parameters
+ gr.Markdown("## βοΈ Generation Parameters")
+ max_tokens_slider = gr.Slider(
+ minimum=4096,
+ maximum=65536,
+ value=default_max_tokens,
+ step=4096,
+ label="Max New Tokens"
+ )
+
+ # Sampling parameters
+ gr.Markdown("### π² Sampling")
+ do_sample_checkbox = gr.Checkbox(
+ value=False,
+ label="Enable Sampling",
+ info="Enable random sampling instead of deterministic decoding"
+ )
+
+ with gr.Column(visible=False) as sampling_params:
+ temperature_slider = gr.Slider(
+ minimum=0.0,
+ maximum=2.0,
+ value=0.0,
+ step=0.1,
+ label="Temperature",
+ info="0 = greedy, higher = more random"
+ )
+ top_p_slider = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=1.0,
+ step=0.05,
+ label="Top-p (Nucleus Sampling)",
+ info="1.0 = no filtering"
+ )
+
+ # Repetition penalty (works with both greedy and sampling)
+ repetition_penalty_slider = gr.Slider(
+ minimum=1.0,
+ maximum=1.2,
+ value=1.0,
+ step=0.01,
+ label="Repetition Penalty",
+ info="1.0 = no penalty, higher = less repetition (works with greedy & sampling)"
+ )
+
+ # Context information section
+ gr.Markdown("## π Context Info (Optional)")
+ context_info_input = gr.Textbox(
+ label="Context Information",
+ placeholder="Enter hotwords, speaker names, topics, or other context to help transcription...\nExample:\nJohn Smith\nMachine Learning\nOpenAI",
+ lines=4,
+ max_lines=8,
+ interactive=True,
+ info="Provide context like proper nouns, technical terms, or speaker names to improve accuracy"
+ )
+
+ with gr.Column(scale=2):
+ # Audio input section
+ gr.Markdown("## π΅ Audio Input")
+ audio_input = gr.Audio(
+ label="Upload Audio File or Record from Microphone",
+ sources=["upload", "microphone"],
+ type="filepath",
+ interactive=True,
+ buttons=["download"]
+ )
+
+ with gr.Accordion("π Advanced: Remote Path & Time Slicing", open=False):
+ audio_path_input = gr.Textbox(
+ label="Audio path (optional)",
+ placeholder="Enter remote audio file path",
+ lines=1
+ )
+ with gr.Row():
+ start_time_input = gr.Textbox(
+ label="Start time",
+ placeholder="e.g., 0 or 00:00:00",
+ lines=1,
+ info="Leave empty to start from the beginning"
+ )
+ end_time_input = gr.Textbox(
+ label="End time",
+ placeholder="e.g., 30.5 or 00:00:30.5",
+ lines=1,
+ info="Leave empty to use full length"
+ )
+
+ with gr.Row():
+ transcribe_button = gr.Button("π― Transcribe", variant="primary", size="lg", scale=3)
+ stop_button = gr.Button("βΉοΈ Stop", variant="secondary", size="lg", scale=1, elem_id="stop-btn")
+
+ # Results section
+ gr.Markdown("## π Results")
+
+ with gr.Tabs():
+ with gr.TabItem("Raw Output"):
+ raw_output = gr.Textbox(
+ label="Raw Transcription Output",
+ lines=8,
+ max_lines=20,
+ interactive=False
+ )
+
+ with gr.TabItem("Audio Segments"):
+ audio_segments_output = gr.HTML(
+ label="Play individual segments to verify accuracy"
+ )
+
+ # Event handlers
+ do_sample_checkbox.change(
+ fn=lambda x: gr.update(visible=x),
+ inputs=[do_sample_checkbox],
+ outputs=[sampling_params]
+ )
+
+ def reset_stop_flag():
+ """Reset stop flag before starting transcription."""
+ global stop_generation_flag
+ stop_generation_flag = False
+
+ def set_stop_flag():
+ """Set stop flag to interrupt generation."""
+ global stop_generation_flag
+ stop_generation_flag = True
+ return "βΉοΈ Stop requested..."
+
+ transcribe_button.click(
+ fn=reset_stop_flag,
+ inputs=[],
+ outputs=[],
+ queue=False
+ ).then(
+ fn=transcribe_audio,
+ inputs=[
+ audio_input,
+ audio_path_input,
+ start_time_input,
+ end_time_input,
+ max_tokens_slider,
+ temperature_slider,
+ top_p_slider,
+ do_sample_checkbox,
+ repetition_penalty_slider,
+ context_info_input
+ ],
+ outputs=[raw_output, audio_segments_output]
+ )
+
+ stop_button.click(
+ fn=set_stop_flag,
+ inputs=[],
+ outputs=[raw_output],
+ queue=False
+ )
+
+ # Add examples
+ gr.Markdown("## π Instructions")
+ gr.Markdown(f"""
+ 1. **Upload Audio**: Use the audio component to upload a file or record from microphone
+ - **Supported formats**: {', '.join(sorted(set([ext.lower() for ext in COMMON_AUDIO_EXTS])))}
+ - Optionally set **Start/End time** (seconds or hh:mm:ss) to clip before transcription
+ 2. **Context Info (Optional)**: Provide context to improve transcription accuracy
+ - Add hotwords, proper nouns, speaker names, or technical terms
+ - One item per line or comma-separated
+ - Examples: "John Smith", "OpenAI", "machine learning"
+ 3. **Adjust Parameters**: Configure generation parameters as needed
+ 4. **Transcribe**: Click "Transcribe" to get results
+ 5. **Review Results**:
+ - **Raw Output**: View the model's original output
+ - **Audio Segments**: Play individual segments directly to verify accuracy
+
+ **Audio Segments**: Each segment shows the time range, speaker ID, transcribed content, and an embedded audio player for immediate verification.
+ """)
+
+ return demo, custom_css
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VibeVoice ASR Gradio Demo")
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="microsoft/VibeVoice-ASR",
+ help="Path to the model (HuggingFace format directory or model name)"
+ )
+ parser.add_argument(
+ "--attn_implementation",
+ type=str,
+ default="flash_attention_2",
+ help="Attention implementation to use (default: flash_attention_2)"
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=32768,
+ help="Default max new tokens for generation (default: 32768)"
+ )
+ parser.add_argument(
+ "--host",
+ type=str,
+ default="0.0.0.0",
+ help="Host to bind the server to"
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=7860,
+ help="Port to bind the server to"
+ )
+ parser.add_argument(
+ "--share",
+ action="store_true",
+ help="Create a public link"
+ )
+
+ args = parser.parse_args()
+
+ # Create and launch interface
+ demo, custom_css = create_gradio_interface(
+ model_path=args.model_path,
+ default_max_tokens=args.max_new_tokens,
+ attn_implementation=args.attn_implementation
+ )
+
+ print(f"π Starting VibeVoice ASR Demo...")
+ print(f"π Server will be available at: http://{args.host}:{args.port}")
+
+ # Gradio 6.0+ moved theme/css to launch()
+ launch_kwargs = {
+ "server_name": args.host,
+ "server_port": args.port,
+ "share": args.share,
+ "show_error": True,
+ "theme": gr.themes.Soft(),
+ "css": custom_css,
+ }
+
+ # Enable queue for concurrent request handling
+ demo.queue(default_concurrency_limit=3)
+ demo.launch(**launch_kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo/vibevoice_asr_inference_from_file.py b/demo/vibevoice_asr_inference_from_file.py
new file mode 100644
index 0000000..e7d95d4
--- /dev/null
+++ b/demo/vibevoice_asr_inference_from_file.py
@@ -0,0 +1,554 @@
+#!/usr/bin/env python
+"""
+VibeVoice ASR Batch Inference Demo Script
+
+This script supports batch inference for ASR model and compares results
+between batch processing and single-sample processing.
+"""
+
+import os
+import sys
+import torch
+import numpy as np
+from pathlib import Path
+import argparse
+import time
+import json
+import re
+from typing import List, Dict, Any, Optional
+from functools import wraps
+
+from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
+from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
+
+
+class VibeVoiceASRBatchInference:
+ """Batch inference wrapper for VibeVoice ASR model."""
+
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.bfloat16,
+ attn_implementation: str = "flash_attention_2"
+ ):
+ """
+ Initialize the ASR batch inference pipeline.
+
+ Args:
+ model_path: Path to the pretrained model
+ device: Device to run inference on
+ dtype: Data type for model weights
+ attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
+ """
+ print(f"Loading VibeVoice ASR model from {model_path}")
+
+ # Load processor
+ self.processor = VibeVoiceASRProcessor.from_pretrained(
+ model_path,
+ language_model_pretrained_name="Qwen/Qwen2.5-7B"
+ )
+
+ # Load model with specified attention implementation
+ print(f"Using attention implementation: {attn_implementation}")
+ self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device if device == "auto" else None,
+ attn_implementation=attn_implementation,
+ trust_remote_code=True
+ )
+
+ if device != "auto":
+ self.model = self.model.to(device)
+
+ self.device = device if device != "auto" else next(self.model.parameters()).device
+ self.dtype = dtype
+ self.model.eval()
+
+ print(f"Model loaded successfully on {self.device}")
+
+ def _prepare_generation_config(
+ self,
+ max_new_tokens: int = 512,
+ temperature: float = 0.0,
+ top_p: float = 0.9,
+ do_sample: bool = True,
+ num_beams: int = 1,
+ ) -> dict:
+ """Prepare generation configuration."""
+ config = {
+ "max_new_tokens": max_new_tokens,
+ "pad_token_id": self.processor.pad_id,
+ "eos_token_id": self.processor.tokenizer.eos_token_id,
+ }
+
+ # Beam search vs sampling
+ if num_beams > 1:
+ config["num_beams"] = num_beams
+ config["do_sample"] = False # Beam search doesn't use sampling
+ else:
+ config["do_sample"] = do_sample
+ # Only set temperature and top_p when sampling is enabled
+ if do_sample:
+ config["temperature"] = temperature
+ config["top_p"] = top_p
+
+ return config
+
+ def transcribe_batch(
+ self,
+ audio_inputs: List,
+ max_new_tokens: int = 512,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+ do_sample: bool = True,
+ num_beams: int = 1,
+ ) -> List[Dict[str, Any]]:
+ """
+ Transcribe multiple audio files/arrays in a single batch.
+
+ Args:
+ audio_inputs: List of audio file paths or (array, sampling_rate) tuples
+ max_new_tokens: Maximum tokens to generate
+ temperature: Temperature for sampling
+ top_p: Top-p for nucleus sampling
+ do_sample: Whether to use sampling
+
+ Returns:
+ List of transcription results
+ """
+ if len(audio_inputs) == 0:
+ return []
+
+ batch_size = len(audio_inputs)
+ print(f"\nProcessing batch of {batch_size} audio(s)...")
+
+ # Process all audio together
+ inputs = self.processor(
+ audio=audio_inputs,
+ sampling_rate=None,
+ return_tensors="pt",
+ padding=True,
+ add_generation_prompt=True
+ )
+
+ # Move to device
+ inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
+ for k, v in inputs.items()}
+
+ # Print batch info
+ print(f" Input IDs shape: {inputs['input_ids'].shape}")
+ print(f" Speech tensors shape: {inputs['speech_tensors'].shape}")
+ print(f" Attention mask shape: {inputs['attention_mask'].shape}")
+
+ # Generate
+ generation_config = self._prepare_generation_config(
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ do_sample=do_sample,
+ num_beams=num_beams,
+ )
+
+ start_time = time.time()
+
+ with torch.no_grad():
+ output_ids = self.model.generate(
+ **inputs,
+ **generation_config
+ )
+
+ generation_time = time.time() - start_time
+
+ # Decode outputs for each sample in the batch
+ results = []
+ input_length = inputs['input_ids'].shape[1]
+
+ for i, audio_input in enumerate(audio_inputs):
+ # Get generated tokens for this sample (excluding input tokens)
+ generated_ids = output_ids[i, input_length:]
+
+ # Remove padding tokens from the end
+ # Find the first eos_token or pad_token
+ eos_positions = (generated_ids == self.processor.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
+ if len(eos_positions) > 0:
+ generated_ids = generated_ids[:eos_positions[0] + 1]
+
+ generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
+
+ # Parse structured output
+ try:
+ transcription_segments = self.processor.post_process_transcription(generated_text)
+ except Exception as e:
+ print(f"Warning: Failed to parse structured output: {e}")
+ transcription_segments = []
+
+ # Get file name based on input type
+ if isinstance(audio_input, str):
+ file_name = audio_input
+ elif isinstance(audio_input, dict) and 'id' in audio_input:
+ file_name = audio_input['id']
+ else:
+ file_name = f"audio_{i}"
+
+ results.append({
+ "file": file_name,
+ "raw_text": generated_text,
+ "segments": transcription_segments,
+ "generation_time": generation_time / batch_size,
+ })
+
+ print(f" Total generation time: {generation_time:.2f}s")
+ print(f" Average time per sample: {generation_time/batch_size:.2f}s")
+
+ return results
+
+ def transcribe_with_batching(
+ self,
+ audio_inputs: List,
+ batch_size: int = 4,
+ max_new_tokens: int = 512,
+ temperature: float = 0.0,
+ top_p: float = 1.0,
+ do_sample: bool = True,
+ num_beams: int = 1,
+ ) -> List[Dict[str, Any]]:
+ """
+ Transcribe multiple audio files/arrays with automatic batching.
+
+ Args:
+ audio_inputs: List of audio file paths or (array, sampling_rate) tuples
+ batch_size: Number of samples per batch
+ max_new_tokens: Maximum tokens to generate
+ temperature: Temperature for sampling
+ top_p: Top-p for nucleus sampling
+ do_sample: Whether to use sampling
+
+ Returns:
+ List of transcription results
+ """
+ all_results = []
+
+ # Process in batches
+ for i in range(0, len(audio_inputs), batch_size):
+ batch_inputs = audio_inputs[i:i + batch_size]
+ print(f"\n{'='*60}")
+ print(f"Processing batch {i//batch_size + 1}/{(len(audio_inputs) + batch_size - 1)//batch_size}")
+
+ batch_results = self.transcribe_batch(
+ batch_inputs,
+ max_new_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ do_sample=do_sample,
+ num_beams=num_beams,
+ )
+ all_results.extend(batch_results)
+
+ return all_results
+
+
+def print_result(result: Dict[str, Any]):
+ """Pretty print a single transcription result."""
+ print(f"\nFile: {result['file']}")
+ print(f"Generation Time: {result['generation_time']:.2f}s")
+ print(f"\n--- Raw Output ---")
+ print(result['raw_text'][:500] + "..." if len(result['raw_text']) > 500 else result['raw_text'])
+
+ if result['segments']:
+ print(f"\n--- Structured Output ({len(result['segments'])} segments) ---")
+ for seg in result['segments'][:50]: # Show first 50 segments
+ print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] "
+ f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')}...")
+ if len(result['segments']) > 50:
+ print(f" ... and {len(result['segments']) - 50} more segments")
+
+
+def load_dataset_and_concatenate(
+ dataset_name: str,
+ split: str,
+ max_duration: float,
+ num_audios: int,
+ target_sr: int = 24000
+) -> Optional[List[np.ndarray]]:
+ """
+ Load a HuggingFace dataset and concatenate audio samples into long audio chunks.
+ (Note, just for demo purpose, not for benchmark evaluation)
+
+ Args:
+ dataset_name: HuggingFace dataset name (e.g., 'openslr/librispeech_asr')
+ split: Dataset split to use (e.g., 'test', 'test.other')
+ max_duration: Maximum duration in seconds for each concatenated audio
+ num_audios: Number of concatenated audios to create
+ target_sr: Target sample rate (default: 24000)
+
+ Returns:
+ List of concatenated audio arrays, or None if loading fails
+ """
+ try:
+ from datasets import load_dataset
+ import torchcodec # just for decode audio in datasets
+ except ImportError:
+ print("Please install it with: pip install datasets torchcodec")
+ return None
+
+ print(f"\nLoading dataset: {dataset_name} (split: {split})")
+ print(f"Will create {num_audios} concatenated audio(s), each up to {max_duration:.1f}s ({max_duration/3600:.2f} hours)")
+
+ try:
+ # Use streaming to avoid downloading the entire dataset
+ dataset = load_dataset(dataset_name, split=split, streaming=True)
+ print(f"Dataset loaded in streaming mode")
+
+ concatenated_audios = [] # List of concatenated audio metadata
+
+ # Create multiple concatenated audios based on num_audios
+ current_chunks = []
+ current_duration = 0.0
+ current_samples_used = 0
+ sample_idx = 0
+
+ for sample in dataset:
+ if len(concatenated_audios) >= num_audios:
+ break
+
+ if 'audio' not in sample:
+ continue
+
+ audio_data = sample['audio']
+ audio_array = audio_data['array']
+ sr = audio_data['sampling_rate']
+
+ # Resample if needed
+ if sr != target_sr:
+ duration = len(audio_array) / sr
+ new_length = int(duration * target_sr)
+ audio_array = np.interp(
+ np.linspace(0, len(audio_array) - 1, new_length),
+ np.arange(len(audio_array)),
+ audio_array
+ )
+
+ chunk_duration = len(audio_array) / target_sr
+
+ # Check if adding this chunk exceeds max_duration
+ if current_duration + chunk_duration > max_duration:
+ remaining_duration = max_duration - current_duration
+ if remaining_duration > 0.5: # Only add if > 0.5s remaining
+ samples_to_take = int(remaining_duration * target_sr)
+ current_chunks.append(audio_array[:samples_to_take])
+ current_duration += remaining_duration
+ current_samples_used += 1
+
+ # Save current concatenated audio and start a new one
+ if current_chunks:
+ concatenated_audios.append({
+ 'array': np.concatenate(current_chunks),
+ 'duration': current_duration,
+ 'samples_used': current_samples_used,
+ })
+ print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
+
+ # Reset for next concatenated audio
+ current_chunks = []
+ current_duration = 0.0
+ current_samples_used = 0
+
+ if len(concatenated_audios) >= num_audios:
+ break
+
+ current_chunks.append(audio_array)
+ current_duration += chunk_duration
+ current_samples_used += 1
+
+ sample_idx += 1
+ if sample_idx % 100 == 0:
+ print(f" Processed {sample_idx} samples...")
+
+ # Don't forget the last batch if it has content
+ if current_chunks and len(concatenated_audios) < num_audios:
+ concatenated_audios.append({
+ 'array': np.concatenate(current_chunks),
+ 'duration': current_duration,
+ 'samples_used': current_samples_used,
+ })
+ print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
+
+ if not concatenated_audios:
+ print("Warning: No audio samples found in dataset")
+ return None
+
+ # Extract arrays and print summary
+ result = [a['array'] for a in concatenated_audios]
+ total_duration = sum(a['duration'] for a in concatenated_audios)
+ total_samples = sum(a['samples_used'] for a in concatenated_audios)
+ print(f"\nCreated {len(result)} concatenated audio(s), total {total_duration:.1f}s ({total_duration/60:.1f} min) from {total_samples} samples")
+
+ return result
+
+ except Exception as e:
+ print(f"Error loading dataset: {e}")
+ import traceback
+ traceback.print_exc()
+ return None
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VibeVoice ASR Batch Inference Demo")
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="",
+ help="Path to the model checkpoint"
+ )
+ parser.add_argument(
+ "--audio_files",
+ type=str,
+ nargs='+',
+ required=False,
+ help="Paths to audio files for transcription"
+ )
+ parser.add_argument(
+ "--audio_dir",
+ type=str,
+ required=False,
+ help="Directory containing audio files for batch transcription"
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ required=False,
+ help="HuggingFace dataset name (e.g., 'openslr/librispeech_asr')"
+ )
+ parser.add_argument(
+ "--split",
+ type=str,
+ default="test",
+ help="Dataset split to use (e.g., 'test', 'test.other', 'test.clean')"
+ )
+ parser.add_argument(
+ "--max_duration",
+ type=float,
+ default=3600.0,
+ help="Maximum duration in seconds for concatenated dataset audio (default: 3600 = 1 hour)"
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=2,
+ help="Batch size for processing multiple files"
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda" if torch.cuda.is_available() else "cpu",
+ choices=["cuda", "cpu", "auto"],
+ help="Device to run inference on"
+ )
+ parser.add_argument(
+ "--max_new_tokens",
+ type=int,
+ default=32768,
+ help="Maximum number of tokens to generate"
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=0.0,
+ help="Temperature for sampling (0 = greedy decoding)"
+ )
+ parser.add_argument(
+ "--top_p",
+ type=float,
+ default=1.0,
+ help="Top-p for nucleus sampling"
+ )
+ parser.add_argument(
+ "--num_beams",
+ type=int,
+ default=1,
+ help="Number of beams for beam search. Use 1 for greedy/sampling"
+ )
+ parser.add_argument(
+ "--attn_implementation",
+ type=str,
+ default="flash_attention_2",
+ help="Attention implementation to use (default: flash_attention_2)"
+ )
+
+ args = parser.parse_args()
+
+ # Collect audio files
+ audio_files = []
+ concatenated_audio = None # For storing concatenated dataset audio
+
+ if args.audio_files:
+ audio_files.extend(args.audio_files)
+
+ if args.audio_dir:
+ import glob
+ for ext in ["*.wav", "*.mp3", "*.flac", "*.mp4", "*.m4a", "*.webm"]:
+ audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext)))
+
+ if args.dataset:
+ concatenated_audio = load_dataset_and_concatenate(
+ dataset_name=args.dataset,
+ split=args.split,
+ max_duration=args.max_duration,
+ num_audios=args.batch_size,
+ )
+ if concatenated_audio is None:
+ return
+
+ if len(audio_files) == 0 and concatenated_audio is None:
+ print("No audio files provided. Please specify --audio_files, --audio_dir, or --dataset.")
+ return
+
+ if audio_files:
+ print(f"\nAudio files to process ({len(audio_files)}):")
+ for f in audio_files:
+ print(f" - {f}")
+
+ if concatenated_audio:
+ print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
+
+ # Initialize model
+ asr = VibeVoiceASRBatchInference(
+ model_path=args.model_path,
+ device=args.device,
+ dtype=torch.bfloat16 if args.device != "cpu" else torch.float32,
+ attn_implementation=args.attn_implementation
+ )
+
+ # If temperature is 0, use greedy decoding (no sampling)
+ do_sample = args.temperature > 0
+
+ # Combine all audio inputs
+ all_audio_inputs = audio_files + (concatenated_audio or [])
+
+ print("\n" + "="*80)
+ print(f"Processing {len(all_audio_inputs)} audio(s)")
+ print("="*80)
+
+ all_results = asr.transcribe_with_batching(
+ all_audio_inputs,
+ batch_size=args.batch_size,
+ max_new_tokens=args.max_new_tokens,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ do_sample=do_sample,
+ num_beams=args.num_beams,
+ )
+
+ # Print results
+ print("\n" + "="*80)
+ print("Results")
+ print("="*80)
+ for result in all_results:
+ print("\n" + "-"*60)
+ print_result(result)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/vibevoice-asr.md b/docs/vibevoice-asr.md
new file mode 100644
index 0000000..5b2b130
--- /dev/null
+++ b/docs/vibevoice-asr.md
@@ -0,0 +1,62 @@
+# VibeVoice-ASR: Long-Form Rich Transcription with User Prompts
+
+**VibeVoice-ASR** is the latest addition to the **VibeVoice** family. While the original VibeVoice / VibeVoice-Realtime focused on expressive TTS, **VibeVoice-ASR** focuses on understanding long-form speech with high precision and rich metadata.
+
+It is a unified speech-to-text model designed to handle **1-hour long-form audio** in a single pass, generating structured transcriptions containing **Who (Speaker), When (Timestamps), and What (Content)**, with support for **User-Customized Context**.
+
+## π₯ Key Features
+
+- **π 60-min Single-Pass Processing**:
+ Unlike conventional ASR models that slice audio into short chunks (often losing global context), VibeVoice ASR accepts up to **60 minutes** of continuous audio input within 64K length. This ensures consistent speaker tracking and semantic coherence across the entire hour.
+
+- **π€ Optional Context Injection**:
+ Users can provide customized context (e.g., specific names, technical terms, or background info) to guide the recognition process, significantly improving accuracy on domain-specific content.
+
+- **π Rich Transcription (Who, When, What)**:
+ The model performs ASR, Diarization, and Timestamping simultaneously. The output is a structured sequence indicating *who* said *what* at *which time*.
+
+## ποΈ Model Architecture
+
+
+
+
+
+## Installation
+We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment.
+
+1. Launch docker
+```bash
+# NVIDIA PyTorch Container 24.07 ~ 25.12 verified.
+# Previous versions are also compatible.
+sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulimit stack=-1:-1 --gpus all --rm -it nvcr.io/nvidia/pytorch:25.12-py3
+
+## If flash attention is not included in your docker environment, you need to install it manually
+## Refer to https://github.com/Dao-AILab/flash-attention for installation instructions
+# pip install flash-attn --no-build-isolation
+```
+
+2. Install from github
+```bash
+git clone https://github.com/microsoft/VibeVoice.git
+cd VibeVoice
+pip install -e .[asr]
+```
+
+## Usages
+
+### Usage 1: Launch Gradio demo
+```bash
+apt update && apt install ffmpeg -y # for demo
+
+python demo/vibevoice_asr_gradio_demo.py --model_path microsoft/VibeVoice-ASR --share
+```
+
+### Usage 2: Inference from files directly
+```bash
+python demo/vibevoice_asr_inference_from_file.py --model_path microsoft/VibeVoice-ASR --audio_files [add a audio path here]
+```
+
+
+## π License
+
+This project is licensed under the [MIT License](../LICENSE).
diff --git a/pyproject.toml b/pyproject.toml
index a36e9cd..6dc69b3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
[project]
name = "vibevoice"
-version = "0.0.1"
+version = "1.0.0"
authors = [
- { name="vibevoice team", email="vibepod@microsoft.com" },
+ { name="vibevoice team", email="VibeVoice@microsoft.com" },
]
-description = "A model for speech generation with an AR + diffusion architecture."
+description = "Open-Source Frontier Voice AI."
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
@@ -18,8 +18,7 @@ classifiers = [
]
dependencies = [
"torch",
- "accelerate==1.6.0",
- "transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
+ "accelerate",
"llvmlite>=0.40.0",
"numba>=0.57.0",
"diffusers",
@@ -30,12 +29,21 @@ dependencies = [
"ml-collections",
"absl-py",
"gradio",
- "av",
- "aiortc",
- "uvicorn[standard]",
- "fastapi"
]
+[project.optional-dependencies]
+tts = [
+ "transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
+ "av",
+ "aiortc",
+ "uvicorn[standard]",
+ "fastapi"
+]
+
+asr = [
+ "transformers>=4.51.3", # the versions after 4.51.3 are all support
+ "pydub" # for visualization
+]
[project.urls]
"Homepage" = "https://github.com/microsoft/VibeVoice"
diff --git a/vibevoice/modular/configuration_vibevoice.py b/vibevoice/modular/configuration_vibevoice.py
index fcffcb9..02b2751 100644
--- a/vibevoice/modular/configuration_vibevoice.py
+++ b/vibevoice/modular/configuration_vibevoice.py
@@ -240,9 +240,112 @@ class VibeVoiceConfig(PretrainedConfig):
super().__init__(**kwargs)
+class VibeVoiceASRConfig(PretrainedConfig):
+ model_type = "vibevoice"
+ is_composition = True
+ sub_configs = {
+ "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
+ "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
+ "decoder_config": Qwen2Config,
+ }
+ # keys_to_ignore_at_inference = ["past_key_values"]
+ # Default tensor parallel plan for base model `Qwen2`
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.mlp.gate_proj": "colwise",
+ "layers.*.mlp.up_proj": "colwise",
+ "layers.*.mlp.down_proj": "rowwise",
+ }
+
+ def __init__(
+ self,
+ acoustic_tokenizer_config=None,
+ semantic_tokenizer_config=None,
+ decoder_config=None,
+ **kwargs
+ ):
+
+ # kwargs["_attn_implementation"] = "flash_attention_2"
+ kwargs["_attn_implementation_autoset"] = False
+
+ if acoustic_tokenizer_config is None:
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
+ elif isinstance(acoustic_tokenizer_config, dict):
+ acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
+ self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
+ elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
+ # If an instance of the config class is provided
+ self.acoustic_tokenizer_config = acoustic_tokenizer_config
+
+ if semantic_tokenizer_config is None:
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
+ elif isinstance(semantic_tokenizer_config, dict):
+ semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
+ self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
+ elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
+ # If an instance of the config class is provided
+ self.semantic_tokenizer_config = semantic_tokenizer_config
+
+ if decoder_config is None:
+ self.decoder_config = self.sub_configs["decoder_config"]()
+ elif isinstance(decoder_config, dict):
+ # If a dictionary is provided, instantiate the config class with it
+ # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
+ if decoder_config.get("model_type", '') == "qwen2":
+ self.decoder_config = Qwen2Config(**decoder_config)
+ else:
+ raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
+ elif isinstance(decoder_config, Qwen2Config):
+ # If an instance of the config class is provided
+ self.decoder_config = decoder_config
+
+ # other parameters
+ self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
+ self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
+
+ super().__init__(**kwargs)
+
+ def get_text_config(self, decoder: bool = False):
+ """Return the text (decoder) config for generation."""
+ return self.decoder_config
+
+ @property
+ def vocab_size(self):
+ """Return vocab_size from decoder config for generation compatibility."""
+ return self.decoder_config.vocab_size
+
+ @property
+ def num_attention_heads(self):
+ """Return num_attention_heads from decoder config for Ulysses SP compatibility."""
+ return self.decoder_config.num_attention_heads
+
+ @property
+ def num_key_value_heads(self):
+ """Return num_key_value_heads from decoder config for Ulysses SP compatibility."""
+ return self.decoder_config.num_key_value_heads
+
+ @property
+ def hidden_size(self):
+ """Return hidden_size from decoder config for model compatibility."""
+ return self.decoder_config.hidden_size
+
+ @property
+ def num_hidden_layers(self):
+ """Return num_hidden_layers from decoder config for Ulysses SP compatibility."""
+ return self.decoder_config.num_hidden_layers
+
+ @property
+ def head_dim(self):
+ """Return head_dim from decoder config for Ulysses SP compatibility."""
+ return getattr(self.decoder_config, 'head_dim', self.hidden_size // self.num_attention_heads)
+
__all__ = [
"VibeVoiceAcousticTokenizerConfig",
"VibeVoiceSemanticTokenizerConfig",
"VibeVoiceDiffusionHeadConfig",
- "VibeVoiceConfig"
+ "VibeVoiceConfig",
+ "VibeVoiceASRConfig"
]
\ No newline at end of file
diff --git a/vibevoice/modular/modeling_vibevoice.py b/vibevoice/modular/modeling_vibevoice.py
new file mode 100644
index 0000000..a4ecbab
--- /dev/null
+++ b/vibevoice/modular/modeling_vibevoice.py
@@ -0,0 +1,496 @@
+# copied from https://github.com/vibevoice-community/VibeVoice/blob/main/vibevoice/modular/modeling_vibevoice.py
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple, Union, Callable
+from tqdm import tqdm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+
+from transformers.models.auto import AutoModel, AutoModelForCausalLM
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
+from transformers.models.llama.modeling_llama import LlamaRMSNorm
+from transformers import modeling_utils
+from transformers.modeling_utils import PreTrainedModel
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.utils import logging
+
+
+from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
+from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
+from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
+
+from .configuration_vibevoice import VibeVoiceConfig
+
+
+logger = logging.get_logger(__name__)
+
+if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
+
+@dataclass
+class VibeVoiceCausalLMOutputWithPast(ModelOutput):
+ loss: Optional[torch.FloatTensor] = None
+ diffusion_loss: Optional[torch.FloatTensor] = None
+ speech_token_num: Optional[int] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class VibeVoiceGenerationOutput(ModelOutput):
+ """
+ Output type for VibeVoice generation.
+
+ Args:
+ sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The generated sequences.
+ speech_outputs (`List[torch.FloatTensor]`, *optional*):
+ List of generated speech waveforms or latents for each speech segment.
+ """
+ sequences: torch.LongTensor = None
+ speech_outputs: Optional[List[torch.FloatTensor]] = None
+
+
+class SpeechConnector(nn.Module):
+ def __init__(self, input_dim, output_dim):
+ super().__init__()
+ self.fc1 = nn.Linear(input_dim, output_dim)
+ self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
+ self.fc2 = nn.Linear(output_dim, output_dim)
+
+ def forward(self, features, **kwargs):
+ x = self.fc1(features)
+ x = self.norm(x)
+ x = self.fc2(x)
+ return x
+
+
+# @auto_docstring
+class VibeVoicePreTrainedModel(PreTrainedModel):
+ config_class = VibeVoiceConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ if isinstance(module, VibeVoiceDiffusionHead):
+ module.initialize_weights()
+ return
+
+ # Use the language model's initializer_range if available
+ if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
+ std = self.config.language_model_config.initializer_range
+ elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
+ std = self.config.decoder_config.initializer_range
+ else:
+ std = 0.02 # Default value
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+# @auto_docstring
+class VibeVoiceModel(VibeVoicePreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
+ if isinstance(config.torch_dtype, str):
+ dtype = getattr(torch, config.torch_dtype)
+ else:
+ dtype = config.torch_dtype
+ else:
+ dtype = torch.float32
+
+ # Initialize Qwen2 model for language modeling
+ lm_config = config.decoder_config
+ self.language_model = AutoModel.from_config(lm_config)
+
+ # Initialize speech components if needed
+ self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
+ self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
+
+ self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
+ self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
+
+ # Register scaling factors as buffers - use 1D tensors for FSDP compatibility
+ self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
+ self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
+
+ # Initialize prediction head for speech generation
+ self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
+
+ # Initialize noise scheduler
+ self.noise_scheduler = DPMSolverMultistepScheduler(
+ num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
+ beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
+ prediction_type=config.diffusion_head_config.prediction_type
+ )
+
+ def get_input_embeddings(self):
+ if hasattr(self.language_model, 'embed_tokens'):
+ # If the language model has an embed_tokens attribute, return it
+ return self.language_model.embed_tokens
+
+ for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
+ if attr.orig_name == 'embed_tokens.weight':
+ return getattr(self.language_model, name)
+ assert False, 'should not arrive here'
+
+ def set_input_embeddings(self, value):
+ self.language_model.embed_tokens = value
+
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
+ """Set the speech tokenizers used for encoding and decoding speech."""
+ self.acoustic_tokenizer = acoustic_tokenizer
+ self.semantic_tokenizer = semantic_tokenizer
+
+ # Reset the encoder to evaluation mode
+ if self.acoustic_tokenizer is not None:
+ self.acoustic_tokenizer.eval()
+
+ if self.semantic_tokenizer is not None:
+ self.semantic_tokenizer.eval()
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Forward through language model
+ outputs = self.language_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = VibeVoiceModel(config)
+ self.vocab_size = config.decoder_config.vocab_size
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_decoder(self, decoder):
+ self.model.language_model = decoder
+
+ def get_decoder(self):
+ return self.model.language_model
+
+ def tie_weights(self):
+ """
+ Tie the weights between the input embeddings and the output embeddings.
+ """
+ if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
+ # The standard PreTrainedModel method will handle the tying.
+ # It typically does a simple parameter object assignment, which is
+ # CORRECT to do BEFORE FSDP wraps the model.
+ output_embeddings = self.get_output_embeddings()
+ input_embeddings = self.get_input_embeddings()
+ if hasattr(input_embeddings, 'weight'):
+ output_embeddings.weight = input_embeddings.weight
+ else:
+ # maybe returned input_embeddings a tensor directly
+ output_embeddings.weight = input_embeddings
+
+ if getattr(output_embeddings, "bias", None) is not None:
+ output_embeddings.bias.data = nn.functional.pad(
+ output_embeddings.bias.data,
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
+ "constant",
+ 0,
+ )
+ print("Tied input and output embeddings using standard assignment.")
+ else:
+ print("tie_word_embeddings is False, not tying weights.")
+
+ # Also, ensure set_output_embeddings is safe, though your implementation looks okay.
+ # The key is to avoid calling it after accelerator.prepare().
+ def set_output_embeddings(self, new_embeddings):
+ # Your current implementation using data.copy_ is good practice,
+ # but the best way is to not call this after prepare().
+ self.lm_head = new_embeddings
+
+ def forward_speech_features(
+ self,
+ speech_tensors=None,
+ speech_masks=None,
+ speech_type="audio",
+ return_unmask=False
+ ):
+ if speech_tensors is None:
+ # Use config to get vae_dim instead of non-existent self.args
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
+ audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
+ connect_features = self.model.acoustic_connector(audio_features)
+ return audio_features, connect_features
+ else:
+ with torch.no_grad():
+ if speech_type == "audio":
+ with torch.no_grad():
+ frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
+ audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
+
+ elif speech_type == "vae":
+ # Use config to get vae_dim instead of non-existent self.args
+ vae_dim = self.config.acoustic_tokenizer_config.vae_dim
+ speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
+
+ # gaussian sample from the speech_mode
+ batch_size = speech_mode.size(0)
+ value = self.model.acoustic_tokenizer.fix_std / 0.8
+ std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
+ std = std.view(-1, *[1] * (speech_mode.dim() - 1))
+ audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
+ else:
+ raise NotImplementedError(f"Speech type {speech_type} not implemented")
+
+ if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
+ scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
+ bias_factor = -audio_tokens[speech_masks].flatten().mean()
+
+ # Only use distributed operations if the process group is initialized
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
+ dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
+ world_size = dist.get_world_size()
+ self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
+ self.model.speech_bias_factor.copy_(bias_factor / world_size)
+ print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
+ else:
+ # Single process case
+ self.model.speech_scaling_factor.copy_(scaling_factor)
+ self.model.speech_bias_factor.copy_(bias_factor)
+ print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
+
+ audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
+
+ connect_features = self.model.acoustic_connector(audio_features)
+ if return_unmask:
+ return audio_features, connect_features
+ return audio_features[speech_masks], connect_features[speech_masks]
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ # New arguments for speech processing and loss calculation
+ speech_tensors: Optional[torch.FloatTensor] = None,
+ speech_masks: Optional[torch.BoolTensor] = None,
+ speeches_loss_input: Optional[torch.FloatTensor] = None,
+ speech_semantic_tensors: Optional[torch.FloatTensor] = None,
+ acoustic_input_mask: Optional[torch.BoolTensor] = None,
+ acoustic_loss_mask: Optional[torch.BoolTensor] = None,
+ ddpm_batch_mul: int = 1,
+ **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
+ ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ x = self.get_input_embeddings()(input_ids)
+
+ semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
+ if speeches_loss_input is not None:
+ # only part audio need diffuse
+ speech_all_features, speech_all_connect_features = self.forward_speech_features(
+ speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
+ speech_masks=speech_masks,
+ speech_type=kwargs.get("speech_type", "audio"),
+ return_unmask=True
+ )
+ if speech_tensors is not None:
+ if semantic_speech_all_connect_features is not None:
+ x[acoustic_input_mask] = (
+ speech_all_connect_features[speech_masks]
+ + semantic_speech_all_connect_features[speech_masks]
+ )
+ else:
+ x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
+
+ # Select only the target segments' latents for diffusion loss.
+ # Both masks are [num_segments, max_latent_len]; using 2D mask on [B,T,D] selects [N_true, D].
+ target_latent_mask = speeches_loss_input & speech_masks
+ speech_features = speech_all_features[target_latent_mask]
+ speech_connect_features = speech_all_connect_features[target_latent_mask]
+ else:
+ speech_features, speech_connect_features = self.forward_speech_features(
+ speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
+ speech_masks=speech_masks,
+ speech_type=kwargs.get("speech_type", "audio"),
+ )
+ if speech_tensors is not None:
+ x[acoustic_input_mask] = speech_connect_features
+
+ outputs = self.model(
+ input_ids=None,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=x,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=False,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ logits = self.lm_head(hidden_states)
+ # logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # The custom CE loss with masking is calculated in the training script.
+ # We leave the standard loss calculation here as None.
+ pass
+
+ # --- Diffusion Loss Calculation ---
+ diffusion_loss = None
+ # This block is executed only if we are in a context that involves speech.
+ if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
+ condition_features = hidden_states[acoustic_loss_mask]
+
+ speech_len, latent_size = speech_features.shape
+
+ noise = torch.randn(
+ (speech_len * ddpm_batch_mul, latent_size),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype
+ )
+
+ timesteps = torch.multinomial(
+ torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
+ speech_len * ddpm_batch_mul,
+ replacement=True,
+ ).to(hidden_states.device)
+
+ speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
+ condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
+
+ noisy_speech_features = self.model.noise_scheduler.add_noise(
+ speech_features_repeated, noise, timesteps
+ )
+
+ model_output = self.model.prediction_head(
+ noisy_speech_features,
+ timesteps.type_as(x),
+ condition_features_repeated
+ )
+
+ prediction_type = self.config.diffusion_head_config.prediction_type
+ if prediction_type == "epsilon":
+ target_for_loss = noise
+ elif prediction_type == "v_prediction":
+ target_for_loss = self.model.noise_scheduler.get_velocity(
+ speech_features_repeated, noise, timesteps
+ )
+ else:
+ raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
+
+ diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
+ if latent_size > 0 and ddpm_batch_mul > 0:
+ diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
+ else:
+ diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
+
+ else:
+ # Dummy loss for DDP to work when there are no speech samples in a batch,
+ # but we are in a speech context.
+ diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
+ diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
+ diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
+ # --- End Diffusion Loss Calculation ---
+
+ if not return_dict:
+ output = (logits, speech_len) + outputs.to_tuple()[1:]
+ return (loss, diffusion_loss) + output
+
+ return VibeVoiceCausalLMOutputWithPast(
+ loss=loss,
+ diffusion_loss=diffusion_loss,
+ speech_token_num=speech_len if speech_tensors is not None else 0,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
+AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
+
+__all__ = [
+ "VibeVoiceModel",
+ "VibeVoicePreTrainedModel",
+ "VibeVoiceForConditionalGeneration",
+ "VibeVoiceCausalLMOutputWithPast",
+ "VibeVoiceGenerationOutput",
+]
\ No newline at end of file
diff --git a/vibevoice/modular/modeling_vibevoice_asr.py b/vibevoice/modular/modeling_vibevoice_asr.py
new file mode 100644
index 0000000..706bf00
--- /dev/null
+++ b/vibevoice/modular/modeling_vibevoice_asr.py
@@ -0,0 +1,520 @@
+from typing import List, Optional, Tuple, Union
+import torch
+import torch.nn as nn
+
+from transformers.models.auto import AutoModel, AutoModelForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast
+from transformers import modeling_utils
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.generation import GenerationMixin
+
+from .modular_vibevoice_tokenizer import (
+ VibeVoiceTokenizerStreamingCache,
+ VibeVoiceTokenizerEncoderOutput
+)
+
+from .configuration_vibevoice import VibeVoiceASRConfig
+from .modeling_vibevoice import (
+ VibeVoiceCausalLMOutputWithPast,
+ SpeechConnector
+)
+
+logger = logging.get_logger(__name__)
+
+if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
+ modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
+
+# @auto_docstring
+class VibeVoiceASRPreTrainedModel(PreTrainedModel):
+ config_class = VibeVoiceASRConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_flash_attn = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+
+ # Use the language model's initializer_range if available
+ if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
+ std = self.config.language_model_config.initializer_range
+ elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
+ std = self.config.decoder_config.initializer_range
+ else:
+ std = 0.02 # Default value
+
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(1.0)
+ module.bias.data.zero_()
+
+# @auto_docstring
+class VibeVoiceASRModel(VibeVoiceASRPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
+ if isinstance(config.torch_dtype, str):
+ dtype = getattr(torch, config.torch_dtype)
+ else:
+ dtype = config.torch_dtype
+ else:
+ dtype = torch.float32
+
+ # Initialize Qwen2 model for language modeling
+ lm_config = config.decoder_config
+ self.language_model = AutoModel.from_config(lm_config)
+
+ # Initialize speech components if needed
+ self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
+ self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
+
+ self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
+ self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
+
+ def get_input_embeddings(self):
+ if hasattr(self.language_model, 'embed_tokens'):
+ # If the language model has an embed_tokens attribute, return it
+ return self.language_model.embed_tokens
+
+ for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
+ if attr.orig_name == 'embed_tokens.weight':
+ return getattr(self.language_model, name)
+ assert False, 'should not arrive here'
+
+ def set_input_embeddings(self, value):
+ self.language_model.embed_tokens = value
+
+ def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
+ """Set the speech tokenizers used for encoding and decoding speech."""
+ self.acoustic_tokenizer = acoustic_tokenizer
+ self.semantic_tokenizer = semantic_tokenizer
+
+ # Reset the encoder to evaluation mode
+ if self.acoustic_tokenizer is not None:
+ self.acoustic_tokenizer.eval()
+
+ if self.semantic_tokenizer is not None:
+ self.semantic_tokenizer.eval()
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Forward through language model
+ outputs = self.language_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ if not return_dict:
+ return outputs
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, GenerationMixin):
+ """
+ VibeVoice model for Automatic Speech Recognition (ASR) with language modeling head for conditional generation.
+ This class is designed for inference and generation tasks.
+ """
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = VibeVoiceASRModel(config)
+ self.vocab_size = config.decoder_config.vocab_size
+
+ # Determine the dtype to use
+ if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
+ if isinstance(config.torch_dtype, str):
+ dtype = getattr(torch, config.torch_dtype)
+ else:
+ dtype = config.torch_dtype
+ else:
+ dtype = torch.float32
+
+ # Initialize lm_head with the correct dtype
+ self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False).to(dtype)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.language_model = decoder
+
+ def get_decoder(self):
+ return self.model.language_model
+
+ def tie_weights(self):
+ """Tie the weights between the input embeddings and the output embeddings."""
+ if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
+ output_embeddings = self.get_output_embeddings()
+ input_embeddings = self.get_input_embeddings()
+ if hasattr(input_embeddings, 'weight'):
+ output_embeddings.weight = input_embeddings.weight
+ else:
+ output_embeddings.weight = input_embeddings
+
+ def encode_speech(
+ self,
+ speech_tensors: torch.FloatTensor,
+ speech_masks: Optional[torch.BoolTensor] = None,
+ speech_semantic_tensors: Optional[torch.FloatTensor] = None,
+ streaming_segment_duration: float = 60.0, # seconds
+ ):
+ """
+ Encode speech input into features that can be used by the language model.
+ This method is called once before generation to process the speech input.
+
+ For long audio (>600s by default), uses streaming processing to avoid conv overflow (>2^32).
+ Segments are processed independently, then concatenated before final sampling.
+
+ Args:
+ speech_tensors: Input audio tensor [batch_size, samples]
+ speech_masks: Optional mask for speech features
+ speech_semantic_tensors: Optional pre-computed semantic tokens
+ streaming_segment_duration: Segment duration in seconds for streaming processing (default: 60s)
+ """
+ if hasattr(self.config, 'torch_dtype') and self.config.torch_dtype is not None:
+ if isinstance(self.config.torch_dtype, str):
+ dtype = getattr(torch, self.config.torch_dtype)
+ else:
+ dtype = self.config.torch_dtype
+ else:
+ dtype = torch.float32
+
+ speech_tensors = speech_tensors.to(dtype)
+
+ # Ensure proper shape: (batch, samples)
+ if speech_tensors.ndim == 1:
+ speech_tensors = speech_tensors.unsqueeze(0)
+
+ batch_size, total_samples = speech_tensors.shape
+ sample_rate = 24000 # fix 24kHz sample rate
+
+ # Calculate segment size in samples
+ segment_samples = int(streaming_segment_duration * sample_rate)
+
+ # Decide whether to use streaming based on audio length
+ use_streaming = total_samples > segment_samples
+
+ with torch.no_grad():
+ if not use_streaming:
+ # Short audio: direct processing (original behavior)
+ encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
+ audio_tokens = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
+ acoustic_features = self.model.acoustic_connector(audio_tokens)
+
+ # Encode semantic features
+ if speech_semantic_tensors is not None:
+ semantic_features = self.model.semantic_connector(speech_semantic_tensors)
+ else:
+ semantic_tokens = self.model.semantic_tokenizer.encode(speech_tensors.unsqueeze(1)).mean
+ semantic_features = self.model.semantic_connector(semantic_tokens)
+ else:
+ # Long audio: streaming processing
+ # print(f"Using streaming processing for long audio: {total_samples/sample_rate:.1f}s "
+ # f"(segment size: {streaming_segment_duration}s)")
+
+ # Initialize caches for both tokenizers
+ acoustic_encoder_cache = VibeVoiceTokenizerStreamingCache()
+ semantic_encoder_cache = VibeVoiceTokenizerStreamingCache()
+ acoustic_mean_segments = []
+ semantic_mean_segments = []
+ sample_indices = torch.arange(batch_size, device=speech_tensors.device)
+
+ # Helper function from batch_asr_sft_cache.py
+ def _iter_segments(total_length: int, segment_length: int):
+ """Iterate over audio segments with a given segment length."""
+ if segment_length <= 0:
+ raise ValueError("segment_length must be positive")
+ for start in range(0, total_length, segment_length):
+ end = min(start + segment_length, total_length)
+ if end > start:
+ yield start, end
+
+ # Process each segment for both acoustic and semantic tokenizers
+ segments = list(_iter_segments(total_samples, segment_samples))
+ num_segments = len(segments)
+ for seg_idx, (start, end) in enumerate(segments):
+ chunk = speech_tensors[:, start:end].contiguous()
+ if chunk.numel() == 0:
+ continue
+
+ # Check if this is the final segment
+ is_final = (seg_idx == num_segments - 1)
+
+ # Encode chunk for acoustic tokenizer (don't sample yet)
+ acoustic_encoder_output = self.model.acoustic_tokenizer.encode(
+ chunk.unsqueeze(1),
+ cache=acoustic_encoder_cache,
+ sample_indices=sample_indices,
+ use_cache=True,
+ is_final_chunk=is_final,
+ )
+ acoustic_mean_segments.append(acoustic_encoder_output.mean)
+
+ # Encode chunk for semantic tokenizer (take mean directly)
+ semantic_encoder_output = self.model.semantic_tokenizer.encode(
+ chunk.unsqueeze(1),
+ cache=semantic_encoder_cache,
+ sample_indices=sample_indices,
+ use_cache=True,
+ is_final_chunk=is_final,
+ )
+ semantic_mean_segments.append(semantic_encoder_output.mean)
+
+ # print(f"Processed {len(acoustic_mean_segments)} segments.")
+ # Concatenate all acoustic means and sample once
+ acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous()
+ acoustic_encoder_output = VibeVoiceTokenizerEncoderOutput(
+ mean=acoustic_mean_full,
+ std=self.model.acoustic_tokenizer.fix_std
+ )
+ audio_tokens = acoustic_encoder_output.sample(
+ dist_type=self.model.acoustic_tokenizer.std_dist_type
+ )[0]
+ acoustic_features = self.model.acoustic_connector(audio_tokens)
+
+ # Concatenate all semantic means
+ semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
+ semantic_features = self.model.semantic_connector(semantic_tokens)
+
+ # Combine acoustic and semantic features
+ if speech_masks is not None:
+ combined_features = acoustic_features[speech_masks] + semantic_features[speech_masks]
+ else:
+ combined_features = acoustic_features + semantic_features
+
+ return combined_features
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ # Speech-specific arguments
+ speech_tensors: Optional[torch.FloatTensor] = None,
+ speech_masks: Optional[torch.BoolTensor] = None,
+ speech_semantic_tensors: Optional[torch.FloatTensor] = None,
+ acoustic_input_mask: Optional[torch.BoolTensor] = None,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutput]:
+ """
+ Forward pass for the model. Handles both training and generation scenarios.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ # Process inputs
+ if inputs_embeds is None and input_ids is not None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # If we have speech input and acoustic_input_mask, encode and insert speech features
+ if speech_tensors is not None and acoustic_input_mask is not None:
+ speech_features = self.encode_speech(
+ speech_tensors=speech_tensors,
+ speech_masks=speech_masks,
+ speech_semantic_tensors=speech_semantic_tensors,
+ )
+ inputs_embeds[acoustic_input_mask] = speech_features
+
+ # Forward through the model
+ outputs = self.model(
+ input_ids=None,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return VibeVoiceCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ speech_tensors=None,
+ speech_masks=None,
+ speech_semantic_tensors=None,
+ acoustic_input_mask=None,
+ **kwargs,
+ ):
+ """
+ Prepare inputs for generation step. This method is called by generate()
+ for each token generation step.
+
+ Following Qwen2-VL's approach: speech inputs are only forwarded on the first pass
+ (when cache_position[0] == 0), and are excluded in subsequent generation steps.
+ """
+ # If we have past key values, we only need to process the new tokens
+ if past_key_values is not None:
+ if isinstance(past_key_values, tuple):
+ past_length = past_key_values[0][0].shape[2]
+ else:
+ past_length = past_key_values.get_seq_length()
+
+ # Keep only the new tokens
+ if input_ids is not None and input_ids.shape[1] > past_length:
+ input_ids = input_ids[:, past_length:]
+
+ # Prepare position ids
+ if position_ids is None and attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values is not None and input_ids is not None:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # Prepare cache position
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]),
+ device=input_ids.device if input_ids is not None else inputs_embeds.device
+ )
+
+ # Prepare model inputs
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+
+ # Following Qwen2-VL pattern: only include speech inputs on the first forward pass
+ # (when cache_position[0] == 0), exclude them in subsequent generation steps
+ if cache_position is not None and len(cache_position) > 0 and cache_position[0] == 0:
+ # First forward pass - include speech inputs if provided
+ model_inputs.update({
+ "speech_tensors": speech_tensors,
+ "speech_masks": speech_masks,
+ "speech_semantic_tensors": speech_semantic_tensors,
+ "acoustic_input_mask": acoustic_input_mask,
+ })
+ else:
+ # Subsequent generation steps - exclude speech inputs
+ model_inputs.update({
+ "speech_tensors": None,
+ "speech_masks": None,
+ "speech_semantic_tensors": None,
+ "acoustic_input_mask": None,
+ })
+
+ # Include any remaining kwargs that might be needed
+ model_inputs.update(kwargs)
+
+ return model_inputs
+
+AutoModel.register(VibeVoiceASRConfig, VibeVoiceASRModel)
+AutoModelForCausalLM.register(VibeVoiceASRConfig, VibeVoiceASRForConditionalGeneration)
+
+__all__ = [
+ "VibeVoiceASRPreTrainedModel",
+ "VibeVoiceASRModel",
+ "VibeVoiceASRForConditionalGeneration",
+]
\ No newline at end of file
diff --git a/vibevoice/modular/modular_vibevoice_text_tokenizer.py b/vibevoice/modular/modular_vibevoice_text_tokenizer.py
index bfa7bdd..da5669f 100644
--- a/vibevoice/modular/modular_vibevoice_text_tokenizer.py
+++ b/vibevoice/modular/modular_vibevoice_text_tokenizer.py
@@ -207,8 +207,107 @@ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
"""Id used for padding (returns -100 for loss masking)."""
return self._pad_id
+class VibeVoiceASRTextTokenizerFast(Qwen2TokenizerFast):
+ """
+ Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
+ Based on the Qwen2 tokenizer with additional special tokens for speech.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ Path to the vocabulary file.
+ merges_file (`str`, *optional*):
+ Path to the merges file.
+ tokenizer_file (`str`, *optional*):
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token.
+ bos_token (`str`, *optional*):
+ The beginning of sequence token. Not used for vibevoice.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The token used for padding.
+ """
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ merges_file=None,
+ tokenizer_file=None,
+ unk_token="<|endoftext|>",
+ bos_token=None,
+ eos_token="<|endoftext|>",
+ pad_token="<|endoftext|>",
+ add_prefix_space=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ tokenizer_file=tokenizer_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_prefix_space=add_prefix_space,
+ **kwargs,
+ )
+
+ # Add VibeVoice-specific special tokens
+ self._add_vibevoice_special_tokens()
+
+ # https://github.com/QwenLM/Qwen2.5-VL/blob/d2240f11656bfe404b9ba56db4e51cd09f522ff1/qwen-vl-finetune/qwenvl/data/data_qwen_packed.py#L57C5-L57C222
+ self.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
+
+ def _add_vibevoice_special_tokens(self):
+ """Add VibeVoice-specific special tokens."""
+ special_tokens = {
+ "additional_special_tokens": [
+ "<|object_ref_start|>", # Speech start (reusing vision tokens)
+ "<|object_ref_end|>", # Speech end
+ "<|box_start|>", # Speech diffusion pad
+ ]
+ }
+ num_added = self.add_special_tokens(special_tokens)
+
+ # Cache special token IDs
+ self._speech_start_id = self.convert_tokens_to_ids("<|object_ref_start|>")
+ self._speech_end_id = self.convert_tokens_to_ids("<|object_ref_end|>")
+ self._speech_pad_id = self.convert_tokens_to_ids("<|box_start|>")
+
+ self._eos_id = self.eos_token_id # qwen2 / qwen3
+ self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
+
+ return num_added
+
+ @property
+ def eos_id(self) -> int:
+ """Id of the end of sequence token."""
+ return self._eos_id
+
+ @property
+ def speech_start_id(self) -> int:
+ """Id of the speech start token."""
+ return self._speech_start_id
+
+ @property
+ def speech_end_id(self) -> int:
+ """Id of the speech end token."""
+ return self._speech_end_id
+
+ @property
+ def speech_pad_id(self) -> int:
+ """Id of the speech diffusion token."""
+ return self._speech_pad_id
+
+ @property
+ def pad_id(self) -> int:
+ return self._pad_id
+
__all__ = [
"VibeVoiceTextTokenizer",
"VibeVoiceTextTokenizerFast",
+ "VibeVoiceASRTextTokenizerFast",
]
\ No newline at end of file
diff --git a/vibevoice/modular/modular_vibevoice_tokenizer.py b/vibevoice/modular/modular_vibevoice_tokenizer.py
index 0031b26..454f9c1 100644
--- a/vibevoice/modular/modular_vibevoice_tokenizer.py
+++ b/vibevoice/modular/modular_vibevoice_tokenizer.py
@@ -17,7 +17,7 @@ from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
-from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig
+from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
logger = logging.get_logger(__name__)
@@ -26,14 +26,13 @@ import os
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
APEX_AVAILABLE = True
- logger.info("APEX FusedRMSNorm is available and will be used for optimization")
+ # logger.info("APEX FusedRMSNorm is available and will be used for optimization")
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
APEX_AVAILABLE = False
- logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
+ # logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
except ImportError:
APEX_AVAILABLE = False
- logger.warning("APEX FusedRMSNorm not available, using native implementation")
-# APEX_AVAILABLE=False
+ # logger.warning("APEX FusedRMSNorm not available, using native implementation")
# Normalization modules
class ConvLayerNorm(nn.LayerNorm):
@@ -297,7 +296,8 @@ class SConv1d(nn.Module):
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
- debug: bool = False) -> torch.Tensor:
+ debug: bool = False,
+ is_final_chunk: bool = False) -> torch.Tensor:
"""
Forward pass with optional streaming support via cache.
@@ -307,6 +307,7 @@ class SConv1d(nn.Module):
sample_indices: Indices identifying each sample for cache management
use_cache: Whether to use cached states for streaming
debug: Whether to print debug information
+ is_final_chunk: Whether this is the final chunk (adds extra padding for alignment)
Returns:
Output tensor
@@ -322,12 +323,13 @@ class SConv1d(nn.Module):
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
- return self._forward_streaming(x, cache, sample_indices, debug)
+ return self._forward_streaming(x, cache, sample_indices, debug, is_final_chunk)
def _forward_streaming(self, x: torch.Tensor,
cache: VibeVoiceTokenizerStreamingCache,
sample_indices: torch.Tensor,
- debug: bool = False) -> torch.Tensor:
+ debug: bool = False,
+ is_final_chunk: bool = False) -> torch.Tensor:
"""Streaming forward pass with cache operations kept separate from compiled code"""
B, C, T = x.shape
@@ -350,6 +352,16 @@ class SConv1d(nn.Module):
input_with_context = torch.cat([cached_states, x], dim=2)
else:
input_with_context = x
+
+ # For final chunk, add extra padding to ensure ceil behavior (same as non-streaming)
+ if is_final_chunk:
+ extra_padding = get_extra_padding_for_conv1d(
+ input_with_context, self.kernel_size, self.stride, self.padding_total
+ )
+ if extra_padding > 0:
+ input_with_context = pad1d(input_with_context, (0, extra_padding), mode=self.pad_mode)
+ if debug:
+ print(f"[DEBUG] Final chunk: added extra_padding={extra_padding}")
if debug:
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
@@ -684,6 +696,135 @@ class Block1D(nn.Module):
return x
+class TokenizerEncoder(nn.Module):
+ """
+ Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
+
+ Args:
+ config: Configuration object with model parameters
+ """
+ def __init__(self, config):
+ super().__init__()
+
+ # Extract parameters from config
+ self.channels = config.channels
+ self.dimension = config.dimension
+ self.n_filters = config.n_filters
+ self.ratios = list(reversed(config.ratios))
+ self.depths = config.depths
+ self.n_residual_layers = getattr(config, "n_residual_layers", 1)
+ self.hop_length = np.prod(self.ratios)
+ self.causal = config.causal
+
+ # Additional config parameters with defaults
+ kernel_size = getattr(config, "kernel_size", 7)
+ last_kernel_size = getattr(config, "last_kernel_size", 7)
+ norm = getattr(config, "norm", "none")
+ norm_params = getattr(config, "norm_params", {})
+ pad_mode = getattr(config, "pad_mode", "reflect")
+ bias = getattr(config, "bias", True)
+ layernorm = getattr(config, "layernorm", "LN")
+ layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
+ layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
+ drop_path_rate = getattr(config, "drop_path_rate", 0.0)
+ mixer_layer = getattr(config, "mixer_layer", "conv")
+ layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
+ disable_last_norm = getattr(config, "disable_last_norm", False)
+
+ # determine the norm type based on layernorm
+ if layernorm == 'LN':
+ norm_type = ConvLayerNorm
+ elif layernorm == 'RMSNorm':
+ norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
+ else:
+ raise ValueError(f"Unsupported norm type: {layernorm}")
+
+ # stem and intermediate downsampling conv layers
+ stem = nn.Sequential(
+ SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
+ )
+
+ self.downsample_layers = nn.ModuleList()
+ self.downsample_layers.append(stem)
+ for i in range(len(self.ratios)):
+ in_ch = self.n_filters * (2 ** i)
+ out_ch = self.n_filters * (2 ** (i + 1))
+ downsample_layer = nn.Sequential(
+ SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ # configure the transformer blocks
+ layer_type = partial(
+ Block1D,
+ mixer_layer=mixer_layer,
+ layernorm=layernorm,
+ eps=layernorm_eps,
+ causal=self.causal,
+ pad_mode=pad_mode,
+ norm=norm,
+ bias=bias,
+ layer_scale_init_value=layer_scale_init_value,
+ )
+
+ self.stages = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
+ cur = 0
+
+ for i in range(len(self.depths)):
+ in_ch = self.n_filters * (2 ** i)
+ stage = nn.Sequential(
+ *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
+ )
+ self.stages.append(stage)
+ cur += self.depths[i]
+
+ if not disable_last_norm:
+ self.norm = norm_type(in_ch, eps=layernorm_eps)
+ else:
+ self.norm = nn.Identity()
+ self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
+
+ def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
+ for i in range(len(self.depths)):
+ # Apply downsampling
+ for layer in self.downsample_layers[i]:
+ if isinstance(layer, SConv1d):
+ x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
+ else:
+ x = layer(x)
+
+ # Apply stage (Block1D contains Convlayer which contains SConv1d)
+ for block in self.stages[i]:
+ if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
+ # Block1D forward with cache support
+ residual = x
+ x = block.norm(x)
+ x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
+ if block.gamma is not None:
+ x = x * block.gamma.unsqueeze(-1)
+ x = residual + x
+
+ # FFN part
+ residual = x
+ x = block.ffn_norm(x)
+ x = x.permute(0, 2, 1)
+ x = block.ffn(x)
+ x = x.permute(0, 2, 1)
+ if block.ffn_gamma is not None:
+ x = x * block.ffn_gamma.unsqueeze(-1)
+ x = residual + x
+ else:
+ x = block(x)
+
+ return self.norm(x)
+
+ def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
+ x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
+ x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
+ return x
+
+
class TokenizerDecoder(nn.Module):
"""
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
@@ -821,15 +962,63 @@ class TokenizerDecoder(nn.Module):
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return x
+
+@dataclass
+class VibeVoiceTokenizerEncoderOutput:
+ """
+ Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
+
+ Args:
+ mean (`torch.FloatTensor`): The mean parameters of the distribution.
+ std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
+ """
+ mean: torch.Tensor
+ std: Optional[Union[float, torch.Tensor]] = None
+
+ def sample(self, dist_type='fix'):
+ """
+ Sample from the distribution.
+
+ Args:
+ dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
+
+ Returns:
+ `torch.FloatTensor`: Sampled values.
+ `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
+ """
+ if dist_type == 'fix':
+ x = self.mean + self.std * torch.randn_like(self.mean)
+ return x, self.std
+ elif dist_type == 'gaussian':
+ batch_size = self.mean.size(0)
+ value = self.std / 0.8
+ std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
+
+ while std.dim() < self.mean.dim():
+ std = std.unsqueeze(-1)
+
+ x = self.mean + std * torch.randn_like(self.mean)
+ return x, std
+ else:
+ return self.mean, self.std
+
+ def kl(self):
+ """Compute KL divergence between this distribution and a standard normal."""
+ target = torch.zeros_like(self.mean)
+ return F.mse_loss(self.mean, target, reduction='none')
+
+ def mode(self):
+ """Return the distribution mode (which is the mean for Gaussian)."""
+ return self.mean
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
- """VibeVoice speech tokenizer model (only decoder) for acoustic tokens"""
+ """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
config_class = VibeVoiceAcousticTokenizerConfig
base_model_prefix = "vibevoice_acoustic_tokenizer"
_supports_flash_attn_2 = True
_supports_sdpa = True
- _no_split_modules = ["TokenizerDecoder"]
+ _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
def __init__(self, config):
super().__init__(config)
@@ -850,6 +1039,21 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
# Default: use reversed encoder depths if decoder_depths is None
decoder_depths = list(reversed(encoder_depths))
+ # Create encoder config
+ encoder_config = copy.deepcopy(config)
+ encoder_config.dimension = config.vae_dim
+ encoder_config.n_filters = config.encoder_n_filters
+ encoder_config.ratios = config.encoder_ratios
+ encoder_config.depths = encoder_depths
+ encoder_config.norm = config.conv_norm
+ encoder_config.pad_mode = config.pad_mode
+ encoder_config.bias = config.conv_bias
+ encoder_config.layernorm_eps = config.layernorm_eps
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
+ encoder_config.mixer_layer = config.mixer_layer
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
+ encoder_config.disable_last_norm = config.disable_last_norm
+
# Create decoder config
decoder_config = copy.deepcopy(config)
decoder_config.dimension = config.vae_dim
@@ -865,6 +1069,8 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
decoder_config.layer_scale_init_value = config.layer_scale_init_value
decoder_config.disable_last_norm = config.disable_last_norm
+ # Initialize encoder and decoder
+ self.encoder = TokenizerEncoder(encoder_config)
self.decoder = TokenizerDecoder(decoder_config)
# Initialize weights
@@ -884,6 +1090,24 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
if module.bias is not None:
nn.init.zeros_(module.bias)
+ @torch.no_grad()
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
+ """Convert audio to latent representations"""
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
+ return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
+
+ @torch.no_grad()
+ def sampling(self, encoder_output, dist_type=None):
+ """Sample from the encoder output distribution"""
+ dist_type = dist_type or self.std_dist_type
+
+ if dist_type == 'fix':
+ return encoder_output.sample(dist_type='fix')
+ elif dist_type == 'gaussian':
+ return encoder_output.sample(dist_type='gaussian')
+ else:
+ raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
+
@torch.no_grad()
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
"""Convert latent representations back to audio"""
@@ -895,10 +1119,89 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return audio
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
+ """Full forward pass: encode audio to latents, then decode back to audio"""
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
+ sampled_latents, _ = self.sampling(encoder_output)
+ reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
+ return reconstructed, sampled_latents
+
+
+class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
+ """VibeVoice speech tokenizer model with only encoder for semantic tokens"""
+
+ config_class = VibeVoiceSemanticTokenizerConfig
+ base_model_prefix = "vibevoice_semantic_tokenizer"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _no_split_modules = ["TokenizerEncoder"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ # Parse encoder depths
+ if isinstance(config.encoder_depths, str):
+ encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
+ else:
+ encoder_depths = config.encoder_depths
+
+ # Create encoder config
+ encoder_config = copy.deepcopy(config)
+ encoder_config.dimension = config.vae_dim
+ encoder_config.n_filters = config.encoder_n_filters
+ encoder_config.ratios = config.encoder_ratios
+ encoder_config.depths = encoder_depths
+ encoder_config.norm = config.conv_norm
+ encoder_config.pad_mode = config.pad_mode
+ encoder_config.bias = config.conv_bias
+ encoder_config.layernorm_eps = config.layernorm_eps
+ encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
+ encoder_config.mixer_layer = config.mixer_layer
+ encoder_config.layer_scale_init_value = config.layer_scale_init_value
+ encoder_config.disable_last_norm = config.disable_last_norm
+
+ # Initialize encoder and decoder
+ self.encoder = TokenizerEncoder(encoder_config)
+
+ # Initialize weights
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module):
+ """Initialize weights for the model"""
+ if isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.normal_(module.weight, std=self.config.weight_init_value)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ @torch.no_grad()
+ def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
+ """Convert audio to latent representations"""
+ latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
+ return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
+
+ @torch.no_grad()
+ def sampling(self, encoder_output, dist_type=None):
+ """Sample from the encoder output distribution"""
+ return encoder_output.sample(dist_type='none')
+
+ def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
+ """Full forward pass: encode audio to latents, then decode back to audio"""
+ encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
+ sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
+ return None, sampled_latents
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
+AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
__all__ = [
"VibeVoiceTokenizerStreamingCache",
"VibeVoiceAcousticTokenizerModel",
+ "VibeVoiceSemanticTokenizerModel",
]
\ No newline at end of file
diff --git a/vibevoice/processor/audio_utils.py b/vibevoice/processor/audio_utils.py
new file mode 100644
index 0000000..ad4d10d
--- /dev/null
+++ b/vibevoice/processor/audio_utils.py
@@ -0,0 +1,143 @@
+import numpy as np
+from subprocess import run
+from typing import List, Optional, Union, Dict, Any
+
+COMMON_AUDIO_EXTS = [
+ '.mp3', '.MP3', '.Mp3', # All case variations of mp3
+ '.m4a',
+ '.mp4', '.MP4',
+ '.wav', '.WAV',
+ '.m4v',
+ '.aac',
+ '.ogg',
+ '.mov', '.MOV',
+ '.opus',
+ '.m4b',
+ '.flac',
+ '.wma', '.WMA',
+ '.rm', '.3gp', '.mpeg', '.flv', '.webm', '.mp2', '.aif', '.aiff', '.oga', '.ogv', '.mpga', '.m3u8', '.amr'
+]
+
+def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24000):
+ """
+ Open an audio file and read as mono waveform, optionally resampling.
+ Returns both the audio data and the original sample rate.
+
+ Parameters
+ ----------
+ file: str
+ The audio file to open
+ resample: bool
+ Whether to resample the audio
+ target_sr: int
+ The target sample rate if resampling is requested
+
+ Returns
+ -------
+ A tuple containing:
+ - A NumPy array with the audio waveform in float32 dtype
+ - The original sample rate of the audio file
+ """
+ if not resample:
+ # First, get the original sample rate
+ cmd_probe = [
+ "ffprobe",
+ "-v", "quiet",
+ "-show_entries", "stream=sample_rate",
+ "-of", "default=noprint_wrappers=1:nokey=1",
+ file
+ ]
+
+ original_sr = int(run(cmd_probe, capture_output=True, check=True).stdout.decode().strip())
+ else:
+ original_sr = None
+
+ # Now load the audio
+ sr_to_use = target_sr if resample else original_sr
+
+ cmd = [
+ "ffmpeg",
+ "-nostdin",
+ "-threads", "0",
+ "-i", file,
+ "-f", "s16le",
+ "-ac", "1",
+ "-acodec", "pcm_s16le",
+ "-ar", str(sr_to_use),
+ "-"
+ ]
+
+ out = run(cmd, capture_output=True, check=True).stdout
+ audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+ return audio_data, sr_to_use
+
+class AudioNormalizer:
+ """
+ Audio normalization class for VibeVoice tokenizer.
+
+ This class provides audio normalization to ensure consistent input levels
+ for the VibeVoice tokenizer while maintaining audio quality.
+ """
+
+ def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
+ """
+ Initialize the audio normalizer.
+
+ Args:
+ target_dB_FS (float): Target dB FS level for the audio. Default: -25
+ eps (float): Small value to avoid division by zero. Default: 1e-6
+ """
+ self.target_dB_FS = target_dB_FS
+ self.eps = eps
+
+ def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
+ """
+ Adjust the audio to the target dB FS level.
+
+ Args:
+ audio (np.ndarray): Input audio signal
+
+ Returns:
+ tuple: (normalized_audio, rms, scalar)
+ """
+ rms = np.sqrt(np.mean(audio**2))
+ scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
+ normalized_audio = audio * scalar
+ return normalized_audio, rms, scalar
+
+ def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
+ """
+ Avoid clipping by scaling down if necessary.
+
+ Args:
+ audio (np.ndarray): Input audio signal
+ scalar (float, optional): Explicit scaling factor
+
+ Returns:
+ tuple: (normalized_audio, scalar)
+ """
+ if scalar is None:
+ max_val = np.max(np.abs(audio))
+ if max_val > 1.0:
+ scalar = max_val + self.eps
+ else:
+ scalar = 1.0
+
+ return audio / scalar, scalar
+
+ def __call__(self, audio: np.ndarray) -> np.ndarray:
+ """
+ Normalize the audio by adjusting to target dB FS and avoiding clipping.
+
+ Args:
+ audio (np.ndarray): Input audio signal
+
+ Returns:
+ np.ndarray: Normalized audio signal
+ """
+ # First adjust to target dB FS
+ audio, _, _ = self.tailor_dB_FS(audio)
+ # Then avoid clipping
+ audio, _ = self.avoid_clipping(audio)
+ return audio
\ No newline at end of file
diff --git a/vibevoice/processor/vibevoice_asr_processor.py b/vibevoice/processor/vibevoice_asr_processor.py
new file mode 100644
index 0000000..007dc39
--- /dev/null
+++ b/vibevoice/processor/vibevoice_asr_processor.py
@@ -0,0 +1,572 @@
+"""
+Processor class for VibeVoice ASR models.
+"""
+
+import os
+import json
+import math
+import warnings
+from typing import List, Optional, Union, Dict, Any, Tuple
+
+import numpy as np
+import torch
+
+from transformers.tokenization_utils_base import BatchEncoding
+from transformers.utils import TensorType, logging
+from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor, AudioNormalizer
+
+try:
+ from .audio_utils import load_audio_use_ffmpeg
+ HAS_FFMPEG_UTILS = True
+except ImportError:
+ HAS_FFMPEG_UTILS = False
+ warnings.warn("audio_utils not available, will fall back to soundfile for audio loading")
+
+logger = logging.get_logger(__name__)
+
+SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format."
+
+
+class VibeVoiceASRProcessor:
+ """
+ Processor for VibeVoice ASR (Automatic Speech Recognition) models.
+
+ This processor handles audio preprocessing and tokenization for ASR tasks,
+ following the exact format used in training with proper chat templates.
+
+ Args:
+ tokenizer: The text tokenizer for processing text
+ audio_processor: The audio processor for processing speech
+ speech_tok_compress_ratio (int): Compression ratio for speech tokenization
+ target_sample_rate (int): Target sample rate for audio
+ normalize_audio (bool): Whether to normalize audio input
+ """
+
+ def __init__(
+ self,
+ tokenizer=None,
+ audio_processor=None,
+ speech_tok_compress_ratio=320,
+ target_sample_rate=24000,
+ normalize_audio=True,
+ **kwargs
+ ):
+ self.tokenizer = tokenizer
+ self.audio_processor = audio_processor or VibeVoiceTokenizerProcessor(
+ sampling_rate=target_sample_rate,
+ normalize_audio=normalize_audio
+ )
+ self.speech_tok_compress_ratio = speech_tok_compress_ratio
+ self.target_sample_rate = target_sample_rate
+ self.normalize_audio = normalize_audio
+
+ if normalize_audio:
+ self.audio_normalizer = AudioNormalizer()
+ else:
+ self.audio_normalizer = None
+
+ # Cache special token IDs
+ self._cache_special_tokens()
+
+ def _cache_special_tokens(self):
+ """Cache special token IDs for efficiency."""
+ # Add safety checks for special tokens
+ if hasattr(self.tokenizer, 'speech_start_id'):
+ self.speech_start_id = self.tokenizer.speech_start_id
+ else:
+ self.speech_start_id = self.tokenizer.convert_tokens_to_ids("<|speech_start|>")
+
+ if hasattr(self.tokenizer, 'speech_end_id'):
+ self.speech_end_id = self.tokenizer.speech_end_id
+ else:
+ self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>")
+
+ if hasattr(self.tokenizer, 'speech_pad_id'):
+ self.speech_pad_id = self.tokenizer.speech_pad_id
+ else:
+ self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>")
+
+ if hasattr(self.tokenizer, 'pad_id'):
+ self.pad_id = self.tokenizer.pad_id
+ elif hasattr(self.tokenizer, 'pad_token_id'):
+ self.pad_id = self.tokenizer.pad_token_id
+ else:
+ self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ """
+ Load processor from a pretrained model path.
+
+ Args:
+ pretrained_model_name_or_path: Path to the pretrained model
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ VibeVoiceASRProcessor: The loaded processor
+ """
+ import json
+ from transformers.utils import cached_file
+ from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
+
+ # Try to load configuration
+ config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
+ config = {}
+
+ if os.path.exists(config_path):
+ with open(config_path, 'r') as f:
+ config = json.load(f)
+ else:
+ try:
+ config_file = cached_file(
+ pretrained_model_name_or_path,
+ "preprocessor_config.json",
+ **kwargs
+ )
+ with open(config_file, 'r') as f:
+ config = json.load(f)
+ except Exception as e:
+ logger.warning(f"Could not load preprocessor_config.json: {e}")
+ logger.warning("Using default configuration")
+
+ # Extract parameters
+ speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
+ target_sample_rate = config.get("target_sample_rate", 24000)
+ normalize_audio = config.get("normalize_audio", True)
+
+ # Load tokenizer
+ language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
+ logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
+
+ if 'qwen' in language_model_pretrained_name.lower():
+ tokenizer = VibeVoiceASRTextTokenizerFast.from_pretrained(
+ language_model_pretrained_name,
+ **kwargs
+ )
+ else:
+ raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}")
+
+ # Load audio processor
+ audio_processor = VibeVoiceTokenizerProcessor(
+ sampling_rate=target_sample_rate,
+ normalize_audio=normalize_audio,
+ target_dB_FS=config.get("target_dB_FS", -25),
+ eps=config.get("eps", 1e-6),
+ )
+
+ return cls(
+ tokenizer=tokenizer,
+ audio_processor=audio_processor,
+ speech_tok_compress_ratio=speech_tok_compress_ratio,
+ target_sample_rate=target_sample_rate,
+ normalize_audio=normalize_audio,
+ )
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
+ """
+ Save processor configuration to a directory.
+
+ Args:
+ save_directory: Directory to save the configuration
+ **kwargs: Additional keyword arguments
+ """
+ import json
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # Save processor configuration
+ processor_config = {
+ "processor_class": "VibeVoiceASRProcessor",
+ "speech_tok_compress_ratio": self.speech_tok_compress_ratio,
+ "target_sample_rate": self.target_sample_rate,
+ "normalize_audio": self.normalize_audio,
+ "target_dB_FS": -25,
+ "eps": 1e-6,
+ }
+
+ config_path = os.path.join(save_directory, "preprocessor_config.json")
+ with open(config_path, 'w') as f:
+ json.dump(processor_config, f, indent=2)
+
+ logger.info(f"Processor configuration saved in {config_path}")
+
+ def __call__(
+ self,
+ audio: Optional[Union[str, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, torch.Tensor]]]] = None,
+ sampling_rate: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ padding: bool = True,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ add_generation_prompt: bool = True,
+ use_streaming: bool = True,
+ context_info: Optional[str] = None,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ Process audio input for ASR model.
+
+ Args:
+ audio: Audio input(s). Can be:
+ - str: Path to audio file
+ - np.ndarray: Audio array
+ - torch.Tensor: Audio tensor
+ - List of the above for batch processing
+ sampling_rate: Sampling rate of input audio
+ return_tensors: Output format ('pt' for PyTorch, 'np' for NumPy)
+ padding: Whether to pad batch inputs
+ max_length: Maximum sequence length
+ truncation: Whether to truncate long sequences
+ add_generation_prompt: Whether to add generation prompt for inference
+ use_streaming: Whether to use streaming mode (True by default, auto False if <60s)
+ context_info: Optional context information (e.g., hotwords, metadata) to help transcription
+
+ Returns:
+ BatchEncoding with:
+ - input_ids: Token IDs for the model
+ - attention_mask: Attention mask
+ - acoustic_input_mask: Mask indicating speech token positions
+ - speech_tensors: Processed speech features
+ - speech_masks: Valid speech masks
+ - vae_tok_seqlens: Length of each speech segment in tokens
+ """
+ if audio is None:
+ raise ValueError("Audio input is required for ASR processing")
+
+ # Handle single vs batch input
+ if isinstance(audio, list):
+ is_batched = True
+ audio_list = audio
+ else:
+ is_batched = False
+ audio_list = [audio]
+
+ # Process each audio input
+ all_encodings = []
+ for audio_input in audio_list:
+ encoding = self._process_single_audio(
+ audio_input,
+ sampling_rate=sampling_rate,
+ add_generation_prompt=add_generation_prompt,
+ use_streaming=use_streaming,
+ context_info=context_info,
+ )
+ all_encodings.append(encoding)
+
+ # Combine into batch
+ batch_encoding = self._batch_encode(
+ all_encodings,
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ return_tensors=return_tensors,
+ )
+
+ return batch_encoding
+
+ def _process_single_audio(
+ self,
+ audio: Union[str, np.ndarray, torch.Tensor],
+ sampling_rate: Optional[int] = None,
+ add_generation_prompt: bool = True,
+ use_streaming: bool = True,
+ context_info: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """
+ Process a single audio input.
+
+ Args:
+ audio: Single audio input
+ sampling_rate: Audio sampling rate
+ add_generation_prompt: Whether to add generation prompt
+ context_info: Optional context information (e.g., hotwords, metadata) to help transcription
+
+ Returns:
+ Dictionary with processed tokens and audio features
+ """
+ # Process audio through audio processor
+ if isinstance(audio, str):
+ # Load from file using ffmpeg for better format support
+ if HAS_FFMPEG_UTILS:
+ try:
+ audio_array, file_sr = load_audio_use_ffmpeg(audio, resample=False)
+ except Exception as e:
+ # Fall back to soundfile if ffmpeg fails
+ warnings.warn(f"ffmpeg loading failed, falling back to soundfile: {e}")
+ import soundfile as sf
+ audio_array, file_sr = sf.read(audio)
+ if audio_array.ndim > 1:
+ audio_array = audio_array.mean(axis=1) # Convert to mono
+ else:
+ import soundfile as sf
+ audio_array, file_sr = sf.read(audio)
+ if audio_array.ndim > 1:
+ audio_array = audio_array.mean(axis=1) # Convert to mono
+
+ # Resample if needed
+ if file_sr != self.target_sample_rate:
+ import librosa
+ audio_array = librosa.resample(
+ audio_array,
+ orig_sr=file_sr,
+ target_sr=self.target_sample_rate
+ )
+ elif isinstance(audio, torch.Tensor):
+ audio_array = audio.cpu().numpy()
+ if audio_array.ndim > 1:
+ audio_array = audio_array.squeeze()
+ else:
+ audio_array = np.array(audio, dtype=np.float32)
+ if audio_array.ndim > 1:
+ audio_array = audio_array.squeeze()
+
+ # Ensure float32
+ audio_array = audio_array.astype(np.float32)
+
+ # Normalize if needed
+ if self.normalize_audio and self.audio_normalizer:
+ audio_array = self.audio_normalizer(audio_array)
+
+ # Calculate audio duration
+ audio_duration = len(audio_array) / self.target_sample_rate
+
+ # Auto-disable streaming for short audio (<60s)
+ if use_streaming and audio_duration < 60.0:
+ use_streaming = False
+
+ # Calculate token length based on streaming mode
+ # Non-streaming: uses ceil (encoder adds extra_padding for stride alignment)
+ # Streaming: uses floor (segments processed independently, no global alignment)
+ # if use_streaming:
+ # vae_tok_len = len(audio_array) // self.speech_tok_compress_ratio
+ # else:
+ vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio)
+
+ # Build token sequence following training format
+ # 1. System prompt - use apply_chat_template then encode like in training
+ system_prompt_text = self.tokenizer.apply_chat_template(
+ [{"role": "system", "content": SYSTEM_PROMPT}],
+ tokenize=False
+ )
+ system_tokens = self.tokenizer.encode(system_prompt_text)
+
+ # 2. User input with speech tokens
+ # Build speech placeholder string
+ sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id)
+ sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id)
+ sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id)
+
+ # User suffix with audio duration info
+ show_keys = ['Start time', 'End time', 'Speaker ID', 'Content']
+ if context_info and context_info.strip():
+ user_suffix = f"This is a {audio_duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\nPlease transcribe it with these keys: " + ", ".join(show_keys)
+ else:
+ user_suffix = f"This is a {audio_duration:.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys)
+
+ user_input_string = ''.join(
+ [sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]
+ ) + '\n' + user_suffix
+
+ user_tokens = self.tokenizer.apply_chat_template(
+ [{"role": "user", "content": user_input_string}],
+ tokenize=True
+ )
+
+ # Combine tokens
+ full_tokens = system_tokens + user_tokens
+
+ # Create acoustic input mask
+ acoustic_input_mask = [1 if token == self.speech_pad_id else 0 for token in full_tokens]
+
+ return {
+ "input_ids": full_tokens,
+ "acoustic_input_mask": acoustic_input_mask,
+ "speech": audio_array,
+ "vae_tok_len": vae_tok_len,
+ }
+
+ def _batch_encode(
+ self,
+ encodings: List[Dict[str, Any]],
+ padding: bool = True,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ return_tensors: Optional[str] = None,
+ ) -> BatchEncoding:
+ """
+ Combine multiple encodings into a batch.
+
+ Args:
+ encodings: List of encoded samples
+ padding: Whether to pad sequences
+ max_length: Maximum sequence length
+ truncation: Whether to truncate
+ return_tensors: Output format
+
+ Returns:
+ BatchEncoding with batched data
+ """
+ # Extract components
+ input_ids_list = [enc["input_ids"] for enc in encodings]
+ acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings]
+ speech_list = [enc["speech"] for enc in encodings]
+ vae_tok_lens = [enc["vae_tok_len"] for enc in encodings]
+
+ # Determine max length for padding
+ if padding:
+ if max_length is not None:
+ target_length = max_length
+ else:
+ target_length = max(len(ids) for ids in input_ids_list)
+
+ # Pad sequences
+ padded_input_ids = []
+ padded_acoustic_masks = []
+ attention_masks = []
+
+ for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list):
+ # Truncate if needed
+ if truncation and len(input_ids) > target_length:
+ input_ids = input_ids[:target_length]
+ acoustic_mask = acoustic_mask[:target_length]
+
+ # Pad sequences to left (for autoregressive generation)
+ padding_length = target_length - len(input_ids)
+ padded_ids = [self.pad_id] * padding_length + input_ids
+ padded_acoustic = [0] * padding_length + acoustic_mask
+ attention_mask = [0] * padding_length + [1] * len(input_ids)
+
+ padded_input_ids.append(padded_ids)
+ padded_acoustic_masks.append(padded_acoustic)
+ attention_masks.append(attention_mask)
+
+ input_ids_list = padded_input_ids
+ acoustic_masks_list = padded_acoustic_masks
+ else:
+ attention_masks = [[1] * len(ids) for ids in input_ids_list]
+
+ # Process speech tensors - raw audio is 1D, so we keep it as is
+ max_speech_length = max(len(s) for s in speech_list)
+ padded_speeches = np.zeros((len(speech_list), max_speech_length), dtype=np.float32)
+ speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool)
+
+ for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)):
+ padded_speeches[i, :len(speech)] = speech
+ speech_masks[i, :vae_len] = True
+
+ # Create batch encoding
+ batch_encoding = BatchEncoding()
+
+ if return_tensors == "pt":
+ batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
+ batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
+ batch_encoding["acoustic_input_mask"] = torch.tensor(acoustic_masks_list, dtype=torch.bool)
+ batch_encoding["speech_tensors"] = torch.tensor(padded_speeches, dtype=torch.float32)
+ batch_encoding["speech_masks"] = torch.tensor(speech_masks, dtype=torch.bool)
+ # Note: vae_tok_seqlens and speech_type are not included as they are not model inputs
+ else:
+ batch_encoding["input_ids"] = input_ids_list if len(input_ids_list) > 1 else input_ids_list[0]
+ batch_encoding["attention_mask"] = attention_masks if len(attention_masks) > 1 else attention_masks[0]
+ batch_encoding["acoustic_input_mask"] = acoustic_masks_list if len(acoustic_masks_list) > 1 else acoustic_masks_list[0]
+ batch_encoding["speech_tensors"] = padded_speeches if len(padded_speeches) > 1 else padded_speeches[0]
+ batch_encoding["speech_masks"] = speech_masks if len(speech_masks) > 1 else speech_masks[0]
+
+ return batch_encoding
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ Decode batch of token IDs to text.
+ Forwards to tokenizer's batch_decode method.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ Decode token IDs to text.
+ Forwards to tokenizer's decode method.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def post_process_transcription(self, text: str) -> List[Dict[str, Any]]:
+ """
+ Post-process the generated transcription text to extract structured data.
+
+ Args:
+ text: Generated text from the model
+
+ Returns:
+ List of dictionaries with transcription segments
+ """
+ try:
+ # Try to parse as JSON
+ if "```json" in text:
+ # Extract JSON from markdown code block
+ json_start = text.find("```json") + 7
+ json_end = text.find("```", json_start)
+ json_str = text[json_start:json_end].strip()
+ else:
+ # Try to find JSON array or object
+ json_start = text.find("[")
+ if json_start == -1:
+ json_start = text.find("{")
+ if json_start != -1:
+ # Find matching closing bracket
+ bracket_count = 0
+ json_end = json_start
+ for i in range(json_start, len(text)):
+ if text[i] in "[{":
+ bracket_count += 1
+ elif text[i] in "]}":
+ bracket_count -= 1
+ if bracket_count == 0:
+ json_end = i + 1
+ break
+ json_str = text[json_start:json_end]
+ else:
+ json_str = text
+
+ # Parse JSON
+ result = json.loads(json_str)
+
+ # Ensure it's a list
+ if isinstance(result, dict):
+ result = [result]
+
+ # Validate and clean up the result
+ cleaned_result = []
+ for item in result:
+ if isinstance(item, dict):
+ cleaned_item = {}
+ # Map keys to expected format
+ key_mapping = {
+ "Start time": "start_time",
+ "Start": "start_time",
+ "End time": "end_time",
+ "End": "end_time",
+ "Speaker ID": "speaker_id",
+ "Speaker": "speaker_id",
+ "Content": "text",
+ }
+ for key, mapped_key in key_mapping.items():
+ if key in item:
+ cleaned_item[mapped_key] = item[key]
+
+ if cleaned_item:
+ cleaned_result.append(cleaned_item)
+
+ return cleaned_result
+
+ except json.JSONDecodeError as e:
+ logger.warning(f"Failed to parse JSON from transcription: {e}")
+ logger.debug(f"Raw text: {text}")
+ return []
+ except Exception as e:
+ logger.warning(f"Error post-processing transcription: {e}")
+ return []
+
+ @property
+ def model_input_names(self):
+ """Return the list of inputs accepted by the model."""
+ return ["input_ids", "attention_mask", "acoustic_input_mask", "speech_tensors", "speech_masks"]
+
+__all__ = ["VibeVoiceASRProcessor"]
diff --git a/vibevoice/processor/vibevoice_tokenizer_processor.py b/vibevoice/processor/vibevoice_tokenizer_processor.py
index 19b5795..67f61a6 100644
--- a/vibevoice/processor/vibevoice_tokenizer_processor.py
+++ b/vibevoice/processor/vibevoice_tokenizer_processor.py
@@ -13,80 +13,10 @@ import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
+from .audio_utils import AudioNormalizer
+
logger = logging.get_logger(__name__)
-
-class AudioNormalizer:
- """
- Audio normalization class for VibeVoice tokenizer.
-
- This class provides audio normalization to ensure consistent input levels
- for the VibeVoice tokenizer while maintaining audio quality.
- """
-
- def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
- """
- Initialize the audio normalizer.
-
- Args:
- target_dB_FS (float): Target dB FS level for the audio. Default: -25
- eps (float): Small value to avoid division by zero. Default: 1e-6
- """
- self.target_dB_FS = target_dB_FS
- self.eps = eps
-
- def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
- """
- Adjust the audio to the target dB FS level.
-
- Args:
- audio (np.ndarray): Input audio signal
-
- Returns:
- tuple: (normalized_audio, rms, scalar)
- """
- rms = np.sqrt(np.mean(audio**2))
- scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
- normalized_audio = audio * scalar
- return normalized_audio, rms, scalar
-
- def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
- """
- Avoid clipping by scaling down if necessary.
-
- Args:
- audio (np.ndarray): Input audio signal
- scalar (float, optional): Explicit scaling factor
-
- Returns:
- tuple: (normalized_audio, scalar)
- """
- if scalar is None:
- max_val = np.max(np.abs(audio))
- if max_val > 1.0:
- scalar = max_val + self.eps
- else:
- scalar = 1.0
-
- return audio / scalar, scalar
-
- def __call__(self, audio: np.ndarray) -> np.ndarray:
- """
- Normalize the audio by adjusting to target dB FS and avoiding clipping.
-
- Args:
- audio (np.ndarray): Input audio signal
-
- Returns:
- np.ndarray: Normalized audio signal
- """
- # First adjust to target dB FS
- audio, _, _ = self.tailor_dB_FS(audio)
- # Then avoid clipping
- audio, _ = self.avoid_clipping(audio)
- return audio
-
-
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
"""