Add VibeVoice-ASR

This commit is contained in:
Zhiliang Peng
2026-01-21 22:18:33 +08:00
committed by GitHub
parent 6c7369bb31
commit 56cb11e7b2
14 changed files with 4062 additions and 94 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

+3 -2
View File
@@ -23,8 +23,9 @@
<img src="https://img.shields.io/badge/Status-New-brightgreen?style=flat" alt="New" />
<img src="https://img.shields.io/badge/Feature-Realtime_TTS-blue?style=flat&logo=soundcharts" alt="Realtime TTS" />
<strong>2026-01-21: 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context.</strong>
<strong>2025-12-16: 📣 We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.</strong>
2025-12-16: 📣 We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.
2025-12-09: 📣 We added experimental speakers in nine languages (DE, FR, IT, JP, KR, NL, PL, PT, ES) for exploration—welcome to try them out and share your feedback.
@@ -123,4 +124,4 @@ We do not recommend using VibeVoice in commercial or real-world applications wit
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=Microsoft/vibevoice&type=date&legend=top-left)
![Star History Chart](https://api.star-history.com/svg?repos=Microsoft/vibevoice&type=date&legend=top-left)
File diff suppressed because it is too large Load Diff
+554
View File
@@ -0,0 +1,554 @@
#!/usr/bin/env python
"""
VibeVoice ASR Batch Inference Demo Script
This script supports batch inference for ASR model and compares results
between batch processing and single-sample processing.
"""
import os
import sys
import torch
import numpy as np
from pathlib import Path
import argparse
import time
import json
import re
from typing import List, Dict, Any, Optional
from functools import wraps
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
class VibeVoiceASRBatchInference:
"""Batch inference wrapper for VibeVoice ASR model."""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
attn_implementation: str = "flash_attention_2"
):
"""
Initialize the ASR batch inference pipeline.
Args:
model_path: Path to the pretrained model
device: Device to run inference on
dtype: Data type for model weights
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
"""
print(f"Loading VibeVoice ASR model from {model_path}")
# Load processor
self.processor = VibeVoiceASRProcessor.from_pretrained(
model_path,
language_model_pretrained_name="Qwen/Qwen2.5-7B"
)
# Load model with specified attention implementation
print(f"Using attention implementation: {attn_implementation}")
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
model_path,
dtype=dtype,
device_map=device if device == "auto" else None,
attn_implementation=attn_implementation,
trust_remote_code=True
)
if device != "auto":
self.model = self.model.to(device)
self.device = device if device != "auto" else next(self.model.parameters()).device
self.dtype = dtype
self.model.eval()
print(f"Model loaded successfully on {self.device}")
def _prepare_generation_config(
self,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 0.9,
do_sample: bool = True,
num_beams: int = 1,
) -> dict:
"""Prepare generation configuration."""
config = {
"max_new_tokens": max_new_tokens,
"pad_token_id": self.processor.pad_id,
"eos_token_id": self.processor.tokenizer.eos_token_id,
}
# Beam search vs sampling
if num_beams > 1:
config["num_beams"] = num_beams
config["do_sample"] = False # Beam search doesn't use sampling
else:
config["do_sample"] = do_sample
# Only set temperature and top_p when sampling is enabled
if do_sample:
config["temperature"] = temperature
config["top_p"] = top_p
return config
def transcribe_batch(
self,
audio_inputs: List,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 1.0,
do_sample: bool = True,
num_beams: int = 1,
) -> List[Dict[str, Any]]:
"""
Transcribe multiple audio files/arrays in a single batch.
Args:
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
max_new_tokens: Maximum tokens to generate
temperature: Temperature for sampling
top_p: Top-p for nucleus sampling
do_sample: Whether to use sampling
Returns:
List of transcription results
"""
if len(audio_inputs) == 0:
return []
batch_size = len(audio_inputs)
print(f"\nProcessing batch of {batch_size} audio(s)...")
# Process all audio together
inputs = self.processor(
audio=audio_inputs,
sampling_rate=None,
return_tensors="pt",
padding=True,
add_generation_prompt=True
)
# Move to device
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Print batch info
print(f" Input IDs shape: {inputs['input_ids'].shape}")
print(f" Speech tensors shape: {inputs['speech_tensors'].shape}")
print(f" Attention mask shape: {inputs['attention_mask'].shape}")
# Generate
generation_config = self._prepare_generation_config(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
num_beams=num_beams,
)
start_time = time.time()
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
**generation_config
)
generation_time = time.time() - start_time
# Decode outputs for each sample in the batch
results = []
input_length = inputs['input_ids'].shape[1]
for i, audio_input in enumerate(audio_inputs):
# Get generated tokens for this sample (excluding input tokens)
generated_ids = output_ids[i, input_length:]
# Remove padding tokens from the end
# Find the first eos_token or pad_token
eos_positions = (generated_ids == self.processor.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
generated_ids = generated_ids[:eos_positions[0] + 1]
generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
# Parse structured output
try:
transcription_segments = self.processor.post_process_transcription(generated_text)
except Exception as e:
print(f"Warning: Failed to parse structured output: {e}")
transcription_segments = []
# Get file name based on input type
if isinstance(audio_input, str):
file_name = audio_input
elif isinstance(audio_input, dict) and 'id' in audio_input:
file_name = audio_input['id']
else:
file_name = f"audio_{i}"
results.append({
"file": file_name,
"raw_text": generated_text,
"segments": transcription_segments,
"generation_time": generation_time / batch_size,
})
print(f" Total generation time: {generation_time:.2f}s")
print(f" Average time per sample: {generation_time/batch_size:.2f}s")
return results
def transcribe_with_batching(
self,
audio_inputs: List,
batch_size: int = 4,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 1.0,
do_sample: bool = True,
num_beams: int = 1,
) -> List[Dict[str, Any]]:
"""
Transcribe multiple audio files/arrays with automatic batching.
Args:
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
batch_size: Number of samples per batch
max_new_tokens: Maximum tokens to generate
temperature: Temperature for sampling
top_p: Top-p for nucleus sampling
do_sample: Whether to use sampling
Returns:
List of transcription results
"""
all_results = []
# Process in batches
for i in range(0, len(audio_inputs), batch_size):
batch_inputs = audio_inputs[i:i + batch_size]
print(f"\n{'='*60}")
print(f"Processing batch {i//batch_size + 1}/{(len(audio_inputs) + batch_size - 1)//batch_size}")
batch_results = self.transcribe_batch(
batch_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
num_beams=num_beams,
)
all_results.extend(batch_results)
return all_results
def print_result(result: Dict[str, Any]):
"""Pretty print a single transcription result."""
print(f"\nFile: {result['file']}")
print(f"Generation Time: {result['generation_time']:.2f}s")
print(f"\n--- Raw Output ---")
print(result['raw_text'][:500] + "..." if len(result['raw_text']) > 500 else result['raw_text'])
if result['segments']:
print(f"\n--- Structured Output ({len(result['segments'])} segments) ---")
for seg in result['segments'][:50]: # Show first 50 segments
print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] "
f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')}...")
if len(result['segments']) > 50:
print(f" ... and {len(result['segments']) - 50} more segments")
def load_dataset_and_concatenate(
dataset_name: str,
split: str,
max_duration: float,
num_audios: int,
target_sr: int = 24000
) -> Optional[List[np.ndarray]]:
"""
Load a HuggingFace dataset and concatenate audio samples into long audio chunks.
(Note, just for demo purpose, not for benchmark evaluation)
Args:
dataset_name: HuggingFace dataset name (e.g., 'openslr/librispeech_asr')
split: Dataset split to use (e.g., 'test', 'test.other')
max_duration: Maximum duration in seconds for each concatenated audio
num_audios: Number of concatenated audios to create
target_sr: Target sample rate (default: 24000)
Returns:
List of concatenated audio arrays, or None if loading fails
"""
try:
from datasets import load_dataset
import torchcodec # just for decode audio in datasets
except ImportError:
print("Please install it with: pip install datasets torchcodec")
return None
print(f"\nLoading dataset: {dataset_name} (split: {split})")
print(f"Will create {num_audios} concatenated audio(s), each up to {max_duration:.1f}s ({max_duration/3600:.2f} hours)")
try:
# Use streaming to avoid downloading the entire dataset
dataset = load_dataset(dataset_name, split=split, streaming=True)
print(f"Dataset loaded in streaming mode")
concatenated_audios = [] # List of concatenated audio metadata
# Create multiple concatenated audios based on num_audios
current_chunks = []
current_duration = 0.0
current_samples_used = 0
sample_idx = 0
for sample in dataset:
if len(concatenated_audios) >= num_audios:
break
if 'audio' not in sample:
continue
audio_data = sample['audio']
audio_array = audio_data['array']
sr = audio_data['sampling_rate']
# Resample if needed
if sr != target_sr:
duration = len(audio_array) / sr
new_length = int(duration * target_sr)
audio_array = np.interp(
np.linspace(0, len(audio_array) - 1, new_length),
np.arange(len(audio_array)),
audio_array
)
chunk_duration = len(audio_array) / target_sr
# Check if adding this chunk exceeds max_duration
if current_duration + chunk_duration > max_duration:
remaining_duration = max_duration - current_duration
if remaining_duration > 0.5: # Only add if > 0.5s remaining
samples_to_take = int(remaining_duration * target_sr)
current_chunks.append(audio_array[:samples_to_take])
current_duration += remaining_duration
current_samples_used += 1
# Save current concatenated audio and start a new one
if current_chunks:
concatenated_audios.append({
'array': np.concatenate(current_chunks),
'duration': current_duration,
'samples_used': current_samples_used,
})
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
# Reset for next concatenated audio
current_chunks = []
current_duration = 0.0
current_samples_used = 0
if len(concatenated_audios) >= num_audios:
break
current_chunks.append(audio_array)
current_duration += chunk_duration
current_samples_used += 1
sample_idx += 1
if sample_idx % 100 == 0:
print(f" Processed {sample_idx} samples...")
# Don't forget the last batch if it has content
if current_chunks and len(concatenated_audios) < num_audios:
concatenated_audios.append({
'array': np.concatenate(current_chunks),
'duration': current_duration,
'samples_used': current_samples_used,
})
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
if not concatenated_audios:
print("Warning: No audio samples found in dataset")
return None
# Extract arrays and print summary
result = [a['array'] for a in concatenated_audios]
total_duration = sum(a['duration'] for a in concatenated_audios)
total_samples = sum(a['samples_used'] for a in concatenated_audios)
print(f"\nCreated {len(result)} concatenated audio(s), total {total_duration:.1f}s ({total_duration/60:.1f} min) from {total_samples} samples")
return result
except Exception as e:
print(f"Error loading dataset: {e}")
import traceback
traceback.print_exc()
return None
def main():
parser = argparse.ArgumentParser(description="VibeVoice ASR Batch Inference Demo")
parser.add_argument(
"--model_path",
type=str,
default="",
help="Path to the model checkpoint"
)
parser.add_argument(
"--audio_files",
type=str,
nargs='+',
required=False,
help="Paths to audio files for transcription"
)
parser.add_argument(
"--audio_dir",
type=str,
required=False,
help="Directory containing audio files for batch transcription"
)
parser.add_argument(
"--dataset",
type=str,
required=False,
help="HuggingFace dataset name (e.g., 'openslr/librispeech_asr')"
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Dataset split to use (e.g., 'test', 'test.other', 'test.clean')"
)
parser.add_argument(
"--max_duration",
type=float,
default=3600.0,
help="Maximum duration in seconds for concatenated dataset audio (default: 3600 = 1 hour)"
)
parser.add_argument(
"--batch_size",
type=int,
default=2,
help="Batch size for processing multiple files"
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cuda", "cpu", "auto"],
help="Device to run inference on"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=32768,
help="Maximum number of tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Temperature for sampling (0 = greedy decoding)"
)
parser.add_argument(
"--top_p",
type=float,
default=1.0,
help="Top-p for nucleus sampling"
)
parser.add_argument(
"--num_beams",
type=int,
default=1,
help="Number of beams for beam search. Use 1 for greedy/sampling"
)
parser.add_argument(
"--attn_implementation",
type=str,
default="flash_attention_2",
help="Attention implementation to use (default: flash_attention_2)"
)
args = parser.parse_args()
# Collect audio files
audio_files = []
concatenated_audio = None # For storing concatenated dataset audio
if args.audio_files:
audio_files.extend(args.audio_files)
if args.audio_dir:
import glob
for ext in ["*.wav", "*.mp3", "*.flac", "*.mp4", "*.m4a", "*.webm"]:
audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext)))
if args.dataset:
concatenated_audio = load_dataset_and_concatenate(
dataset_name=args.dataset,
split=args.split,
max_duration=args.max_duration,
num_audios=args.batch_size,
)
if concatenated_audio is None:
return
if len(audio_files) == 0 and concatenated_audio is None:
print("No audio files provided. Please specify --audio_files, --audio_dir, or --dataset.")
return
if audio_files:
print(f"\nAudio files to process ({len(audio_files)}):")
for f in audio_files:
print(f" - {f}")
if concatenated_audio:
print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
# Initialize model
asr = VibeVoiceASRBatchInference(
model_path=args.model_path,
device=args.device,
dtype=torch.bfloat16 if args.device != "cpu" else torch.float32,
attn_implementation=args.attn_implementation
)
# If temperature is 0, use greedy decoding (no sampling)
do_sample = args.temperature > 0
# Combine all audio inputs
all_audio_inputs = audio_files + (concatenated_audio or [])
print("\n" + "="*80)
print(f"Processing {len(all_audio_inputs)} audio(s)")
print("="*80)
all_results = asr.transcribe_with_batching(
all_audio_inputs,
batch_size=args.batch_size,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=do_sample,
num_beams=args.num_beams,
)
# Print results
print("\n" + "="*80)
print("Results")
print("="*80)
for result in all_results:
print("\n" + "-"*60)
print_result(result)
if __name__ == "__main__":
main()
+62
View File
@@ -0,0 +1,62 @@
# VibeVoice-ASR: Long-Form Rich Transcription with User Prompts
**VibeVoice-ASR** is the latest addition to the **VibeVoice** family. While the original VibeVoice / VibeVoice-Realtime focused on expressive TTS, **VibeVoice-ASR** focuses on understanding long-form speech with high precision and rich metadata.
It is a unified speech-to-text model designed to handle **1-hour long-form audio** in a single pass, generating structured transcriptions containing **Who (Speaker), When (Timestamps), and What (Content)**, with support for **User-Customized Context**.
## 🔥 Key Features
- **🕒 60-min Single-Pass Processing**:
Unlike conventional ASR models that slice audio into short chunks (often losing global context), VibeVoice ASR accepts up to **60 minutes** of continuous audio input within 64K length. This ensures consistent speaker tracking and semantic coherence across the entire hour.
- **👤 Optional Context Injection**:
Users can provide customized context (e.g., specific names, technical terms, or background info) to guide the recognition process, significantly improving accuracy on domain-specific content.
- **📝 Rich Transcription (Who, When, What)**:
The model performs ASR, Diarization, and Timestamping simultaneously. The output is a structured sequence indicating *who* said *what* at *which time*.
## 🏗️ Model Architecture
<p align="center">
<img src="../Figures/VibeVoice_ASR_archi.png" alt="VibeVoice ASR Architecture" width="80%">
</p>
## Installation
We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment.
1. Launch docker
```bash
# NVIDIA PyTorch Container 24.07 ~ 25.12 verified.
# Previous versions are also compatible.
sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulimit stack=-1:-1 --gpus all --rm -it nvcr.io/nvidia/pytorch:25.12-py3
## If flash attention is not included in your docker environment, you need to install it manually
## Refer to https://github.com/Dao-AILab/flash-attention for installation instructions
# pip install flash-attn --no-build-isolation
```
2. Install from github
```bash
git clone https://github.com/microsoft/VibeVoice.git
cd VibeVoice
pip install -e .[asr]
```
## Usages
### Usage 1: Launch Gradio demo
```bash
apt update && apt install ffmpeg -y # for demo
python demo/vibevoice_asr_gradio_demo.py --model_path microsoft/VibeVoice-ASR --share
```
### Usage 2: Inference from files directly
```bash
python demo/vibevoice_asr_inference_from_file.py --model_path microsoft/VibeVoice-ASR --audio_files [add a audio path here]
```
## 📄 License
This project is licensed under the [MIT License](../LICENSE).
+17 -9
View File
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
[project]
name = "vibevoice"
version = "0.0.1"
version = "1.0.0"
authors = [
{ name="vibevoice team", email="vibepod@microsoft.com" },
{ name="vibevoice team", email="VibeVoice@microsoft.com" },
]
description = "A model for speech generation with an AR + diffusion architecture."
description = "Open-Source Frontier Voice AI."
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
@@ -18,8 +18,7 @@ classifiers = [
]
dependencies = [
"torch",
"accelerate==1.6.0",
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
"accelerate",
"llvmlite>=0.40.0",
"numba>=0.57.0",
"diffusers",
@@ -30,12 +29,21 @@ dependencies = [
"ml-collections",
"absl-py",
"gradio",
"av",
"aiortc",
"uvicorn[standard]",
"fastapi"
]
[project.optional-dependencies]
tts = [
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
"av",
"aiortc",
"uvicorn[standard]",
"fastapi"
]
asr = [
"transformers>=4.51.3", # the versions after 4.51.3 are all support
"pydub" # for visualization
]
[project.urls]
"Homepage" = "https://github.com/microsoft/VibeVoice"
+104 -1
View File
@@ -240,9 +240,112 @@ class VibeVoiceConfig(PretrainedConfig):
super().__init__(**kwargs)
class VibeVoiceASRConfig(PretrainedConfig):
model_type = "vibevoice"
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
}
# keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
**kwargs
):
# kwargs["_attn_implementation"] = "flash_attention_2"
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
# If an instance of the config class is provided
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
# If an instance of the config class is provided
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
# If a dictionary is provided, instantiate the config class with it
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
if decoder_config.get("model_type", '') == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
elif isinstance(decoder_config, Qwen2Config):
# If an instance of the config class is provided
self.decoder_config = decoder_config
# other parameters
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
super().__init__(**kwargs)
def get_text_config(self, decoder: bool = False):
"""Return the text (decoder) config for generation."""
return self.decoder_config
@property
def vocab_size(self):
"""Return vocab_size from decoder config for generation compatibility."""
return self.decoder_config.vocab_size
@property
def num_attention_heads(self):
"""Return num_attention_heads from decoder config for Ulysses SP compatibility."""
return self.decoder_config.num_attention_heads
@property
def num_key_value_heads(self):
"""Return num_key_value_heads from decoder config for Ulysses SP compatibility."""
return self.decoder_config.num_key_value_heads
@property
def hidden_size(self):
"""Return hidden_size from decoder config for model compatibility."""
return self.decoder_config.hidden_size
@property
def num_hidden_layers(self):
"""Return num_hidden_layers from decoder config for Ulysses SP compatibility."""
return self.decoder_config.num_hidden_layers
@property
def head_dim(self):
"""Return head_dim from decoder config for Ulysses SP compatibility."""
return getattr(self.decoder_config, 'head_dim', self.hidden_size // self.num_attention_heads)
__all__ = [
"VibeVoiceAcousticTokenizerConfig",
"VibeVoiceSemanticTokenizerConfig",
"VibeVoiceDiffusionHeadConfig",
"VibeVoiceConfig"
"VibeVoiceConfig",
"VibeVoiceASRConfig"
]
+496
View File
@@ -0,0 +1,496 @@
# copied from https://github.com/vibevoice-community/VibeVoice/blob/main/vibevoice/modular/modeling_vibevoice.py
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice import VibeVoiceConfig
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
@dataclass
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
diffusion_loss: Optional[torch.FloatTensor] = None
speech_token_num: Optional[int] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class VibeVoiceGenerationOutput(ModelOutput):
"""
Output type for VibeVoice generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences.
speech_outputs (`List[torch.FloatTensor]`, *optional*):
List of generated speech waveforms or latents for each speech segment.
"""
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
# @auto_docstring
class VibeVoicePreTrainedModel(PreTrainedModel):
config_class = VibeVoiceConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, VibeVoiceDiffusionHead):
module.initialize_weights()
return
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceModel(VibeVoicePreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
# Initialize prediction head for speech generation
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
# Initialize noise scheduler
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type
)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.eval()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward through language model
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = VibeVoiceModel(config)
self.vocab_size = config.decoder_config.vocab_size
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
"""
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
# The standard PreTrainedModel method will handle the tying.
# It typically does a simple parameter object assignment, which is
# CORRECT to do BEFORE FSDP wraps the model.
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, 'weight'):
output_embeddings.weight = input_embeddings.weight
else:
# maybe returned input_embeddings a tensor directly
output_embeddings.weight = input_embeddings
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
"constant",
0,
)
print("Tied input and output embeddings using standard assignment.")
else:
print("tie_word_embeddings is False, not tying weights.")
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
# The key is to avoid calling it after accelerator.prepare().
def set_output_embeddings(self, new_embeddings):
# Your current implementation using data.copy_ is good practice,
# but the best way is to not call this after prepare().
self.lm_head = new_embeddings
def forward_speech_features(
self,
speech_tensors=None,
speech_masks=None,
speech_type="audio",
return_unmask=False
):
if speech_tensors is None:
# Use config to get vae_dim instead of non-existent self.args
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
connect_features = self.model.acoustic_connector(audio_features)
return audio_features, connect_features
else:
with torch.no_grad():
if speech_type == "audio":
with torch.no_grad():
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
elif speech_type == "vae":
# Use config to get vae_dim instead of non-existent self.args
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
# gaussian sample from the speech_mode
batch_size = speech_mode.size(0)
value = self.model.acoustic_tokenizer.fix_std / 0.8
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
else:
raise NotImplementedError(f"Speech type {speech_type} not implemented")
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
bias_factor = -audio_tokens[speech_masks].flatten().mean()
# Only use distributed operations if the process group is initialized
if dist.is_available() and dist.is_initialized():
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
world_size = dist.get_world_size()
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
self.model.speech_bias_factor.copy_(bias_factor / world_size)
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
else:
# Single process case
self.model.speech_scaling_factor.copy_(scaling_factor)
self.model.speech_bias_factor.copy_(bias_factor)
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
connect_features = self.model.acoustic_connector(audio_features)
if return_unmask:
return audio_features, connect_features
return audio_features[speech_masks], connect_features[speech_masks]
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
# New arguments for speech processing and loss calculation
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speeches_loss_input: Optional[torch.FloatTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
ddpm_batch_mul: int = 1,
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
x = self.get_input_embeddings()(input_ids)
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
if speeches_loss_input is not None:
# only part audio need diffuse
speech_all_features, speech_all_connect_features = self.forward_speech_features(
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
return_unmask=True
)
if speech_tensors is not None:
if semantic_speech_all_connect_features is not None:
x[acoustic_input_mask] = (
speech_all_connect_features[speech_masks]
+ semantic_speech_all_connect_features[speech_masks]
)
else:
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
# Select only the target segments' latents for diffusion loss.
# Both masks are [num_segments, max_latent_len]; using 2D mask on [B,T,D] selects [N_true, D].
target_latent_mask = speeches_loss_input & speech_masks
speech_features = speech_all_features[target_latent_mask]
speech_connect_features = speech_all_connect_features[target_latent_mask]
else:
speech_features, speech_connect_features = self.forward_speech_features(
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
speech_masks=speech_masks,
speech_type=kwargs.get("speech_type", "audio"),
)
if speech_tensors is not None:
x[acoustic_input_mask] = speech_connect_features
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=x,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
# logits = logits.float()
loss = None
if labels is not None:
# The custom CE loss with masking is calculated in the training script.
# We leave the standard loss calculation here as None.
pass
# --- Diffusion Loss Calculation ---
diffusion_loss = None
# This block is executed only if we are in a context that involves speech.
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
condition_features = hidden_states[acoustic_loss_mask]
speech_len, latent_size = speech_features.shape
noise = torch.randn(
(speech_len * ddpm_batch_mul, latent_size),
device=hidden_states.device,
dtype=hidden_states.dtype
)
timesteps = torch.multinomial(
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
speech_len * ddpm_batch_mul,
replacement=True,
).to(hidden_states.device)
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
noisy_speech_features = self.model.noise_scheduler.add_noise(
speech_features_repeated, noise, timesteps
)
model_output = self.model.prediction_head(
noisy_speech_features,
timesteps.type_as(x),
condition_features_repeated
)
prediction_type = self.config.diffusion_head_config.prediction_type
if prediction_type == "epsilon":
target_for_loss = noise
elif prediction_type == "v_prediction":
target_for_loss = self.model.noise_scheduler.get_velocity(
speech_features_repeated, noise, timesteps
)
else:
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
if latent_size > 0 and ddpm_batch_mul > 0:
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
else:
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
else:
# Dummy loss for DDP to work when there are no speech samples in a batch,
# but we are in a speech context.
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
# --- End Diffusion Loss Calculation ---
if not return_dict:
output = (logits, speech_len) + outputs.to_tuple()[1:]
return (loss, diffusion_loss) + output
return VibeVoiceCausalLMOutputWithPast(
loss=loss,
diffusion_loss=diffusion_loss,
speech_token_num=speech_len if speech_tensors is not None else 0,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
__all__ = [
"VibeVoiceModel",
"VibeVoicePreTrainedModel",
"VibeVoiceForConditionalGeneration",
"VibeVoiceCausalLMOutputWithPast",
"VibeVoiceGenerationOutput",
]
+520
View File
@@ -0,0 +1,520 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation import GenerationMixin
from .modular_vibevoice_tokenizer import (
VibeVoiceTokenizerStreamingCache,
VibeVoiceTokenizerEncoderOutput
)
from .configuration_vibevoice import VibeVoiceASRConfig
from .modeling_vibevoice import (
VibeVoiceCausalLMOutputWithPast,
SpeechConnector
)
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
# @auto_docstring
class VibeVoiceASRPreTrainedModel(PreTrainedModel):
config_class = VibeVoiceASRConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceASRModel(VibeVoiceASRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling
lm_config = config.decoder_config
self.language_model = AutoModel.from_config(lm_config)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
self.semantic_tokenizer = semantic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
if self.semantic_tokenizer is not None:
self.semantic_tokenizer.eval()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Forward through language model
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
if not return_dict:
return outputs
return BaseModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, GenerationMixin):
"""
VibeVoice model for Automatic Speech Recognition (ASR) with language modeling head for conditional generation.
This class is designed for inference and generation tasks.
"""
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = VibeVoiceASRModel(config)
self.vocab_size = config.decoder_config.vocab_size
# Determine the dtype to use
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize lm_head with the correct dtype
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False).to(dtype)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.language_model = decoder
def get_decoder(self):
return self.model.language_model
def tie_weights(self):
"""Tie the weights between the input embeddings and the output embeddings."""
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
if hasattr(input_embeddings, 'weight'):
output_embeddings.weight = input_embeddings.weight
else:
output_embeddings.weight = input_embeddings
def encode_speech(
self,
speech_tensors: torch.FloatTensor,
speech_masks: Optional[torch.BoolTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
streaming_segment_duration: float = 60.0, # seconds
):
"""
Encode speech input into features that can be used by the language model.
This method is called once before generation to process the speech input.
For long audio (>600s by default), uses streaming processing to avoid conv overflow (>2^32).
Segments are processed independently, then concatenated before final sampling.
Args:
speech_tensors: Input audio tensor [batch_size, samples]
speech_masks: Optional mask for speech features
speech_semantic_tensors: Optional pre-computed semantic tokens
streaming_segment_duration: Segment duration in seconds for streaming processing (default: 60s)
"""
if hasattr(self.config, 'torch_dtype') and self.config.torch_dtype is not None:
if isinstance(self.config.torch_dtype, str):
dtype = getattr(torch, self.config.torch_dtype)
else:
dtype = self.config.torch_dtype
else:
dtype = torch.float32
speech_tensors = speech_tensors.to(dtype)
# Ensure proper shape: (batch, samples)
if speech_tensors.ndim == 1:
speech_tensors = speech_tensors.unsqueeze(0)
batch_size, total_samples = speech_tensors.shape
sample_rate = 24000 # fix 24kHz sample rate
# Calculate segment size in samples
segment_samples = int(streaming_segment_duration * sample_rate)
# Decide whether to use streaming based on audio length
use_streaming = total_samples > segment_samples
with torch.no_grad():
if not use_streaming:
# Short audio: direct processing (original behavior)
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
audio_tokens = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
acoustic_features = self.model.acoustic_connector(audio_tokens)
# Encode semantic features
if speech_semantic_tensors is not None:
semantic_features = self.model.semantic_connector(speech_semantic_tensors)
else:
semantic_tokens = self.model.semantic_tokenizer.encode(speech_tensors.unsqueeze(1)).mean
semantic_features = self.model.semantic_connector(semantic_tokens)
else:
# Long audio: streaming processing
# print(f"Using streaming processing for long audio: {total_samples/sample_rate:.1f}s "
# f"(segment size: {streaming_segment_duration}s)")
# Initialize caches for both tokenizers
acoustic_encoder_cache = VibeVoiceTokenizerStreamingCache()
semantic_encoder_cache = VibeVoiceTokenizerStreamingCache()
acoustic_mean_segments = []
semantic_mean_segments = []
sample_indices = torch.arange(batch_size, device=speech_tensors.device)
# Helper function from batch_asr_sft_cache.py
def _iter_segments(total_length: int, segment_length: int):
"""Iterate over audio segments with a given segment length."""
if segment_length <= 0:
raise ValueError("segment_length must be positive")
for start in range(0, total_length, segment_length):
end = min(start + segment_length, total_length)
if end > start:
yield start, end
# Process each segment for both acoustic and semantic tokenizers
segments = list(_iter_segments(total_samples, segment_samples))
num_segments = len(segments)
for seg_idx, (start, end) in enumerate(segments):
chunk = speech_tensors[:, start:end].contiguous()
if chunk.numel() == 0:
continue
# Check if this is the final segment
is_final = (seg_idx == num_segments - 1)
# Encode chunk for acoustic tokenizer (don't sample yet)
acoustic_encoder_output = self.model.acoustic_tokenizer.encode(
chunk.unsqueeze(1),
cache=acoustic_encoder_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
acoustic_mean_segments.append(acoustic_encoder_output.mean)
# Encode chunk for semantic tokenizer (take mean directly)
semantic_encoder_output = self.model.semantic_tokenizer.encode(
chunk.unsqueeze(1),
cache=semantic_encoder_cache,
sample_indices=sample_indices,
use_cache=True,
is_final_chunk=is_final,
)
semantic_mean_segments.append(semantic_encoder_output.mean)
# print(f"Processed {len(acoustic_mean_segments)} segments.")
# Concatenate all acoustic means and sample once
acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous()
acoustic_encoder_output = VibeVoiceTokenizerEncoderOutput(
mean=acoustic_mean_full,
std=self.model.acoustic_tokenizer.fix_std
)
audio_tokens = acoustic_encoder_output.sample(
dist_type=self.model.acoustic_tokenizer.std_dist_type
)[0]
acoustic_features = self.model.acoustic_connector(audio_tokens)
# Concatenate all semantic means
semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
semantic_features = self.model.semantic_connector(semantic_tokens)
# Combine acoustic and semantic features
if speech_masks is not None:
combined_features = acoustic_features[speech_masks] + semantic_features[speech_masks]
else:
combined_features = acoustic_features + semantic_features
return combined_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
# Speech-specific arguments
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
acoustic_input_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutput]:
"""
Forward pass for the model. Handles both training and generation scenarios.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# Process inputs
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# If we have speech input and acoustic_input_mask, encode and insert speech features
if speech_tensors is not None and acoustic_input_mask is not None:
speech_features = self.encode_speech(
speech_tensors=speech_tensors,
speech_masks=speech_masks,
speech_semantic_tensors=speech_semantic_tensors,
)
inputs_embeds[acoustic_input_mask] = speech_features
# Forward through the model
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return VibeVoiceCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
speech_tensors=None,
speech_masks=None,
speech_semantic_tensors=None,
acoustic_input_mask=None,
**kwargs,
):
"""
Prepare inputs for generation step. This method is called by generate()
for each token generation step.
Following Qwen2-VL's approach: speech inputs are only forwarded on the first pass
(when cache_position[0] == 0), and are excluded in subsequent generation steps.
"""
# If we have past key values, we only need to process the new tokens
if past_key_values is not None:
if isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].shape[2]
else:
past_length = past_key_values.get_seq_length()
# Keep only the new tokens
if input_ids is not None and input_ids.shape[1] > past_length:
input_ids = input_ids[:, past_length:]
# Prepare position ids
if position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None and input_ids is not None:
position_ids = position_ids[:, -input_ids.shape[1]:]
# Prepare cache position
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]),
device=input_ids.device if input_ids is not None else inputs_embeds.device
)
# Prepare model inputs
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
# Following Qwen2-VL pattern: only include speech inputs on the first forward pass
# (when cache_position[0] == 0), exclude them in subsequent generation steps
if cache_position is not None and len(cache_position) > 0 and cache_position[0] == 0:
# First forward pass - include speech inputs if provided
model_inputs.update({
"speech_tensors": speech_tensors,
"speech_masks": speech_masks,
"speech_semantic_tensors": speech_semantic_tensors,
"acoustic_input_mask": acoustic_input_mask,
})
else:
# Subsequent generation steps - exclude speech inputs
model_inputs.update({
"speech_tensors": None,
"speech_masks": None,
"speech_semantic_tensors": None,
"acoustic_input_mask": None,
})
# Include any remaining kwargs that might be needed
model_inputs.update(kwargs)
return model_inputs
AutoModel.register(VibeVoiceASRConfig, VibeVoiceASRModel)
AutoModelForCausalLM.register(VibeVoiceASRConfig, VibeVoiceASRForConditionalGeneration)
__all__ = [
"VibeVoiceASRPreTrainedModel",
"VibeVoiceASRModel",
"VibeVoiceASRForConditionalGeneration",
]
@@ -207,8 +207,107 @@ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
"""Id used for padding (returns -100 for loss masking)."""
return self._pad_id
class VibeVoiceASRTextTokenizerFast(Qwen2TokenizerFast):
"""
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
# https://github.com/QwenLM/Qwen2.5-VL/blob/d2240f11656bfe404b9ba56db4e51cd09f522ff1/qwen-vl-finetune/qwenvl/data/data_qwen_packed.py#L57C5-L57C222
self.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|object_ref_start|>", # Speech start (reusing vision tokens)
"<|object_ref_end|>", # Speech end
"<|box_start|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|object_ref_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|object_ref_end|>")
self._speech_pad_id = self.convert_tokens_to_ids("<|box_start|>")
self._eos_id = self.eos_token_id # qwen2 / qwen3
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_pad_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_pad_id
@property
def pad_id(self) -> int:
return self._pad_id
__all__ = [
"VibeVoiceTextTokenizer",
"VibeVoiceTextTokenizerFast",
"VibeVoiceASRTextTokenizerFast",
]
+313 -10
View File
@@ -17,7 +17,7 @@ from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
logger = logging.get_logger(__name__)
@@ -26,14 +26,13 @@ import os
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
APEX_AVAILABLE = True
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
# logger.info("APEX FusedRMSNorm is available and will be used for optimization")
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
APEX_AVAILABLE = False
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
# logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
except ImportError:
APEX_AVAILABLE = False
logger.warning("APEX FusedRMSNorm not available, using native implementation")
# APEX_AVAILABLE=False
# logger.warning("APEX FusedRMSNorm not available, using native implementation")
# Normalization modules
class ConvLayerNorm(nn.LayerNorm):
@@ -297,7 +296,8 @@ class SConv1d(nn.Module):
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False) -> torch.Tensor:
debug: bool = False,
is_final_chunk: bool = False) -> torch.Tensor:
"""
Forward pass with optional streaming support via cache.
@@ -307,6 +307,7 @@ class SConv1d(nn.Module):
sample_indices: Indices identifying each sample for cache management
use_cache: Whether to use cached states for streaming
debug: Whether to print debug information
is_final_chunk: Whether this is the final chunk (adds extra padding for alignment)
Returns:
Output tensor
@@ -322,12 +323,13 @@ class SConv1d(nn.Module):
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
return self._forward_streaming(x, cache, sample_indices, debug, is_final_chunk)
def _forward_streaming(self, x: torch.Tensor,
cache: VibeVoiceTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False) -> torch.Tensor:
debug: bool = False,
is_final_chunk: bool = False) -> torch.Tensor:
"""Streaming forward pass with cache operations kept separate from compiled code"""
B, C, T = x.shape
@@ -350,6 +352,16 @@ class SConv1d(nn.Module):
input_with_context = torch.cat([cached_states, x], dim=2)
else:
input_with_context = x
# For final chunk, add extra padding to ensure ceil behavior (same as non-streaming)
if is_final_chunk:
extra_padding = get_extra_padding_for_conv1d(
input_with_context, self.kernel_size, self.stride, self.padding_total
)
if extra_padding > 0:
input_with_context = pad1d(input_with_context, (0, extra_padding), mode=self.pad_mode)
if debug:
print(f"[DEBUG] Final chunk: added extra_padding={extra_padding}")
if debug:
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
@@ -684,6 +696,135 @@ class Block1D(nn.Module):
return x
class TokenizerEncoder(nn.Module):
"""
Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
Args:
config: Configuration object with model parameters
"""
def __init__(self, config):
super().__init__()
# Extract parameters from config
self.channels = config.channels
self.dimension = config.dimension
self.n_filters = config.n_filters
self.ratios = list(reversed(config.ratios))
self.depths = config.depths
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
self.hop_length = np.prod(self.ratios)
self.causal = config.causal
# Additional config parameters with defaults
kernel_size = getattr(config, "kernel_size", 7)
last_kernel_size = getattr(config, "last_kernel_size", 7)
norm = getattr(config, "norm", "none")
norm_params = getattr(config, "norm_params", {})
pad_mode = getattr(config, "pad_mode", "reflect")
bias = getattr(config, "bias", True)
layernorm = getattr(config, "layernorm", "LN")
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
mixer_layer = getattr(config, "mixer_layer", "conv")
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
disable_last_norm = getattr(config, "disable_last_norm", False)
# determine the norm type based on layernorm
if layernorm == 'LN':
norm_type = ConvLayerNorm
elif layernorm == 'RMSNorm':
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
else:
raise ValueError(f"Unsupported norm type: {layernorm}")
# stem and intermediate downsampling conv layers
stem = nn.Sequential(
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
)
self.downsample_layers = nn.ModuleList()
self.downsample_layers.append(stem)
for i in range(len(self.ratios)):
in_ch = self.n_filters * (2 ** i)
out_ch = self.n_filters * (2 ** (i + 1))
downsample_layer = nn.Sequential(
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
)
self.downsample_layers.append(downsample_layer)
# configure the transformer blocks
layer_type = partial(
Block1D,
mixer_layer=mixer_layer,
layernorm=layernorm,
eps=layernorm_eps,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
layer_scale_init_value=layer_scale_init_value,
)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(len(self.depths)):
in_ch = self.n_filters * (2 ** i)
stage = nn.Sequential(
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
)
self.stages.append(stage)
cur += self.depths[i]
if not disable_last_norm:
self.norm = norm_type(in_ch, eps=layernorm_eps)
else:
self.norm = nn.Identity()
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
for i in range(len(self.depths)):
# Apply downsampling
for layer in self.downsample_layers[i]:
if isinstance(layer, SConv1d):
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
else:
x = layer(x)
# Apply stage (Block1D contains Convlayer which contains SConv1d)
for block in self.stages[i]:
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
# Block1D forward with cache support
residual = x
x = block.norm(x)
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
if block.gamma is not None:
x = x * block.gamma.unsqueeze(-1)
x = residual + x
# FFN part
residual = x
x = block.ffn_norm(x)
x = x.permute(0, 2, 1)
x = block.ffn(x)
x = x.permute(0, 2, 1)
if block.ffn_gamma is not None:
x = x * block.ffn_gamma.unsqueeze(-1)
x = residual + x
else:
x = block(x)
return self.norm(x)
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
return x
class TokenizerDecoder(nn.Module):
"""
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
@@ -821,15 +962,63 @@ class TokenizerDecoder(nn.Module):
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return x
@dataclass
class VibeVoiceTokenizerEncoderOutput:
"""
Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
Args:
mean (`torch.FloatTensor`): The mean parameters of the distribution.
std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
"""
mean: torch.Tensor
std: Optional[Union[float, torch.Tensor]] = None
def sample(self, dist_type='fix'):
"""
Sample from the distribution.
Args:
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
Returns:
`torch.FloatTensor`: Sampled values.
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
"""
if dist_type == 'fix':
x = self.mean + self.std * torch.randn_like(self.mean)
return x, self.std
elif dist_type == 'gaussian':
batch_size = self.mean.size(0)
value = self.std / 0.8
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
while std.dim() < self.mean.dim():
std = std.unsqueeze(-1)
x = self.mean + std * torch.randn_like(self.mean)
return x, std
else:
return self.mean, self.std
def kl(self):
"""Compute KL divergence between this distribution and a standard normal."""
target = torch.zeros_like(self.mean)
return F.mse_loss(self.mean, target, reduction='none')
def mode(self):
"""Return the distribution mode (which is the mean for Gaussian)."""
return self.mean
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
"""VibeVoice speech tokenizer model (only decoder) for acoustic tokens"""
"""VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
config_class = VibeVoiceAcousticTokenizerConfig
base_model_prefix = "vibevoice_acoustic_tokenizer"
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerDecoder"]
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
def __init__(self, config):
super().__init__(config)
@@ -850,6 +1039,21 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
# Default: use reversed encoder depths if decoder_depths is None
decoder_depths = list(reversed(encoder_depths))
# Create encoder config
encoder_config = copy.deepcopy(config)
encoder_config.dimension = config.vae_dim
encoder_config.n_filters = config.encoder_n_filters
encoder_config.ratios = config.encoder_ratios
encoder_config.depths = encoder_depths
encoder_config.norm = config.conv_norm
encoder_config.pad_mode = config.pad_mode
encoder_config.bias = config.conv_bias
encoder_config.layernorm_eps = config.layernorm_eps
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
encoder_config.mixer_layer = config.mixer_layer
encoder_config.layer_scale_init_value = config.layer_scale_init_value
encoder_config.disable_last_norm = config.disable_last_norm
# Create decoder config
decoder_config = copy.deepcopy(config)
decoder_config.dimension = config.vae_dim
@@ -865,6 +1069,8 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
decoder_config.layer_scale_init_value = config.layer_scale_init_value
decoder_config.disable_last_norm = config.disable_last_norm
# Initialize encoder and decoder
self.encoder = TokenizerEncoder(encoder_config)
self.decoder = TokenizerDecoder(decoder_config)
# Initialize weights
@@ -884,6 +1090,24 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
"""Convert audio to latent representations"""
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
@torch.no_grad()
def sampling(self, encoder_output, dist_type=None):
"""Sample from the encoder output distribution"""
dist_type = dist_type or self.std_dist_type
if dist_type == 'fix':
return encoder_output.sample(dist_type='fix')
elif dist_type == 'gaussian':
return encoder_output.sample(dist_type='gaussian')
else:
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
@torch.no_grad()
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
"""Convert latent representations back to audio"""
@@ -895,10 +1119,89 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return audio
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
"""Full forward pass: encode audio to latents, then decode back to audio"""
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
sampled_latents, _ = self.sampling(encoder_output)
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return reconstructed, sampled_latents
class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
"""VibeVoice speech tokenizer model with only encoder for semantic tokens"""
config_class = VibeVoiceSemanticTokenizerConfig
base_model_prefix = "vibevoice_semantic_tokenizer"
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerEncoder"]
def __init__(self, config):
super().__init__(config)
# Parse encoder depths
if isinstance(config.encoder_depths, str):
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
else:
encoder_depths = config.encoder_depths
# Create encoder config
encoder_config = copy.deepcopy(config)
encoder_config.dimension = config.vae_dim
encoder_config.n_filters = config.encoder_n_filters
encoder_config.ratios = config.encoder_ratios
encoder_config.depths = encoder_depths
encoder_config.norm = config.conv_norm
encoder_config.pad_mode = config.pad_mode
encoder_config.bias = config.conv_bias
encoder_config.layernorm_eps = config.layernorm_eps
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
encoder_config.mixer_layer = config.mixer_layer
encoder_config.layer_scale_init_value = config.layer_scale_init_value
encoder_config.disable_last_norm = config.disable_last_norm
# Initialize encoder and decoder
self.encoder = TokenizerEncoder(encoder_config)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights for the model"""
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
"""Convert audio to latent representations"""
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
@torch.no_grad()
def sampling(self, encoder_output, dist_type=None):
"""Sample from the encoder output distribution"""
return encoder_output.sample(dist_type='none')
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
"""Full forward pass: encode audio to latents, then decode back to audio"""
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
return None, sampled_latents
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
__all__ = [
"VibeVoiceTokenizerStreamingCache",
"VibeVoiceAcousticTokenizerModel",
"VibeVoiceSemanticTokenizerModel",
]
+143
View File
@@ -0,0 +1,143 @@
import numpy as np
from subprocess import run
from typing import List, Optional, Union, Dict, Any
COMMON_AUDIO_EXTS = [
'.mp3', '.MP3', '.Mp3', # All case variations of mp3
'.m4a',
'.mp4', '.MP4',
'.wav', '.WAV',
'.m4v',
'.aac',
'.ogg',
'.mov', '.MOV',
'.opus',
'.m4b',
'.flac',
'.wma', '.WMA',
'.rm', '.3gp', '.mpeg', '.flv', '.webm', '.mp2', '.aif', '.aiff', '.oga', '.ogv', '.mpga', '.m3u8', '.amr'
]
def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24000):
"""
Open an audio file and read as mono waveform, optionally resampling.
Returns both the audio data and the original sample rate.
Parameters
----------
file: str
The audio file to open
resample: bool
Whether to resample the audio
target_sr: int
The target sample rate if resampling is requested
Returns
-------
A tuple containing:
- A NumPy array with the audio waveform in float32 dtype
- The original sample rate of the audio file
"""
if not resample:
# First, get the original sample rate
cmd_probe = [
"ffprobe",
"-v", "quiet",
"-show_entries", "stream=sample_rate",
"-of", "default=noprint_wrappers=1:nokey=1",
file
]
original_sr = int(run(cmd_probe, capture_output=True, check=True).stdout.decode().strip())
else:
original_sr = None
# Now load the audio
sr_to_use = target_sr if resample else original_sr
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr_to_use),
"-"
]
out = run(cmd, capture_output=True, check=True).stdout
audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
return audio_data, sr_to_use
class AudioNormalizer:
"""
Audio normalization class for VibeVoice tokenizer.
This class provides audio normalization to ensure consistent input levels
for the VibeVoice tokenizer while maintaining audio quality.
"""
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
"""
Initialize the audio normalizer.
Args:
target_dB_FS (float): Target dB FS level for the audio. Default: -25
eps (float): Small value to avoid division by zero. Default: 1e-6
"""
self.target_dB_FS = target_dB_FS
self.eps = eps
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
"""
Adjust the audio to the target dB FS level.
Args:
audio (np.ndarray): Input audio signal
Returns:
tuple: (normalized_audio, rms, scalar)
"""
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
normalized_audio = audio * scalar
return normalized_audio, rms, scalar
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
"""
Avoid clipping by scaling down if necessary.
Args:
audio (np.ndarray): Input audio signal
scalar (float, optional): Explicit scaling factor
Returns:
tuple: (normalized_audio, scalar)
"""
if scalar is None:
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
else:
scalar = 1.0
return audio / scalar, scalar
def __call__(self, audio: np.ndarray) -> np.ndarray:
"""
Normalize the audio by adjusting to target dB FS and avoiding clipping.
Args:
audio (np.ndarray): Input audio signal
Returns:
np.ndarray: Normalized audio signal
"""
# First adjust to target dB FS
audio, _, _ = self.tailor_dB_FS(audio)
# Then avoid clipping
audio, _ = self.avoid_clipping(audio)
return audio
@@ -0,0 +1,572 @@
"""
Processor class for VibeVoice ASR models.
"""
import os
import json
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import TensorType, logging
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor, AudioNormalizer
try:
from .audio_utils import load_audio_use_ffmpeg
HAS_FFMPEG_UTILS = True
except ImportError:
HAS_FFMPEG_UTILS = False
warnings.warn("audio_utils not available, will fall back to soundfile for audio loading")
logger = logging.get_logger(__name__)
SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format."
class VibeVoiceASRProcessor:
"""
Processor for VibeVoice ASR (Automatic Speech Recognition) models.
This processor handles audio preprocessing and tokenization for ASR tasks,
following the exact format used in training with proper chat templates.
Args:
tokenizer: The text tokenizer for processing text
audio_processor: The audio processor for processing speech
speech_tok_compress_ratio (int): Compression ratio for speech tokenization
target_sample_rate (int): Target sample rate for audio
normalize_audio (bool): Whether to normalize audio input
"""
def __init__(
self,
tokenizer=None,
audio_processor=None,
speech_tok_compress_ratio=320,
target_sample_rate=24000,
normalize_audio=True,
**kwargs
):
self.tokenizer = tokenizer
self.audio_processor = audio_processor or VibeVoiceTokenizerProcessor(
sampling_rate=target_sample_rate,
normalize_audio=normalize_audio
)
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.target_sample_rate = target_sample_rate
self.normalize_audio = normalize_audio
if normalize_audio:
self.audio_normalizer = AudioNormalizer()
else:
self.audio_normalizer = None
# Cache special token IDs
self._cache_special_tokens()
def _cache_special_tokens(self):
"""Cache special token IDs for efficiency."""
# Add safety checks for special tokens
if hasattr(self.tokenizer, 'speech_start_id'):
self.speech_start_id = self.tokenizer.speech_start_id
else:
self.speech_start_id = self.tokenizer.convert_tokens_to_ids("<|speech_start|>")
if hasattr(self.tokenizer, 'speech_end_id'):
self.speech_end_id = self.tokenizer.speech_end_id
else:
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>")
if hasattr(self.tokenizer, 'speech_pad_id'):
self.speech_pad_id = self.tokenizer.speech_pad_id
else:
self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>")
if hasattr(self.tokenizer, 'pad_id'):
self.pad_id = self.tokenizer.pad_id
elif hasattr(self.tokenizer, 'pad_token_id'):
self.pad_id = self.tokenizer.pad_token_id
else:
self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load processor from a pretrained model path.
Args:
pretrained_model_name_or_path: Path to the pretrained model
**kwargs: Additional keyword arguments
Returns:
VibeVoiceASRProcessor: The loaded processor
"""
import json
from transformers.utils import cached_file
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
# Try to load configuration
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
config = {}
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
try:
config_file = cached_file(
pretrained_model_name_or_path,
"preprocessor_config.json",
**kwargs
)
with open(config_file, 'r') as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Could not load preprocessor_config.json: {e}")
logger.warning("Using default configuration")
# Extract parameters
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
target_sample_rate = config.get("target_sample_rate", 24000)
normalize_audio = config.get("normalize_audio", True)
# Load tokenizer
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
if 'qwen' in language_model_pretrained_name.lower():
tokenizer = VibeVoiceASRTextTokenizerFast.from_pretrained(
language_model_pretrained_name,
**kwargs
)
else:
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}")
# Load audio processor
audio_processor = VibeVoiceTokenizerProcessor(
sampling_rate=target_sample_rate,
normalize_audio=normalize_audio,
target_dB_FS=config.get("target_dB_FS", -25),
eps=config.get("eps", 1e-6),
)
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
target_sample_rate=target_sample_rate,
normalize_audio=normalize_audio,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save processor configuration to a directory.
Args:
save_directory: Directory to save the configuration
**kwargs: Additional keyword arguments
"""
import json
os.makedirs(save_directory, exist_ok=True)
# Save processor configuration
processor_config = {
"processor_class": "VibeVoiceASRProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"target_sample_rate": self.target_sample_rate,
"normalize_audio": self.normalize_audio,
"target_dB_FS": -25,
"eps": 1e-6,
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path}")
def __call__(
self,
audio: Optional[Union[str, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, torch.Tensor]]]] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
padding: bool = True,
max_length: Optional[int] = None,
truncation: bool = False,
add_generation_prompt: bool = True,
use_streaming: bool = True,
context_info: Optional[str] = None,
**kwargs
) -> BatchEncoding:
"""
Process audio input for ASR model.
Args:
audio: Audio input(s). Can be:
- str: Path to audio file
- np.ndarray: Audio array
- torch.Tensor: Audio tensor
- List of the above for batch processing
sampling_rate: Sampling rate of input audio
return_tensors: Output format ('pt' for PyTorch, 'np' for NumPy)
padding: Whether to pad batch inputs
max_length: Maximum sequence length
truncation: Whether to truncate long sequences
add_generation_prompt: Whether to add generation prompt for inference
use_streaming: Whether to use streaming mode (True by default, auto False if <60s)
context_info: Optional context information (e.g., hotwords, metadata) to help transcription
Returns:
BatchEncoding with:
- input_ids: Token IDs for the model
- attention_mask: Attention mask
- acoustic_input_mask: Mask indicating speech token positions
- speech_tensors: Processed speech features
- speech_masks: Valid speech masks
- vae_tok_seqlens: Length of each speech segment in tokens
"""
if audio is None:
raise ValueError("Audio input is required for ASR processing")
# Handle single vs batch input
if isinstance(audio, list):
is_batched = True
audio_list = audio
else:
is_batched = False
audio_list = [audio]
# Process each audio input
all_encodings = []
for audio_input in audio_list:
encoding = self._process_single_audio(
audio_input,
sampling_rate=sampling_rate,
add_generation_prompt=add_generation_prompt,
use_streaming=use_streaming,
context_info=context_info,
)
all_encodings.append(encoding)
# Combine into batch
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
max_length=max_length,
truncation=truncation,
return_tensors=return_tensors,
)
return batch_encoding
def _process_single_audio(
self,
audio: Union[str, np.ndarray, torch.Tensor],
sampling_rate: Optional[int] = None,
add_generation_prompt: bool = True,
use_streaming: bool = True,
context_info: Optional[str] = None,
) -> Dict[str, Any]:
"""
Process a single audio input.
Args:
audio: Single audio input
sampling_rate: Audio sampling rate
add_generation_prompt: Whether to add generation prompt
context_info: Optional context information (e.g., hotwords, metadata) to help transcription
Returns:
Dictionary with processed tokens and audio features
"""
# Process audio through audio processor
if isinstance(audio, str):
# Load from file using ffmpeg for better format support
if HAS_FFMPEG_UTILS:
try:
audio_array, file_sr = load_audio_use_ffmpeg(audio, resample=False)
except Exception as e:
# Fall back to soundfile if ffmpeg fails
warnings.warn(f"ffmpeg loading failed, falling back to soundfile: {e}")
import soundfile as sf
audio_array, file_sr = sf.read(audio)
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1) # Convert to mono
else:
import soundfile as sf
audio_array, file_sr = sf.read(audio)
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1) # Convert to mono
# Resample if needed
if file_sr != self.target_sample_rate:
import librosa
audio_array = librosa.resample(
audio_array,
orig_sr=file_sr,
target_sr=self.target_sample_rate
)
elif isinstance(audio, torch.Tensor):
audio_array = audio.cpu().numpy()
if audio_array.ndim > 1:
audio_array = audio_array.squeeze()
else:
audio_array = np.array(audio, dtype=np.float32)
if audio_array.ndim > 1:
audio_array = audio_array.squeeze()
# Ensure float32
audio_array = audio_array.astype(np.float32)
# Normalize if needed
if self.normalize_audio and self.audio_normalizer:
audio_array = self.audio_normalizer(audio_array)
# Calculate audio duration
audio_duration = len(audio_array) / self.target_sample_rate
# Auto-disable streaming for short audio (<60s)
if use_streaming and audio_duration < 60.0:
use_streaming = False
# Calculate token length based on streaming mode
# Non-streaming: uses ceil (encoder adds extra_padding for stride alignment)
# Streaming: uses floor (segments processed independently, no global alignment)
# if use_streaming:
# vae_tok_len = len(audio_array) // self.speech_tok_compress_ratio
# else:
vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio)
# Build token sequence following training format
# 1. System prompt - use apply_chat_template then encode like in training
system_prompt_text = self.tokenizer.apply_chat_template(
[{"role": "system", "content": SYSTEM_PROMPT}],
tokenize=False
)
system_tokens = self.tokenizer.encode(system_prompt_text)
# 2. User input with speech tokens
# Build speech placeholder string
sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id)
sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id)
sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id)
# User suffix with audio duration info
show_keys = ['Start time', 'End time', 'Speaker ID', 'Content']
if context_info and context_info.strip():
user_suffix = f"This is a {audio_duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\nPlease transcribe it with these keys: " + ", ".join(show_keys)
else:
user_suffix = f"This is a {audio_duration:.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys)
user_input_string = ''.join(
[sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]
) + '\n' + user_suffix
user_tokens = self.tokenizer.apply_chat_template(
[{"role": "user", "content": user_input_string}],
tokenize=True
)
# Combine tokens
full_tokens = system_tokens + user_tokens
# Create acoustic input mask
acoustic_input_mask = [1 if token == self.speech_pad_id else 0 for token in full_tokens]
return {
"input_ids": full_tokens,
"acoustic_input_mask": acoustic_input_mask,
"speech": audio_array,
"vae_tok_len": vae_tok_len,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: bool = True,
max_length: Optional[int] = None,
truncation: bool = False,
return_tensors: Optional[str] = None,
) -> BatchEncoding:
"""
Combine multiple encodings into a batch.
Args:
encodings: List of encoded samples
padding: Whether to pad sequences
max_length: Maximum sequence length
truncation: Whether to truncate
return_tensors: Output format
Returns:
BatchEncoding with batched data
"""
# Extract components
input_ids_list = [enc["input_ids"] for enc in encodings]
acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings]
speech_list = [enc["speech"] for enc in encodings]
vae_tok_lens = [enc["vae_tok_len"] for enc in encodings]
# Determine max length for padding
if padding:
if max_length is not None:
target_length = max_length
else:
target_length = max(len(ids) for ids in input_ids_list)
# Pad sequences
padded_input_ids = []
padded_acoustic_masks = []
attention_masks = []
for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list):
# Truncate if needed
if truncation and len(input_ids) > target_length:
input_ids = input_ids[:target_length]
acoustic_mask = acoustic_mask[:target_length]
# Pad sequences to left (for autoregressive generation)
padding_length = target_length - len(input_ids)
padded_ids = [self.pad_id] * padding_length + input_ids
padded_acoustic = [0] * padding_length + acoustic_mask
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_input_ids.append(padded_ids)
padded_acoustic_masks.append(padded_acoustic)
attention_masks.append(attention_mask)
input_ids_list = padded_input_ids
acoustic_masks_list = padded_acoustic_masks
else:
attention_masks = [[1] * len(ids) for ids in input_ids_list]
# Process speech tensors - raw audio is 1D, so we keep it as is
max_speech_length = max(len(s) for s in speech_list)
padded_speeches = np.zeros((len(speech_list), max_speech_length), dtype=np.float32)
speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool)
for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)):
padded_speeches[i, :len(speech)] = speech
speech_masks[i, :vae_len] = True
# Create batch encoding
batch_encoding = BatchEncoding()
if return_tensors == "pt":
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
batch_encoding["acoustic_input_mask"] = torch.tensor(acoustic_masks_list, dtype=torch.bool)
batch_encoding["speech_tensors"] = torch.tensor(padded_speeches, dtype=torch.float32)
batch_encoding["speech_masks"] = torch.tensor(speech_masks, dtype=torch.bool)
# Note: vae_tok_seqlens and speech_type are not included as they are not model inputs
else:
batch_encoding["input_ids"] = input_ids_list if len(input_ids_list) > 1 else input_ids_list[0]
batch_encoding["attention_mask"] = attention_masks if len(attention_masks) > 1 else attention_masks[0]
batch_encoding["acoustic_input_mask"] = acoustic_masks_list if len(acoustic_masks_list) > 1 else acoustic_masks_list[0]
batch_encoding["speech_tensors"] = padded_speeches if len(padded_speeches) > 1 else padded_speeches[0]
batch_encoding["speech_masks"] = speech_masks if len(speech_masks) > 1 else speech_masks[0]
return batch_encoding
def batch_decode(self, *args, **kwargs):
"""
Decode batch of token IDs to text.
Forwards to tokenizer's batch_decode method.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
Decode token IDs to text.
Forwards to tokenizer's decode method.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_transcription(self, text: str) -> List[Dict[str, Any]]:
"""
Post-process the generated transcription text to extract structured data.
Args:
text: Generated text from the model
Returns:
List of dictionaries with transcription segments
"""
try:
# Try to parse as JSON
if "```json" in text:
# Extract JSON from markdown code block
json_start = text.find("```json") + 7
json_end = text.find("```", json_start)
json_str = text[json_start:json_end].strip()
else:
# Try to find JSON array or object
json_start = text.find("[")
if json_start == -1:
json_start = text.find("{")
if json_start != -1:
# Find matching closing bracket
bracket_count = 0
json_end = json_start
for i in range(json_start, len(text)):
if text[i] in "[{":
bracket_count += 1
elif text[i] in "]}":
bracket_count -= 1
if bracket_count == 0:
json_end = i + 1
break
json_str = text[json_start:json_end]
else:
json_str = text
# Parse JSON
result = json.loads(json_str)
# Ensure it's a list
if isinstance(result, dict):
result = [result]
# Validate and clean up the result
cleaned_result = []
for item in result:
if isinstance(item, dict):
cleaned_item = {}
# Map keys to expected format
key_mapping = {
"Start time": "start_time",
"Start": "start_time",
"End time": "end_time",
"End": "end_time",
"Speaker ID": "speaker_id",
"Speaker": "speaker_id",
"Content": "text",
}
for key, mapped_key in key_mapping.items():
if key in item:
cleaned_item[mapped_key] = item[key]
if cleaned_item:
cleaned_result.append(cleaned_item)
return cleaned_result
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON from transcription: {e}")
logger.debug(f"Raw text: {text}")
return []
except Exception as e:
logger.warning(f"Error post-processing transcription: {e}")
return []
@property
def model_input_names(self):
"""Return the list of inputs accepted by the model."""
return ["input_ids", "attention_mask", "acoustic_input_mask", "speech_tensors", "speech_masks"]
__all__ = ["VibeVoiceASRProcessor"]
@@ -13,80 +13,10 @@ import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
from .audio_utils import AudioNormalizer
logger = logging.get_logger(__name__)
class AudioNormalizer:
"""
Audio normalization class for VibeVoice tokenizer.
This class provides audio normalization to ensure consistent input levels
for the VibeVoice tokenizer while maintaining audio quality.
"""
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
"""
Initialize the audio normalizer.
Args:
target_dB_FS (float): Target dB FS level for the audio. Default: -25
eps (float): Small value to avoid division by zero. Default: 1e-6
"""
self.target_dB_FS = target_dB_FS
self.eps = eps
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
"""
Adjust the audio to the target dB FS level.
Args:
audio (np.ndarray): Input audio signal
Returns:
tuple: (normalized_audio, rms, scalar)
"""
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
normalized_audio = audio * scalar
return normalized_audio, rms, scalar
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
"""
Avoid clipping by scaling down if necessary.
Args:
audio (np.ndarray): Input audio signal
scalar (float, optional): Explicit scaling factor
Returns:
tuple: (normalized_audio, scalar)
"""
if scalar is None:
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
else:
scalar = 1.0
return audio / scalar, scalar
def __call__(self, audio: np.ndarray) -> np.ndarray:
"""
Normalize the audio by adjusting to target dB FS and avoiding clipping.
Args:
audio (np.ndarray): Input audio signal
Returns:
np.ndarray: Normalized audio signal
"""
# First adjust to target dB FS
audio, _, _ = self.tailor_dB_FS(audio)
# Then avoid clipping
audio, _ = self.avoid_clipping(audio)
return audio
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
"""