Add VibeVoice-ASR
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 164 KiB |
@@ -23,8 +23,9 @@
|
|||||||
<img src="https://img.shields.io/badge/Status-New-brightgreen?style=flat" alt="New" />
|
<img src="https://img.shields.io/badge/Status-New-brightgreen?style=flat" alt="New" />
|
||||||
<img src="https://img.shields.io/badge/Feature-Realtime_TTS-blue?style=flat&logo=soundcharts" alt="Realtime TTS" />
|
<img src="https://img.shields.io/badge/Feature-Realtime_TTS-blue?style=flat&logo=soundcharts" alt="Realtime TTS" />
|
||||||
|
|
||||||
|
<strong>2026-01-21: 📣 We open-sourced <a href="docs/vibevoice-asr.md"><strong>VibeVoice-ASR</strong></a>, a unified speech-to-text model designed to handle 60-minute long-form audio in a single pass, generating structured transcriptions containing Who (Speaker), When (Timestamps), and What (Content), with support for User-Customized Context.</strong>
|
||||||
|
|
||||||
<strong>2025-12-16: 📣 We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.</strong>
|
2025-12-16: 📣 We added more experimental speakers for exploration, including multilingual voices and 11 distinct English style voices. [Try it](docs/vibevoice-realtime-0.5b.md#optional-more-experimental-voices). More speaker types will be added over time.
|
||||||
|
|
||||||
2025-12-09: 📣 We added experimental speakers in nine languages (DE, FR, IT, JP, KR, NL, PL, PT, ES) for exploration—welcome to try them out and share your feedback.
|
2025-12-09: 📣 We added experimental speakers in nine languages (DE, FR, IT, JP, KR, NL, PL, PT, ES) for exploration—welcome to try them out and share your feedback.
|
||||||
|
|
||||||
@@ -123,4 +124,4 @@ We do not recommend using VibeVoice in commercial or real-world applications wit
|
|||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||

|

|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,554 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
VibeVoice ASR Batch Inference Demo Script
|
||||||
|
|
||||||
|
This script supports batch inference for ASR model and compares results
|
||||||
|
between batch processing and single-sample processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
|
||||||
|
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class VibeVoiceASRBatchInference:
|
||||||
|
"""Batch inference wrapper for VibeVoice ASR model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
attn_implementation: str = "flash_attention_2"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the ASR batch inference pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the pretrained model
|
||||||
|
device: Device to run inference on
|
||||||
|
dtype: Data type for model weights
|
||||||
|
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
|
||||||
|
"""
|
||||||
|
print(f"Loading VibeVoice ASR model from {model_path}")
|
||||||
|
|
||||||
|
# Load processor
|
||||||
|
self.processor = VibeVoiceASRProcessor.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
language_model_pretrained_name="Qwen/Qwen2.5-7B"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model with specified attention implementation
|
||||||
|
print(f"Using attention implementation: {attn_implementation}")
|
||||||
|
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
dtype=dtype,
|
||||||
|
device_map=device if device == "auto" else None,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if device != "auto":
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
self.device = device if device != "auto" else next(self.model.parameters()).device
|
||||||
|
self.dtype = dtype
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
print(f"Model loaded successfully on {self.device}")
|
||||||
|
|
||||||
|
def _prepare_generation_config(
|
||||||
|
self,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
do_sample: bool = True,
|
||||||
|
num_beams: int = 1,
|
||||||
|
) -> dict:
|
||||||
|
"""Prepare generation configuration."""
|
||||||
|
config = {
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"pad_token_id": self.processor.pad_id,
|
||||||
|
"eos_token_id": self.processor.tokenizer.eos_token_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Beam search vs sampling
|
||||||
|
if num_beams > 1:
|
||||||
|
config["num_beams"] = num_beams
|
||||||
|
config["do_sample"] = False # Beam search doesn't use sampling
|
||||||
|
else:
|
||||||
|
config["do_sample"] = do_sample
|
||||||
|
# Only set temperature and top_p when sampling is enabled
|
||||||
|
if do_sample:
|
||||||
|
config["temperature"] = temperature
|
||||||
|
config["top_p"] = top_p
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def transcribe_batch(
|
||||||
|
self,
|
||||||
|
audio_inputs: List,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
do_sample: bool = True,
|
||||||
|
num_beams: int = 1,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Transcribe multiple audio files/arrays in a single batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
|
||||||
|
max_new_tokens: Maximum tokens to generate
|
||||||
|
temperature: Temperature for sampling
|
||||||
|
top_p: Top-p for nucleus sampling
|
||||||
|
do_sample: Whether to use sampling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of transcription results
|
||||||
|
"""
|
||||||
|
if len(audio_inputs) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
batch_size = len(audio_inputs)
|
||||||
|
print(f"\nProcessing batch of {batch_size} audio(s)...")
|
||||||
|
|
||||||
|
# Process all audio together
|
||||||
|
inputs = self.processor(
|
||||||
|
audio=audio_inputs,
|
||||||
|
sampling_rate=None,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# Print batch info
|
||||||
|
print(f" Input IDs shape: {inputs['input_ids'].shape}")
|
||||||
|
print(f" Speech tensors shape: {inputs['speech_tensors'].shape}")
|
||||||
|
print(f" Attention mask shape: {inputs['attention_mask'].shape}")
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
generation_config = self._prepare_generation_config(
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=do_sample,
|
||||||
|
num_beams=num_beams,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output_ids = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
**generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Decode outputs for each sample in the batch
|
||||||
|
results = []
|
||||||
|
input_length = inputs['input_ids'].shape[1]
|
||||||
|
|
||||||
|
for i, audio_input in enumerate(audio_inputs):
|
||||||
|
# Get generated tokens for this sample (excluding input tokens)
|
||||||
|
generated_ids = output_ids[i, input_length:]
|
||||||
|
|
||||||
|
# Remove padding tokens from the end
|
||||||
|
# Find the first eos_token or pad_token
|
||||||
|
eos_positions = (generated_ids == self.processor.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
|
||||||
|
if len(eos_positions) > 0:
|
||||||
|
generated_ids = generated_ids[:eos_positions[0] + 1]
|
||||||
|
|
||||||
|
generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Parse structured output
|
||||||
|
try:
|
||||||
|
transcription_segments = self.processor.post_process_transcription(generated_text)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to parse structured output: {e}")
|
||||||
|
transcription_segments = []
|
||||||
|
|
||||||
|
# Get file name based on input type
|
||||||
|
if isinstance(audio_input, str):
|
||||||
|
file_name = audio_input
|
||||||
|
elif isinstance(audio_input, dict) and 'id' in audio_input:
|
||||||
|
file_name = audio_input['id']
|
||||||
|
else:
|
||||||
|
file_name = f"audio_{i}"
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"file": file_name,
|
||||||
|
"raw_text": generated_text,
|
||||||
|
"segments": transcription_segments,
|
||||||
|
"generation_time": generation_time / batch_size,
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f" Total generation time: {generation_time:.2f}s")
|
||||||
|
print(f" Average time per sample: {generation_time/batch_size:.2f}s")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def transcribe_with_batching(
|
||||||
|
self,
|
||||||
|
audio_inputs: List,
|
||||||
|
batch_size: int = 4,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
do_sample: bool = True,
|
||||||
|
num_beams: int = 1,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Transcribe multiple audio files/arrays with automatic batching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
|
||||||
|
batch_size: Number of samples per batch
|
||||||
|
max_new_tokens: Maximum tokens to generate
|
||||||
|
temperature: Temperature for sampling
|
||||||
|
top_p: Top-p for nucleus sampling
|
||||||
|
do_sample: Whether to use sampling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of transcription results
|
||||||
|
"""
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
for i in range(0, len(audio_inputs), batch_size):
|
||||||
|
batch_inputs = audio_inputs[i:i + batch_size]
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Processing batch {i//batch_size + 1}/{(len(audio_inputs) + batch_size - 1)//batch_size}")
|
||||||
|
|
||||||
|
batch_results = self.transcribe_batch(
|
||||||
|
batch_inputs,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=do_sample,
|
||||||
|
num_beams=num_beams,
|
||||||
|
)
|
||||||
|
all_results.extend(batch_results)
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
|
||||||
|
def print_result(result: Dict[str, Any]):
|
||||||
|
"""Pretty print a single transcription result."""
|
||||||
|
print(f"\nFile: {result['file']}")
|
||||||
|
print(f"Generation Time: {result['generation_time']:.2f}s")
|
||||||
|
print(f"\n--- Raw Output ---")
|
||||||
|
print(result['raw_text'][:500] + "..." if len(result['raw_text']) > 500 else result['raw_text'])
|
||||||
|
|
||||||
|
if result['segments']:
|
||||||
|
print(f"\n--- Structured Output ({len(result['segments'])} segments) ---")
|
||||||
|
for seg in result['segments'][:50]: # Show first 50 segments
|
||||||
|
print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] "
|
||||||
|
f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')}...")
|
||||||
|
if len(result['segments']) > 50:
|
||||||
|
print(f" ... and {len(result['segments']) - 50} more segments")
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_and_concatenate(
|
||||||
|
dataset_name: str,
|
||||||
|
split: str,
|
||||||
|
max_duration: float,
|
||||||
|
num_audios: int,
|
||||||
|
target_sr: int = 24000
|
||||||
|
) -> Optional[List[np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Load a HuggingFace dataset and concatenate audio samples into long audio chunks.
|
||||||
|
(Note, just for demo purpose, not for benchmark evaluation)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_name: HuggingFace dataset name (e.g., 'openslr/librispeech_asr')
|
||||||
|
split: Dataset split to use (e.g., 'test', 'test.other')
|
||||||
|
max_duration: Maximum duration in seconds for each concatenated audio
|
||||||
|
num_audios: Number of concatenated audios to create
|
||||||
|
target_sr: Target sample rate (default: 24000)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of concatenated audio arrays, or None if loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
import torchcodec # just for decode audio in datasets
|
||||||
|
except ImportError:
|
||||||
|
print("Please install it with: pip install datasets torchcodec")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print(f"\nLoading dataset: {dataset_name} (split: {split})")
|
||||||
|
print(f"Will create {num_audios} concatenated audio(s), each up to {max_duration:.1f}s ({max_duration/3600:.2f} hours)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use streaming to avoid downloading the entire dataset
|
||||||
|
dataset = load_dataset(dataset_name, split=split, streaming=True)
|
||||||
|
print(f"Dataset loaded in streaming mode")
|
||||||
|
|
||||||
|
concatenated_audios = [] # List of concatenated audio metadata
|
||||||
|
|
||||||
|
# Create multiple concatenated audios based on num_audios
|
||||||
|
current_chunks = []
|
||||||
|
current_duration = 0.0
|
||||||
|
current_samples_used = 0
|
||||||
|
sample_idx = 0
|
||||||
|
|
||||||
|
for sample in dataset:
|
||||||
|
if len(concatenated_audios) >= num_audios:
|
||||||
|
break
|
||||||
|
|
||||||
|
if 'audio' not in sample:
|
||||||
|
continue
|
||||||
|
|
||||||
|
audio_data = sample['audio']
|
||||||
|
audio_array = audio_data['array']
|
||||||
|
sr = audio_data['sampling_rate']
|
||||||
|
|
||||||
|
# Resample if needed
|
||||||
|
if sr != target_sr:
|
||||||
|
duration = len(audio_array) / sr
|
||||||
|
new_length = int(duration * target_sr)
|
||||||
|
audio_array = np.interp(
|
||||||
|
np.linspace(0, len(audio_array) - 1, new_length),
|
||||||
|
np.arange(len(audio_array)),
|
||||||
|
audio_array
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_duration = len(audio_array) / target_sr
|
||||||
|
|
||||||
|
# Check if adding this chunk exceeds max_duration
|
||||||
|
if current_duration + chunk_duration > max_duration:
|
||||||
|
remaining_duration = max_duration - current_duration
|
||||||
|
if remaining_duration > 0.5: # Only add if > 0.5s remaining
|
||||||
|
samples_to_take = int(remaining_duration * target_sr)
|
||||||
|
current_chunks.append(audio_array[:samples_to_take])
|
||||||
|
current_duration += remaining_duration
|
||||||
|
current_samples_used += 1
|
||||||
|
|
||||||
|
# Save current concatenated audio and start a new one
|
||||||
|
if current_chunks:
|
||||||
|
concatenated_audios.append({
|
||||||
|
'array': np.concatenate(current_chunks),
|
||||||
|
'duration': current_duration,
|
||||||
|
'samples_used': current_samples_used,
|
||||||
|
})
|
||||||
|
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
|
||||||
|
|
||||||
|
# Reset for next concatenated audio
|
||||||
|
current_chunks = []
|
||||||
|
current_duration = 0.0
|
||||||
|
current_samples_used = 0
|
||||||
|
|
||||||
|
if len(concatenated_audios) >= num_audios:
|
||||||
|
break
|
||||||
|
|
||||||
|
current_chunks.append(audio_array)
|
||||||
|
current_duration += chunk_duration
|
||||||
|
current_samples_used += 1
|
||||||
|
|
||||||
|
sample_idx += 1
|
||||||
|
if sample_idx % 100 == 0:
|
||||||
|
print(f" Processed {sample_idx} samples...")
|
||||||
|
|
||||||
|
# Don't forget the last batch if it has content
|
||||||
|
if current_chunks and len(concatenated_audios) < num_audios:
|
||||||
|
concatenated_audios.append({
|
||||||
|
'array': np.concatenate(current_chunks),
|
||||||
|
'duration': current_duration,
|
||||||
|
'samples_used': current_samples_used,
|
||||||
|
})
|
||||||
|
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
|
||||||
|
|
||||||
|
if not concatenated_audios:
|
||||||
|
print("Warning: No audio samples found in dataset")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract arrays and print summary
|
||||||
|
result = [a['array'] for a in concatenated_audios]
|
||||||
|
total_duration = sum(a['duration'] for a in concatenated_audios)
|
||||||
|
total_samples = sum(a['samples_used'] for a in concatenated_audios)
|
||||||
|
print(f"\nCreated {len(result)} concatenated audio(s), total {total_duration:.1f}s ({total_duration/60:.1f} min) from {total_samples} samples")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading dataset: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="VibeVoice ASR Batch Inference Demo")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_path",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Path to the model checkpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_files",
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
|
required=False,
|
||||||
|
help="Paths to audio files for transcription"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_dir",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Directory containing audio files for batch transcription"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="HuggingFace dataset name (e.g., 'openslr/librispeech_asr')"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--split",
|
||||||
|
type=str,
|
||||||
|
default="test",
|
||||||
|
help="Dataset split to use (e.g., 'test', 'test.other', 'test.clean')"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_duration",
|
||||||
|
type=float,
|
||||||
|
default=3600.0,
|
||||||
|
help="Maximum duration in seconds for concatenated dataset audio (default: 3600 = 1 hour)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Batch size for processing multiple files"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
choices=["cuda", "cpu", "auto"],
|
||||||
|
help="Device to run inference on"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_new_tokens",
|
||||||
|
type=int,
|
||||||
|
default=32768,
|
||||||
|
help="Maximum number of tokens to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Temperature for sampling (0 = greedy decoding)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_p",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Top-p for nucleus sampling"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_beams",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of beams for beam search. Use 1 for greedy/sampling"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--attn_implementation",
|
||||||
|
type=str,
|
||||||
|
default="flash_attention_2",
|
||||||
|
help="Attention implementation to use (default: flash_attention_2)"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Collect audio files
|
||||||
|
audio_files = []
|
||||||
|
concatenated_audio = None # For storing concatenated dataset audio
|
||||||
|
|
||||||
|
if args.audio_files:
|
||||||
|
audio_files.extend(args.audio_files)
|
||||||
|
|
||||||
|
if args.audio_dir:
|
||||||
|
import glob
|
||||||
|
for ext in ["*.wav", "*.mp3", "*.flac", "*.mp4", "*.m4a", "*.webm"]:
|
||||||
|
audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext)))
|
||||||
|
|
||||||
|
if args.dataset:
|
||||||
|
concatenated_audio = load_dataset_and_concatenate(
|
||||||
|
dataset_name=args.dataset,
|
||||||
|
split=args.split,
|
||||||
|
max_duration=args.max_duration,
|
||||||
|
num_audios=args.batch_size,
|
||||||
|
)
|
||||||
|
if concatenated_audio is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(audio_files) == 0 and concatenated_audio is None:
|
||||||
|
print("No audio files provided. Please specify --audio_files, --audio_dir, or --dataset.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if audio_files:
|
||||||
|
print(f"\nAudio files to process ({len(audio_files)}):")
|
||||||
|
for f in audio_files:
|
||||||
|
print(f" - {f}")
|
||||||
|
|
||||||
|
if concatenated_audio:
|
||||||
|
print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
asr = VibeVoiceASRBatchInference(
|
||||||
|
model_path=args.model_path,
|
||||||
|
device=args.device,
|
||||||
|
dtype=torch.bfloat16 if args.device != "cpu" else torch.float32,
|
||||||
|
attn_implementation=args.attn_implementation
|
||||||
|
)
|
||||||
|
|
||||||
|
# If temperature is 0, use greedy decoding (no sampling)
|
||||||
|
do_sample = args.temperature > 0
|
||||||
|
|
||||||
|
# Combine all audio inputs
|
||||||
|
all_audio_inputs = audio_files + (concatenated_audio or [])
|
||||||
|
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print(f"Processing {len(all_audio_inputs)} audio(s)")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
all_results = asr.transcribe_with_batching(
|
||||||
|
all_audio_inputs,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
max_new_tokens=args.max_new_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
do_sample=do_sample,
|
||||||
|
num_beams=args.num_beams,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("Results")
|
||||||
|
print("="*80)
|
||||||
|
for result in all_results:
|
||||||
|
print("\n" + "-"*60)
|
||||||
|
print_result(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
# VibeVoice-ASR: Long-Form Rich Transcription with User Prompts
|
||||||
|
|
||||||
|
**VibeVoice-ASR** is the latest addition to the **VibeVoice** family. While the original VibeVoice / VibeVoice-Realtime focused on expressive TTS, **VibeVoice-ASR** focuses on understanding long-form speech with high precision and rich metadata.
|
||||||
|
|
||||||
|
It is a unified speech-to-text model designed to handle **1-hour long-form audio** in a single pass, generating structured transcriptions containing **Who (Speaker), When (Timestamps), and What (Content)**, with support for **User-Customized Context**.
|
||||||
|
|
||||||
|
## 🔥 Key Features
|
||||||
|
|
||||||
|
- **🕒 60-min Single-Pass Processing**:
|
||||||
|
Unlike conventional ASR models that slice audio into short chunks (often losing global context), VibeVoice ASR accepts up to **60 minutes** of continuous audio input within 64K length. This ensures consistent speaker tracking and semantic coherence across the entire hour.
|
||||||
|
|
||||||
|
- **👤 Optional Context Injection**:
|
||||||
|
Users can provide customized context (e.g., specific names, technical terms, or background info) to guide the recognition process, significantly improving accuracy on domain-specific content.
|
||||||
|
|
||||||
|
- **📝 Rich Transcription (Who, When, What)**:
|
||||||
|
The model performs ASR, Diarization, and Timestamping simultaneously. The output is a structured sequence indicating *who* said *what* at *which time*.
|
||||||
|
|
||||||
|
## 🏗️ Model Architecture
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="../Figures/VibeVoice_ASR_archi.png" alt="VibeVoice ASR Architecture" width="80%">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment.
|
||||||
|
|
||||||
|
1. Launch docker
|
||||||
|
```bash
|
||||||
|
# NVIDIA PyTorch Container 24.07 ~ 25.12 verified.
|
||||||
|
# Previous versions are also compatible.
|
||||||
|
sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulimit stack=-1:-1 --gpus all --rm -it nvcr.io/nvidia/pytorch:25.12-py3
|
||||||
|
|
||||||
|
## If flash attention is not included in your docker environment, you need to install it manually
|
||||||
|
## Refer to https://github.com/Dao-AILab/flash-attention for installation instructions
|
||||||
|
# pip install flash-attn --no-build-isolation
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install from github
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/microsoft/VibeVoice.git
|
||||||
|
cd VibeVoice
|
||||||
|
pip install -e .[asr]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usages
|
||||||
|
|
||||||
|
### Usage 1: Launch Gradio demo
|
||||||
|
```bash
|
||||||
|
apt update && apt install ffmpeg -y # for demo
|
||||||
|
|
||||||
|
python demo/vibevoice_asr_gradio_demo.py --model_path microsoft/VibeVoice-ASR --share
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage 2: Inference from files directly
|
||||||
|
```bash
|
||||||
|
python demo/vibevoice_asr_inference_from_file.py --model_path microsoft/VibeVoice-ASR --audio_files [add a audio path here]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
This project is licensed under the [MIT License](../LICENSE).
|
||||||
+17
-9
@@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "vibevoice"
|
name = "vibevoice"
|
||||||
version = "0.0.1"
|
version = "1.0.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name="vibevoice team", email="vibepod@microsoft.com" },
|
{ name="vibevoice team", email="VibeVoice@microsoft.com" },
|
||||||
]
|
]
|
||||||
description = "A model for speech generation with an AR + diffusion architecture."
|
description = "Open-Source Frontier Voice AI."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
@@ -18,8 +18,7 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch",
|
"torch",
|
||||||
"accelerate==1.6.0",
|
"accelerate",
|
||||||
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
|
|
||||||
"llvmlite>=0.40.0",
|
"llvmlite>=0.40.0",
|
||||||
"numba>=0.57.0",
|
"numba>=0.57.0",
|
||||||
"diffusers",
|
"diffusers",
|
||||||
@@ -30,12 +29,21 @@ dependencies = [
|
|||||||
"ml-collections",
|
"ml-collections",
|
||||||
"absl-py",
|
"absl-py",
|
||||||
"gradio",
|
"gradio",
|
||||||
"av",
|
|
||||||
"aiortc",
|
|
||||||
"uvicorn[standard]",
|
|
||||||
"fastapi"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
tts = [
|
||||||
|
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
|
||||||
|
"av",
|
||||||
|
"aiortc",
|
||||||
|
"uvicorn[standard]",
|
||||||
|
"fastapi"
|
||||||
|
]
|
||||||
|
|
||||||
|
asr = [
|
||||||
|
"transformers>=4.51.3", # the versions after 4.51.3 are all support
|
||||||
|
"pydub" # for visualization
|
||||||
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
"Homepage" = "https://github.com/microsoft/VibeVoice"
|
"Homepage" = "https://github.com/microsoft/VibeVoice"
|
||||||
|
|||||||
@@ -240,9 +240,112 @@ class VibeVoiceConfig(PretrainedConfig):
|
|||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
class VibeVoiceASRConfig(PretrainedConfig):
|
||||||
|
model_type = "vibevoice"
|
||||||
|
is_composition = True
|
||||||
|
sub_configs = {
|
||||||
|
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
|
||||||
|
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
|
||||||
|
"decoder_config": Qwen2Config,
|
||||||
|
}
|
||||||
|
# keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
# Default tensor parallel plan for base model `Qwen2`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
acoustic_tokenizer_config=None,
|
||||||
|
semantic_tokenizer_config=None,
|
||||||
|
decoder_config=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
# kwargs["_attn_implementation"] = "flash_attention_2"
|
||||||
|
kwargs["_attn_implementation_autoset"] = False
|
||||||
|
|
||||||
|
if acoustic_tokenizer_config is None:
|
||||||
|
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
|
||||||
|
elif isinstance(acoustic_tokenizer_config, dict):
|
||||||
|
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
|
||||||
|
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
|
||||||
|
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
|
||||||
|
# If an instance of the config class is provided
|
||||||
|
self.acoustic_tokenizer_config = acoustic_tokenizer_config
|
||||||
|
|
||||||
|
if semantic_tokenizer_config is None:
|
||||||
|
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
|
||||||
|
elif isinstance(semantic_tokenizer_config, dict):
|
||||||
|
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
|
||||||
|
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
|
||||||
|
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
|
||||||
|
# If an instance of the config class is provided
|
||||||
|
self.semantic_tokenizer_config = semantic_tokenizer_config
|
||||||
|
|
||||||
|
if decoder_config is None:
|
||||||
|
self.decoder_config = self.sub_configs["decoder_config"]()
|
||||||
|
elif isinstance(decoder_config, dict):
|
||||||
|
# If a dictionary is provided, instantiate the config class with it
|
||||||
|
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
|
||||||
|
if decoder_config.get("model_type", '') == "qwen2":
|
||||||
|
self.decoder_config = Qwen2Config(**decoder_config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
|
||||||
|
elif isinstance(decoder_config, Qwen2Config):
|
||||||
|
# If an instance of the config class is provided
|
||||||
|
self.decoder_config = decoder_config
|
||||||
|
|
||||||
|
# other parameters
|
||||||
|
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
|
||||||
|
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_text_config(self, decoder: bool = False):
|
||||||
|
"""Return the text (decoder) config for generation."""
|
||||||
|
return self.decoder_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
"""Return vocab_size from decoder config for generation compatibility."""
|
||||||
|
return self.decoder_config.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_attention_heads(self):
|
||||||
|
"""Return num_attention_heads from decoder config for Ulysses SP compatibility."""
|
||||||
|
return self.decoder_config.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_key_value_heads(self):
|
||||||
|
"""Return num_key_value_heads from decoder config for Ulysses SP compatibility."""
|
||||||
|
return self.decoder_config.num_key_value_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_size(self):
|
||||||
|
"""Return hidden_size from decoder config for model compatibility."""
|
||||||
|
return self.decoder_config.hidden_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_hidden_layers(self):
|
||||||
|
"""Return num_hidden_layers from decoder config for Ulysses SP compatibility."""
|
||||||
|
return self.decoder_config.num_hidden_layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
"""Return head_dim from decoder config for Ulysses SP compatibility."""
|
||||||
|
return getattr(self.decoder_config, 'head_dim', self.hidden_size // self.num_attention_heads)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"VibeVoiceAcousticTokenizerConfig",
|
"VibeVoiceAcousticTokenizerConfig",
|
||||||
"VibeVoiceSemanticTokenizerConfig",
|
"VibeVoiceSemanticTokenizerConfig",
|
||||||
"VibeVoiceDiffusionHeadConfig",
|
"VibeVoiceDiffusionHeadConfig",
|
||||||
"VibeVoiceConfig"
|
"VibeVoiceConfig",
|
||||||
|
"VibeVoiceASRConfig"
|
||||||
]
|
]
|
||||||
@@ -0,0 +1,496 @@
|
|||||||
|
# copied from https://github.com/vibevoice-community/VibeVoice/blob/main/vibevoice/modular/modeling_vibevoice.py
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union, Callable
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from transformers import modeling_utils
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel
|
||||||
|
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
|
||||||
|
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
|
||||||
|
|
||||||
|
from .configuration_vibevoice import VibeVoiceConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||||
|
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VibeVoiceCausalLMOutputWithPast(ModelOutput):
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
diffusion_loss: Optional[torch.FloatTensor] = None
|
||||||
|
speech_token_num: Optional[int] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VibeVoiceGenerationOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type for VibeVoice generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
The generated sequences.
|
||||||
|
speech_outputs (`List[torch.FloatTensor]`, *optional*):
|
||||||
|
List of generated speech waveforms or latents for each speech segment.
|
||||||
|
"""
|
||||||
|
sequences: torch.LongTensor = None
|
||||||
|
speech_outputs: Optional[List[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechConnector(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(input_dim, output_dim)
|
||||||
|
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
|
||||||
|
self.fc2 = nn.Linear(output_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(self, features, **kwargs):
|
||||||
|
x = self.fc1(features)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# @auto_docstring
|
||||||
|
class VibeVoicePreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = VibeVoiceConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, VibeVoiceDiffusionHead):
|
||||||
|
module.initialize_weights()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use the language model's initializer_range if available
|
||||||
|
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
||||||
|
std = self.config.language_model_config.initializer_range
|
||||||
|
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
||||||
|
std = self.config.decoder_config.initializer_range
|
||||||
|
else:
|
||||||
|
std = 0.02 # Default value
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
# @auto_docstring
|
||||||
|
class VibeVoiceModel(VibeVoicePreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||||
|
if isinstance(config.torch_dtype, str):
|
||||||
|
dtype = getattr(torch, config.torch_dtype)
|
||||||
|
else:
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
# Initialize Qwen2 model for language modeling
|
||||||
|
lm_config = config.decoder_config
|
||||||
|
self.language_model = AutoModel.from_config(lm_config)
|
||||||
|
|
||||||
|
# Initialize speech components if needed
|
||||||
|
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
||||||
|
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
||||||
|
|
||||||
|
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||||
|
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||||
|
|
||||||
|
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
|
||||||
|
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
|
||||||
|
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
|
||||||
|
|
||||||
|
# Initialize prediction head for speech generation
|
||||||
|
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
|
||||||
|
|
||||||
|
# Initialize noise scheduler
|
||||||
|
self.noise_scheduler = DPMSolverMultistepScheduler(
|
||||||
|
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
|
||||||
|
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
|
||||||
|
prediction_type=config.diffusion_head_config.prediction_type
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
if hasattr(self.language_model, 'embed_tokens'):
|
||||||
|
# If the language model has an embed_tokens attribute, return it
|
||||||
|
return self.language_model.embed_tokens
|
||||||
|
|
||||||
|
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
||||||
|
if attr.orig_name == 'embed_tokens.weight':
|
||||||
|
return getattr(self.language_model, name)
|
||||||
|
assert False, 'should not arrive here'
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.language_model.embed_tokens = value
|
||||||
|
|
||||||
|
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||||
|
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||||
|
self.acoustic_tokenizer = acoustic_tokenizer
|
||||||
|
self.semantic_tokenizer = semantic_tokenizer
|
||||||
|
|
||||||
|
# Reset the encoder to evaluation mode
|
||||||
|
if self.acoustic_tokenizer is not None:
|
||||||
|
self.acoustic_tokenizer.eval()
|
||||||
|
|
||||||
|
if self.semantic_tokenizer is not None:
|
||||||
|
self.semantic_tokenizer.eval()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# Forward through language model
|
||||||
|
outputs = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=outputs.last_hidden_state,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel):
|
||||||
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = VibeVoiceModel(config)
|
||||||
|
self.vocab_size = config.decoder_config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.language_model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.language_model
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""
|
||||||
|
Tie the weights between the input embeddings and the output embeddings.
|
||||||
|
"""
|
||||||
|
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
||||||
|
# The standard PreTrainedModel method will handle the tying.
|
||||||
|
# It typically does a simple parameter object assignment, which is
|
||||||
|
# CORRECT to do BEFORE FSDP wraps the model.
|
||||||
|
output_embeddings = self.get_output_embeddings()
|
||||||
|
input_embeddings = self.get_input_embeddings()
|
||||||
|
if hasattr(input_embeddings, 'weight'):
|
||||||
|
output_embeddings.weight = input_embeddings.weight
|
||||||
|
else:
|
||||||
|
# maybe returned input_embeddings a tensor directly
|
||||||
|
output_embeddings.weight = input_embeddings
|
||||||
|
|
||||||
|
if getattr(output_embeddings, "bias", None) is not None:
|
||||||
|
output_embeddings.bias.data = nn.functional.pad(
|
||||||
|
output_embeddings.bias.data,
|
||||||
|
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]),
|
||||||
|
"constant",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
print("Tied input and output embeddings using standard assignment.")
|
||||||
|
else:
|
||||||
|
print("tie_word_embeddings is False, not tying weights.")
|
||||||
|
|
||||||
|
# Also, ensure set_output_embeddings is safe, though your implementation looks okay.
|
||||||
|
# The key is to avoid calling it after accelerator.prepare().
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
# Your current implementation using data.copy_ is good practice,
|
||||||
|
# but the best way is to not call this after prepare().
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def forward_speech_features(
|
||||||
|
self,
|
||||||
|
speech_tensors=None,
|
||||||
|
speech_masks=None,
|
||||||
|
speech_type="audio",
|
||||||
|
return_unmask=False
|
||||||
|
):
|
||||||
|
if speech_tensors is None:
|
||||||
|
# Use config to get vae_dim instead of non-existent self.args
|
||||||
|
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||||
|
audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight)
|
||||||
|
connect_features = self.model.acoustic_connector(audio_features)
|
||||||
|
return audio_features, connect_features
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if speech_type == "audio":
|
||||||
|
with torch.no_grad():
|
||||||
|
frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0]
|
||||||
|
audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||||
|
|
||||||
|
elif speech_type == "vae":
|
||||||
|
# Use config to get vae_dim instead of non-existent self.args
|
||||||
|
vae_dim = self.config.acoustic_tokenizer_config.vae_dim
|
||||||
|
speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim)
|
||||||
|
|
||||||
|
# gaussian sample from the speech_mode
|
||||||
|
batch_size = speech_mode.size(0)
|
||||||
|
value = self.model.acoustic_tokenizer.fix_std / 0.8
|
||||||
|
std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value
|
||||||
|
std = std.view(-1, *[1] * (speech_mode.dim() - 1))
|
||||||
|
audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Speech type {speech_type} not implemented")
|
||||||
|
|
||||||
|
if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor):
|
||||||
|
scaling_factor = 1. / audio_tokens[speech_masks].flatten().std()
|
||||||
|
bias_factor = -audio_tokens[speech_masks].flatten().mean()
|
||||||
|
|
||||||
|
# Only use distributed operations if the process group is initialized
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM)
|
||||||
|
dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM)
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
self.model.speech_scaling_factor.copy_(scaling_factor / world_size)
|
||||||
|
self.model.speech_bias_factor.copy_(bias_factor / world_size)
|
||||||
|
print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||||
|
else:
|
||||||
|
# Single process case
|
||||||
|
self.model.speech_scaling_factor.copy_(scaling_factor)
|
||||||
|
self.model.speech_bias_factor.copy_(bias_factor)
|
||||||
|
print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True)
|
||||||
|
|
||||||
|
audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor
|
||||||
|
|
||||||
|
connect_features = self.model.acoustic_connector(audio_features)
|
||||||
|
if return_unmask:
|
||||||
|
return audio_features, connect_features
|
||||||
|
return audio_features[speech_masks], connect_features[speech_masks]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
# New arguments for speech processing and loss calculation
|
||||||
|
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||||
|
speech_masks: Optional[torch.BoolTensor] = None,
|
||||||
|
speeches_loss_input: Optional[torch.FloatTensor] = None,
|
||||||
|
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||||
|
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
acoustic_loss_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
ddpm_batch_mul: int = 1,
|
||||||
|
**kwargs: Optional[Dict[str, Union[torch.Tensor, str]]],
|
||||||
|
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
x = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors)
|
||||||
|
if speeches_loss_input is not None:
|
||||||
|
# only part audio need diffuse
|
||||||
|
speech_all_features, speech_all_connect_features = self.forward_speech_features(
|
||||||
|
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||||
|
speech_masks=speech_masks,
|
||||||
|
speech_type=kwargs.get("speech_type", "audio"),
|
||||||
|
return_unmask=True
|
||||||
|
)
|
||||||
|
if speech_tensors is not None:
|
||||||
|
if semantic_speech_all_connect_features is not None:
|
||||||
|
x[acoustic_input_mask] = (
|
||||||
|
speech_all_connect_features[speech_masks]
|
||||||
|
+ semantic_speech_all_connect_features[speech_masks]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x[acoustic_input_mask] = speech_all_connect_features[speech_masks]
|
||||||
|
|
||||||
|
# Select only the target segments' latents for diffusion loss.
|
||||||
|
# Both masks are [num_segments, max_latent_len]; using 2D mask on [B,T,D] selects [N_true, D].
|
||||||
|
target_latent_mask = speeches_loss_input & speech_masks
|
||||||
|
speech_features = speech_all_features[target_latent_mask]
|
||||||
|
speech_connect_features = speech_all_connect_features[target_latent_mask]
|
||||||
|
else:
|
||||||
|
speech_features, speech_connect_features = self.forward_speech_features(
|
||||||
|
speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None,
|
||||||
|
speech_masks=speech_masks,
|
||||||
|
speech_type=kwargs.get("speech_type", "audio"),
|
||||||
|
)
|
||||||
|
if speech_tensors is not None:
|
||||||
|
x[acoustic_input_mask] = speech_connect_features
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=x,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
# logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# The custom CE loss with masking is calculated in the training script.
|
||||||
|
# We leave the standard loss calculation here as None.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# --- Diffusion Loss Calculation ---
|
||||||
|
diffusion_loss = None
|
||||||
|
# This block is executed only if we are in a context that involves speech.
|
||||||
|
if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0:
|
||||||
|
condition_features = hidden_states[acoustic_loss_mask]
|
||||||
|
|
||||||
|
speech_len, latent_size = speech_features.shape
|
||||||
|
|
||||||
|
noise = torch.randn(
|
||||||
|
(speech_len * ddpm_batch_mul, latent_size),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
timesteps = torch.multinomial(
|
||||||
|
torch.ones(self.config.diffusion_head_config.ddpm_num_steps),
|
||||||
|
speech_len * ddpm_batch_mul,
|
||||||
|
replacement=True,
|
||||||
|
).to(hidden_states.device)
|
||||||
|
|
||||||
|
speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||||
|
condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0)
|
||||||
|
|
||||||
|
noisy_speech_features = self.model.noise_scheduler.add_noise(
|
||||||
|
speech_features_repeated, noise, timesteps
|
||||||
|
)
|
||||||
|
|
||||||
|
model_output = self.model.prediction_head(
|
||||||
|
noisy_speech_features,
|
||||||
|
timesteps.type_as(x),
|
||||||
|
condition_features_repeated
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction_type = self.config.diffusion_head_config.prediction_type
|
||||||
|
if prediction_type == "epsilon":
|
||||||
|
target_for_loss = noise
|
||||||
|
elif prediction_type == "v_prediction":
|
||||||
|
target_for_loss = self.model.noise_scheduler.get_velocity(
|
||||||
|
speech_features_repeated, noise, timesteps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Prediction type {prediction_type} not implemented")
|
||||||
|
|
||||||
|
diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum')
|
||||||
|
if latent_size > 0 and ddpm_batch_mul > 0:
|
||||||
|
diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul
|
||||||
|
else:
|
||||||
|
diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Dummy loss for DDP to work when there are no speech samples in a batch,
|
||||||
|
# but we are in a speech context.
|
||||||
|
diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0
|
||||||
|
diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0
|
||||||
|
diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0
|
||||||
|
# --- End Diffusion Loss Calculation ---
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits, speech_len) + outputs.to_tuple()[1:]
|
||||||
|
return (loss, diffusion_loss) + output
|
||||||
|
|
||||||
|
return VibeVoiceCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
diffusion_loss=diffusion_loss,
|
||||||
|
speech_token_num=speech_len if speech_tensors is not None else 0,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
AutoModel.register(VibeVoiceConfig, VibeVoiceModel)
|
||||||
|
AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VibeVoiceModel",
|
||||||
|
"VibeVoicePreTrainedModel",
|
||||||
|
"VibeVoiceForConditionalGeneration",
|
||||||
|
"VibeVoiceCausalLMOutputWithPast",
|
||||||
|
"VibeVoiceGenerationOutput",
|
||||||
|
]
|
||||||
@@ -0,0 +1,520 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
|
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast
|
||||||
|
from transformers import modeling_utils
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.utils import logging
|
||||||
|
from transformers.generation import GenerationMixin
|
||||||
|
|
||||||
|
from .modular_vibevoice_tokenizer import (
|
||||||
|
VibeVoiceTokenizerStreamingCache,
|
||||||
|
VibeVoiceTokenizerEncoderOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
from .configuration_vibevoice import VibeVoiceASRConfig
|
||||||
|
from .modeling_vibevoice import (
|
||||||
|
VibeVoiceCausalLMOutputWithPast,
|
||||||
|
SpeechConnector
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
|
||||||
|
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
|
||||||
|
|
||||||
|
# @auto_docstring
|
||||||
|
class VibeVoiceASRPreTrainedModel(PreTrainedModel):
|
||||||
|
config_class = VibeVoiceASRConfig
|
||||||
|
base_model_prefix = "model"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
_skip_keys_device_placement = "past_key_values"
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_flash_attn = True
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
_supports_quantized_cache = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
_supports_attention_backend = True
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
|
||||||
|
# Use the language model's initializer_range if available
|
||||||
|
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
|
||||||
|
std = self.config.language_model_config.initializer_range
|
||||||
|
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
|
||||||
|
std = self.config.decoder_config.initializer_range
|
||||||
|
else:
|
||||||
|
std = 0.02 # Default value
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
# @auto_docstring
|
||||||
|
class VibeVoiceASRModel(VibeVoiceASRPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||||
|
if isinstance(config.torch_dtype, str):
|
||||||
|
dtype = getattr(torch, config.torch_dtype)
|
||||||
|
else:
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
# Initialize Qwen2 model for language modeling
|
||||||
|
lm_config = config.decoder_config
|
||||||
|
self.language_model = AutoModel.from_config(lm_config)
|
||||||
|
|
||||||
|
# Initialize speech components if needed
|
||||||
|
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
|
||||||
|
self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype)
|
||||||
|
|
||||||
|
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||||
|
self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
if hasattr(self.language_model, 'embed_tokens'):
|
||||||
|
# If the language model has an embed_tokens attribute, return it
|
||||||
|
return self.language_model.embed_tokens
|
||||||
|
|
||||||
|
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
|
||||||
|
if attr.orig_name == 'embed_tokens.weight':
|
||||||
|
return getattr(self.language_model, name)
|
||||||
|
assert False, 'should not arrive here'
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.language_model.embed_tokens = value
|
||||||
|
|
||||||
|
def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None):
|
||||||
|
"""Set the speech tokenizers used for encoding and decoding speech."""
|
||||||
|
self.acoustic_tokenizer = acoustic_tokenizer
|
||||||
|
self.semantic_tokenizer = semantic_tokenizer
|
||||||
|
|
||||||
|
# Reset the encoder to evaluation mode
|
||||||
|
if self.acoustic_tokenizer is not None:
|
||||||
|
self.acoustic_tokenizer.eval()
|
||||||
|
|
||||||
|
if self.semantic_tokenizer is not None:
|
||||||
|
self.semantic_tokenizer.eval()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# Forward through language model
|
||||||
|
outputs = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=outputs.last_hidden_state,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, GenerationMixin):
|
||||||
|
"""
|
||||||
|
VibeVoice model for Automatic Speech Recognition (ASR) with language modeling head for conditional generation.
|
||||||
|
This class is designed for inference and generation tasks.
|
||||||
|
"""
|
||||||
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.model = VibeVoiceASRModel(config)
|
||||||
|
self.vocab_size = config.decoder_config.vocab_size
|
||||||
|
|
||||||
|
# Determine the dtype to use
|
||||||
|
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
|
||||||
|
if isinstance(config.torch_dtype, str):
|
||||||
|
dtype = getattr(torch, config.torch_dtype)
|
||||||
|
else:
|
||||||
|
dtype = config.torch_dtype
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
# Initialize lm_head with the correct dtype
|
||||||
|
self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False).to(dtype)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.model.language_model = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.model.language_model
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
"""Tie the weights between the input embeddings and the output embeddings."""
|
||||||
|
if getattr(self.config.decoder_config, 'tie_word_embeddings', False):
|
||||||
|
output_embeddings = self.get_output_embeddings()
|
||||||
|
input_embeddings = self.get_input_embeddings()
|
||||||
|
if hasattr(input_embeddings, 'weight'):
|
||||||
|
output_embeddings.weight = input_embeddings.weight
|
||||||
|
else:
|
||||||
|
output_embeddings.weight = input_embeddings
|
||||||
|
|
||||||
|
def encode_speech(
|
||||||
|
self,
|
||||||
|
speech_tensors: torch.FloatTensor,
|
||||||
|
speech_masks: Optional[torch.BoolTensor] = None,
|
||||||
|
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||||
|
streaming_segment_duration: float = 60.0, # seconds
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Encode speech input into features that can be used by the language model.
|
||||||
|
This method is called once before generation to process the speech input.
|
||||||
|
|
||||||
|
For long audio (>600s by default), uses streaming processing to avoid conv overflow (>2^32).
|
||||||
|
Segments are processed independently, then concatenated before final sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech_tensors: Input audio tensor [batch_size, samples]
|
||||||
|
speech_masks: Optional mask for speech features
|
||||||
|
speech_semantic_tensors: Optional pre-computed semantic tokens
|
||||||
|
streaming_segment_duration: Segment duration in seconds for streaming processing (default: 60s)
|
||||||
|
"""
|
||||||
|
if hasattr(self.config, 'torch_dtype') and self.config.torch_dtype is not None:
|
||||||
|
if isinstance(self.config.torch_dtype, str):
|
||||||
|
dtype = getattr(torch, self.config.torch_dtype)
|
||||||
|
else:
|
||||||
|
dtype = self.config.torch_dtype
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
speech_tensors = speech_tensors.to(dtype)
|
||||||
|
|
||||||
|
# Ensure proper shape: (batch, samples)
|
||||||
|
if speech_tensors.ndim == 1:
|
||||||
|
speech_tensors = speech_tensors.unsqueeze(0)
|
||||||
|
|
||||||
|
batch_size, total_samples = speech_tensors.shape
|
||||||
|
sample_rate = 24000 # fix 24kHz sample rate
|
||||||
|
|
||||||
|
# Calculate segment size in samples
|
||||||
|
segment_samples = int(streaming_segment_duration * sample_rate)
|
||||||
|
|
||||||
|
# Decide whether to use streaming based on audio length
|
||||||
|
use_streaming = total_samples > segment_samples
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if not use_streaming:
|
||||||
|
# Short audio: direct processing (original behavior)
|
||||||
|
encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))
|
||||||
|
audio_tokens = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0]
|
||||||
|
acoustic_features = self.model.acoustic_connector(audio_tokens)
|
||||||
|
|
||||||
|
# Encode semantic features
|
||||||
|
if speech_semantic_tensors is not None:
|
||||||
|
semantic_features = self.model.semantic_connector(speech_semantic_tensors)
|
||||||
|
else:
|
||||||
|
semantic_tokens = self.model.semantic_tokenizer.encode(speech_tensors.unsqueeze(1)).mean
|
||||||
|
semantic_features = self.model.semantic_connector(semantic_tokens)
|
||||||
|
else:
|
||||||
|
# Long audio: streaming processing
|
||||||
|
# print(f"Using streaming processing for long audio: {total_samples/sample_rate:.1f}s "
|
||||||
|
# f"(segment size: {streaming_segment_duration}s)")
|
||||||
|
|
||||||
|
# Initialize caches for both tokenizers
|
||||||
|
acoustic_encoder_cache = VibeVoiceTokenizerStreamingCache()
|
||||||
|
semantic_encoder_cache = VibeVoiceTokenizerStreamingCache()
|
||||||
|
acoustic_mean_segments = []
|
||||||
|
semantic_mean_segments = []
|
||||||
|
sample_indices = torch.arange(batch_size, device=speech_tensors.device)
|
||||||
|
|
||||||
|
# Helper function from batch_asr_sft_cache.py
|
||||||
|
def _iter_segments(total_length: int, segment_length: int):
|
||||||
|
"""Iterate over audio segments with a given segment length."""
|
||||||
|
if segment_length <= 0:
|
||||||
|
raise ValueError("segment_length must be positive")
|
||||||
|
for start in range(0, total_length, segment_length):
|
||||||
|
end = min(start + segment_length, total_length)
|
||||||
|
if end > start:
|
||||||
|
yield start, end
|
||||||
|
|
||||||
|
# Process each segment for both acoustic and semantic tokenizers
|
||||||
|
segments = list(_iter_segments(total_samples, segment_samples))
|
||||||
|
num_segments = len(segments)
|
||||||
|
for seg_idx, (start, end) in enumerate(segments):
|
||||||
|
chunk = speech_tensors[:, start:end].contiguous()
|
||||||
|
if chunk.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is the final segment
|
||||||
|
is_final = (seg_idx == num_segments - 1)
|
||||||
|
|
||||||
|
# Encode chunk for acoustic tokenizer (don't sample yet)
|
||||||
|
acoustic_encoder_output = self.model.acoustic_tokenizer.encode(
|
||||||
|
chunk.unsqueeze(1),
|
||||||
|
cache=acoustic_encoder_cache,
|
||||||
|
sample_indices=sample_indices,
|
||||||
|
use_cache=True,
|
||||||
|
is_final_chunk=is_final,
|
||||||
|
)
|
||||||
|
acoustic_mean_segments.append(acoustic_encoder_output.mean)
|
||||||
|
|
||||||
|
# Encode chunk for semantic tokenizer (take mean directly)
|
||||||
|
semantic_encoder_output = self.model.semantic_tokenizer.encode(
|
||||||
|
chunk.unsqueeze(1),
|
||||||
|
cache=semantic_encoder_cache,
|
||||||
|
sample_indices=sample_indices,
|
||||||
|
use_cache=True,
|
||||||
|
is_final_chunk=is_final,
|
||||||
|
)
|
||||||
|
semantic_mean_segments.append(semantic_encoder_output.mean)
|
||||||
|
|
||||||
|
# print(f"Processed {len(acoustic_mean_segments)} segments.")
|
||||||
|
# Concatenate all acoustic means and sample once
|
||||||
|
acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous()
|
||||||
|
acoustic_encoder_output = VibeVoiceTokenizerEncoderOutput(
|
||||||
|
mean=acoustic_mean_full,
|
||||||
|
std=self.model.acoustic_tokenizer.fix_std
|
||||||
|
)
|
||||||
|
audio_tokens = acoustic_encoder_output.sample(
|
||||||
|
dist_type=self.model.acoustic_tokenizer.std_dist_type
|
||||||
|
)[0]
|
||||||
|
acoustic_features = self.model.acoustic_connector(audio_tokens)
|
||||||
|
|
||||||
|
# Concatenate all semantic means
|
||||||
|
semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
|
||||||
|
semantic_features = self.model.semantic_connector(semantic_tokens)
|
||||||
|
|
||||||
|
# Combine acoustic and semantic features
|
||||||
|
if speech_masks is not None:
|
||||||
|
combined_features = acoustic_features[speech_masks] + semantic_features[speech_masks]
|
||||||
|
else:
|
||||||
|
combined_features = acoustic_features + semantic_features
|
||||||
|
|
||||||
|
return combined_features
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
# Speech-specific arguments
|
||||||
|
speech_tensors: Optional[torch.FloatTensor] = None,
|
||||||
|
speech_masks: Optional[torch.BoolTensor] = None,
|
||||||
|
speech_semantic_tensors: Optional[torch.FloatTensor] = None,
|
||||||
|
acoustic_input_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutput]:
|
||||||
|
"""
|
||||||
|
Forward pass for the model. Handles both training and generation scenarios.
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
# Process inputs
|
||||||
|
if inputs_embeds is None and input_ids is not None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
# If we have speech input and acoustic_input_mask, encode and insert speech features
|
||||||
|
if speech_tensors is not None and acoustic_input_mask is not None:
|
||||||
|
speech_features = self.encode_speech(
|
||||||
|
speech_tensors=speech_tensors,
|
||||||
|
speech_masks=speech_masks,
|
||||||
|
speech_semantic_tensors=speech_semantic_tensors,
|
||||||
|
)
|
||||||
|
inputs_embeds[acoustic_input_mask] = speech_features
|
||||||
|
|
||||||
|
# Forward through the model
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return VibeVoiceCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
speech_tensors=None,
|
||||||
|
speech_masks=None,
|
||||||
|
speech_semantic_tensors=None,
|
||||||
|
acoustic_input_mask=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prepare inputs for generation step. This method is called by generate()
|
||||||
|
for each token generation step.
|
||||||
|
|
||||||
|
Following Qwen2-VL's approach: speech inputs are only forwarded on the first pass
|
||||||
|
(when cache_position[0] == 0), and are excluded in subsequent generation steps.
|
||||||
|
"""
|
||||||
|
# If we have past key values, we only need to process the new tokens
|
||||||
|
if past_key_values is not None:
|
||||||
|
if isinstance(past_key_values, tuple):
|
||||||
|
past_length = past_key_values[0][0].shape[2]
|
||||||
|
else:
|
||||||
|
past_length = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
# Keep only the new tokens
|
||||||
|
if input_ids is not None and input_ids.shape[1] > past_length:
|
||||||
|
input_ids = input_ids[:, past_length:]
|
||||||
|
|
||||||
|
# Prepare position ids
|
||||||
|
if position_ids is None and attention_mask is not None:
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values is not None and input_ids is not None:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1]:]
|
||||||
|
|
||||||
|
# Prepare cache position
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]),
|
||||||
|
device=input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare model inputs
|
||||||
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
else:
|
||||||
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Following Qwen2-VL pattern: only include speech inputs on the first forward pass
|
||||||
|
# (when cache_position[0] == 0), exclude them in subsequent generation steps
|
||||||
|
if cache_position is not None and len(cache_position) > 0 and cache_position[0] == 0:
|
||||||
|
# First forward pass - include speech inputs if provided
|
||||||
|
model_inputs.update({
|
||||||
|
"speech_tensors": speech_tensors,
|
||||||
|
"speech_masks": speech_masks,
|
||||||
|
"speech_semantic_tensors": speech_semantic_tensors,
|
||||||
|
"acoustic_input_mask": acoustic_input_mask,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Subsequent generation steps - exclude speech inputs
|
||||||
|
model_inputs.update({
|
||||||
|
"speech_tensors": None,
|
||||||
|
"speech_masks": None,
|
||||||
|
"speech_semantic_tensors": None,
|
||||||
|
"acoustic_input_mask": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Include any remaining kwargs that might be needed
|
||||||
|
model_inputs.update(kwargs)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
AutoModel.register(VibeVoiceASRConfig, VibeVoiceASRModel)
|
||||||
|
AutoModelForCausalLM.register(VibeVoiceASRConfig, VibeVoiceASRForConditionalGeneration)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VibeVoiceASRPreTrainedModel",
|
||||||
|
"VibeVoiceASRModel",
|
||||||
|
"VibeVoiceASRForConditionalGeneration",
|
||||||
|
]
|
||||||
@@ -207,8 +207,107 @@ class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
|
|||||||
"""Id used for padding (returns -100 for loss masking)."""
|
"""Id used for padding (returns -100 for loss masking)."""
|
||||||
return self._pad_id
|
return self._pad_id
|
||||||
|
|
||||||
|
class VibeVoiceASRTextTokenizerFast(Qwen2TokenizerFast):
|
||||||
|
"""
|
||||||
|
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
|
||||||
|
Based on the Qwen2 tokenizer with additional special tokens for speech.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`, *optional*):
|
||||||
|
Path to the vocabulary file.
|
||||||
|
merges_file (`str`, *optional*):
|
||||||
|
Path to the merges file.
|
||||||
|
tokenizer_file (`str`, *optional*):
|
||||||
|
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
|
||||||
|
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||||
|
The unknown token.
|
||||||
|
bos_token (`str`, *optional*):
|
||||||
|
The beginning of sequence token. Not used for vibevoice.
|
||||||
|
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||||
|
The token used for padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["input_ids", "attention_mask"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file=None,
|
||||||
|
merges_file=None,
|
||||||
|
tokenizer_file=None,
|
||||||
|
unk_token="<|endoftext|>",
|
||||||
|
bos_token=None,
|
||||||
|
eos_token="<|endoftext|>",
|
||||||
|
pad_token="<|endoftext|>",
|
||||||
|
add_prefix_space=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
merges_file=merges_file,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
unk_token=unk_token,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
add_prefix_space=add_prefix_space,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add VibeVoice-specific special tokens
|
||||||
|
self._add_vibevoice_special_tokens()
|
||||||
|
|
||||||
|
# https://github.com/QwenLM/Qwen2.5-VL/blob/d2240f11656bfe404b9ba56db4e51cd09f522ff1/qwen-vl-finetune/qwenvl/data/data_qwen_packed.py#L57C5-L57C222
|
||||||
|
self.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||||
|
|
||||||
|
def _add_vibevoice_special_tokens(self):
|
||||||
|
"""Add VibeVoice-specific special tokens."""
|
||||||
|
special_tokens = {
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|object_ref_start|>", # Speech start (reusing vision tokens)
|
||||||
|
"<|object_ref_end|>", # Speech end
|
||||||
|
"<|box_start|>", # Speech diffusion pad
|
||||||
|
]
|
||||||
|
}
|
||||||
|
num_added = self.add_special_tokens(special_tokens)
|
||||||
|
|
||||||
|
# Cache special token IDs
|
||||||
|
self._speech_start_id = self.convert_tokens_to_ids("<|object_ref_start|>")
|
||||||
|
self._speech_end_id = self.convert_tokens_to_ids("<|object_ref_end|>")
|
||||||
|
self._speech_pad_id = self.convert_tokens_to_ids("<|box_start|>")
|
||||||
|
|
||||||
|
self._eos_id = self.eos_token_id # qwen2 / qwen3
|
||||||
|
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
|
||||||
|
|
||||||
|
return num_added
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
"""Id of the end of sequence token."""
|
||||||
|
return self._eos_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speech_start_id(self) -> int:
|
||||||
|
"""Id of the speech start token."""
|
||||||
|
return self._speech_start_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speech_end_id(self) -> int:
|
||||||
|
"""Id of the speech end token."""
|
||||||
|
return self._speech_end_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def speech_pad_id(self) -> int:
|
||||||
|
"""Id of the speech diffusion token."""
|
||||||
|
return self._speech_pad_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_id(self) -> int:
|
||||||
|
return self._pad_id
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"VibeVoiceTextTokenizer",
|
"VibeVoiceTextTokenizer",
|
||||||
"VibeVoiceTextTokenizerFast",
|
"VibeVoiceTextTokenizerFast",
|
||||||
|
"VibeVoiceASRTextTokenizerFast",
|
||||||
]
|
]
|
||||||
@@ -17,7 +17,7 @@ from transformers.utils import logging
|
|||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig
|
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@@ -26,14 +26,13 @@ import os
|
|||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
||||||
APEX_AVAILABLE = True
|
APEX_AVAILABLE = True
|
||||||
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
# logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
||||||
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
|
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
|
||||||
APEX_AVAILABLE = False
|
APEX_AVAILABLE = False
|
||||||
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
# logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
APEX_AVAILABLE = False
|
APEX_AVAILABLE = False
|
||||||
logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
# logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
||||||
# APEX_AVAILABLE=False
|
|
||||||
|
|
||||||
# Normalization modules
|
# Normalization modules
|
||||||
class ConvLayerNorm(nn.LayerNorm):
|
class ConvLayerNorm(nn.LayerNorm):
|
||||||
@@ -297,7 +296,8 @@ class SConv1d(nn.Module):
|
|||||||
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
||||||
sample_indices: Optional[torch.Tensor] = None,
|
sample_indices: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
debug: bool = False) -> torch.Tensor:
|
debug: bool = False,
|
||||||
|
is_final_chunk: bool = False) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass with optional streaming support via cache.
|
Forward pass with optional streaming support via cache.
|
||||||
|
|
||||||
@@ -307,6 +307,7 @@ class SConv1d(nn.Module):
|
|||||||
sample_indices: Indices identifying each sample for cache management
|
sample_indices: Indices identifying each sample for cache management
|
||||||
use_cache: Whether to use cached states for streaming
|
use_cache: Whether to use cached states for streaming
|
||||||
debug: Whether to print debug information
|
debug: Whether to print debug information
|
||||||
|
is_final_chunk: Whether this is the final chunk (adds extra padding for alignment)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output tensor
|
Output tensor
|
||||||
@@ -322,12 +323,13 @@ class SConv1d(nn.Module):
|
|||||||
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
||||||
assert len(sample_indices) == B, "sample_indices must match batch size"
|
assert len(sample_indices) == B, "sample_indices must match batch size"
|
||||||
|
|
||||||
return self._forward_streaming(x, cache, sample_indices, debug)
|
return self._forward_streaming(x, cache, sample_indices, debug, is_final_chunk)
|
||||||
|
|
||||||
def _forward_streaming(self, x: torch.Tensor,
|
def _forward_streaming(self, x: torch.Tensor,
|
||||||
cache: VibeVoiceTokenizerStreamingCache,
|
cache: VibeVoiceTokenizerStreamingCache,
|
||||||
sample_indices: torch.Tensor,
|
sample_indices: torch.Tensor,
|
||||||
debug: bool = False) -> torch.Tensor:
|
debug: bool = False,
|
||||||
|
is_final_chunk: bool = False) -> torch.Tensor:
|
||||||
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
"""Streaming forward pass with cache operations kept separate from compiled code"""
|
||||||
B, C, T = x.shape
|
B, C, T = x.shape
|
||||||
|
|
||||||
@@ -350,6 +352,16 @@ class SConv1d(nn.Module):
|
|||||||
input_with_context = torch.cat([cached_states, x], dim=2)
|
input_with_context = torch.cat([cached_states, x], dim=2)
|
||||||
else:
|
else:
|
||||||
input_with_context = x
|
input_with_context = x
|
||||||
|
|
||||||
|
# For final chunk, add extra padding to ensure ceil behavior (same as non-streaming)
|
||||||
|
if is_final_chunk:
|
||||||
|
extra_padding = get_extra_padding_for_conv1d(
|
||||||
|
input_with_context, self.kernel_size, self.stride, self.padding_total
|
||||||
|
)
|
||||||
|
if extra_padding > 0:
|
||||||
|
input_with_context = pad1d(input_with_context, (0, extra_padding), mode=self.pad_mode)
|
||||||
|
if debug:
|
||||||
|
print(f"[DEBUG] Final chunk: added extra_padding={extra_padding}")
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
||||||
@@ -684,6 +696,135 @@ class Block1D(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Encoder component for the VibeVoice tokenizer that converts audio to latent representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration object with model parameters
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Extract parameters from config
|
||||||
|
self.channels = config.channels
|
||||||
|
self.dimension = config.dimension
|
||||||
|
self.n_filters = config.n_filters
|
||||||
|
self.ratios = list(reversed(config.ratios))
|
||||||
|
self.depths = config.depths
|
||||||
|
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
||||||
|
self.hop_length = np.prod(self.ratios)
|
||||||
|
self.causal = config.causal
|
||||||
|
|
||||||
|
# Additional config parameters with defaults
|
||||||
|
kernel_size = getattr(config, "kernel_size", 7)
|
||||||
|
last_kernel_size = getattr(config, "last_kernel_size", 7)
|
||||||
|
norm = getattr(config, "norm", "none")
|
||||||
|
norm_params = getattr(config, "norm_params", {})
|
||||||
|
pad_mode = getattr(config, "pad_mode", "reflect")
|
||||||
|
bias = getattr(config, "bias", True)
|
||||||
|
layernorm = getattr(config, "layernorm", "LN")
|
||||||
|
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
||||||
|
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
||||||
|
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
||||||
|
mixer_layer = getattr(config, "mixer_layer", "conv")
|
||||||
|
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
||||||
|
disable_last_norm = getattr(config, "disable_last_norm", False)
|
||||||
|
|
||||||
|
# determine the norm type based on layernorm
|
||||||
|
if layernorm == 'LN':
|
||||||
|
norm_type = ConvLayerNorm
|
||||||
|
elif layernorm == 'RMSNorm':
|
||||||
|
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported norm type: {layernorm}")
|
||||||
|
|
||||||
|
# stem and intermediate downsampling conv layers
|
||||||
|
stem = nn.Sequential(
|
||||||
|
SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.downsample_layers = nn.ModuleList()
|
||||||
|
self.downsample_layers.append(stem)
|
||||||
|
for i in range(len(self.ratios)):
|
||||||
|
in_ch = self.n_filters * (2 ** i)
|
||||||
|
out_ch = self.n_filters * (2 ** (i + 1))
|
||||||
|
downsample_layer = nn.Sequential(
|
||||||
|
SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
||||||
|
)
|
||||||
|
self.downsample_layers.append(downsample_layer)
|
||||||
|
|
||||||
|
# configure the transformer blocks
|
||||||
|
layer_type = partial(
|
||||||
|
Block1D,
|
||||||
|
mixer_layer=mixer_layer,
|
||||||
|
layernorm=layernorm,
|
||||||
|
eps=layernorm_eps,
|
||||||
|
causal=self.causal,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
norm=norm,
|
||||||
|
bias=bias,
|
||||||
|
layer_scale_init_value=layer_scale_init_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stages = nn.ModuleList()
|
||||||
|
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
||||||
|
cur = 0
|
||||||
|
|
||||||
|
for i in range(len(self.depths)):
|
||||||
|
in_ch = self.n_filters * (2 ** i)
|
||||||
|
stage = nn.Sequential(
|
||||||
|
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
||||||
|
)
|
||||||
|
self.stages.append(stage)
|
||||||
|
cur += self.depths[i]
|
||||||
|
|
||||||
|
if not disable_last_norm:
|
||||||
|
self.norm = norm_type(in_ch, eps=layernorm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
||||||
|
|
||||||
|
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||||
|
for i in range(len(self.depths)):
|
||||||
|
# Apply downsampling
|
||||||
|
for layer in self.downsample_layers[i]:
|
||||||
|
if isinstance(layer, SConv1d):
|
||||||
|
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
# Apply stage (Block1D contains Convlayer which contains SConv1d)
|
||||||
|
for block in self.stages[i]:
|
||||||
|
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
||||||
|
# Block1D forward with cache support
|
||||||
|
residual = x
|
||||||
|
x = block.norm(x)
|
||||||
|
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||||
|
if block.gamma is not None:
|
||||||
|
x = x * block.gamma.unsqueeze(-1)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
# FFN part
|
||||||
|
residual = x
|
||||||
|
x = block.ffn_norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = block.ffn(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
if block.ffn_gamma is not None:
|
||||||
|
x = x * block.ffn_gamma.unsqueeze(-1)
|
||||||
|
x = residual + x
|
||||||
|
else:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||||
|
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||||
|
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TokenizerDecoder(nn.Module):
|
class TokenizerDecoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
|
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
|
||||||
@@ -821,15 +962,63 @@ class TokenizerDecoder(nn.Module):
|
|||||||
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VibeVoiceTokenizerEncoderOutput:
|
||||||
|
"""
|
||||||
|
Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mean (`torch.FloatTensor`): The mean parameters of the distribution.
|
||||||
|
std (`float` or `torch.FloatTensor`): Fixed standard deviation value.
|
||||||
|
"""
|
||||||
|
mean: torch.Tensor
|
||||||
|
std: Optional[Union[float, torch.Tensor]] = None
|
||||||
|
|
||||||
|
def sample(self, dist_type='fix'):
|
||||||
|
"""
|
||||||
|
Sample from the distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist_type (`str`): Sampling method, either 'fix' or 'gaussian'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: Sampled values.
|
||||||
|
`torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian').
|
||||||
|
"""
|
||||||
|
if dist_type == 'fix':
|
||||||
|
x = self.mean + self.std * torch.randn_like(self.mean)
|
||||||
|
return x, self.std
|
||||||
|
elif dist_type == 'gaussian':
|
||||||
|
batch_size = self.mean.size(0)
|
||||||
|
value = self.std / 0.8
|
||||||
|
std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value
|
||||||
|
|
||||||
|
while std.dim() < self.mean.dim():
|
||||||
|
std = std.unsqueeze(-1)
|
||||||
|
|
||||||
|
x = self.mean + std * torch.randn_like(self.mean)
|
||||||
|
return x, std
|
||||||
|
else:
|
||||||
|
return self.mean, self.std
|
||||||
|
|
||||||
|
def kl(self):
|
||||||
|
"""Compute KL divergence between this distribution and a standard normal."""
|
||||||
|
target = torch.zeros_like(self.mean)
|
||||||
|
return F.mse_loss(self.mean, target, reduction='none')
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
"""Return the distribution mode (which is the mean for Gaussian)."""
|
||||||
|
return self.mean
|
||||||
|
|
||||||
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
||||||
"""VibeVoice speech tokenizer model (only decoder) for acoustic tokens"""
|
"""VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens"""
|
||||||
|
|
||||||
config_class = VibeVoiceAcousticTokenizerConfig
|
config_class = VibeVoiceAcousticTokenizerConfig
|
||||||
base_model_prefix = "vibevoice_acoustic_tokenizer"
|
base_model_prefix = "vibevoice_acoustic_tokenizer"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
_no_split_modules = ["TokenizerDecoder"]
|
_no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -850,6 +1039,21 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
|||||||
# Default: use reversed encoder depths if decoder_depths is None
|
# Default: use reversed encoder depths if decoder_depths is None
|
||||||
decoder_depths = list(reversed(encoder_depths))
|
decoder_depths = list(reversed(encoder_depths))
|
||||||
|
|
||||||
|
# Create encoder config
|
||||||
|
encoder_config = copy.deepcopy(config)
|
||||||
|
encoder_config.dimension = config.vae_dim
|
||||||
|
encoder_config.n_filters = config.encoder_n_filters
|
||||||
|
encoder_config.ratios = config.encoder_ratios
|
||||||
|
encoder_config.depths = encoder_depths
|
||||||
|
encoder_config.norm = config.conv_norm
|
||||||
|
encoder_config.pad_mode = config.pad_mode
|
||||||
|
encoder_config.bias = config.conv_bias
|
||||||
|
encoder_config.layernorm_eps = config.layernorm_eps
|
||||||
|
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
||||||
|
encoder_config.mixer_layer = config.mixer_layer
|
||||||
|
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
||||||
|
encoder_config.disable_last_norm = config.disable_last_norm
|
||||||
|
|
||||||
# Create decoder config
|
# Create decoder config
|
||||||
decoder_config = copy.deepcopy(config)
|
decoder_config = copy.deepcopy(config)
|
||||||
decoder_config.dimension = config.vae_dim
|
decoder_config.dimension = config.vae_dim
|
||||||
@@ -865,6 +1069,8 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
|||||||
decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
||||||
decoder_config.disable_last_norm = config.disable_last_norm
|
decoder_config.disable_last_norm = config.disable_last_norm
|
||||||
|
|
||||||
|
# Initialize encoder and decoder
|
||||||
|
self.encoder = TokenizerEncoder(encoder_config)
|
||||||
self.decoder = TokenizerDecoder(decoder_config)
|
self.decoder = TokenizerDecoder(decoder_config)
|
||||||
|
|
||||||
# Initialize weights
|
# Initialize weights
|
||||||
@@ -884,6 +1090,24 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
|||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
nn.init.zeros_(module.bias)
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||||
|
"""Convert audio to latent representations"""
|
||||||
|
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||||
|
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sampling(self, encoder_output, dist_type=None):
|
||||||
|
"""Sample from the encoder output distribution"""
|
||||||
|
dist_type = dist_type or self.std_dist_type
|
||||||
|
|
||||||
|
if dist_type == 'fix':
|
||||||
|
return encoder_output.sample(dist_type='fix')
|
||||||
|
elif dist_type == 'gaussian':
|
||||||
|
return encoder_output.sample(dist_type='gaussian')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
||||||
"""Convert latent representations back to audio"""
|
"""Convert latent representations back to audio"""
|
||||||
@@ -895,10 +1119,89 @@ class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
|||||||
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
||||||
|
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
||||||
|
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||||
|
sampled_latents, _ = self.sampling(encoder_output)
|
||||||
|
reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||||
|
return reconstructed, sampled_latents
|
||||||
|
|
||||||
|
|
||||||
|
class VibeVoiceSemanticTokenizerModel(PreTrainedModel):
|
||||||
|
"""VibeVoice speech tokenizer model with only encoder for semantic tokens"""
|
||||||
|
|
||||||
|
config_class = VibeVoiceSemanticTokenizerConfig
|
||||||
|
base_model_prefix = "vibevoice_semantic_tokenizer"
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
_no_split_modules = ["TokenizerEncoder"]
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Parse encoder depths
|
||||||
|
if isinstance(config.encoder_depths, str):
|
||||||
|
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
||||||
|
else:
|
||||||
|
encoder_depths = config.encoder_depths
|
||||||
|
|
||||||
|
# Create encoder config
|
||||||
|
encoder_config = copy.deepcopy(config)
|
||||||
|
encoder_config.dimension = config.vae_dim
|
||||||
|
encoder_config.n_filters = config.encoder_n_filters
|
||||||
|
encoder_config.ratios = config.encoder_ratios
|
||||||
|
encoder_config.depths = encoder_depths
|
||||||
|
encoder_config.norm = config.conv_norm
|
||||||
|
encoder_config.pad_mode = config.pad_mode
|
||||||
|
encoder_config.bias = config.conv_bias
|
||||||
|
encoder_config.layernorm_eps = config.layernorm_eps
|
||||||
|
encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
||||||
|
encoder_config.mixer_layer = config.mixer_layer
|
||||||
|
encoder_config.layer_scale_init_value = config.layer_scale_init_value
|
||||||
|
encoder_config.disable_last_norm = config.disable_last_norm
|
||||||
|
|
||||||
|
# Initialize encoder and decoder
|
||||||
|
self.encoder = TokenizerEncoder(encoder_config)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize weights for the model"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Conv1d):
|
||||||
|
nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False, is_final_chunk=False):
|
||||||
|
"""Convert audio to latent representations"""
|
||||||
|
latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug, is_final_chunk=is_final_chunk)
|
||||||
|
return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sampling(self, encoder_output, dist_type=None):
|
||||||
|
"""Sample from the encoder output distribution"""
|
||||||
|
return encoder_output.sample(dist_type='none')
|
||||||
|
|
||||||
|
def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False):
|
||||||
|
"""Full forward pass: encode audio to latents, then decode back to audio"""
|
||||||
|
encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
||||||
|
sampled_latents, _ = self.sampling(encoder_output, dist_type='none')
|
||||||
|
return None, sampled_latents
|
||||||
|
|
||||||
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
|
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
|
||||||
|
AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"VibeVoiceTokenizerStreamingCache",
|
"VibeVoiceTokenizerStreamingCache",
|
||||||
"VibeVoiceAcousticTokenizerModel",
|
"VibeVoiceAcousticTokenizerModel",
|
||||||
|
"VibeVoiceSemanticTokenizerModel",
|
||||||
]
|
]
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
import numpy as np
|
||||||
|
from subprocess import run
|
||||||
|
from typing import List, Optional, Union, Dict, Any
|
||||||
|
|
||||||
|
COMMON_AUDIO_EXTS = [
|
||||||
|
'.mp3', '.MP3', '.Mp3', # All case variations of mp3
|
||||||
|
'.m4a',
|
||||||
|
'.mp4', '.MP4',
|
||||||
|
'.wav', '.WAV',
|
||||||
|
'.m4v',
|
||||||
|
'.aac',
|
||||||
|
'.ogg',
|
||||||
|
'.mov', '.MOV',
|
||||||
|
'.opus',
|
||||||
|
'.m4b',
|
||||||
|
'.flac',
|
||||||
|
'.wma', '.WMA',
|
||||||
|
'.rm', '.3gp', '.mpeg', '.flv', '.webm', '.mp2', '.aif', '.aiff', '.oga', '.ogv', '.mpga', '.m3u8', '.amr'
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24000):
|
||||||
|
"""
|
||||||
|
Open an audio file and read as mono waveform, optionally resampling.
|
||||||
|
Returns both the audio data and the original sample rate.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file: str
|
||||||
|
The audio file to open
|
||||||
|
resample: bool
|
||||||
|
Whether to resample the audio
|
||||||
|
target_sr: int
|
||||||
|
The target sample rate if resampling is requested
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A tuple containing:
|
||||||
|
- A NumPy array with the audio waveform in float32 dtype
|
||||||
|
- The original sample rate of the audio file
|
||||||
|
"""
|
||||||
|
if not resample:
|
||||||
|
# First, get the original sample rate
|
||||||
|
cmd_probe = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "quiet",
|
||||||
|
"-show_entries", "stream=sample_rate",
|
||||||
|
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||||
|
file
|
||||||
|
]
|
||||||
|
|
||||||
|
original_sr = int(run(cmd_probe, capture_output=True, check=True).stdout.decode().strip())
|
||||||
|
else:
|
||||||
|
original_sr = None
|
||||||
|
|
||||||
|
# Now load the audio
|
||||||
|
sr_to_use = target_sr if resample else original_sr
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostdin",
|
||||||
|
"-threads", "0",
|
||||||
|
"-i", file,
|
||||||
|
"-f", "s16le",
|
||||||
|
"-ac", "1",
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sr_to_use),
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
|
||||||
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
return audio_data, sr_to_use
|
||||||
|
|
||||||
|
class AudioNormalizer:
|
||||||
|
"""
|
||||||
|
Audio normalization class for VibeVoice tokenizer.
|
||||||
|
|
||||||
|
This class provides audio normalization to ensure consistent input levels
|
||||||
|
for the VibeVoice tokenizer while maintaining audio quality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
||||||
|
"""
|
||||||
|
Initialize the audio normalizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
||||||
|
eps (float): Small value to avoid division by zero. Default: 1e-6
|
||||||
|
"""
|
||||||
|
self.target_dB_FS = target_dB_FS
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
||||||
|
"""
|
||||||
|
Adjust the audio to the target dB FS level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (np.ndarray): Input audio signal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (normalized_audio, rms, scalar)
|
||||||
|
"""
|
||||||
|
rms = np.sqrt(np.mean(audio**2))
|
||||||
|
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
||||||
|
normalized_audio = audio * scalar
|
||||||
|
return normalized_audio, rms, scalar
|
||||||
|
|
||||||
|
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Avoid clipping by scaling down if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (np.ndarray): Input audio signal
|
||||||
|
scalar (float, optional): Explicit scaling factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (normalized_audio, scalar)
|
||||||
|
"""
|
||||||
|
if scalar is None:
|
||||||
|
max_val = np.max(np.abs(audio))
|
||||||
|
if max_val > 1.0:
|
||||||
|
scalar = max_val + self.eps
|
||||||
|
else:
|
||||||
|
scalar = 1.0
|
||||||
|
|
||||||
|
return audio / scalar, scalar
|
||||||
|
|
||||||
|
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (np.ndarray): Input audio signal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Normalized audio signal
|
||||||
|
"""
|
||||||
|
# First adjust to target dB FS
|
||||||
|
audio, _, _ = self.tailor_dB_FS(audio)
|
||||||
|
# Then avoid clipping
|
||||||
|
audio, _ = self.avoid_clipping(audio)
|
||||||
|
return audio
|
||||||
@@ -0,0 +1,572 @@
|
|||||||
|
"""
|
||||||
|
Processor class for VibeVoice ASR models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union, Dict, Any, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
|
from transformers.utils import TensorType, logging
|
||||||
|
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor, AudioNormalizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .audio_utils import load_audio_use_ffmpeg
|
||||||
|
HAS_FFMPEG_UTILS = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_FFMPEG_UTILS = False
|
||||||
|
warnings.warn("audio_utils not available, will fall back to soundfile for audio loading")
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
||||||
|
|
||||||
|
|
||||||
|
class VibeVoiceASRProcessor:
|
||||||
|
"""
|
||||||
|
Processor for VibeVoice ASR (Automatic Speech Recognition) models.
|
||||||
|
|
||||||
|
This processor handles audio preprocessing and tokenization for ASR tasks,
|
||||||
|
following the exact format used in training with proper chat templates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: The text tokenizer for processing text
|
||||||
|
audio_processor: The audio processor for processing speech
|
||||||
|
speech_tok_compress_ratio (int): Compression ratio for speech tokenization
|
||||||
|
target_sample_rate (int): Target sample rate for audio
|
||||||
|
normalize_audio (bool): Whether to normalize audio input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer=None,
|
||||||
|
audio_processor=None,
|
||||||
|
speech_tok_compress_ratio=320,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
normalize_audio=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.audio_processor = audio_processor or VibeVoiceTokenizerProcessor(
|
||||||
|
sampling_rate=target_sample_rate,
|
||||||
|
normalize_audio=normalize_audio
|
||||||
|
)
|
||||||
|
self.speech_tok_compress_ratio = speech_tok_compress_ratio
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
self.normalize_audio = normalize_audio
|
||||||
|
|
||||||
|
if normalize_audio:
|
||||||
|
self.audio_normalizer = AudioNormalizer()
|
||||||
|
else:
|
||||||
|
self.audio_normalizer = None
|
||||||
|
|
||||||
|
# Cache special token IDs
|
||||||
|
self._cache_special_tokens()
|
||||||
|
|
||||||
|
def _cache_special_tokens(self):
|
||||||
|
"""Cache special token IDs for efficiency."""
|
||||||
|
# Add safety checks for special tokens
|
||||||
|
if hasattr(self.tokenizer, 'speech_start_id'):
|
||||||
|
self.speech_start_id = self.tokenizer.speech_start_id
|
||||||
|
else:
|
||||||
|
self.speech_start_id = self.tokenizer.convert_tokens_to_ids("<|speech_start|>")
|
||||||
|
|
||||||
|
if hasattr(self.tokenizer, 'speech_end_id'):
|
||||||
|
self.speech_end_id = self.tokenizer.speech_end_id
|
||||||
|
else:
|
||||||
|
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|speech_end|>")
|
||||||
|
|
||||||
|
if hasattr(self.tokenizer, 'speech_pad_id'):
|
||||||
|
self.speech_pad_id = self.tokenizer.speech_pad_id
|
||||||
|
else:
|
||||||
|
self.speech_pad_id = self.tokenizer.convert_tokens_to_ids("<|speech_pad|>")
|
||||||
|
|
||||||
|
if hasattr(self.tokenizer, 'pad_id'):
|
||||||
|
self.pad_id = self.tokenizer.pad_id
|
||||||
|
elif hasattr(self.tokenizer, 'pad_token_id'):
|
||||||
|
self.pad_id = self.tokenizer.pad_token_id
|
||||||
|
else:
|
||||||
|
self.pad_id = self.tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||||
|
"""
|
||||||
|
Load processor from a pretrained model path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path: Path to the pretrained model
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VibeVoiceASRProcessor: The loaded processor
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
||||||
|
|
||||||
|
# Try to load configuration
|
||||||
|
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config_file = cached_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
"preprocessor_config.json",
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load preprocessor_config.json: {e}")
|
||||||
|
logger.warning("Using default configuration")
|
||||||
|
|
||||||
|
# Extract parameters
|
||||||
|
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
||||||
|
target_sample_rate = config.get("target_sample_rate", 24000)
|
||||||
|
normalize_audio = config.get("normalize_audio", True)
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
|
||||||
|
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
|
||||||
|
|
||||||
|
if 'qwen' in language_model_pretrained_name.lower():
|
||||||
|
tokenizer = VibeVoiceASRTextTokenizerFast.from_pretrained(
|
||||||
|
language_model_pretrained_name,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}")
|
||||||
|
|
||||||
|
# Load audio processor
|
||||||
|
audio_processor = VibeVoiceTokenizerProcessor(
|
||||||
|
sampling_rate=target_sample_rate,
|
||||||
|
normalize_audio=normalize_audio,
|
||||||
|
target_dB_FS=config.get("target_dB_FS", -25),
|
||||||
|
eps=config.get("eps", 1e-6),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
audio_processor=audio_processor,
|
||||||
|
speech_tok_compress_ratio=speech_tok_compress_ratio,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
normalize_audio=normalize_audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
||||||
|
"""
|
||||||
|
Save processor configuration to a directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_directory: Directory to save the configuration
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
# Save processor configuration
|
||||||
|
processor_config = {
|
||||||
|
"processor_class": "VibeVoiceASRProcessor",
|
||||||
|
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
|
||||||
|
"target_sample_rate": self.target_sample_rate,
|
||||||
|
"normalize_audio": self.normalize_audio,
|
||||||
|
"target_dB_FS": -25,
|
||||||
|
"eps": 1e-6,
|
||||||
|
}
|
||||||
|
|
||||||
|
config_path = os.path.join(save_directory, "preprocessor_config.json")
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(processor_config, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Processor configuration saved in {config_path}")
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
audio: Optional[Union[str, np.ndarray, torch.Tensor, List[Union[str, np.ndarray, torch.Tensor]]]] = None,
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
padding: bool = True,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
truncation: bool = False,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
use_streaming: bool = True,
|
||||||
|
context_info: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> BatchEncoding:
|
||||||
|
"""
|
||||||
|
Process audio input for ASR model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio input(s). Can be:
|
||||||
|
- str: Path to audio file
|
||||||
|
- np.ndarray: Audio array
|
||||||
|
- torch.Tensor: Audio tensor
|
||||||
|
- List of the above for batch processing
|
||||||
|
sampling_rate: Sampling rate of input audio
|
||||||
|
return_tensors: Output format ('pt' for PyTorch, 'np' for NumPy)
|
||||||
|
padding: Whether to pad batch inputs
|
||||||
|
max_length: Maximum sequence length
|
||||||
|
truncation: Whether to truncate long sequences
|
||||||
|
add_generation_prompt: Whether to add generation prompt for inference
|
||||||
|
use_streaming: Whether to use streaming mode (True by default, auto False if <60s)
|
||||||
|
context_info: Optional context information (e.g., hotwords, metadata) to help transcription
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchEncoding with:
|
||||||
|
- input_ids: Token IDs for the model
|
||||||
|
- attention_mask: Attention mask
|
||||||
|
- acoustic_input_mask: Mask indicating speech token positions
|
||||||
|
- speech_tensors: Processed speech features
|
||||||
|
- speech_masks: Valid speech masks
|
||||||
|
- vae_tok_seqlens: Length of each speech segment in tokens
|
||||||
|
"""
|
||||||
|
if audio is None:
|
||||||
|
raise ValueError("Audio input is required for ASR processing")
|
||||||
|
|
||||||
|
# Handle single vs batch input
|
||||||
|
if isinstance(audio, list):
|
||||||
|
is_batched = True
|
||||||
|
audio_list = audio
|
||||||
|
else:
|
||||||
|
is_batched = False
|
||||||
|
audio_list = [audio]
|
||||||
|
|
||||||
|
# Process each audio input
|
||||||
|
all_encodings = []
|
||||||
|
for audio_input in audio_list:
|
||||||
|
encoding = self._process_single_audio(
|
||||||
|
audio_input,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
use_streaming=use_streaming,
|
||||||
|
context_info=context_info,
|
||||||
|
)
|
||||||
|
all_encodings.append(encoding)
|
||||||
|
|
||||||
|
# Combine into batch
|
||||||
|
batch_encoding = self._batch_encode(
|
||||||
|
all_encodings,
|
||||||
|
padding=padding,
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=truncation,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
return batch_encoding
|
||||||
|
|
||||||
|
def _process_single_audio(
|
||||||
|
self,
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
use_streaming: bool = True,
|
||||||
|
context_info: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process a single audio input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Single audio input
|
||||||
|
sampling_rate: Audio sampling rate
|
||||||
|
add_generation_prompt: Whether to add generation prompt
|
||||||
|
context_info: Optional context information (e.g., hotwords, metadata) to help transcription
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with processed tokens and audio features
|
||||||
|
"""
|
||||||
|
# Process audio through audio processor
|
||||||
|
if isinstance(audio, str):
|
||||||
|
# Load from file using ffmpeg for better format support
|
||||||
|
if HAS_FFMPEG_UTILS:
|
||||||
|
try:
|
||||||
|
audio_array, file_sr = load_audio_use_ffmpeg(audio, resample=False)
|
||||||
|
except Exception as e:
|
||||||
|
# Fall back to soundfile if ffmpeg fails
|
||||||
|
warnings.warn(f"ffmpeg loading failed, falling back to soundfile: {e}")
|
||||||
|
import soundfile as sf
|
||||||
|
audio_array, file_sr = sf.read(audio)
|
||||||
|
if audio_array.ndim > 1:
|
||||||
|
audio_array = audio_array.mean(axis=1) # Convert to mono
|
||||||
|
else:
|
||||||
|
import soundfile as sf
|
||||||
|
audio_array, file_sr = sf.read(audio)
|
||||||
|
if audio_array.ndim > 1:
|
||||||
|
audio_array = audio_array.mean(axis=1) # Convert to mono
|
||||||
|
|
||||||
|
# Resample if needed
|
||||||
|
if file_sr != self.target_sample_rate:
|
||||||
|
import librosa
|
||||||
|
audio_array = librosa.resample(
|
||||||
|
audio_array,
|
||||||
|
orig_sr=file_sr,
|
||||||
|
target_sr=self.target_sample_rate
|
||||||
|
)
|
||||||
|
elif isinstance(audio, torch.Tensor):
|
||||||
|
audio_array = audio.cpu().numpy()
|
||||||
|
if audio_array.ndim > 1:
|
||||||
|
audio_array = audio_array.squeeze()
|
||||||
|
else:
|
||||||
|
audio_array = np.array(audio, dtype=np.float32)
|
||||||
|
if audio_array.ndim > 1:
|
||||||
|
audio_array = audio_array.squeeze()
|
||||||
|
|
||||||
|
# Ensure float32
|
||||||
|
audio_array = audio_array.astype(np.float32)
|
||||||
|
|
||||||
|
# Normalize if needed
|
||||||
|
if self.normalize_audio and self.audio_normalizer:
|
||||||
|
audio_array = self.audio_normalizer(audio_array)
|
||||||
|
|
||||||
|
# Calculate audio duration
|
||||||
|
audio_duration = len(audio_array) / self.target_sample_rate
|
||||||
|
|
||||||
|
# Auto-disable streaming for short audio (<60s)
|
||||||
|
if use_streaming and audio_duration < 60.0:
|
||||||
|
use_streaming = False
|
||||||
|
|
||||||
|
# Calculate token length based on streaming mode
|
||||||
|
# Non-streaming: uses ceil (encoder adds extra_padding for stride alignment)
|
||||||
|
# Streaming: uses floor (segments processed independently, no global alignment)
|
||||||
|
# if use_streaming:
|
||||||
|
# vae_tok_len = len(audio_array) // self.speech_tok_compress_ratio
|
||||||
|
# else:
|
||||||
|
vae_tok_len = math.ceil(len(audio_array) / self.speech_tok_compress_ratio)
|
||||||
|
|
||||||
|
# Build token sequence following training format
|
||||||
|
# 1. System prompt - use apply_chat_template then encode like in training
|
||||||
|
system_prompt_text = self.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "system", "content": SYSTEM_PROMPT}],
|
||||||
|
tokenize=False
|
||||||
|
)
|
||||||
|
system_tokens = self.tokenizer.encode(system_prompt_text)
|
||||||
|
|
||||||
|
# 2. User input with speech tokens
|
||||||
|
# Build speech placeholder string
|
||||||
|
sp_start_token = self.tokenizer.convert_ids_to_tokens(self.speech_start_id)
|
||||||
|
sp_pad_token = self.tokenizer.convert_ids_to_tokens(self.speech_pad_id)
|
||||||
|
sp_end_token = self.tokenizer.convert_ids_to_tokens(self.speech_end_id)
|
||||||
|
|
||||||
|
# User suffix with audio duration info
|
||||||
|
show_keys = ['Start time', 'End time', 'Speaker ID', 'Content']
|
||||||
|
if context_info and context_info.strip():
|
||||||
|
user_suffix = f"This is a {audio_duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\nPlease transcribe it with these keys: " + ", ".join(show_keys)
|
||||||
|
else:
|
||||||
|
user_suffix = f"This is a {audio_duration:.2f} seconds audio, please transcribe it with these keys: " + ", ".join(show_keys)
|
||||||
|
|
||||||
|
user_input_string = ''.join(
|
||||||
|
[sp_start_token] + [sp_pad_token] * vae_tok_len + [sp_end_token]
|
||||||
|
) + '\n' + user_suffix
|
||||||
|
|
||||||
|
user_tokens = self.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": user_input_string}],
|
||||||
|
tokenize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine tokens
|
||||||
|
full_tokens = system_tokens + user_tokens
|
||||||
|
|
||||||
|
# Create acoustic input mask
|
||||||
|
acoustic_input_mask = [1 if token == self.speech_pad_id else 0 for token in full_tokens]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": full_tokens,
|
||||||
|
"acoustic_input_mask": acoustic_input_mask,
|
||||||
|
"speech": audio_array,
|
||||||
|
"vae_tok_len": vae_tok_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _batch_encode(
|
||||||
|
self,
|
||||||
|
encodings: List[Dict[str, Any]],
|
||||||
|
padding: bool = True,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
truncation: bool = False,
|
||||||
|
return_tensors: Optional[str] = None,
|
||||||
|
) -> BatchEncoding:
|
||||||
|
"""
|
||||||
|
Combine multiple encodings into a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encodings: List of encoded samples
|
||||||
|
padding: Whether to pad sequences
|
||||||
|
max_length: Maximum sequence length
|
||||||
|
truncation: Whether to truncate
|
||||||
|
return_tensors: Output format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BatchEncoding with batched data
|
||||||
|
"""
|
||||||
|
# Extract components
|
||||||
|
input_ids_list = [enc["input_ids"] for enc in encodings]
|
||||||
|
acoustic_masks_list = [enc["acoustic_input_mask"] for enc in encodings]
|
||||||
|
speech_list = [enc["speech"] for enc in encodings]
|
||||||
|
vae_tok_lens = [enc["vae_tok_len"] for enc in encodings]
|
||||||
|
|
||||||
|
# Determine max length for padding
|
||||||
|
if padding:
|
||||||
|
if max_length is not None:
|
||||||
|
target_length = max_length
|
||||||
|
else:
|
||||||
|
target_length = max(len(ids) for ids in input_ids_list)
|
||||||
|
|
||||||
|
# Pad sequences
|
||||||
|
padded_input_ids = []
|
||||||
|
padded_acoustic_masks = []
|
||||||
|
attention_masks = []
|
||||||
|
|
||||||
|
for input_ids, acoustic_mask in zip(input_ids_list, acoustic_masks_list):
|
||||||
|
# Truncate if needed
|
||||||
|
if truncation and len(input_ids) > target_length:
|
||||||
|
input_ids = input_ids[:target_length]
|
||||||
|
acoustic_mask = acoustic_mask[:target_length]
|
||||||
|
|
||||||
|
# Pad sequences to left (for autoregressive generation)
|
||||||
|
padding_length = target_length - len(input_ids)
|
||||||
|
padded_ids = [self.pad_id] * padding_length + input_ids
|
||||||
|
padded_acoustic = [0] * padding_length + acoustic_mask
|
||||||
|
attention_mask = [0] * padding_length + [1] * len(input_ids)
|
||||||
|
|
||||||
|
padded_input_ids.append(padded_ids)
|
||||||
|
padded_acoustic_masks.append(padded_acoustic)
|
||||||
|
attention_masks.append(attention_mask)
|
||||||
|
|
||||||
|
input_ids_list = padded_input_ids
|
||||||
|
acoustic_masks_list = padded_acoustic_masks
|
||||||
|
else:
|
||||||
|
attention_masks = [[1] * len(ids) for ids in input_ids_list]
|
||||||
|
|
||||||
|
# Process speech tensors - raw audio is 1D, so we keep it as is
|
||||||
|
max_speech_length = max(len(s) for s in speech_list)
|
||||||
|
padded_speeches = np.zeros((len(speech_list), max_speech_length), dtype=np.float32)
|
||||||
|
speech_masks = np.zeros((len(speech_list), max(vae_tok_lens)), dtype=bool)
|
||||||
|
|
||||||
|
for i, (speech, vae_len) in enumerate(zip(speech_list, vae_tok_lens)):
|
||||||
|
padded_speeches[i, :len(speech)] = speech
|
||||||
|
speech_masks[i, :vae_len] = True
|
||||||
|
|
||||||
|
# Create batch encoding
|
||||||
|
batch_encoding = BatchEncoding()
|
||||||
|
|
||||||
|
if return_tensors == "pt":
|
||||||
|
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
|
||||||
|
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
||||||
|
batch_encoding["acoustic_input_mask"] = torch.tensor(acoustic_masks_list, dtype=torch.bool)
|
||||||
|
batch_encoding["speech_tensors"] = torch.tensor(padded_speeches, dtype=torch.float32)
|
||||||
|
batch_encoding["speech_masks"] = torch.tensor(speech_masks, dtype=torch.bool)
|
||||||
|
# Note: vae_tok_seqlens and speech_type are not included as they are not model inputs
|
||||||
|
else:
|
||||||
|
batch_encoding["input_ids"] = input_ids_list if len(input_ids_list) > 1 else input_ids_list[0]
|
||||||
|
batch_encoding["attention_mask"] = attention_masks if len(attention_masks) > 1 else attention_masks[0]
|
||||||
|
batch_encoding["acoustic_input_mask"] = acoustic_masks_list if len(acoustic_masks_list) > 1 else acoustic_masks_list[0]
|
||||||
|
batch_encoding["speech_tensors"] = padded_speeches if len(padded_speeches) > 1 else padded_speeches[0]
|
||||||
|
batch_encoding["speech_masks"] = speech_masks if len(speech_masks) > 1 else speech_masks[0]
|
||||||
|
|
||||||
|
return batch_encoding
|
||||||
|
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Decode batch of token IDs to text.
|
||||||
|
Forwards to tokenizer's batch_decode method.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Decode token IDs to text.
|
||||||
|
Forwards to tokenizer's decode method.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
def post_process_transcription(self, text: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Post-process the generated transcription text to extract structured data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Generated text from the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries with transcription segments
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to parse as JSON
|
||||||
|
if "```json" in text:
|
||||||
|
# Extract JSON from markdown code block
|
||||||
|
json_start = text.find("```json") + 7
|
||||||
|
json_end = text.find("```", json_start)
|
||||||
|
json_str = text[json_start:json_end].strip()
|
||||||
|
else:
|
||||||
|
# Try to find JSON array or object
|
||||||
|
json_start = text.find("[")
|
||||||
|
if json_start == -1:
|
||||||
|
json_start = text.find("{")
|
||||||
|
if json_start != -1:
|
||||||
|
# Find matching closing bracket
|
||||||
|
bracket_count = 0
|
||||||
|
json_end = json_start
|
||||||
|
for i in range(json_start, len(text)):
|
||||||
|
if text[i] in "[{":
|
||||||
|
bracket_count += 1
|
||||||
|
elif text[i] in "]}":
|
||||||
|
bracket_count -= 1
|
||||||
|
if bracket_count == 0:
|
||||||
|
json_end = i + 1
|
||||||
|
break
|
||||||
|
json_str = text[json_start:json_end]
|
||||||
|
else:
|
||||||
|
json_str = text
|
||||||
|
|
||||||
|
# Parse JSON
|
||||||
|
result = json.loads(json_str)
|
||||||
|
|
||||||
|
# Ensure it's a list
|
||||||
|
if isinstance(result, dict):
|
||||||
|
result = [result]
|
||||||
|
|
||||||
|
# Validate and clean up the result
|
||||||
|
cleaned_result = []
|
||||||
|
for item in result:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
cleaned_item = {}
|
||||||
|
# Map keys to expected format
|
||||||
|
key_mapping = {
|
||||||
|
"Start time": "start_time",
|
||||||
|
"Start": "start_time",
|
||||||
|
"End time": "end_time",
|
||||||
|
"End": "end_time",
|
||||||
|
"Speaker ID": "speaker_id",
|
||||||
|
"Speaker": "speaker_id",
|
||||||
|
"Content": "text",
|
||||||
|
}
|
||||||
|
for key, mapped_key in key_mapping.items():
|
||||||
|
if key in item:
|
||||||
|
cleaned_item[mapped_key] = item[key]
|
||||||
|
|
||||||
|
if cleaned_item:
|
||||||
|
cleaned_result.append(cleaned_item)
|
||||||
|
|
||||||
|
return cleaned_result
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse JSON from transcription: {e}")
|
||||||
|
logger.debug(f"Raw text: {text}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error post-processing transcription: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
"""Return the list of inputs accepted by the model."""
|
||||||
|
return ["input_ids", "attention_mask", "acoustic_input_mask", "speech_tensors", "speech_masks"]
|
||||||
|
|
||||||
|
__all__ = ["VibeVoiceASRProcessor"]
|
||||||
@@ -13,80 +13,10 @@ import torch
|
|||||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from .audio_utils import AudioNormalizer
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AudioNormalizer:
|
|
||||||
"""
|
|
||||||
Audio normalization class for VibeVoice tokenizer.
|
|
||||||
|
|
||||||
This class provides audio normalization to ensure consistent input levels
|
|
||||||
for the VibeVoice tokenizer while maintaining audio quality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
|
|
||||||
"""
|
|
||||||
Initialize the audio normalizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_dB_FS (float): Target dB FS level for the audio. Default: -25
|
|
||||||
eps (float): Small value to avoid division by zero. Default: 1e-6
|
|
||||||
"""
|
|
||||||
self.target_dB_FS = target_dB_FS
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
|
|
||||||
"""
|
|
||||||
Adjust the audio to the target dB FS level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio (np.ndarray): Input audio signal
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (normalized_audio, rms, scalar)
|
|
||||||
"""
|
|
||||||
rms = np.sqrt(np.mean(audio**2))
|
|
||||||
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
|
|
||||||
normalized_audio = audio * scalar
|
|
||||||
return normalized_audio, rms, scalar
|
|
||||||
|
|
||||||
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
|
|
||||||
"""
|
|
||||||
Avoid clipping by scaling down if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio (np.ndarray): Input audio signal
|
|
||||||
scalar (float, optional): Explicit scaling factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (normalized_audio, scalar)
|
|
||||||
"""
|
|
||||||
if scalar is None:
|
|
||||||
max_val = np.max(np.abs(audio))
|
|
||||||
if max_val > 1.0:
|
|
||||||
scalar = max_val + self.eps
|
|
||||||
else:
|
|
||||||
scalar = 1.0
|
|
||||||
|
|
||||||
return audio / scalar, scalar
|
|
||||||
|
|
||||||
def __call__(self, audio: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Normalize the audio by adjusting to target dB FS and avoiding clipping.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio (np.ndarray): Input audio signal
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Normalized audio signal
|
|
||||||
"""
|
|
||||||
# First adjust to target dB FS
|
|
||||||
audio, _, _ = self.tailor_dB_FS(audio)
|
|
||||||
# Then avoid clipping
|
|
||||||
audio, _ = self.avoid_clipping(audio)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
|
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
|
||||||
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
|
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user