Add VibeVoice-ASR
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 164 KiB |
@@ -23,8 +23,9 @@
|
||||
<img src="https://img.shields.io/badge/Status-New-brightgreen?style=flat" alt="New" />
|
||||
<img src="https://img.shields.io/badge/Feature-Realtime_TTS-blue?style=flat&logo=soundcharts" alt="Realtime TTS" />
|
||||
|
||||
<strong>2026-01-21: 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context.</strong>
|
||||
|
||||
<strong>2025-12-16: 📣 We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.</strong>
|
||||
2025-12-16: 📣 We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.
|
||||
|
||||
2025-12-09: 📣 We added experimental speakers in nine languages (DE, FR, IT, JP, KR, NL, PL, PT, ES) for exploration—welcome to try them out and share your feedback.
|
||||
|
||||
@@ -123,4 +124,4 @@ We do not recommend using VibeVoice in commercial or real-world applications wit
|
||||
|
||||
## Star History
|
||||
|
||||

|
||||

|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,554 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
VibeVoice ASR Batch Inference Demo Script
|
||||
|
||||
This script supports batch inference for ASR model and compares results
|
||||
between batch processing and single-sample processing.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import time
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from functools import wraps
|
||||
|
||||
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
|
||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||
|
||||
|
||||
class VibeVoiceASRBatchInference:
|
||||
"""Batch inference wrapper for VibeVoice ASR model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
attn_implementation: str = "flash_attention_2"
|
||||
):
|
||||
"""
|
||||
Initialize the ASR batch inference pipeline.
|
||||
|
||||
Args:
|
||||
model_path: Path to the pretrained model
|
||||
device: Device to run inference on
|
||||
dtype: Data type for model weights
|
||||
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
|
||||
"""
|
||||
print(f"Loading VibeVoice ASR model from {model_path}")
|
||||
|
||||
# Load processor
|
||||
self.processor = VibeVoiceASRProcessor.from_pretrained(
|
||||
model_path,
|
||||
language_model_pretrained_name="Qwen/Qwen2.5-7B"
|
||||
)
|
||||
|
||||
# Load model with specified attention implementation
|
||||
print(f"Using attention implementation: {attn_implementation}")
|
||||
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
dtype=dtype,
|
||||
device_map=device if device == "auto" else None,
|
||||
attn_implementation=attn_implementation,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if device != "auto":
|
||||
self.model = self.model.to(device)
|
||||
|
||||
self.device = device if device != "auto" else next(self.model.parameters()).device
|
||||
self.dtype = dtype
|
||||
self.model.eval()
|
||||
|
||||
print(f"Model loaded successfully on {self.device}")
|
||||
|
||||
def _prepare_generation_config(
|
||||
self,
|
||||
max_new_tokens: int = 512,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 0.9,
|
||||
do_sample: bool = True,
|
||||
num_beams: int = 1,
|
||||
) -> dict:
|
||||
"""Prepare generation configuration."""
|
||||
config = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"pad_token_id": self.processor.pad_id,
|
||||
"eos_token_id": self.processor.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
# Beam search vs sampling
|
||||
if num_beams > 1:
|
||||
config["num_beams"] = num_beams
|
||||
config["do_sample"] = False # Beam search doesn't use sampling
|
||||
else:
|
||||
config["do_sample"] = do_sample
|
||||
# Only set temperature and top_p when sampling is enabled
|
||||
if do_sample:
|
||||
config["temperature"] = temperature
|
||||
config["top_p"] = top_p
|
||||
|
||||
return config
|
||||
|
||||
def transcribe_batch(
|
||||
self,
|
||||
audio_inputs: List,
|
||||
max_new_tokens: int = 512,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
num_beams: int = 1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transcribe multiple audio files/arrays in a single batch.
|
||||
|
||||
Args:
|
||||
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
temperature: Temperature for sampling
|
||||
top_p: Top-p for nucleus sampling
|
||||
do_sample: Whether to use sampling
|
||||
|
||||
Returns:
|
||||
List of transcription results
|
||||
"""
|
||||
if len(audio_inputs) == 0:
|
||||
return []
|
||||
|
||||
batch_size = len(audio_inputs)
|
||||
print(f"\nProcessing batch of {batch_size} audio(s)...")
|
||||
|
||||
# Process all audio together
|
||||
inputs = self.processor(
|
||||
audio=audio_inputs,
|
||||
sampling_rate=None,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Move to device
|
||||
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in inputs.items()}
|
||||
|
||||
# Print batch info
|
||||
print(f" Input IDs shape: {inputs['input_ids'].shape}")
|
||||
print(f" Speech tensors shape: {inputs['speech_tensors'].shape}")
|
||||
print(f" Attention mask shape: {inputs['attention_mask'].shape}")
|
||||
|
||||
# Generate
|
||||
generation_config = self._prepare_generation_config(
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
**generation_config
|
||||
)
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Decode outputs for each sample in the batch
|
||||
results = []
|
||||
input_length = inputs['input_ids'].shape[1]
|
||||
|
||||
for i, audio_input in enumerate(audio_inputs):
|
||||
# Get generated tokens for this sample (excluding input tokens)
|
||||
generated_ids = output_ids[i, input_length:]
|
||||
|
||||
# Remove padding tokens from the end
|
||||
# Find the first eos_token or pad_token
|
||||
eos_positions = (generated_ids == self.processor.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
|
||||
if len(eos_positions) > 0:
|
||||
generated_ids = generated_ids[:eos_positions[0] + 1]
|
||||
|
||||
generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# Parse structured output
|
||||
try:
|
||||
transcription_segments = self.processor.post_process_transcription(generated_text)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse structured output: {e}")
|
||||
transcription_segments = []
|
||||
|
||||
# Get file name based on input type
|
||||
if isinstance(audio_input, str):
|
||||
file_name = audio_input
|
||||
elif isinstance(audio_input, dict) and 'id' in audio_input:
|
||||
file_name = audio_input['id']
|
||||
else:
|
||||
file_name = f"audio_{i}"
|
||||
|
||||
results.append({
|
||||
"file": file_name,
|
||||
"raw_text": generated_text,
|
||||
"segments": transcription_segments,
|
||||
"generation_time": generation_time / batch_size,
|
||||
})
|
||||
|
||||
print(f" Total generation time: {generation_time:.2f}s")
|
||||
print(f" Average time per sample: {generation_time/batch_size:.2f}s")
|
||||
|
||||
return results
|
||||
|
||||
def transcribe_with_batching(
|
||||
self,
|
||||
audio_inputs: List,
|
||||
batch_size: int = 4,
|
||||
max_new_tokens: int = 512,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
num_beams: int = 1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transcribe multiple audio files/arrays with automatic batching.
|
||||
|
||||
Args:
|
||||
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
|
||||
batch_size: Number of samples per batch
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
temperature: Temperature for sampling
|
||||
top_p: Top-p for nucleus sampling
|
||||
do_sample: Whether to use sampling
|
||||
|
||||
Returns:
|
||||
List of transcription results
|
||||
"""
|
||||
all_results = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(audio_inputs), batch_size):
|
||||
batch_inputs = audio_inputs[i:i + batch_size]
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Processing batch {i//batch_size + 1}/{(len(audio_inputs) + batch_size - 1)//batch_size}")
|
||||
|
||||
batch_results = self.transcribe_batch(
|
||||
batch_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
do_sample=do_sample,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
all_results.extend(batch_results)
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def print_result(result: Dict[str, Any]):
|
||||
"""Pretty print a single transcription result."""
|
||||
print(f"\nFile: {result['file']}")
|
||||
print(f"Generation Time: {result['generation_time']:.2f}s")
|
||||
print(f"\n--- Raw Output ---")
|
||||
print(result['raw_text'][:500] + "..." if len(result['raw_text']) > 500 else result['raw_text'])
|
||||
|
||||
if result['segments']:
|
||||
print(f"\n--- Structured Output ({len(result['segments'])} segments) ---")
|
||||
for seg in result['segments'][:50]: # Show first 50 segments
|
||||
print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] "
|
||||
f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')}...")
|
||||
if len(result['segments']) > 50:
|
||||
print(f" ... and {len(result['segments']) - 50} more segments")
|
||||
|
||||
|
||||
def load_dataset_and_concatenate(
|
||||
dataset_name: str,
|
||||
split: str,
|
||||
max_duration: float,
|
||||
num_audios: int,
|
||||
target_sr: int = 24000
|
||||
) -> Optional[List[np.ndarray]]:
|
||||
"""
|
||||
Load a HuggingFace dataset and concatenate audio samples into long audio chunks.
|
||||
(Note, just for demo purpose, not for benchmark evaluation)
|
||||
|
||||
Args:
|
||||
dataset_name: HuggingFace dataset name (e.g., 'openslr/librispeech_asr')
|
||||
split: Dataset split to use (e.g., 'test', 'test.other')
|
||||
max_duration: Maximum duration in seconds for each concatenated audio
|
||||
num_audios: Number of concatenated audios to create
|
||||
target_sr: Target sample rate (default: 24000)
|
||||
|
||||
Returns:
|
||||
List of concatenated audio arrays, or None if loading fails
|
||||
"""
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
import torchcodec # just for decode audio in datasets
|
||||
except ImportError:
|
||||
print("Please install it with: pip install datasets torchcodec")
|
||||
return None
|
||||
|
||||
print(f"\nLoading dataset: {dataset_name} (split: {split})")
|
||||
print(f"Will create {num_audios} concatenated audio(s), each up to {max_duration:.1f}s ({max_duration/3600:.2f} hours)")
|
||||
|
||||
try:
|
||||
# Use streaming to avoid downloading the entire dataset
|
||||
dataset = load_dataset(dataset_name, split=split, streaming=True)
|
||||
print(f"Dataset loaded in streaming mode")
|
||||
|
||||
concatenated_audios = [] # List of concatenated audio metadata
|
||||
|
||||
# Create multiple concatenated audios based on num_audios
|
||||
current_chunks = []
|
||||
current_duration = 0.0
|
||||
current_samples_used = 0
|
||||
sample_idx = 0
|
||||
|
||||
for sample in dataset:
|
||||
if len(concatenated_audios) >= num_audios:
|
||||
break
|
||||
|
||||
if 'audio' not in sample:
|
||||
continue
|
||||
|
||||
audio_data = sample['audio']
|
||||
audio_array = audio_data['array']
|
||||
sr = audio_data['sampling_rate']
|
||||
|
||||
# Resample if needed
|
||||
if sr != target_sr:
|
||||
duration = len(audio_array) / sr
|
||||
new_length = int(duration * target_sr)
|
||||
audio_array = np.interp(
|
||||
np.linspace(0, len(audio_array) - 1, new_length),
|
||||
np.arange(len(audio_array)),
|
||||
audio_array
|
||||
)
|
||||
|
||||
chunk_duration = len(audio_array) / target_sr
|
||||
|
||||
# Check if adding this chunk exceeds max_duration
|
||||
if current_duration + chunk_duration > max_duration:
|
||||
remaining_duration = max_duration - current_duration
|
||||
if remaining_duration > 0.5: # Only add if > 0.5s remaining
|
||||
samples_to_take = int(remaining_duration * target_sr)
|
||||
current_chunks.append(audio_array[:samples_to_take])
|
||||
current_duration += remaining_duration
|
||||
current_samples_used += 1
|
||||
|
||||
# Save current concatenated audio and start a new one
|
||||
if current_chunks:
|
||||
concatenated_audios.append({
|
||||
'array': np.concatenate(current_chunks),
|
||||
'duration': current_duration,
|
||||
'samples_used': current_samples_used,
|
||||
})
|
||||
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
|
||||
|
||||
# Reset for next concatenated audio
|
||||
current_chunks = []
|
||||
current_duration = 0.0
|
||||
current_samples_used = 0
|
||||
|
||||
if len(concatenated_audios) >= num_audios:
|
||||
break
|
||||
|
||||
current_chunks.append(audio_array)
|
||||
current_duration += chunk_duration
|
||||
current_samples_used += 1
|
||||
|
||||
sample_idx += 1
|
||||
if sample_idx % 100 == 0:
|
||||
print(f" Processed {sample_idx} samples...")
|
||||
|
||||
# Don't forget the last batch if it has content
|
||||
if current_chunks and len(concatenated_audios) < num_audios:
|
||||
concatenated_audios.append({
|
||||
'array': np.concatenate(current_chunks),
|
||||
'duration': current_duration,
|
||||
'samples_used': current_samples_used,
|
||||
})
|
||||
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
|
||||
|
||||
if not concatenated_audios:
|
||||
print("Warning: No audio samples found in dataset")
|
||||
return None
|
||||
|
||||
# Extract arrays and print summary
|
||||
result = [a['array'] for a in concatenated_audios]
|
||||
total_duration = sum(a['duration'] for a in concatenated_audios)
|
||||
total_samples = sum(a['samples_used'] for a in concatenated_audios)
|
||||
print(f"\nCreated {len(result)} concatenated audio(s), total {total_duration:.1f}s ({total_duration/60:.1f} min) from {total_samples} samples")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading dataset: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="VibeVoice ASR Batch Inference Demo")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the model checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_files",
|
||||
type=str,
|
||||
nargs='+',
|
||||
required=False,
|
||||
help="Paths to audio files for transcription"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_dir",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Directory containing audio files for batch transcription"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
help="HuggingFace dataset name (e.g., 'openslr/librispeech_asr')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
default="test",
|
||||
help="Dataset split to use (e.g., 'test', 'test.other', 'test.clean')"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_duration",
|
||||
type=float,
|
||||
default=3600.0,
|
||||
help="Maximum duration in seconds for concatenated dataset audio (default: 3600 = 1 hour)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Batch size for processing multiple files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cuda", "cpu", "auto"],
|
||||
help="Device to run inference on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for sampling (0 = greedy decoding)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Top-p for nucleus sampling"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of beams for beam search. Use 1 for greedy/sampling"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attn_implementation",
|
||||
type=str,
|
||||
default="flash_attention_2",
|
||||
help="Attention implementation to use (default: flash_attention_2)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Collect audio files
|
||||
audio_files = []
|
||||
concatenated_audio = None # For storing concatenated dataset audio
|
||||
|
||||
if args.audio_files:
|
||||
audio_files.extend(args.audio_files)
|
||||
|
||||
if args.audio_dir:
|
||||
import glob
|
||||
for ext in ["*.wav", "*.mp3", "*.flac", "*.mp4", "*.m4a", "*.webm"]:
|
||||
audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext)))
|
||||
|
||||
if args.dataset:
|
||||
concatenated_audio = load_dataset_and_concatenate(
|
||||
dataset_name=args.dataset,
|
||||
split=args.split,
|
||||
max_duration=args.max_duration,
|
||||
num_audios=args.batch_size,
|
||||
)
|
||||
if concatenated_audio is None:
|
||||
return
|
||||
|
||||
if len(audio_files) == 0 and concatenated_audio is None:
|
||||
print("No audio files provided. Please specify --audio_files, --audio_dir, or --dataset.")
|
||||
return
|
||||
|
||||
if audio_files:
|
||||
print(f"\nAudio files to process ({len(audio_files)}):")
|
||||
for f in audio_files:
|
||||
print(f" - {f}")
|
||||
|
||||
if concatenated_audio:
|
||||
print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
|
||||
|
||||
# Initialize model
|
||||
asr = VibeVoiceASRBatchInference(
|
||||
model_path=args.model_path,
|
||||
device=args.device,
|
||||
dtype=torch.bfloat16 if args.device != "cpu" else torch.float32,
|
||||
attn_implementation=args.attn_implementation
|
||||
)
|
||||
|
||||
# If temperature is 0, use greedy decoding (no sampling)
|
||||
do_sample = args.temperature > 0
|
||||
|
||||
# Combine all audio inputs
|
||||
all_audio_inputs = audio_files + (concatenated_audio or [])
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"Processing {len(all_audio_inputs)} audio(s)")
|
||||
print("="*80)
|
||||
|
||||
all_results = asr.transcribe_with_batching(
|
||||
all_audio_inputs,
|
||||
batch_size=args.batch_size,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
do_sample=do_sample,
|
||||
num_beams=args.num_beams,
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\n" + "="*80)
|
||||
print("Results")
|
||||
print("="*80)
|
||||
for result in all_results:
|
||||
print("\n" + "-"*60)
|
||||
print_result(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,62 @@
|
||||
# VibeVoice-ASR: Long-Form Rich Transcription with User Prompts
|
||||
|
||||
**VibeVoice-ASR** is the latest addition to the **VibeVoice** family. While the original VibeVoice / VibeVoice-Realtime focused on expressive TTS, **VibeVoice-ASR** focuses on understanding long-form speech with high precision and rich metadata.
|
||||
|
||||
It is a unified speech-to-text model designed to handle **1-hour long-form audio** in a single pass, generating structured transcriptions containing **Who (Speaker), When (Timestamps), and What (Content)**, with support for **User-Customized Context**.
|
||||
|
||||
## 🔥 Key Features
|
||||
|
||||
- **🕒 60-min Single-Pass Processing**:
|
||||
Unlike conventional ASR models that slice audio into short chunks (often losing global context), VibeVoice ASR accepts up to **60 minutes** of continuous audio input within 64K length. This ensures consistent speaker tracking and semantic coherence across the entire hour.
|
||||
|
||||
- **👤 Optional Context Injection**:
|
||||
Users can provide customized context (e.g., specific names, technical terms, or background info) to guide the recognition process, significantly improving accuracy on domain-specific content.
|
||||
|
||||
- **📝 Rich Transcription (Who, When, What)**:
|
||||
The model performs ASR, Diarization, and Timestamping simultaneously. The output is a structured sequence indicating *who* said *what* at *which time*.
|
||||
|
||||
## 🏗️ Model Architecture
|
||||
|
||||
<p align="center">
|
||||
<img src="../Figures/VibeVoice_ASR_archi.png" alt="VibeVoice ASR Architecture" width="80%">
|
||||
</p>
|
||||
|
||||
## Installation
|
||||
We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment.
|
||||
|
||||
1. Launch docker
|
||||
```bash
|
||||
# NVIDIA PyTorch Container 24.07 ~ 25.12 verified.
|
||||
# Previous versions are also compatible.
|
||||
sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulimit stack=-1:-1 --gpus all --rm -it nvcr.io/nvidia/pytorch:25.12-py3
|
||||
|
||||
## If flash attention is not included in your docker environment, you need to install it manually
|
||||
## Refer to https://github.com/Dao-AILab/flash-attention for installation instructions
|
||||
# pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
2. Install from github
|
||||
```bash
|
||||
git clone https://github.com/microsoft/VibeVoice.git
|
||||
cd VibeVoice
|
||||
pip install -e .[asr]
|
||||
```
|
||||
|
||||
## Usages
|
||||
|
||||
### Usage 1: Launch Gradio demo
|
||||
```bash
|
||||
apt update && apt install ffmpeg -y # for demo
|
||||
|
||||
python demo/vibevoice_asr_gradio_demo.py --model_path microsoft/VibeVoice-ASR --share
|
||||
```
|
||||
|
||||
### Usage 2: Inference from files directly
|
||||
```bash
|
||||
python demo/vibevoice_asr_inference_from_file.py --model_path microsoft/VibeVoice-ASR --audio_files [add a audio path here]
|
||||
```
|
||||
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the [MIT License](../LICENSE).
|
||||
+17
-9
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "vibevoice"
|
||||
version = "0.0.1"
|
||||
version = "1.0.0"
|
||||
authors = [
|
||||
{ name="vibevoice team", email="vibepod@microsoft.com" },
|
||||
{ name="vibevoice team", email="VibeVoice@microsoft.com" },
|
||||
]
|
||||
description = "A model for speech generation with an AR + diffusion architecture."
|
||||
description = "Open-Source Frontier Voice AI."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
classifiers = [
|
||||
@@ -18,8 +18,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"torch",
|
||||
"accelerate==1.6.0",
|
||||
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
|
||||
"accelerate",
|
||||
"llvmlite>=0.40.0",
|
||||
"numba>=0.57.0",
|
||||
"diffusers",
|
||||
@@ -30,12 +29,21 @@ dependencies = [
|
||||
"ml-collections",
|
||||
"absl-py",
|
||||
"gradio",
|
||||
"av",
|
||||
"aiortc",
|
||||
"uvicorn[standard]",
|
||||
"fastapi"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
tts = [
|
||||
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
|
||||
"av",
|
||||
"aiortc",
|
||||
"uvicorn[standard]",
|
||||
"fastapi"
|
||||
]
|
||||
|
||||
asr = [
|
||||
"transformers>=4.51.3", # the versions after 4.51.3 are all support
|
||||
"pydub" # for visualization
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/microsoft/VibeVoice"
|
||||
|
||||
@@ -240,9 +240,112 @@ class VibeVoiceConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
class VibeVoiceASRConfig(PretrainedConfig):
|
||||
model_type = "vibevoice"
|
||||
is_composition = True
|
||||
sub_configs = {
|
||||
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
|
||||
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
|
||||
"decoder_config": Qwen2Config,
|
||||
}
|
||||
# keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Qwen2`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
acoustic_tokenizer_config=None,
|
||||
semantic_tokenizer_config=None,
|
||||
decoder_config=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
# kwargs["_attn_implementation"] = "flash_attention_2"
|
||||
kwargs["_attn_implementation_autoset"] = False
|
||||
|
||||
if acoustic_tokenizer_config is None:
|
||||
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
||||
elif isinstance(acoustic_tokenizer_config, dict):
|
||||
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
|
||||
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
||||
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
||||
|
||||
if semantic_tokenizer_config is None:
|
||||
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
||||
elif isinstance(semantic_tokenizer_config, dict):
|
||||
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
|
||||
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
||||
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
|
||||
# If an instance of the config class is provided
|
||||
self.semantic_tokenizer_config = semantic_tokenizer_config
|
||||
|
||||
if decoder_config is None:
|
||||
self.decoder_config = self.sub_configs["decoder_config"]()
|
||||
elif isinstance(decoder_config, dict):
|
||||
# If a dictionary is provided, instantiate the config class with it
|
||||
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
|
||||
if decoder_config.get("model_type", '') == "qwen2":
|
||||
self.decoder_config = Qwen2Config(**decoder_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
|
||||
elif isinstance(decoder_config, Qwen2Config):
|
||||
# If an instance of the config class is provided
|
||||
self.decoder_config = decoder_config
|
||||
|
||||
# other parameters
|
||||
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
|
||||
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder: bool = False):
|
||||
"""Return the text (decoder) config for generation."""
|
||||
return self.decoder_config
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Return vocab_size from decoder config for generation compatibility."""
|
||||
return self.decoder_config.vocab_size
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
"""Return num_attention_heads from decoder config for Ulysses SP compatibility."""
|
||||
return self.decoder_config.num_attention_heads
|
||||
|
||||
@property
|
||||
def num_key_value_heads(self):
|
||||
"""Return num_key_value_heads from decoder config for Ulysses SP compatibility."""
|
||||
return self.decoder_config.num_key_value_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
"""Return hidden_size from decoder config for model compatibility."""
|
||||
return self.decoder_config.hidden_size
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
"""Return num_hidden_layers from decoder config for Ulysses SP compatibility."""
|
||||
return self.decoder_config.num_hidden_layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
"""Return head_dim from decoder config for Ulysses SP compatibility."""
|
||||
return getattr(self.decoder_config, 'head_dim', self.hidden_size // self.num_attention_heads)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceAcousticTokenizerConfig",
|
||||
"VibeVoiceSemanticTokenizerConfig",
|
||||
"VibeVoiceDiffusionHeadConfig",
|
||||
"VibeVoiceConfig"
|
||||
"VibeVoiceConfig",
|
||||
"VibeVoiceASRConfig"
|
||||
]
|
||||
@@ -0,0 +1,496 @@
|
||||
# copied from https://github.com/vibevoice-community/VibeVoice/blob/main/vibevoice/modular/modeling_vibevoice.py
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union, Callable
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from transformers import modeling_utils
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
||||
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
||||
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
diffusion_loss: Optional[torch.FloatTensor] = None
|
||||
speech_token_num: Optional[int] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceGenerationOutput(ModelOutput):
|
||||
"""
|
||||
Output type for VibeVoice generation.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences.
|
||||
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
||||
List of generated speech waveforms or latents for each speech segment.
|
||||
"""
|
||||
sequences: torch.LongTensor = None
|
||||
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class SpeechConnector(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(input_dim, output_dim)
|
||||
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
||||
self.fc2 = nn.Linear(output_dim, output_dim)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.fc1(features)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoicePreTrainedModel(PreTrainedModel):
|
||||
config_class = VibeVoiceConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, VibeVoiceDiffusionHead):
|
||||
module.initialize_weights()
|
||||
return
|
||||
|
||||
# Use the language model's initializer_range if available
|
||||
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
||||
std = self.config.language_model_config.initializer_range
|
||||
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
||||
std = self.config.decoder_config.initializer_range
|
||||
else:
|
||||
std = 0.02 # Default value
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoiceModel(VibeVoicePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||
if isinstance(config.torch_dtype, str):
|
||||
dtype = getattr(torch, config.torch_dtype)
|
||||
else:
|
||||
dtype = config.torch_dtype
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# Initialize Qwen2 model for language modeling
|
||||
lm_config = config.decoder_config
|
||||
self.language_model = AutoModel.from_config(lm_config)
|
||||
|
||||
# Initialize speech components if needed
|
||||
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
||||
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
||||
|
||||
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
|
||||
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
||||
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
|
||||
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
|
||||
|
||||
# Initialize prediction head for speech generation
|
||||
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
|
||||
|
||||
# Initialize noise scheduler
|
||||
self.noise_scheduler = DPMSolverMultistepScheduler(
|
||||
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
||||
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
||||
prediction_type=config.diffusion_head_config.prediction_type
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
if hasattr(self.language_model, 'embed_tokens'):
|
||||
# If the language model has an embed_tokens attribute, return it
|
||||
return self.language_model.embed_tokens
|
||||
|
||||
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
||||
if attr.orig_name == 'embed_tokens.weight':
|
||||
return getattr(self.language_model, name)
|
||||
assert False, 'should not arrive here'
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.embed_tokens = value
|
||||
|
||||
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||
self.acoustic_tokenizer = acoustic_tokenizer
|
||||
self.semantic_tokenizer = semantic_tokenizer
|
||||
|
||||
# Reset the encoder to evaluation mode
|
||||
if self.acoustic_tokenizer is not None:
|
||||
self.acoustic_tokenizer.eval()
|
||||
|
||||
if self.semantic_tokenizer is not None:
|
||||
self.semantic_tokenizer.eval()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Forward through language model
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return outputs
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = VibeVoiceModel(config)
|
||||
self.vocab_size = config.decoder_config.vocab_size
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.language_model
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
"""
|
||||
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
||||
# The standard PreTrainedModel method will handle the tying.
|
||||
# It typically does a simple parameter object assignment, which is
|
||||
# CORRECT to do BEFORE FSDP wraps the model.
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
if hasattr(input_embeddings, 'weight'):
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
else:
|
||||
# maybe returned input_embeddings a tensor directly
|
||||
output_embeddings.weight = input_embeddings
|
||||
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
output_embeddings.bias.data = nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
print("Tied input and output embeddings using standard assignment.")
|
||||
else:
|
||||
print("tie_word_embeddings is False, not tying weights.")
|
||||
|
||||
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
||||
# The key is to avoid calling it after accelerator.prepare().
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
# Your current implementation using data.copy_ is good practice,
|
||||
# but the best way is to not call this after prepare().
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def forward_speech_features(
|
||||
self,
|
||||
speech_tensors=None,
|
||||
speech_masks=None,
|
||||
speech_type="audio",
|
||||
return_unmask=False
|
||||
):
|
||||
if speech_tensors is None:
|
||||
# Use config to get vae_dim instead of non-existent self.args
|
||||
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
|
||||
connect_features = self.model.acoustic_connector(audio_features)
|
||||
return audio_features, connect_features
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if speech_type == "audio":
|
||||
with torch.no_grad():
|
||||
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
|
||||
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
|
||||
elif speech_type == "vae":
|
||||
# Use config to get vae_dim instead of non-existent self.args
|
||||
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
|
||||
|
||||
# gaussian sample from the speech_mode
|
||||
batch_size = speech_mode.size(0)
|
||||
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
||||
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
|
||||
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
||||
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
|
||||
else:
|
||||
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
||||
|
||||
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
|
||||
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
|
||||
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
||||
|
||||
# Only use distributed operations if the process group is initialized
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
||||
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
||||
world_size = dist.get_world_size()
|
||||
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
|
||||
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
||||
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||
else:
|
||||
# Single process case
|
||||
self.model.speech_scaling_factor.copy_(scaling_factor)
|
||||
self.model.speech_bias_factor.copy_(bias_factor)
|
||||
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||
|
||||
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
|
||||
|
||||
connect_features = self.model.acoustic_connector(audio_features)
|
||||
if return_unmask:
|
||||
return audio_features, connect_features
|
||||
return audio_features[speech_masks], connect_features[speech_masks]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# New arguments for speech processing and loss calculation
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
||||
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
||||
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
||||
ddpm_batch_mul: int = 1,
|
||||
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
||||
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
x = self.get_input_embeddings()(input_ids)
|
||||
|
||||
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
|
||||
if speeches_loss_input is not None:
|
||||
# only part audio need diffuse
|
||||
speech_all_features, speech_all_connect_features = self.forward_speech_features(
|
||||
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||
speech_masks=speech_masks,
|
||||
speech_type=kwargs.get("speech_type", "audio"),
|
||||
return_unmask=True
|
||||
)
|
||||
if speech_tensors is not None:
|
||||
if semantic_speech_all_connect_features is not None:
|
||||
x[acoustic_input_mask] = (
|
||||
speech_all_connect_features[speech_masks]
|
||||
+ semantic_speech_all_connect_features[speech_masks]
|
||||
)
|
||||
else:
|
||||
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
||||
|
||||
# Select only the target segments' latents for diffusion loss.
|
||||
# Both masks are [num_segments, max_latent_len]; using 2D mask on [B,T,D] selects [N_true, D].
|
||||
target_latent_mask = speeches_loss_input & speech_masks
|
||||
speech_features = speech_all_features[target_latent_mask]
|
||||
speech_connect_features = speech_all_connect_features[target_latent_mask]
|
||||
else:
|
||||
speech_features, speech_connect_features = self.forward_speech_features(
|
||||
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||
speech_masks=speech_masks,
|
||||
speech_type=kwargs.get("speech_type", "audio"),
|
||||
)
|
||||
if speech_tensors is not None:
|
||||
x[acoustic_input_mask] = speech_connect_features
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=x,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=False,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = self.lm_head(hidden_states)
|
||||
# logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# The custom CE loss with masking is calculated in the training script.
|
||||
# We leave the standard loss calculation here as None.
|
||||
pass
|
||||
|
||||
# --- Diffusion Loss Calculation ---
|
||||
diffusion_loss = None
|
||||
# This block is executed only if we are in a context that involves speech.
|
||||
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
||||
condition_features = hidden_states[acoustic_loss_mask]
|
||||
|
||||
speech_len, latent_size = speech_features.shape
|
||||
|
||||
noise = torch.randn(
|
||||
(speech_len * ddpm_batch_mul, latent_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
timesteps = torch.multinomial(
|
||||
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
||||
speech_len * ddpm_batch_mul,
|
||||
replacement=True,
|
||||
).to(hidden_states.device)
|
||||
|
||||
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||
|
||||
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
||||
speech_features_repeated, noise, timesteps
|
||||
)
|
||||
|
||||
model_output = self.model.prediction_head(
|
||||
noisy_speech_features,
|
||||
timesteps.type_as(x),
|
||||
condition_features_repeated
|
||||
)
|
||||
|
||||
prediction_type = self.config.diffusion_head_config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target_for_loss = noise
|
||||
elif prediction_type == "v_prediction":
|
||||
target_for_loss = self.model.noise_scheduler.get_velocity(
|
||||
speech_features_repeated, noise, timesteps
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
|
||||
|
||||
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
|
||||
if latent_size > 0 and ddpm_batch_mul > 0:
|
||||
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
|
||||
else:
|
||||
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
||||
|
||||
else:
|
||||
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
||||
# but we are in a speech context.
|
||||
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
||||
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
||||
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
||||
# --- End Diffusion Loss Calculation ---
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
||||
return (loss, diffusion_loss) + output
|
||||
|
||||
return VibeVoiceCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
diffusion_loss=diffusion_loss,
|
||||
speech_token_num=speech_len if speech_tensors is not None else 0,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
|
||||
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceModel",
|
||||
"VibeVoicePreTrainedModel",
|
||||
"VibeVoiceForConditionalGeneration",
|
||||
"VibeVoiceCausalLMOutputWithPast",
|
||||
"VibeVoiceGenerationOutput",
|
||||
]
|
||||
@@ -0,0 +1,520 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast
|
||||
from transformers import modeling_utils
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.generation import GenerationMixin
|
||||
|
||||
from .modular_vibevoice_tokenizer import (
|
||||
VibeVoiceTokenizerStreamingCache,
|
||||
VibeVoiceTokenizerEncoderOutput
|
||||
)
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceASRConfig
|
||||
from .modeling_vibevoice import (
|
||||
VibeVoiceCausalLMOutputWithPast,
|
||||
SpeechConnector
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoiceASRPreTrainedModel(PreTrainedModel):
|
||||
config_class = VibeVoiceASRConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
||||
# Use the language model's initializer_range if available
|
||||
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
||||
std = self.config.language_model_config.initializer_range
|
||||
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
||||
std = self.config.decoder_config.initializer_range
|
||||
else:
|
||||
std = 0.02 # Default value
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
# @auto_docstring
|
||||
class VibeVoiceASRModel(VibeVoiceASRPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||
if isinstance(config.torch_dtype, str):
|
||||
dtype = getattr(torch, config.torch_dtype)
|
||||
else:
|
||||
dtype = config.torch_dtype
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# Initialize Qwen2 model for language modeling
|
||||
lm_config = config.decoder_config
|
||||
self.language_model = AutoModel.from_config(lm_config)
|
||||
|
||||
# Initialize speech components if needed
|
||||
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
||||
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
||||
|
||||
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
if hasattr(self.language_model, 'embed_tokens'):
|
||||
# If the language model has an embed_tokens attribute, return it
|
||||
return self.language_model.embed_tokens
|
||||
|
||||
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
||||
if attr.orig_name == 'embed_tokens.weight':
|
||||
return getattr(self.language_model, name)
|
||||
assert False, 'should not arrive here'
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.embed_tokens = value
|
||||
|
||||
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||
self.acoustic_tokenizer = acoustic_tokenizer
|
||||
self.semantic_tokenizer = semantic_tokenizer
|
||||
|
||||
# Reset the encoder to evaluation mode
|
||||
if self.acoustic_tokenizer is not None:
|
||||
self.acoustic_tokenizer.eval()
|
||||
|
||||
if self.semantic_tokenizer is not None:
|
||||
self.semantic_tokenizer.eval()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Forward through language model
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return outputs
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
VibeVoice model for Automatic Speech Recognition (ASR) with language modeling head for conditional generation.
|
||||
This class is designed for inference and generation tasks.
|
||||
"""
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = VibeVoiceASRModel(config)
|
||||
self.vocab_size = config.decoder_config.vocab_size
|
||||
|
||||
# Determine the dtype to use
|
||||
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||
if isinstance(config.torch_dtype, str):
|
||||
dtype = getattr(torch, config.torch_dtype)
|
||||
else:
|
||||
dtype = config.torch_dtype
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
# Initialize lm_head with the correct dtype
|
||||
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False).to(dtype)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.language_model
|
||||
|
||||
def tie_weights(self):
|
||||
"""Tie the weights between the input embeddings and the output embeddings."""
|
||||
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
if hasattr(input_embeddings, 'weight'):
|
||||
output_embeddings.weight = input_embeddings.weight
|
||||
else:
|
||||
output_embeddings.weight = input_embeddings
|
||||
|
||||
def encode_speech(
|
||||
self,
|
||||
speech_tensors: torch.FloatTensor,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||
streaming_segment_duration: float = 60.0, # seconds
|
||||
):
|
||||
"""
|
||||
Encode speech input into features that can be used by the language model.
|
||||
This method is called once before generation to process the speech input.
|
||||
|
||||
For long audio (>600s by default), uses streaming processing to avoid conv overflow (>2^32).
|
||||
Segments are processed independently, then concatenated before final sampling.
|
||||
|
||||
Args:
|
||||
speech_tensors: Input audio tensor [batch_size, samples]
|
||||
speech_masks: Optional mask for speech features
|
||||
speech_semantic_tensors: Optional pre-computed semantic tokens
|
||||
streaming_segment_duration: Segment duration in seconds for streaming processing (default: 60s)
|
||||
"""
|
||||
if hasattr(self.config, 'torch_dtype') and self.config.torch_dtype is not None:
|
||||
if isinstance(self.config.torch_dtype, str):
|
||||
dtype = getattr(torch, self.config.torch_dtype)
|
||||
else:
|
||||
dtype = self.config.torch_dtype
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
speech_tensors = speech_tensors.to(dtype)
|
||||
|
||||
# Ensure proper shape: (batch, samples)
|
||||
if speech_tensors.ndim == 1:
|
||||
speech_tensors = speech_tensors.unsqueeze(0)
|
||||
|
||||
batch_size, total_samples = speech_tensors.shape
|
||||
sample_rate = 24000 # fix 24kHz sample rate
|
||||
|
||||
# Calculate segment size in samples
|
||||
segment_samples = int(streaming_segment_duration * sample_rate)
|
||||
|
||||
# Decide whether to use streaming based on audio length
|
||||
use_streaming = total_samples > segment_samples
|
||||
|
||||
with torch.no_grad():
|
||||
if not use_streaming:
|
||||
# Short audio: direct processing (original behavior)
|
||||
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
|
||||
audio_tokens = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||
acoustic_features = self.model.acoustic_connector(audio_tokens)
|
||||
|
||||
# Encode semantic features
|
||||
if speech_semantic_tensors is not None:
|
||||
semantic_features = self.model.semantic_connector(speech_semantic_tensors)
|
||||
else:
|
||||
semantic_tokens = self.model.semantic_tokenizer.encode(speech_tensors.unsqueeze(1)).mean
|
||||
semantic_features = self.model.semantic_connector(semantic_tokens)
|
||||
else:
|
||||
# Long audio: streaming processing
|
||||
# print(f"Using streaming processing for long audio: {total_samples/sample_rate:.1f}s "
|
||||
# f"(segment size: {streaming_segment_duration}s)")
|
||||
|
||||
# Initialize caches for both tokenizers
|
||||
acoustic_encoder_cache = VibeVoiceTokenizerStreamingCache()
|
||||
semantic_encoder_cache = VibeVoiceTokenizerStreamingCache()
|
||||
acoustic_mean_segments = []
|
||||
semantic_mean_segments = []
|
||||
sample_indices = torch.arange(batch_size, device=speech_tensors.device)
|
||||
|
||||
# Helper function from batch_asr_sft_cache.py
|
||||
def _iter_segments(total_length: int, segment_length: int):
|
||||
"""Iterate over audio segments with a given segment length."""
|
||||
if segment_length <= 0:
|
||||
raise ValueError("segment_length must be positive")
|
||||
for start in range(0, total_length, segment_length):
|
||||
end = min(start + segment_length, total_length)
|
||||
if end > start:
|
||||
yield start, end
|
||||
|
||||
# Process each segment for both acoustic and semantic tokenizers
|
||||
segments = list(_iter_segments(total_samples, segment_samples))
|
||||
num_segments = len(segments)
|
||||
for seg_idx, (start, end) in enumerate(segments):
|
||||
chunk = speech_tensors[:, start:end].contiguous()
|
||||
if chunk.numel() == 0:
|
||||
continue
|
||||
|
||||
# Check if this is the final segment
|
||||
is_final = (seg_idx == num_segments - 1)
|
||||
|
||||
# Encode chunk for acoustic tokenizer (don't sample yet)
|
||||
acoustic_encoder_output = self.model.acoustic_tokenizer.encode(
|
||||
chunk.unsqueeze(1),
|
||||
cache=acoustic_encoder_cache,
|
||||
sample_indices=sample_indices,
|
||||
use_cache=True,
|
||||
is_final_chunk=is_final,
|
||||
)
|
||||
acoustic_mean_segments.append(acoustic_encoder_output.mean)
|
||||
|
||||
# Encode chunk for semantic tokenizer (take mean directly)
|
||||
semantic_encoder_output = self.model.semantic_tokenizer.encode(
|
||||
chunk.unsqueeze(1),
|
||||
cache=semantic_encoder_cache,
|
||||
sample_indices=sample_indices,
|
||||
use_cache=True,
|
||||
is_final_chunk=is_final,
|
||||
)
|
||||
semantic_mean_segments.append(semantic_encoder_output.mean)
|
||||
|
||||
# print(f"Processed {len(acoustic_mean_segments)} segments.")
|
||||
# Concatenate all acoustic means and sample once
|
||||
acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous()
|
||||
acoustic_encoder_output = VibeVoiceTokenizerEncoderOutput(
|
||||
mean=acoustic_mean_full,
|
||||
std=self.model.acoustic_tokenizer.fix_std
|
||||
)
|
||||
audio_tokens = acoustic_encoder_output.sample(
|
||||
dist_type=self.model.acoustic_tokenizer.std_dist_type
|
||||
)[0]
|
||||
acoustic_features = self.model.acoustic_connector(audio_tokens)
|
||||
|
||||
# Concatenate all semantic means
|
||||
semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
|
||||
semantic_features = self.model.semantic_connector(semantic_tokens)
|
||||
|
||||
# Combine acoustic and semantic features
|
||||
if speech_masks is not None:
|
||||
combined_features = acoustic_features[speech_masks] + semantic_features[speech_masks]
|
||||
else:
|
||||
combined_features = acoustic_features + semantic_features
|
||||
|
||||
return combined_features
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
# Speech-specific arguments
|
||||
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||
speech_masks: Optional[torch.BoolTensor] = None,
|
||||
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutput]:
|
||||
"""
|
||||
Forward pass for the model. Handles both training and generation scenarios.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# Process inputs
|
||||
if inputs_embeds is None and input_ids is not None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# If we have speech input and acoustic_input_mask, encode and insert speech features
|
||||
if speech_tensors is not None and acoustic_input_mask is not None:
|
||||
speech_features = self.encode_speech(
|
||||
speech_tensors=speech_tensors,
|
||||
speech_masks=speech_masks,
|
||||
speech_semantic_tensors=speech_semantic_tensors,
|
||||
)
|
||||
inputs_embeds[acoustic_input_mask] = speech_features
|
||||
|
||||
# Forward through the model
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return VibeVoiceCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
speech_tensors=None,
|
||||
speech_masks=None,
|
||||
speech_semantic_tensors=None,
|
||||
acoustic_input_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepare inputs for generation step. This method is called by generate()
|
||||
for each token generation step.
|
||||
|
||||
Following Qwen2-VL's approach: speech inputs are only forwarded on the first pass
|
||||
(when cache_position[0] == 0), and are excluded in subsequent generation steps.
|
||||
"""
|
||||
# If we have past key values, we only need to process the new tokens
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, tuple):
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
else:
|
||||
past_length = past_key_values.get_seq_length()
|
||||
|
||||
# Keep only the new tokens
|
||||
if input_ids is not None and input_ids.shape[1] > past_length:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
|
||||
# Prepare position ids
|
||||
if position_ids is None and attention_mask is not None:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values is not None and input_ids is not None:
|
||||
position_ids = position_ids[:, -input_ids.shape[1]:]
|
||||
|
||||
# Prepare cache position
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]),
|
||||
device=input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
)
|
||||
|
||||
# Prepare model inputs
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
|
||||
# Following Qwen2-VL pattern: only include speech inputs on the first forward pass
|
||||
# (when cache_position[0] == 0), exclude them in subsequent generation steps
|
||||
if cache_position is not None and len(cache_position) > 0 and cache_position[0] == 0:
|
||||
# First forward pass - include speech inputs if provided
|
||||
model_inputs.update({
|
||||
"speech_tensors": speech_tensors,
|
||||
"speech_masks": speech_masks,
|
||||
"speech_semantic_tensors": speech_semantic_tensors,
|
||||
"acoustic_input_mask": acoustic_input_mask,
|
||||
})
|
||||
else:
|
||||
# Subsequent generation steps - exclude speech inputs
|
||||
model_inputs.update({
|
||||
"speech_tensors": None,
|
||||
"speech_masks": None,
|
||||
"speech_semantic_tensors": None,
|
||||
"acoustic_input_mask": None,
|
||||
})
|
||||
|
||||
# Include any remaining kwargs that might be needed
|
||||
model_inputs.update(kwargs)
|
||||
|
||||
return model_inputs
|
||||
|
||||
AutoModel.register(VibeVoiceASRConfig, VibeVoiceASRModel)
|
||||
AutoModelForCausalLM.register(VibeVoiceASRConfig, VibeVoiceASRForConditionalGeneration)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceASRPreTrainedModel",
|
||||
"VibeVoiceASRModel",
|
||||
"VibeVoiceASRForConditionalGeneration",
|
||||
]
|
||||
@@ -207,8 +207,107 @@ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
|
||||
"""Id used for padding (returns -100 for loss masking)."""
|
||||
return self._pad_id
|
||||
|
||||
class VibeVoiceASRTextTokenizerFast(Qwen2TokenizerFast):
|
||||
"""
|
||||
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
|
||||
Based on the Qwen2 tokenizer with additional special tokens for speech.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`, *optional*):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`, *optional*):
|
||||
Path to the merges file.
|
||||
tokenizer_file (`str`, *optional*):
|
||||
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
|
||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The unknown token.
|
||||
bos_token (`str`, *optional*):
|
||||
The beginning of sequence token. Not used for vibevoice.
|
||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The end of sequence token.
|
||||
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||
The token used for padding.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token=None,
|
||||
eos_token="<|endoftext|>",
|
||||
pad_token="<|endoftext|>",
|
||||
add_prefix_space=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add VibeVoice-specific special tokens
|
||||
self._add_vibevoice_special_tokens()
|
||||
|
||||
# https://github.com/QwenLM/Qwen2.5-VL/blob/d2240f11656bfe404b9ba56db4e51cd09f522ff1/qwen-vl-finetune/qwenvl/data/data_qwen_packed.py#L57C5-L57C222
|
||||
self.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
|
||||
def _add_vibevoice_special_tokens(self):
|
||||
"""Add VibeVoice-specific special tokens."""
|
||||
special_tokens = {
|
||||
"additional_special_tokens": [
|
||||
"<|object_ref_start|>", # Speech start (reusing vision tokens)
|
||||
"<|object_ref_end|>", # Speech end
|
||||
"<|box_start|>", # Speech diffusion pad
|
||||
]
|
||||
}
|
||||
num_added = self.add_special_tokens(special_tokens)
|
||||
|
||||
# Cache special token IDs
|
||||
self._speech_start_id = self.convert_tokens_to_ids("<|object_ref_start|>")
|
||||
self._speech_end_id = self.convert_tokens_to_ids("<|object_ref_end|>")
|
||||
self._speech_pad_id = self.convert_tokens_to_ids("<|box_start|>")
|
||||
|
||||
self._eos_id = self.eos_token_id # qwen2 / qwen3
|
||||
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
|
||||
|
||||
return num_added
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
"""Id of the end of sequence token."""
|
||||
return self._eos_id
|
||||
|
||||
@property
|
||||
def speech_start_id(self) -> int:
|
||||
"""Id of the speech start token."""
|
||||
return self._speech_start_id
|
||||
|
||||
@property
|
||||
def speech_end_id(self) -> int:
|
||||
"""Id of the speech end token."""
|
||||
return self._speech_end_id
|
||||
|
||||
@property
|
||||
def speech_pad_id(self) -> int:
|
||||
"""Id of the speech diffusion token."""
|
||||
return self._speech_pad_id
|
||||
|
||||
@property
|
||||
def pad_id(self) -> int:
|
||||
return self._pad_id
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceTextTokenizer",
|
||||
"VibeVoiceTextTokenizerFast",
|
||||
"VibeVoiceASRTextTokenizerFast",
|
||||
]
|
||||
@@ -17,7 +17,7 @@ from transformers.utils import logging
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig
|
||||
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -26,14 +26,13 @@ import os
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
||||
APEX_AVAILABLE = True
|
||||
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
||||
# logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
||||
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
|
||||
APEX_AVAILABLE = False
|
||||
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
||||
# logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
||||
# APEX_AVAILABLE=False
|
||||
# logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
||||
|
||||
# Normalization modules
|
||||
class ConvLayerNorm(nn.LayerNorm):
|
||||
@@ -297,7 +296,8 @@ class SConv1d(nn.Module):
|
||||
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
||||
sample_indices: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
debug: bool = False) -> torch.Tensor:
|
||||
debug: bool = False,
|
||||
is_final_chunk: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with optional streaming support via cache.
|
||||
|
||||
@@ -307,6 +307,7 @@ class SConv1d(nn.Module):
|
||||
sample_indices: Indices identifying each sample for cache management
|
||||
use_cache: Whether to use cached states for streaming
|
||||
debug: Whether to print debug information
|
||||
is_final_chunk: Whether this is the final chunk (adds extra padding for alignment)
|
||||
|
||||
Returns:
|
||||
Output tensor
|
||||
@@ -322,12 +323,13 @@ class SConv1d(nn.Module):
|
||||
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
||||
assert len(sample_indices) == B, "sample_indices must match batch size"
|
||||
|
||||
return self._forward_streaming(x, cache, sample_indices, debug)
|
||||
return self._forward_streaming(x, cache, sample_indices, debug, is_final_chunk)
|
||||
|
||||
def _forward_streaming(self, x: torch.Tensor,
|
||||
cache: VibeVoiceTokenizerStreamingCache,
|
||||
sample_indices: torch.Tensor,
|
||||
debug: bool = False) -> torch.Tensor:
|
||||
debug: bool = False,
|
||||
is_final_chunk: bool = False) -> torch.Tensor:
|
||||
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
||||
B, C, T = x.shape
|
||||
|
||||
@@ -350,6 +352,16 @@ class SConv1d(nn.Module):
|
||||
input_with_context = torch.cat([cached_states, x], dim=2)
|
||||
else:
|
||||
input_with_context = x
|
||||
|
||||
# For final chunk, add extra padding to ensure ceil behavior (same as non-streaming)
|
||||
if is_final_chunk:
|
||||
extra_padding = get_extra_padding_for_conv1d(
|
||||
input_with_context, self.kernel_size, self.stride, self.padding_total
|
||||
)
|
||||
if extra_padding > 0:
|
||||
input_with_context = pad1d(input_with_context, (0, extra_padding), mode=self.pad_mode)
|
||||
if debug:
|
||||
print(f"[DEBUG] Final chunk: added extra_padding={extra_padding}")
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
||||
@@ -684,6 +696,135 @@ class Block1D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class TokenizerEncoder(nn.Module):
|
||||
"""
|
||||
Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
|
||||
|
||||
Args:
|
||||
config: Configuration object with model parameters
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# Extract parameters from config
|
||||
self.channels = config.channels
|
||||
self.dimension = config.dimension
|
||||
self.n_filters = config.n_filters
|
||||
self.ratios = list(reversed(config.ratios))
|
||||
self.depths = config.depths
|
||||
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
self.causal = config.causal
|
||||
|
||||
# Additional config parameters with defaults
|
||||
kernel_size = getattr(config, "kernel_size", 7)
|
||||
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
||||
norm = getattr(config, "norm", "none")
|
||||
norm_params = getattr(config, "norm_params", {})
|
||||
pad_mode = getattr(config, "pad_mode", "reflect")
|
||||
bias = getattr(config, "bias", True)
|
||||
layernorm = getattr(config, "layernorm", "LN")
|
||||
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
||||
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
||||
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
||||
mixer_layer = getattr(config, "mixer_layer", "conv")
|
||||
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
||||
disable_last_norm = getattr(config, "disable_last_norm", False)
|
||||
|
||||
# determine the norm type based on layernorm
|
||||
if layernorm == 'LN':
|
||||
norm_type = ConvLayerNorm
|
||||
elif layernorm == 'RMSNorm':
|
||||
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm type: {layernorm}")
|
||||
|
||||
# stem and intermediate downsampling conv layers
|
||||
stem = nn.Sequential(
|
||||
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
||||
)
|
||||
|
||||
self.downsample_layers = nn.ModuleList()
|
||||
self.downsample_layers.append(stem)
|
||||
for i in range(len(self.ratios)):
|
||||
in_ch = self.n_filters * (2 ** i)
|
||||
out_ch = self.n_filters * (2 ** (i + 1))
|
||||
downsample_layer = nn.Sequential(
|
||||
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
||||
)
|
||||
self.downsample_layers.append(downsample_layer)
|
||||
|
||||
# configure the transformer blocks
|
||||
layer_type = partial(
|
||||
Block1D,
|
||||
mixer_layer=mixer_layer,
|
||||
layernorm=layernorm,
|
||||
eps=layernorm_eps,
|
||||
causal=self.causal,
|
||||
pad_mode=pad_mode,
|
||||
norm=norm,
|
||||
bias=bias,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
|
||||
self.stages = nn.ModuleList()
|
||||
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
||||
cur = 0
|
||||
|
||||
for i in range(len(self.depths)):
|
||||
in_ch = self.n_filters * (2 ** i)
|
||||
stage = nn.Sequential(
|
||||
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
||||
)
|
||||
self.stages.append(stage)
|
||||
cur += self.depths[i]
|
||||
|
||||
if not disable_last_norm:
|
||||
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
||||
|
||||
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||
for i in range(len(self.depths)):
|
||||
# Apply downsampling
|
||||
for layer in self.downsample_layers[i]:
|
||||
if isinstance(layer, SConv1d):
|
||||
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
||||
for block in self.stages[i]:
|
||||
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
||||
# Block1D forward with cache support
|
||||
residual = x
|
||||
x = block.norm(x)
|
||||
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||
if block.gamma is not None:
|
||||
x = x * block.gamma.unsqueeze(-1)
|
||||
x = residual + x
|
||||
|
||||
# FFN part
|
||||
residual = x
|
||||
x = block.ffn_norm(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = block.ffn(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
if block.ffn_gamma is not None:
|
||||
x = x * block.ffn_gamma.unsqueeze(-1)
|
||||
x = residual + x
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||
return x
|
||||
|
||||
|
||||
class TokenizerDecoder(nn.Module):
|
||||
"""
|
||||
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
|
||||
@@ -821,15 +962,63 @@ class TokenizerDecoder(nn.Module):
|
||||
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class VibeVoiceTokenizerEncoderOutput:
|
||||
"""
|
||||
Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
|
||||
|
||||
Args:
|
||||
mean (`torch.FloatTensor`): The mean parameters of the distribution.
|
||||
std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
|
||||
"""
|
||||
mean: torch.Tensor
|
||||
std: Optional[Union[float, torch.Tensor]] = None
|
||||
|
||||
def sample(self, dist_type='fix'):
|
||||
"""
|
||||
Sample from the distribution.
|
||||
|
||||
Args:
|
||||
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Sampled values.
|
||||
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
|
||||
"""
|
||||
if dist_type == 'fix':
|
||||
x = self.mean + self.std * torch.randn_like(self.mean)
|
||||
return x, self.std
|
||||
elif dist_type == 'gaussian':
|
||||
batch_size = self.mean.size(0)
|
||||
value = self.std / 0.8
|
||||
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
|
||||
|
||||
while std.dim() < self.mean.dim():
|
||||
std = std.unsqueeze(-1)
|
||||
|
||||
x = self.mean + std * torch.randn_like(self.mean)
|
||||
return x, std
|
||||
else:
|
||||
return self.mean, self.std
|
||||
|
||||
def kl(self):
|
||||
"""Compute KL divergence between this distribution and a standard normal."""
|
||||
target = torch.zeros_like(self.mean)
|
||||
return F.mse_loss(self.mean, target, reduction='none')
|
||||
|
||||
def mode(self):
|
||||
"""Return the distribution mode (which is the mean for Gaussian)."""
|
||||
return self.mean
|
||||
|
||||
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
||||
"""VibeVoice speech tokenizer model (only decoder) for acoustic tokens"""
|
||||
"""VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
|
||||
|
||||
config_class = VibeVoiceAcousticTokenizerConfig
|
||||
base_model_prefix = "vibevoice_acoustic_tokenizer"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_no_split_modules = ["TokenizerDecoder"]
|
||||
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@@ -850,6 +1039,21 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
||||
# Default: use reversed encoder depths if decoder_depths is None
|
||||
decoder_depths = list(reversed(encoder_depths))
|
||||
|
||||
# Create encoder config
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.dimension = config.vae_dim
|
||||
encoder_config.n_filters = config.encoder_n_filters
|
||||
encoder_config.ratios = config.encoder_ratios
|
||||
encoder_config.depths = encoder_depths
|
||||
encoder_config.norm = config.conv_norm
|
||||
encoder_config.pad_mode = config.pad_mode
|
||||
encoder_config.bias = config.conv_bias
|
||||
encoder_config.layernorm_eps = config.layernorm_eps
|
||||
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
||||
encoder_config.mixer_layer = config.mixer_layer
|
||||
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
||||
encoder_config.disable_last_norm = config.disable_last_norm
|
||||
|
||||
# Create decoder config
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.dimension = config.vae_dim
|
||||
@@ -865,6 +1069,8 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
||||
decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
||||
decoder_config.disable_last_norm = config.disable_last_norm
|
||||
|
||||
# Initialize encoder and decoder
|
||||
self.encoder = TokenizerEncoder(encoder_config)
|
||||
self.decoder = TokenizerDecoder(decoder_config)
|
||||
|
||||
# Initialize weights
|
||||
@@ -884,6 +1090,24 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||
"""Convert audio to latent representations"""
|
||||
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
|
||||
|
||||
@torch.no_grad()
|
||||
def sampling(self, encoder_output, dist_type=None):
|
||||
"""Sample from the encoder output distribution"""
|
||||
dist_type = dist_type or self.std_dist_type
|
||||
|
||||
if dist_type == 'fix':
|
||||
return encoder_output.sample(dist_type='fix')
|
||||
elif dist_type == 'gaussian':
|
||||
return encoder_output.sample(dist_type='gaussian')
|
||||
else:
|
||||
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
||||
"""Convert latent representations back to audio"""
|
||||
@@ -895,10 +1119,89 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
||||
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||
return audio
|
||||
|
||||
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
||||
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
||||
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||
sampled_latents, _ = self.sampling(encoder_output)
|
||||
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||
return reconstructed, sampled_latents
|
||||
|
||||
|
||||
class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
|
||||
"""VibeVoice speech tokenizer model with only encoder for semantic tokens"""
|
||||
|
||||
config_class = VibeVoiceSemanticTokenizerConfig
|
||||
base_model_prefix = "vibevoice_semantic_tokenizer"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_no_split_modules = ["TokenizerEncoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Parse encoder depths
|
||||
if isinstance(config.encoder_depths, str):
|
||||
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
||||
else:
|
||||
encoder_depths = config.encoder_depths
|
||||
|
||||
# Create encoder config
|
||||
encoder_config = copy.deepcopy(config)
|
||||
encoder_config.dimension = config.vae_dim
|
||||
encoder_config.n_filters = config.encoder_n_filters
|
||||
encoder_config.ratios = config.encoder_ratios
|
||||
encoder_config.depths = encoder_depths
|
||||
encoder_config.norm = config.conv_norm
|
||||
encoder_config.pad_mode = config.pad_mode
|
||||
encoder_config.bias = config.conv_bias
|
||||
encoder_config.layernorm_eps = config.layernorm_eps
|
||||
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
||||
encoder_config.mixer_layer = config.mixer_layer
|
||||
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
||||
encoder_config.disable_last_norm = config.disable_last_norm
|
||||
|
||||
# Initialize encoder and decoder
|
||||
self.encoder = TokenizerEncoder(encoder_config)
|
||||
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize weights for the model"""
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||
"""Convert audio to latent representations"""
|
||||
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
|
||||
|
||||
@torch.no_grad()
|
||||
def sampling(self, encoder_output, dist_type=None):
|
||||
"""Sample from the encoder output distribution"""
|
||||
return encoder_output.sample(dist_type='none')
|
||||
|
||||
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
||||
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
||||
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||
sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
|
||||
return None, sampled_latents
|
||||
|
||||
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
|
||||
AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceTokenizerStreamingCache",
|
||||
"VibeVoiceAcousticTokenizerModel",
|
||||
"VibeVoiceSemanticTokenizerModel",
|
||||
]
|
||||
@@ -0,0 +1,143 @@
|
||||
import numpy as np
|
||||
from subprocess import run
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
COMMON_AUDIO_EXTS = [
|
||||
'.mp3', '.MP3', '.Mp3', # All case variations of mp3
|
||||
'.m4a',
|
||||
'.mp4', '.MP4',
|
||||
'.wav', '.WAV',
|
||||
'.m4v',
|
||||
'.aac',
|
||||
'.ogg',
|
||||
'.mov', '.MOV',
|
||||
'.opus',
|
||||
'.m4b',
|
||||
'.flac',
|
||||
'.wma', '.WMA',
|
||||
'.rm', '.3gp', '.mpeg', '.flv', '.webm', '.mp2', '.aif', '.aiff', '.oga', '.ogv', '.mpga', '.m3u8', '.amr'
|
||||
]
|
||||
|
||||
def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24000):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, optionally resampling.
|
||||
Returns both the audio data and the original sample rate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
resample: bool
|
||||
Whether to resample the audio
|
||||
target_sr: int
|
||||
The target sample rate if resampling is requested
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing:
|
||||
- A NumPy array with the audio waveform in float32 dtype
|
||||
- The original sample rate of the audio file
|
||||
"""
|
||||
if not resample:
|
||||
# First, get the original sample rate
|
||||
cmd_probe = [
|
||||
"ffprobe",
|
||||
"-v", "quiet",
|
||||
"-show_entries", "stream=sample_rate",
|
||||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
file
|
||||
]
|
||||
|
||||
original_sr = int(run(cmd_probe, capture_output=True, check=True).stdout.decode().strip())
|
||||
else:
|
||||
original_sr = None
|
||||
|
||||
# Now load the audio
|
||||
sr_to_use = target_sr if resample else original_sr
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr_to_use),
|
||||
"-"
|
||||
]
|
||||
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
return audio_data, sr_to_use
|
||||
|
||||
class AudioNormalizer:
|
||||
"""
|
||||
Audio normalization class for VibeVoice tokenizer.
|
||||
|
||||
This class provides audio normalization to ensure consistent input levels
|
||||
for the VibeVoice tokenizer while maintaining audio quality.
|
||||
"""
|
||||
|
||||
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the audio normalizer.
|
||||
|
||||
Args:
|
||||
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
||||
eps (float): Small value to avoid division by zero. Default: 1e-6
|
||||
"""
|
||||
self.target_dB_FS = target_dB_FS
|
||||
self.eps = eps
|
||||
|
||||
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
||||
"""
|
||||
Adjust the audio to the target dB FS level.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
|
||||
Returns:
|
||||
tuple: (normalized_audio, rms, scalar)
|
||||
"""
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
||||
normalized_audio = audio * scalar
|
||||
return normalized_audio, rms, scalar
|
||||
|
||||
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
||||
"""
|
||||
Avoid clipping by scaling down if necessary.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
scalar (float, optional): Explicit scaling factor
|
||||
|
||||
Returns:
|
||||
tuple: (normalized_audio, scalar)
|
||||
"""
|
||||
if scalar is None:
|
||||
max_val = np.max(np.abs(audio))
|
||||
if max_val > 1.0:
|
||||
scalar = max_val + self.eps
|
||||
else:
|
||||
scalar = 1.0
|
||||
|
||||
return audio / scalar, scalar
|
||||
|
||||
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized audio signal
|
||||
"""
|
||||
# First adjust to target dB FS
|
||||
audio, _, _ = self.tailor_dB_FS(audio)
|
||||
# Then avoid clipping
|
||||
audio, _ = self.avoid_clipping(audio)
|
||||
return audio
|
||||
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
Processor class for VibeVoice ASR models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Union, Dict, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from transformers.utils import TensorType, logging
|
||||
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor, AudioNormalizer
|
||||
|
||||
try:
|
||||
from .audio_utils import load_audio_use_ffmpeg
|
||||
HAS_FFMPEG_UTILS = True
|
||||
except ImportError:
|
||||
HAS_FFMPEG_UTILS = False
|
||||
warnings.warn("audio_utils not available, will fall back to soundfile for audio loading")
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
||||
|
||||
|
||||
class VibeVoiceASRProcessor:
|
||||
"""
|
||||
Processor for VibeVoice ASR (Automatic Speech Recognition) models.
|
||||
|
||||
This processor handles audio preprocessing and tokenization for ASR tasks,
|
||||
following the exact format used in training with proper chat templates.
|
||||
|
||||
Args:
|
||||
tokenizer: The text tokenizer for processing text
|
||||
audio_processor: The audio processor for processing speech
|
||||
speech_tok_compress_ratio (int): Compression ratio for speech tokenization
|
||||
target_sample_rate (int): Target sample rate for audio
|
||||
normalize_audio (bool): Whether to normalize audio input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer=None,
|
||||
audio_processor=None,
|
||||
speech_tok_compress_ratio=320,
|
||||
target_sample_rate=24000,
|
||||
normalize_audio=True,
|
||||
**kwargs
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.audio_processor = audio_processor or VibeVoiceTokenizerProcessor(
|
||||
sampling_rate=target_sample_rate,
|
||||
normalize_audio=normalize_audio
|
||||
)
|
||||
self.speech_tok_compress_ratio = speech_tok_compress_ratio
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.normalize_audio = normalize_audio
|
||||
|
||||
if normalize_audio:
|
||||
self.audio_normalizer = AudioNormalizer()
|
||||
else:
|
||||
self.audio_normalizer = None
|
||||
|
||||
# Cache special token IDs
|
||||
self._cache_special_tokens()
|
||||
|
||||
def _cache_special_tokens(self):
|
||||
"""Cache special token IDs for efficiency."""
|
||||
# Add safety checks for special tokens
|
||||
if hasattr(self.tokenizer, 'speech_start_id'):
|
||||
self.speech_start_id = self.tokenizer.speech_start_id
|
||||
else:
|
||||
self.speech_start_id = self.tokenizer.convert_tokens_to_ids("<|speech_start|>")
|
||||
|
||||
if hasattr(self.tokenizer, 'speech_end_id'):
|
||||
self.speech_end_id = self.tokenizer.speech_end_id
|
||||
else:
|
||||
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>")
|
||||
|
||||
if hasattr(self.tokenizer, 'speech_pad_id'):
|
||||
self.speech_pad_id = self.tokenizer.speech_pad_id
|
||||
else:
|
||||
self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>")
|
||||
|
||||
if hasattr(self.tokenizer, 'pad_id'):
|
||||
self.pad_id = self.tokenizer.pad_id
|
||||
elif hasattr(self.tokenizer, 'pad_token_id'):
|
||||
self.pad_id = self.tokenizer.pad_token_id
|
||||
else:
|
||||
self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Load processor from a pretrained model path.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path: Path to the pretrained model
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
VibeVoiceASRProcessor: The loaded processor
|
||||
"""
|
||||
import json
|
||||
from transformers.utils import cached_file
|
||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
||||
|
||||
# Try to load configuration
|
||||
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
||||
config = {}
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
try:
|
||||
config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
"preprocessor_config.json",
|
||||
**kwargs
|
||||
)
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load preprocessor_config.json: {e}")
|
||||
logger.warning("Using default configuration")
|
||||
|
||||
# Extract parameters
|
||||
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
||||
target_sample_rate = config.get("target_sample_rate", 24000)
|
||||
normalize_audio = config.get("normalize_audio", True)
|
||||
|
||||
# Load tokenizer
|
||||
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
|
||||
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
|
||||
|
||||
if 'qwen' in language_model_pretrained_name.lower():
|
||||
tokenizer = VibeVoiceASRTextTokenizerFast.from_pretrained(
|
||||
language_model_pretrained_name,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}")
|
||||
|
||||
# Load audio processor
|
||||
audio_processor = VibeVoiceTokenizerProcessor(
|
||||
sampling_rate=target_sample_rate,
|
||||
normalize_audio=normalize_audio,
|
||||
target_dB_FS=config.get("target_dB_FS", -25),
|
||||
eps=config.get("eps", 1e-6),
|
||||
)
|
||||
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
speech_tok_compress_ratio=speech_tok_compress_ratio,
|
||||
target_sample_rate=target_sample_rate,
|
||||
normalize_audio=normalize_audio,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
||||
"""
|
||||
Save processor configuration to a directory.
|
||||
|
||||
Args:
|
||||
save_directory: Directory to save the configuration
|
||||
**kwargs: Additional keyword arguments
|
||||
"""
|
||||
import json
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Save processor configuration
|
||||
processor_config = {
|
||||
"processor_class": "VibeVoiceASRProcessor",
|
||||
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
|
||||
"target_sample_rate": self.target_sample_rate,
|
||||
"normalize_audio": self.normalize_audio,
|
||||
"target_dB_FS": -25,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
|
||||
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(processor_config, f, indent=2)
|
||||
|
||||
logger.info(f"Processor configuration saved in {config_path}")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
audio: Optional[Union[str, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, torch.Tensor]]]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
padding: bool = True,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: bool = False,
|
||||
add_generation_prompt: bool = True,
|
||||
use_streaming: bool = True,
|
||||
context_info: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Process audio input for ASR model.
|
||||
|
||||
Args:
|
||||
audio: Audio input(s). Can be:
|
||||
- str: Path to audio file
|
||||
- np.ndarray: Audio array
|
||||
- torch.Tensor: Audio tensor
|
||||
- List of the above for batch processing
|
||||
sampling_rate: Sampling rate of input audio
|
||||
return_tensors: Output format ('pt' for PyTorch, 'np' for NumPy)
|
||||
padding: Whether to pad batch inputs
|
||||
max_length: Maximum sequence length
|
||||
truncation: Whether to truncate long sequences
|
||||
add_generation_prompt: Whether to add generation prompt for inference
|
||||
use_streaming: Whether to use streaming mode (True by default, auto False if <60s)
|
||||
context_info: Optional context information (e.g., hotwords, metadata) to help transcription
|
||||
|
||||
Returns:
|
||||
BatchEncoding with:
|
||||
- input_ids: Token IDs for the model
|
||||
- attention_mask: Attention mask
|
||||
- acoustic_input_mask: Mask indicating speech token positions
|
||||
- speech_tensors: Processed speech features
|
||||
- speech_masks: Valid speech masks
|
||||
- vae_tok_seqlens: Length of each speech segment in tokens
|
||||
"""
|
||||
if audio is None:
|
||||
raise ValueError("Audio input is required for ASR processing")
|
||||
|
||||
# Handle single vs batch input
|
||||
if isinstance(audio, list):
|
||||
is_batched = True
|
||||
audio_list = audio
|
||||
else:
|
||||
is_batched = False
|
||||
audio_list = [audio]
|
||||
|
||||
# Process each audio input
|
||||
all_encodings = []
|
||||
for audio_input in audio_list:
|
||||
encoding = self._process_single_audio(
|
||||
audio_input,
|
||||
sampling_rate=sampling_rate,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
use_streaming=use_streaming,
|
||||
context_info=context_info,
|
||||
)
|
||||
all_encodings.append(encoding)
|
||||
|
||||
# Combine into batch
|
||||
batch_encoding = self._batch_encode(
|
||||
all_encodings,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def _process_single_audio(
|
||||
self,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
sampling_rate: Optional[int] = None,
|
||||
add_generation_prompt: bool = True,
|
||||
use_streaming: bool = True,
|
||||
context_info: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a single audio input.
|
||||
|
||||
Args:
|
||||
audio: Single audio input
|
||||
sampling_rate: Audio sampling rate
|
||||
add_generation_prompt: Whether to add generation prompt
|
||||
context_info: Optional context information (e.g., hotwords, metadata) to help transcription
|
||||
|
||||
Returns:
|
||||
Dictionary with processed tokens and audio features
|
||||
"""
|
||||
# Process audio through audio processor
|
||||
if isinstance(audio, str):
|
||||
# Load from file using ffmpeg for better format support
|
||||
if HAS_FFMPEG_UTILS:
|
||||
try:
|
||||
audio_array, file_sr = load_audio_use_ffmpeg(audio, resample=False)
|
||||
except Exception as e:
|
||||
# Fall back to soundfile if ffmpeg fails
|
||||
warnings.warn(f"ffmpeg loading failed, falling back to soundfile: {e}")
|
||||
import soundfile as sf
|
||||
audio_array, file_sr = sf.read(audio)
|
||||
if audio_array.ndim > 1:
|
||||
audio_array = audio_array.mean(axis=1) # Convert to mono
|
||||
else:
|
||||
import soundfile as sf
|
||||
audio_array, file_sr = sf.read(audio)
|
||||
if audio_array.ndim > 1:
|
||||
audio_array = audio_array.mean(axis=1) # Convert to mono
|
||||
|
||||
# Resample if needed
|
||||
if file_sr != self.target_sample_rate:
|
||||
import librosa
|
||||
audio_array = librosa.resample(
|
||||
audio_array,
|
||||
orig_sr=file_sr,
|
||||
target_sr=self.target_sample_rate
|
||||
)
|
||||
elif isinstance(audio, torch.Tensor):
|
||||
audio_array = audio.cpu().numpy()
|
||||
if audio_array.ndim > 1:
|
||||
audio_array = audio_array.squeeze()
|
||||
else:
|
||||
audio_array = np.array(audio, dtype=np.float32)
|
||||
if audio_array.ndim > 1:
|
||||
audio_array = audio_array.squeeze()
|
||||
|
||||
# Ensure float32
|
||||
audio_array = audio_array.astype(np.float32)
|
||||
|
||||
# Normalize if needed
|
||||
if self.normalize_audio and self.audio_normalizer:
|
||||
audio_array = self.audio_normalizer(audio_array)
|
||||
|
||||
# Calculate audio duration
|
||||
audio_duration = len(audio_array) / self.target_sample_rate
|
||||
|
||||
# Auto-disable streaming for short audio (<60s)
|
||||
if use_streaming and audio_duration < 60.0:
|
||||
use_streaming = False
|
||||
|
||||
# Calculate token length based on streaming mode
|
||||
# Non-streaming: uses ceil (encoder adds extra_padding for stride alignment)
|
||||
# Streaming: uses floor (segments processed independently, no global alignment)
|
||||
# if use_streaming:
|
||||
# vae_tok_len = len(audio_array) // self.speech_tok_compress_ratio
|
||||
# else:
|
||||
vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio)
|
||||
|
||||
# Build token sequence following training format
|
||||
# 1. System prompt - use apply_chat_template then encode like in training
|
||||
system_prompt_text = self.tokenizer.apply_chat_template(
|
||||
[{"role": "system", "content": SYSTEM_PROMPT}],
|
||||
tokenize=False
|
||||
)
|
||||
system_tokens = self.tokenizer.encode(system_prompt_text)
|
||||
|
||||
# 2. User input with speech tokens
|
||||
# Build speech placeholder string
|
||||
sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id)
|
||||
sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id)
|
||||
sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id)
|
||||
|
||||
# User suffix with audio duration info
|
||||
show_keys = ['Start time', 'End time', 'Speaker ID', 'Content']
|
||||
if context_info and context_info.strip():
|
||||
user_suffix = f"This is a {audio_duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\nPlease transcribe it with these keys: " + ", ".join(show_keys)
|
||||
else:
|
||||
user_suffix = f"This is a {audio_duration:.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys)
|
||||
|
||||
user_input_string = ''.join(
|
||||
[sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]
|
||||
) + '\n' + user_suffix
|
||||
|
||||
user_tokens = self.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": user_input_string}],
|
||||
tokenize=True
|
||||
)
|
||||
|
||||
# Combine tokens
|
||||
full_tokens = system_tokens + user_tokens
|
||||
|
||||
# Create acoustic input mask
|
||||
acoustic_input_mask = [1 if token == self.speech_pad_id else 0 for token in full_tokens]
|
||||
|
||||
return {
|
||||
"input_ids": full_tokens,
|
||||
"acoustic_input_mask": acoustic_input_mask,
|
||||
"speech": audio_array,
|
||||
"vae_tok_len": vae_tok_len,
|
||||
}
|
||||
|
||||
def _batch_encode(
|
||||
self,
|
||||
encodings: List[Dict[str, Any]],
|
||||
padding: bool = True,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: bool = False,
|
||||
return_tensors: Optional[str] = None,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Combine multiple encodings into a batch.
|
||||
|
||||
Args:
|
||||
encodings: List of encoded samples
|
||||
padding: Whether to pad sequences
|
||||
max_length: Maximum sequence length
|
||||
truncation: Whether to truncate
|
||||
return_tensors: Output format
|
||||
|
||||
Returns:
|
||||
BatchEncoding with batched data
|
||||
"""
|
||||
# Extract components
|
||||
input_ids_list = [enc["input_ids"] for enc in encodings]
|
||||
acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings]
|
||||
speech_list = [enc["speech"] for enc in encodings]
|
||||
vae_tok_lens = [enc["vae_tok_len"] for enc in encodings]
|
||||
|
||||
# Determine max length for padding
|
||||
if padding:
|
||||
if max_length is not None:
|
||||
target_length = max_length
|
||||
else:
|
||||
target_length = max(len(ids) for ids in input_ids_list)
|
||||
|
||||
# Pad sequences
|
||||
padded_input_ids = []
|
||||
padded_acoustic_masks = []
|
||||
attention_masks = []
|
||||
|
||||
for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list):
|
||||
# Truncate if needed
|
||||
if truncation and len(input_ids) > target_length:
|
||||
input_ids = input_ids[:target_length]
|
||||
acoustic_mask = acoustic_mask[:target_length]
|
||||
|
||||
# Pad sequences to left (for autoregressive generation)
|
||||
padding_length = target_length - len(input_ids)
|
||||
padded_ids = [self.pad_id] * padding_length + input_ids
|
||||
padded_acoustic = [0] * padding_length + acoustic_mask
|
||||
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
||||
|
||||
padded_input_ids.append(padded_ids)
|
||||
padded_acoustic_masks.append(padded_acoustic)
|
||||
attention_masks.append(attention_mask)
|
||||
|
||||
input_ids_list = padded_input_ids
|
||||
acoustic_masks_list = padded_acoustic_masks
|
||||
else:
|
||||
attention_masks = [[1] * len(ids) for ids in input_ids_list]
|
||||
|
||||
# Process speech tensors - raw audio is 1D, so we keep it as is
|
||||
max_speech_length = max(len(s) for s in speech_list)
|
||||
padded_speeches = np.zeros((len(speech_list), max_speech_length), dtype=np.float32)
|
||||
speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool)
|
||||
|
||||
for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)):
|
||||
padded_speeches[i, :len(speech)] = speech
|
||||
speech_masks[i, :vae_len] = True
|
||||
|
||||
# Create batch encoding
|
||||
batch_encoding = BatchEncoding()
|
||||
|
||||
if return_tensors == "pt":
|
||||
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
|
||||
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
||||
batch_encoding["acoustic_input_mask"] = torch.tensor(acoustic_masks_list, dtype=torch.bool)
|
||||
batch_encoding["speech_tensors"] = torch.tensor(padded_speeches, dtype=torch.float32)
|
||||
batch_encoding["speech_masks"] = torch.tensor(speech_masks, dtype=torch.bool)
|
||||
# Note: vae_tok_seqlens and speech_type are not included as they are not model inputs
|
||||
else:
|
||||
batch_encoding["input_ids"] = input_ids_list if len(input_ids_list) > 1 else input_ids_list[0]
|
||||
batch_encoding["attention_mask"] = attention_masks if len(attention_masks) > 1 else attention_masks[0]
|
||||
batch_encoding["acoustic_input_mask"] = acoustic_masks_list if len(acoustic_masks_list) > 1 else acoustic_masks_list[0]
|
||||
batch_encoding["speech_tensors"] = padded_speeches if len(padded_speeches) > 1 else padded_speeches[0]
|
||||
batch_encoding["speech_masks"] = speech_masks if len(speech_masks) > 1 else speech_masks[0]
|
||||
|
||||
return batch_encoding
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
Decode batch of token IDs to text.
|
||||
Forwards to tokenizer's batch_decode method.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
Decode token IDs to text.
|
||||
Forwards to tokenizer's decode method.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_transcription(self, text: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Post-process the generated transcription text to extract structured data.
|
||||
|
||||
Args:
|
||||
text: Generated text from the model
|
||||
|
||||
Returns:
|
||||
List of dictionaries with transcription segments
|
||||
"""
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
if "```json" in text:
|
||||
# Extract JSON from markdown code block
|
||||
json_start = text.find("```json") + 7
|
||||
json_end = text.find("```", json_start)
|
||||
json_str = text[json_start:json_end].strip()
|
||||
else:
|
||||
# Try to find JSON array or object
|
||||
json_start = text.find("[")
|
||||
if json_start == -1:
|
||||
json_start = text.find("{")
|
||||
if json_start != -1:
|
||||
# Find matching closing bracket
|
||||
bracket_count = 0
|
||||
json_end = json_start
|
||||
for i in range(json_start, len(text)):
|
||||
if text[i] in "[{":
|
||||
bracket_count += 1
|
||||
elif text[i] in "]}":
|
||||
bracket_count -= 1
|
||||
if bracket_count == 0:
|
||||
json_end = i + 1
|
||||
break
|
||||
json_str = text[json_start:json_end]
|
||||
else:
|
||||
json_str = text
|
||||
|
||||
# Parse JSON
|
||||
result = json.loads(json_str)
|
||||
|
||||
# Ensure it's a list
|
||||
if isinstance(result, dict):
|
||||
result = [result]
|
||||
|
||||
# Validate and clean up the result
|
||||
cleaned_result = []
|
||||
for item in result:
|
||||
if isinstance(item, dict):
|
||||
cleaned_item = {}
|
||||
# Map keys to expected format
|
||||
key_mapping = {
|
||||
"Start time": "start_time",
|
||||
"Start": "start_time",
|
||||
"End time": "end_time",
|
||||
"End": "end_time",
|
||||
"Speaker ID": "speaker_id",
|
||||
"Speaker": "speaker_id",
|
||||
"Content": "text",
|
||||
}
|
||||
for key, mapped_key in key_mapping.items():
|
||||
if key in item:
|
||||
cleaned_item[mapped_key] = item[key]
|
||||
|
||||
if cleaned_item:
|
||||
cleaned_result.append(cleaned_item)
|
||||
|
||||
return cleaned_result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON from transcription: {e}")
|
||||
logger.debug(f"Raw text: {text}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Error post-processing transcription: {e}")
|
||||
return []
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
"""Return the list of inputs accepted by the model."""
|
||||
return ["input_ids", "attention_mask", "acoustic_input_mask", "speech_tensors", "speech_masks"]
|
||||
|
||||
__all__ = ["VibeVoiceASRProcessor"]
|
||||
@@ -13,80 +13,10 @@ import torch
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
from transformers.utils import logging
|
||||
|
||||
from .audio_utils import AudioNormalizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AudioNormalizer:
|
||||
"""
|
||||
Audio normalization class for VibeVoice tokenizer.
|
||||
|
||||
This class provides audio normalization to ensure consistent input levels
|
||||
for the VibeVoice tokenizer while maintaining audio quality.
|
||||
"""
|
||||
|
||||
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize the audio normalizer.
|
||||
|
||||
Args:
|
||||
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
||||
eps (float): Small value to avoid division by zero. Default: 1e-6
|
||||
"""
|
||||
self.target_dB_FS = target_dB_FS
|
||||
self.eps = eps
|
||||
|
||||
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
||||
"""
|
||||
Adjust the audio to the target dB FS level.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
|
||||
Returns:
|
||||
tuple: (normalized_audio, rms, scalar)
|
||||
"""
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
||||
normalized_audio = audio * scalar
|
||||
return normalized_audio, rms, scalar
|
||||
|
||||
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
||||
"""
|
||||
Avoid clipping by scaling down if necessary.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
scalar (float, optional): Explicit scaling factor
|
||||
|
||||
Returns:
|
||||
tuple: (normalized_audio, scalar)
|
||||
"""
|
||||
if scalar is None:
|
||||
max_val = np.max(np.abs(audio))
|
||||
if max_val > 1.0:
|
||||
scalar = max_val + self.eps
|
||||
else:
|
||||
scalar = 1.0
|
||||
|
||||
return audio / scalar, scalar
|
||||
|
||||
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): Input audio signal
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized audio signal
|
||||
"""
|
||||
# First adjust to target dB FS
|
||||
audio, _, _ = self.tailor_dB_FS(audio)
|
||||
# Then avoid clipping
|
||||
audio, _ = self.avoid_clipping(audio)
|
||||
return audio
|
||||
|
||||
|
||||
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
|
||||
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user