add VibeVoice-Realtime

This commit is contained in:
YaoyaoChang
2025-12-04 05:38:30 -08:00
parent e81395cf6d
commit fc83be5d92
39 changed files with 8190 additions and 6 deletions
Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

+7 -6
View File
@@ -1,6 +1,6 @@
<div align="center">
## 🎙️ VibeVoice: A Frontier Open-Source Voice AI
## 🎙️ VibeVoice: Frontier Open-Source Voice AI
[![Project Page](https://img.shields.io/badge/Project-Page-blue?logo=microsoft)](https://microsoft.github.io/VibeVoice)
[![Hugging Face](https://img.shields.io/badge/HuggingFace-Collection-orange?logo=huggingface)](https://huggingface.co/collections/microsoft/vibevoice-68a2ef24a875c44be47b034f)
[![Technical Report](https://img.shields.io/badge/Technical-Report-red?logo=adobeacrobatreader)](https://arxiv.org/pdf/2508.19205)
@@ -23,11 +23,12 @@
<img src="https://img.shields.io/badge/Status-New-brightgreen?style=flat" alt="New" />
<img src="https://img.shields.io/badge/Feature-Realtime_TTS-blue?style=flat&logo=soundcharts" alt="Realtime TTS" />
<strong>2025-12-03: 📣 We open-sourced <strong>VibeVoiceRealtime0.5B</strong>, a realtime texttospeech model that supports streaming text input.</strong>
<strong>2025-12-03: 📣 We open-sourced <a href="docs/vibevoice-realtime-0.5b.md"><strong>VibeVoiceRealtime0.5B</strong></a>, a realtime texttospeech model that supports streaming text input and robust long-form speech generation.</strong>
<br>
<a href="https://github.com/user-attachments/assets/c4fb9be1-e721-41c7-9260-5890b49c1a19" target="_blank">▶️ Watch demo video</a>
&nbsp;•&nbsp;
<a href="https://github.com/user-attachments/assets/9aa8ab3c-681d-4a02-b9ea-3f54ffd180b2" target="_blank">🎧 Listen to generated example</a>
https://github.com/user-attachments/assets/0901d274-f6ae-46ef-a0fd-3c4fba4f76dc
> (Launch your own realtime demo via the websocket example in [Usage](docs/vibevoice-realtime-0.5b.md#usage-1-launch-real-time-websocket-demo)).
</div>
@@ -41,7 +42,7 @@ VibeVoice is a novel framework designed for generating **expressive**, **long-fo
VibeVoice currently includes two model variants:
- **Long-form multi-speaker model**: Synthesizes conversational/single-speaker speech up to **90 minutes** with up to **4 distinct speakers**, surpassing the typical 12 speaker limits of many prior models.
- **Realtime streaming TTS model**: Produces initial audible speech in ~**300 ms** and supports **streaming text input** for single-speaker **realtime** speech generation; designed for low-latency generation.
- **[Realtime streaming TTS model](docs/vibevoice-realtime-0.5b.md)**: Produces initial audible speech in ~**300 ms** and supports **streaming text input** for single-speaker **real-time** speech generation; designed for low-latency generation.
A core innovation of VibeVoice is its use of continuous speech tokenizers (Acoustic and Semantic) operating at an ultra-low frame rate of 7.5 Hz. These tokenizers efficiently preserve audio fidelity while significantly boosting computational efficiency for processing long sequences. VibeVoice employs a [next-token diffusion](https://arxiv.org/abs/2412.08635) framework, leveraging a Large Language Model (LLM) to understand textual context and dialogue flow, and a diffusion head to generate high-fidelity acoustic details.
+314
View File
@@ -0,0 +1,314 @@
import argparse
import os
import re
import traceback
from typing import List, Tuple, Union, Dict, Any
import time
import torch
import copy
from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
class VoiceMapper:
"""Maps speaker names to voice file paths"""
def __init__(self):
self.setup_voice_presets()
# change name according to our preset voice file
new_dict = {}
for name, path in self.voice_presets.items():
if '_' in name:
name = name.split('_')[0]
if '-' in name:
name = name.split('-')[-1]
new_dict[name] = path
self.voice_presets.update(new_dict)
# print(list(self.voice_presets.keys()))
def setup_voice_presets(self):
"""Setup voice presets by scanning the voices directory."""
voices_dir = os.path.join(os.path.dirname(__file__), "voices/streaming_model")
# Check if voices directory exists
if not os.path.exists(voices_dir):
print(f"Warning: Voices directory not found at {voices_dir}")
self.voice_presets = {}
self.available_voices = {}
return
# Scan for all VOICE files in the voices directory
self.voice_presets = {}
# Get all .pt files in the voices directory
pt_files = [f for f in os.listdir(voices_dir)
if f.lower().endswith('.pt') and os.path.isfile(os.path.join(voices_dir, f))]
# Create dictionary with filename (without extension) as key
for pt_file in pt_files:
# Remove .pt extension to get the name
name = os.path.splitext(pt_file)[0]
# Create full path
full_path = os.path.join(voices_dir, pt_file)
self.voice_presets[name] = full_path
# Sort the voice presets alphabetically by name for better UI
self.voice_presets = dict(sorted(self.voice_presets.items()))
# Filter out voices that don't exist (this is now redundant but kept for safety)
self.available_voices = {
name: path for name, path in self.voice_presets.items()
if os.path.exists(path)
}
print(f"Found {len(self.available_voices)} voice files in {voices_dir}")
print(f"Available voices: {', '.join(self.available_voices.keys())}")
def get_voice_path(self, speaker_name: str) -> str:
"""Get voice file path for a given speaker name"""
# First try exact match
if speaker_name in self.voice_presets:
return self.voice_presets[speaker_name]
# Try partial matching (case insensitive)
speaker_lower = speaker_name.lower()
for preset_name, path in self.voice_presets.items():
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
return path
# Default to first voice if no match found
default_voice = list(self.voice_presets.values())[0]
print(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}")
return default_voice
def parse_args():
parser = argparse.ArgumentParser(description="VibeVoiceStreaming Processor TXT Input Test")
parser.add_argument(
"--model_path",
type=str,
default="microsoft/VibeVoice-Realtime-0.5B",
help="Path to the HuggingFace model directory",
)
parser.add_argument(
"--txt_path",
type=str,
default="demo/text_examples/1p_vibevoice.txt",
help="Path to the txt file containing the script",
)
parser.add_argument(
"--speaker_name",
type=str,
default="Wayne",
help="Single speaker name (e.g., --speaker_name Wayne)",
)
parser.add_argument(
"--output_dir",
type=str,
default="./outputs",
help="Directory to save output audio files",
)
parser.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
help="Device for inference: cuda | mps | cpu",
)
parser.add_argument(
"--cfg_scale",
type=float,
default=1.5,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)",
)
return parser.parse_args()
def main():
args = parse_args()
# Normalize potential 'mpx' typo to 'mps'
if args.device.lower() == "mpx":
print("Note: device 'mpx' detected, treating it as 'mps'.")
args.device = "mps"
# Validate mps availability if requested
if args.device == "mps" and not torch.backends.mps.is_available():
print("Warning: MPS not available. Falling back to CPU.")
args.device = "cpu"
print(f"Using device: {args.device}")
# Initialize voice mapper
voice_mapper = VoiceMapper()
# Check if txt file exists
if not os.path.exists(args.txt_path):
print(f"Error: txt file not found: {args.txt_path}")
return
# Read and parse txt file
print(f"Reading script from: {args.txt_path}")
with open(args.txt_path, 'r', encoding='utf-8') as f:
scripts = f.read().strip()
if not scripts:
print("Error: No valid scripts found in the txt file")
return
full_script = scripts.replace("", "'").replace('', '"').replace('', '"')
print(f"Loading processor & model from {args.model_path}")
processor = VibeVoiceStreamingProcessor.from_pretrained(args.model_path)
# Decide dtype & attention implementation
if args.device == "mps":
load_dtype = torch.float32 # MPS requires float32
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
elif args.device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
else: # cpu
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
# Load model with device-specific logic
try:
if args.device == "mps":
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl_primary,
device_map=None, # load then move
)
model.to("mps")
elif args.device == "cuda":
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
)
else: # cpu
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl_primary,
)
except Exception as e:
if attn_impl_primary == 'flash_attention_2':
print(f"[ERROR] : {type(e).__name__}: {e}")
print(traceback.format_exc())
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map=(args.device if args.device in ("cuda", "cpu") else None),
attn_implementation='sdpa'
)
if args.device == "mps":
model.to("mps")
else:
raise e
model.eval()
model.set_ddpm_inference_steps(num_steps=5)
if hasattr(model.model, 'language_model'):
print(f"Language model attention: {model.model.language_model.config._attn_implementation}")
target_device = args.device if args.device != "cpu" else "cpu"
voice_sample = voice_mapper.get_voice_path(args.speaker_name)
all_prefilled_outputs = torch.load(voice_sample, map_location=target_device, weights_only=False)
# Prepare inputs for the model
inputs = processor.process_input_with_cached_prompt(
text=full_script,
cached_prompt=all_prefilled_outputs,
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
# Move tensors to target device
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target_device)
print(f"Starting generation with cfg_scale: {args.cfg_scale}")
# Generate audio
start_time = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=args.cfg_scale,
tokenizer=processor.tokenizer,
generation_config={'do_sample': False},
verbose=True,
all_prefilled_outputs=copy.deepcopy(all_prefilled_outputs) if all_prefilled_outputs is not None else None,
)
generation_time = time.time() - start_time
print(f"Generation time: {generation_time:.2f} seconds")
# Calculate audio duration and additional metrics
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
# Assuming 24kHz sample rate (common for speech synthesis)
sample_rate = 24000
audio_samples = outputs.speech_outputs[0].shape[-1] if len(outputs.speech_outputs[0].shape) > 0 else len(outputs.speech_outputs[0])
audio_duration = audio_samples / sample_rate
rtf = generation_time / audio_duration if audio_duration > 0 else float('inf')
print(f"Generated audio duration: {audio_duration:.2f} seconds")
print(f"RTF (Real Time Factor): {rtf:.2f}x")
else:
print("No audio output generated")
# Calculate token metrics
input_tokens = inputs['tts_text_ids'].shape[1] # Number of input tokens
output_tokens = outputs.sequences.shape[1] # Total tokens (input + generated)
generated_tokens = output_tokens - input_tokens - all_prefilled_outputs['tts_lm']['last_hidden_state'].size(1)
print(f"Prefilling text tokens: {input_tokens}")
print(f"Generated speech tokens: {generated_tokens}")
print(f"Total tokens: {output_tokens}")
# Save output (processor handles device internally)
txt_filename = os.path.splitext(os.path.basename(args.txt_path))[0]
output_path = os.path.join(args.output_dir, f"{txt_filename}_generated.wav")
os.makedirs(args.output_dir, exist_ok=True)
processor.save_audio(
outputs.speech_outputs[0], # First (and only) batch item
output_path=output_path,
)
print(f"Saved output to {output_path}")
# Print summary
print("\n" + "="*50)
print("GENERATION SUMMARY")
print("="*50)
print(f"Input file: {args.txt_path}")
print(f"Output file: {output_path}")
print(f"Speaker names: {args.speaker_name}")
print(f"Prefilling text tokens: {input_tokens}")
print(f"Generated speech tokens: {generated_tokens}")
print(f"Total tokens: {output_tokens}")
print(f"Generation time: {generation_time:.2f} seconds")
print(f"Audio duration: {audio_duration:.2f} seconds")
print(f"RTF (Real Time Factor): {rtf:.2f}x")
print("="*50)
if __name__ == "__main__":
main()
+2
View File
@@ -0,0 +1,2 @@
Generating long-form, multi-speaker conversational audio like podcasts poses significant challenges for traditional Text-to-Speech (TTS) systems, particularly in scalability, speaker consistency, and natural turn-taking. This report presents VibeVoice, a novel model designed to synthesize long-form speech with multiple speakers by employing the next-token diffusion framework, a unified method for modeling continuous data by autoregressively generating latent vectors via diffusion.
A core component of our approach is the continuous speech tokenizers operating at an ultra-low frame rate of 7.5. This tokenizer effectively preserves audio fidelity while significantly boosting computational efficiency for processing long sequences. This enables VibeVoice to synthesize long-form speech for up to 90 minutes (in a 64K context window length) with up to 4 speakers, capturing the authentic conversational "vibe" and surpassing all known open-source and closed-source dialogue models (for example, Gemini 2.5 Pro Preview TTS). Code and checkpoint are available now.
+1
View File
@@ -0,0 +1 @@
VibeVoice is a novel framework designed for generating expressive, long-form, multi-speaker conversational audio, such as podcasts, from text. It addresses significant challenges in traditional Text-to-Speech (TTS) systems, particularly in scalability, speaker consistency, and natural turn-taking. A core innovation of VibeVoice is its use of continuous speech tokenizers operating at an ultra-low frame rate of 7.5 Hz. These tokenizers efficiently preserve audio fidelity while significantly boosting computational efficiency for processing long sequences. VibeVoice employs a next-token diffusion framework, leveraging a Large Language Model to understand textual context and dialogue flow, and a diffusion head to generate high-fidelity acoustic details. The model can synthesize speech up to 90 minutes long with up to 4 distinct speakers, surpassing the typical 1-2 speaker limits of many prior models.
+167
View File
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d1785adb",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/VibeVoice/blob/main/demo/vibevoice_realtime_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "WvIaUJD2y0yU",
"metadata": {
"id": "WvIaUJD2y0yU"
},
"source": [
"# VibeVoice-Realtime Colab — T4 Quickstart\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "e8fTKYGx7DZk",
"metadata": {
"id": "e8fTKYGx7DZk"
},
"source": [
"## Step 1: Setup Environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4wxJ6QHM-ZOb",
"metadata": {
"id": "4wxJ6QHM-ZOb"
},
"outputs": [],
"source": [
"# Check for T4 GPU\n",
"import torch\n",
"if torch.cuda.is_available() and \"T4\" in torch.cuda.get_device_name(0):\n",
" print(\"✅ T4 GPU detected\")\n",
"else:\n",
" print(\"\"\"\n",
" ⚠️ WARNING: T4 GPU not detected\n",
"\n",
" The recommended runtime for this Colab notebook is \"T4 GPU\".\n",
"\n",
" To change the runtime type:\n",
"\n",
" 1. Click on \"Runtime\" in the top navigation menu\n",
" 2. Click on \"Change runtime type\"\n",
" 3. Select \"T4 GPU\"\n",
" 4. Click \"OK\" if a \"Disconnect and delete runtime\" window appears\n",
" 5. Click on \"Save\"\n",
"\n",
" \"\"\")\n",
"\n",
"# Clone the VibeVoice repository\n",
"![ -d /content/VibeVoice ] || git clone --quiet --branch main --depth 1 https://github.com/microsoft/VibeVoice.git /content/VibeVoice\n",
"print(\"✅ Cloned VibeVoice repository\")\n",
"\n",
"# Install project dependencies\n",
"!uv pip --quiet install --system -e /content/VibeVoice\n",
"!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared && chmod +x cloudflared\n",
"print(\"✅ Installed dependencies\")\n",
"\n",
"# Download model\n",
"!HF_XET_HIGH_PERFORMANCE=1 hf download microsoft/VibeVoice-Realtime-0.5B --quiet --local-dir /content/models/VibeVoice-Realtime-0.5B > /dev/null\n",
"print(\"✅ Downloaded model: microsoft/VibeVoice-Realtime-0.5B\")\n"
]
},
{
"cell_type": "markdown",
"id": "pgKlV7153Ifi",
"metadata": {
"id": "pgKlV7153Ifi"
},
"source": [
"## Step 2: Launch VibeVoice-Realtime Demo"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "Yc1N9EHswFxA",
"metadata": {
"id": "Yc1N9EHswFxA"
},
"outputs": [],
"source": [
"import subprocess, re, time, threading\n",
"\n",
"srv = subprocess.Popen(\n",
" \"python /content/VibeVoice/demo/vibevoice_realtime_demo.py --model_path /content/models/VibeVoice-Realtime-0.5B --port 8000\",\n",
" shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True,\n",
")\n",
"cf = subprocess.Popen(\n",
" \"./cloudflared tunnel --url http://localhost:8000 --no-autoupdate\",\n",
" shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True,\n",
")\n",
"\n",
"public_url = None\n",
"server_ready = False\n",
"url_pattern = re.compile(r\"(https://[a-z0-9-]+\\.trycloudflare\\.com)\")\n",
"\n",
"def read_srv():\n",
" global server_ready\n",
" for ln in srv.stdout:\n",
" print(ln.strip())\n",
" if \"Uvicorn running on\" in ln:\n",
" server_ready = True\n",
"\n",
"def read_cf():\n",
" global public_url\n",
" for ln in cf.stdout:\n",
" m = url_pattern.search(ln)\n",
" if m:\n",
" public_url = m.group(1)\n",
" break\n",
"\n",
"threading.Thread(target=read_srv, daemon=True).start()\n",
"threading.Thread(target=read_cf, daemon=True).start()\n",
"\n",
"\n",
"while True:\n",
" if server_ready and public_url:\n",
" print(f\"✅ Public URL: {public_url}\\n\");\n",
" public_url = None\n",
" time.sleep(0.25)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"include_colab_link": true,
"machine_shape": "hm",
"name": "VibeVoice_Colab.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+17
View File
@@ -0,0 +1,17 @@
import argparse, os, uvicorn
def main():
p = argparse.ArgumentParser()
p.add_argument("--port", type=int, default=3000)
p.add_argument("--model_path", type=str, default="default_model")
p.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mpx", "mps"])
p.add_argument("--reload", action="store_true", help="Reload the model or not")
args = p.parse_args()
os.environ["MODEL_PATH"] = args.model_path
os.environ["MODEL_DEVICE"] = args.device
uvicorn.run("web.app:app", host="0.0.0.0", port=args.port, reload=args.reload)
if __name__ == "__main__":
main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+506
View File
@@ -0,0 +1,506 @@
import datetime
import builtins
import asyncio
import json
import os
import threading
import traceback
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, cast
import numpy as np
import torch
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from starlette.websockets import WebSocketDisconnect, WebSocketState
from vibevoice.modular.modeling_vibevoice_streaming_inference import (
VibeVoiceStreamingForConditionalGenerationInference,
)
from vibevoice.processor.vibevoice_streaming_processor import (
VibeVoiceStreamingProcessor,
)
from vibevoice.modular.streamer import AudioStreamer
import copy
BASE = Path(__file__).parent
SAMPLE_RATE = 24_000
def get_timestamp():
timestamp = datetime.datetime.utcnow().replace(
tzinfo=datetime.timezone.utc
).astimezone(
datetime.timezone(datetime.timedelta(hours=8))
).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
return timestamp
class StreamingTTSService:
def __init__(
self,
model_path: str,
device: str = "cuda",
inference_steps: int = 5,
) -> None:
self.model_path = Path(model_path)
self.inference_steps = inference_steps
self.sample_rate = SAMPLE_RATE
self.processor: Optional[VibeVoiceStreamingProcessor] = None
self.model: Optional[VibeVoiceStreamingForConditionalGenerationInference] = None
self.voice_presets: Dict[str, Path] = {}
self.default_voice_key: Optional[str] = None
self._voice_cache: Dict[str, Tuple[object, Path, str]] = {}
if device == "mpx":
print("Note: device 'mpx' detected, treating it as 'mps'.")
device = "mps"
if device == "mps" and not torch.backends.mps.is_available():
print("Warning: MPS not available. Falling back to CPU.")
device = "cpu"
self.device = device
self._torch_device = torch.device(device)
def load(self) -> None:
print(f"[startup] Loading processor from {self.model_path}")
self.processor = VibeVoiceStreamingProcessor.from_pretrained(str(self.model_path))
# Decide dtype & attention
if self.device == "mps":
load_dtype = torch.float32
device_map = None
attn_impl_primary = "sdpa"
elif self.device == "cuda":
load_dtype = torch.bfloat16
device_map = 'cuda'
attn_impl_primary = "flash_attention_2"
else:
load_dtype = torch.float32
device_map = 'cpu'
attn_impl_primary = "sdpa"
print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
# Load model
try:
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
str(self.model_path),
torch_dtype=load_dtype,
device_map=device_map,
attn_implementation=attn_impl_primary,
)
if self.device == "mps":
self.model.to("mps")
except Exception as e:
if attn_impl_primary == 'flash_attention_2':
print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
str(self.model_path),
torch_dtype=load_dtype,
device_map=self.device,
attn_implementation='sdpa',
)
print("Load model with SDPA successfully ")
else:
raise e
self.model.eval()
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config,
algorithm_type="sde-dpmsolver++",
beta_schedule="squaredcos_cap_v2",
)
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
self.voice_presets = self._load_voice_presets()
preset_name = os.environ.get("VOICE_PRESET")
self.default_voice_key = self._determine_voice_key(preset_name)
self._ensure_voice_cached(self.default_voice_key)
def _load_voice_presets(self) -> Dict[str, Path]:
voices_dir = BASE.parent / "voices" / "streaming_model"
if not voices_dir.exists():
raise RuntimeError(f"Voices directory not found: {voices_dir}")
presets: Dict[str, Path] = {}
for pt_path in voices_dir.glob("*.pt"):
presets[pt_path.stem] = pt_path
if not presets:
raise RuntimeError(f"No voice preset (.pt) files found in {voices_dir}")
print(f"[startup] Found {len(presets)} voice presets")
return dict(sorted(presets.items()))
def _determine_voice_key(self, name: Optional[str]) -> str:
if name and name in self.voice_presets:
return name
default_key = "en-WHTest_man"
if default_key in self.voice_presets:
return default_key
first_key = next(iter(self.voice_presets))
print(f"[startup] Using fallback voice preset: {first_key}")
return first_key
def _ensure_voice_cached(self, key: str) -> Tuple[object, Path, str]:
if key not in self.voice_presets:
raise RuntimeError(f"Voice preset {key!r} not found")
if key not in self._voice_cache:
preset_path = self.voice_presets[key]
print(f"[startup] Loading voice preset {key} from {preset_path}")
print(f"[startup] Loading prefilled prompt from {preset_path}")
prefilled_outputs = torch.load(
preset_path,
map_location=self._torch_device,
weights_only=False,
)
self._voice_cache[key] = prefilled_outputs
return self._voice_cache[key]
def _get_voice_resources(self, requested_key: Optional[str]) -> Tuple[str, object, Path, str]:
key = requested_key if requested_key and requested_key in self.voice_presets else self.default_voice_key
if key is None:
key = next(iter(self.voice_presets))
self.default_voice_key = key
prefilled_outputs = self._ensure_voice_cached(key)
return key, prefilled_outputs
def _prepare_inputs(self, text: str, prefilled_outputs: object):
if not self.processor or not self.model:
raise RuntimeError("StreamingTTSService not initialized")
processor_kwargs = {
"text": text.strip(),
"cached_prompt": prefilled_outputs,
"padding": True,
"return_tensors": "pt",
"return_attention_mask": True,
}
processed = self.processor.process_input_with_cached_prompt(**processor_kwargs)
prepared = {
key: value.to(self._torch_device) if hasattr(value, "to") else value
for key, value in processed.items()
}
return prepared
def _run_generation(
self,
inputs,
audio_streamer: AudioStreamer,
errors,
cfg_scale: float,
do_sample: bool,
temperature: float,
top_p: float,
refresh_negative: bool,
prefilled_outputs,
stop_event: threading.Event,
) -> None:
try:
self.model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=cfg_scale,
tokenizer=self.processor.tokenizer,
generation_config={
"do_sample": do_sample,
"temperature": temperature if do_sample else 1.0,
"top_p": top_p if do_sample else 1.0,
},
audio_streamer=audio_streamer,
stop_check_fn=stop_event.is_set,
verbose=False,
refresh_negative=refresh_negative,
all_prefilled_outputs=copy.deepcopy(prefilled_outputs),
)
except Exception as exc: # pragma: no cover - diagnostic logging
errors.append(exc)
traceback.print_exc()
audio_streamer.end()
def stream(
self,
text: str,
cfg_scale: float = 1.5,
do_sample: bool = False,
temperature: float = 0.9,
top_p: float = 0.9,
refresh_negative: bool = True,
inference_steps: Optional[int] = None,
voice_key: Optional[str] = None,
log_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
stop_event: Optional[threading.Event] = None,
) -> Iterator[np.ndarray]:
if not text.strip():
return
text = text.replace("", "'")
selected_voice, prefilled_outputs = self._get_voice_resources(voice_key)
def emit(event: str, **payload: Any) -> None:
if log_callback:
try:
log_callback(event, **payload)
except Exception as exc:
print(f"[log_callback] Error while emitting {event}: {exc}")
steps_to_use = self.inference_steps
if inference_steps is not None:
try:
parsed_steps = int(inference_steps)
if parsed_steps > 0:
steps_to_use = parsed_steps
except (TypeError, ValueError):
pass
if self.model:
self.model.set_ddpm_inference_steps(num_steps=steps_to_use)
self.inference_steps = steps_to_use
inputs = self._prepare_inputs(text, prefilled_outputs)
audio_streamer = AudioStreamer(batch_size=1, stop_signal=None, timeout=None)
errors: list = []
stop_signal = stop_event or threading.Event()
thread = threading.Thread(
target=self._run_generation,
kwargs={
"inputs": inputs,
"audio_streamer": audio_streamer,
"errors": errors,
"cfg_scale": cfg_scale,
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"refresh_negative": refresh_negative,
"prefilled_outputs": prefilled_outputs,
"stop_event": stop_signal,
},
daemon=True,
)
thread.start()
generated_samples = 0
try:
stream = audio_streamer.get_stream(0)
for audio_chunk in stream:
if torch.is_tensor(audio_chunk):
audio_chunk = audio_chunk.detach().cpu().to(torch.float32).numpy()
else:
audio_chunk = np.asarray(audio_chunk, dtype=np.float32)
if audio_chunk.ndim > 1:
audio_chunk = audio_chunk.reshape(-1)
peak = np.max(np.abs(audio_chunk)) if audio_chunk.size else 0.0
if peak > 1.0:
audio_chunk = audio_chunk / peak
generated_samples += int(audio_chunk.size)
emit(
"model_progress",
generated_sec=generated_samples / self.sample_rate,
chunk_sec=audio_chunk.size / self.sample_rate,
)
chunk_to_yield = audio_chunk.astype(np.float32, copy=False)
yield chunk_to_yield
finally:
stop_signal.set()
audio_streamer.end()
thread.join()
if errors:
emit("generation_error", message=str(errors[0]))
raise errors[0]
def chunk_to_pcm16(self, chunk: np.ndarray) -> bytes:
chunk = np.clip(chunk, -1.0, 1.0)
pcm = (chunk * 32767.0).astype(np.int16)
return pcm.tobytes()
app = FastAPI()
@app.on_event("startup")
async def _startup() -> None:
model_path = os.environ.get("MODEL_PATH")
if not model_path:
raise RuntimeError("MODEL_PATH not set in environment")
device = os.environ.get("MODEL_DEVICE", "cuda")
service = StreamingTTSService(
model_path=model_path,
device=device
)
service.load()
app.state.tts_service = service
app.state.model_path = model_path
app.state.device = device
app.state.websocket_lock = asyncio.Lock()
print("[startup] Model ready.")
def streaming_tts(text: str, **kwargs) -> Iterator[np.ndarray]:
service: StreamingTTSService = app.state.tts_service
yield from service.stream(text, **kwargs)
@app.websocket("/stream")
async def websocket_stream(ws: WebSocket) -> None:
await ws.accept()
text = ws.query_params.get("text", "")
print(f"Client connected, text={text!r}")
cfg_param = ws.query_params.get("cfg")
steps_param = ws.query_params.get("steps")
voice_param = ws.query_params.get("voice")
try:
cfg_scale = float(cfg_param) if cfg_param is not None else 1.5
except ValueError:
cfg_scale = 1.5
if cfg_scale <= 0:
cfg_scale = 1.5
try:
inference_steps = int(steps_param) if steps_param is not None else None
if inference_steps is not None and inference_steps <= 0:
inference_steps = None
except ValueError:
inference_steps = None
service: StreamingTTSService = app.state.tts_service
lock: asyncio.Lock = app.state.websocket_lock
if lock.locked():
busy_message = {
"type": "log",
"event": "backend_busy",
"data": {"message": "Please wait for the other requests to complete."},
"timestamp": get_timestamp(),
}
print("Please wait for the other requests to complete.")
try:
await ws.send_text(json.dumps(busy_message))
except Exception:
pass
await ws.close(code=1013, reason="Service busy")
return
acquired = False
try:
await lock.acquire()
acquired = True
log_queue: "Queue[Dict[str, Any]]" = Queue()
def enqueue_log(event: str, **data: Any) -> None:
log_queue.put({"event": event, "data": data})
async def flush_logs() -> None:
while True:
try:
entry = log_queue.get_nowait()
except Empty:
break
message = {
"type": "log",
"event": entry.get("event"),
"data": entry.get("data", {}),
"timestamp": get_timestamp(),
}
try:
await ws.send_text(json.dumps(message))
except Exception:
break
enqueue_log(
"backend_request_received",
text_length=len(text or ""),
cfg_scale=cfg_scale,
inference_steps=inference_steps,
voice=voice_param,
)
stop_signal = threading.Event()
iterator = streaming_tts(
text,
cfg_scale=cfg_scale,
inference_steps=inference_steps,
voice_key=voice_param,
log_callback=enqueue_log,
stop_event=stop_signal,
)
sentinel = object()
first_ws_send_logged = False
await flush_logs()
try:
while ws.client_state == WebSocketState.CONNECTED:
await flush_logs()
chunk = await asyncio.to_thread(next, iterator, sentinel)
if chunk is sentinel:
break
chunk = cast(np.ndarray, chunk)
payload = service.chunk_to_pcm16(chunk)
await ws.send_bytes(payload)
if not first_ws_send_logged:
first_ws_send_logged = True
enqueue_log("backend_first_chunk_sent")
await flush_logs()
except WebSocketDisconnect:
print("Client disconnected (WebSocketDisconnect)")
enqueue_log("client_disconnected")
stop_signal.set()
finally:
stop_signal.set()
enqueue_log("backend_stream_complete")
await flush_logs()
try:
iterator_close = getattr(iterator, "close", None)
if callable(iterator_close):
iterator_close()
except Exception:
pass
# clear the log queue
while not log_queue.empty():
try:
log_queue.get_nowait()
except Empty:
break
if ws.client_state == WebSocketState.CONNECTED:
await ws.close()
print("WS handler exit")
finally:
if acquired:
lock.release()
@app.get("/")
def index():
return FileResponse(BASE / "index.html")
@app.get("/config")
def get_config():
service: StreamingTTSService = app.state.tts_service
voices = sorted(service.voice_presets.keys())
return {
"voices": voices,
"default_voice": service.default_voice_key,
}
+1017
View File
File diff suppressed because it is too large Load Diff
+137
View File
@@ -0,0 +1,137 @@
<div align="center">
## 🎙️ VibeVoice-Realtime: Real-time LongForm TexttoSpeech with Streaming Input
[![Hugging Face](https://img.shields.io/badge/HuggingFace-Collection-orange?logo=huggingface)](https://huggingface.co/microsoft/VibeVoice-Realtime-0.5B)
[![Colab](https://img.shields.io/badge/Run-Colab-orange?logo=googlecolab)](https://colab.research.google.com/github/microsoft/VibeVoice/blob/main/demo/vibevoice_realtime_colab.ipynb)
</div>
VibeVoice-Realtime is a **lightweight realtime** text-to-speech model supporting **streaming text input** and **robust long-form speech generation**. It can be used to build real-time TTS services, narrate live data streams, and let different LLMs start speaking from their very first tokens (plug in your preferred model) long before a full answer is generated. It produces initial audible speech in **~300 milliseconds** (hardware dependent).
<div align="center">
| Model | Context Length | Generation Length | Weight |
|-------|----------------|----------|----------|
| VibeVoice-Realtime-0.5B | 8K | ~10 min | [HF link](https://huggingface.co/microsoft/VibeVoice-Realtime-0.5B) |
</div>
The model uses an interleaved, windowed design: it incrementally encodes incoming text chunks while, in parallel, continuing diffusion-based acoustic latent generation from prior context. Unlike the full multi-speaker long-form variants, this streaming model removes the semantic tokenizer and relies solely on an efficient acoustic tokenizer operating at an ultra-low frame rate (7.5 Hz).
<div align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="../Figures/VibeVoice_logo_white.png">
<img src="../Figures/VibeVoice_Realtime.png" alt="VibeVoice Realtime Overview" width="800" />
</picture>
<br>
<em>Overview of VibeVoice Realtime Model.</em>
</div>
Key features:
- Parameter size: 0.5B (deployment-friendly)
- Real-time TTS (~300 milliseconds first audible latency)
- Streaming text input
- Robust long-form speech generation
This real-time variant supports only a single speaker. For multispeaker conversational speech generation, please use other VibeVoice models (longform multispeaker variants). The model is currently intended for English speech only; other languages may produce unpredictable results.
To mitigate deepfake risks and ensure low latency for the first speech chunk, voice prompts are provided in an embedded format. For users requiring voice customization, please reach out to our team. We will also be expanding the range of available speakers.
### 📋 TODO
- [ ] Add more voices (expand available speakers/voice timbres)
- [ ] Implement streaming text input function to feed new tokens while audio is still being generated
- [ ] Merge models into official HuggingFace's `transformers` repository
### 🎵 Demo Examples
<div align="center" id="generated-example-audio-vibevoice-realtime">
https://github.com/user-attachments/assets/9aa8ab3c-681d-4a02-b9ea-3f54ffd180b2
</div>
## Results
The model achieves satisfactory performance on short-sentence benchmarks, while the model is more focused on longform speech generation.
### Zero-shot TTS performance on LibriSpeech test-clean set
| Model | WER (%) ↓ | Speaker Similarity ↑ |
|:--------------------|:---------:|:----------------:|
| VALL-E 2 | 2.40 | 0.643 |
| Voicebox | 1.90 | 0.662 |
| MELLE | 2.10 | 0.625 |
| **VibeVoice-Realtime-0.5B** | 2.00 | 0.695 |
### Zero-shot TTS performance on SEED test-en set
| Model | WER (%) ↓ | Speaker Similarity ↑ |
|:--------------------|:---------:|:----------------:|
| MaskGCT | 2.62 | 0.714 |
| Seed-TTS | 2.25 | 0.762 |
| FireRedTTS | 3.82 | 0.460 |
| SparkTTS | 1.98 | 0.584 |
| CosyVoice2 | 2.57 | 0.652 |
| **VibeVoice-Realtime-0.5B** | 2.05 | 0.633 |
## Installation
We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment.
1. Launch docker
```bash
# NVIDIA PyTorch Container 24.07 / 24.10 / 24.12 verified.
# Later versions are also compatible.
sudo docker run --privileged --net=host --ipc=host --ulimit memlock=-1:-1 --ulimit stack=-1:-1 --gpus all --rm -it nvcr.io/nvidia/pytorch:24.07-py3
## If flash attention is not included in your docker environment, you need to install it manually
## Refer to https://github.com/Dao-AILab/flash-attention for installation instructions
# pip install flash-attn --no-build-isolation
```
2. Install from github
```bash
git clone https://github.com/microsoft/VibeVoice.git
cd VibeVoice/
pip install -e .
```
## Usages
### Usage 1: Launch real-time websocket demo
Note: NVIDIA T4 / Mac M4 Pro achieve realtime in our tests; other devices with weaker inference capability may require further testing and speed optimizations.
Due to network latency, the time when audio playback is heard may exceed the ~300 ms first speech chunk generation latency.
```bash
python demo/vibevoice_realtime_demo.py --model_path microsoft/VibeVoice-Realtime-0.5B
```
Tip: You can also deploy and run the real-time demo on [Colab](https://colab.research.google.com/github/microsoft/VibeVoice/blob/main/demo/vibevoice_realtime_colab.ipynb).
### Usage 2: Inference from files directly
```bash
# We provide some LLM generated example scripts under demo/text_examples/ for demo
python demo/realtime_model_inference_from_file.py --model_path microsoft/VibeVoice-Realtime-0.5B --txt_path demo/text_examples/1p_vibevoice.txt --speaker_name Carter
```
## Risks and limitations
While efforts have been made to optimize it through various techniques, it may still produce outputs that are unexpected, biased, or inaccurate. VibeVoice inherits any biases, errors, or omissions produced by its base model (specifically, Qwen2.5 0.5b in this release).
Potential for Deepfakes and Disinformation: High-quality synthetic speech can be misused to create convincing fake audio content for impersonation, fraud, or spreading disinformation. Users must ensure transcripts are reliable, check content accuracy, and avoid using generated content in misleading ways. Users are expected to use the generated content and to deploy the models in a lawful manner, in full compliance with all applicable laws and regulations in the relevant jurisdictions. It is best practice to disclose the use of AI when sharing AI-generated content.
English only: Transcripts in languages other than English may result in unexpected audio outputs.
Non-Speech Audio: The model focuses solely on speech synthesis and does not handle background noise, music, or other sound effects.
Code, formulas, and special symbols: The model does not currently support reading code, mathematical formulas, or uncommon symbols. Please preprocess input text to remove or normalize such content to avoid unpredictable results.
Very short inputs: When the input text is extremely short (three words or fewer), the models stability may degrade.
We do not recommend using VibeVoice in commercial or real-world applications without further testing and development. This model is intended for research and development purposes only. Please use responsibly.
+45
View File
@@ -0,0 +1,45 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "vibevoice"
version = "0.0.1"
authors = [
{ name="vibevoice team", email="vibepod@microsoft.com" },
]
description = "A model for speech generation with an AR + diffusion architecture."
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
# "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"torch",
"accelerate==1.6.0",
"transformers==4.51.3", # we develop this project on transformers==4.51.3, later version may not be compatible
"llvmlite>=0.40.0",
"numba>=0.57.0",
"diffusers",
"tqdm",
"numpy",
"scipy",
"librosa",
"ml-collections",
"absl-py",
"gradio",
"av",
"aiortc",
"uvicorn[standard]",
"fastapi"
]
[project.urls]
"Homepage" = "https://github.com/microsoft/VibeVoice"
"Bug Tracker" = "https://github.com/microsoft/VibeVoice/issues"
[tool.setuptools.packages.find]
where = ["."]
View File
+112
View File
@@ -0,0 +1,112 @@
{
"_attn_implementation_autoset": true,
"acoustic_vae_dim": 64,
"acoustic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"decoder_depths": null,
"decoder_n_filters": 32,
"decoder_ratios": [
8,
5,
5,
4,
2,
2
],
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0.5,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_acoustic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "gaussian",
"vae_dim": 64,
"weight_init_value": 0.01
},
"decoder_config": {
"attention_dropout": 0.0,
"hidden_act": "silu",
"hidden_size": 1536,
"initializer_range": 0.02,
"intermediate_size": 8960,
"max_position_embeddings": 65536,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 12,
"num_hidden_layers": 28,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
},
"diffusion_head_config": {
"ddpm_batch_mul": 4,
"ddpm_beta_schedule": "cosine",
"ddpm_num_inference_steps": 20,
"ddpm_num_steps": 1000,
"diffusion_type": "ddpm",
"head_ffn_ratio": 3.0,
"head_layers": 4,
"hidden_size": 1536,
"latent_size": 64,
"model_type": "vibepod_diffusion_head",
"prediction_type": "v_prediction",
"rms_norm_eps": 1e-05,
"speech_vae_dim": 64
},
"model_type": "vibepod",
"semantic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_semantic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "none",
"vae_dim": 128,
"weight_init_value": 0.01
},
"semantic_vae_dim": 128,
"torch_dtype": "bfloat16"
}
+113
View File
@@ -0,0 +1,113 @@
{
"_attn_implementation_autoset": true,
"acoustic_vae_dim": 64,
"acoustic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"decoder_depths": null,
"decoder_n_filters": 32,
"decoder_ratios": [
8,
5,
5,
4,
2,
2
],
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0.5,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_acoustic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "gaussian",
"vae_dim": 64,
"weight_init_value": 0.01
},
"decoder_config": {
"attention_dropout": 0.0,
"hidden_act": "silu",
"hidden_size": 3584,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.1",
"use_cache": true,
"use_mrope": false,
"use_sliding_window": false,
"vocab_size": 152064
},
"diffusion_head_config": {
"ddpm_batch_mul": 4,
"ddpm_beta_schedule": "cosine",
"ddpm_num_inference_steps": 20,
"ddpm_num_steps": 1000,
"diffusion_type": "ddpm",
"head_ffn_ratio": 3.0,
"head_layers": 4,
"hidden_size": 3584,
"latent_size": 64,
"model_type": "vibepod_diffusion_head",
"prediction_type": "v_prediction",
"rms_norm_eps": 1e-05,
"speech_vae_dim": 64
},
"model_type": "vibepod",
"semantic_tokenizer_config": {
"causal": true,
"channels": 1,
"conv_bias": true,
"conv_norm": "none",
"corpus_normalize": 0.0,
"disable_last_norm": true,
"encoder_depths": "3-3-3-3-3-3-8",
"encoder_n_filters": 32,
"encoder_ratios": [
8,
5,
5,
4,
2,
2
],
"fix_std": 0,
"layer_scale_init_value": 1e-06,
"layernorm": "RMSNorm",
"layernorm_elementwise_affine": true,
"layernorm_eps": 1e-05,
"mixer_layer": "depthwise_conv",
"model_type": "vibepod_semantic_tokenizer",
"pad_mode": "constant",
"std_dist_type": "none",
"vae_dim": 128,
"weight_init_value": 0.01
},
"semantic_vae_dim": 128,
"torch_dtype": "bfloat16"
}
View File
@@ -0,0 +1,248 @@
""" VibeVoice_AcousticTokenizer model configuration"""
from typing import Dict, List, Optional, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_acoustic_tokenizer"
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0.5,
std_dist_type: str = 'gaussian',
# common
mixer_layer: str = 'depthwise_conv',
conv_norm: str = 'none',
pad_mode: str = 'constant',
disable_last_norm: bool = True,
layernorm: str = 'RMSNorm',
layernorm_eps: float = 1e-5,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-6,
weight_init_value: float = 1e-2,
# encoder specific
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
encoder_depths: str = "3-3-3-3-3-3-8",
# decoder specific
decoder_n_filters: int = 32,
decoder_ratios: Optional[List[int]] = None, # if None, same as encoder
decoder_depths: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
# common parameters
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
# encoder specific parameters
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
# decoder specific parameters
self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios
self.decoder_n_filters = decoder_n_filters
self.decoder_depths = decoder_depths
class VibeVoiceSemanticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_semantic_tokenizer"
def __init__(
self,
channels: int = 1,
corpus_normalize: float = 0.0,
causal: bool = True,
vae_dim: int = 64,
fix_std: float = 0,
std_dist_type: str = 'none',
# common
mixer_layer: str = 'depthwise_conv',
conv_norm: str = 'none',
pad_mode: str = 'constant',
disable_last_norm: bool = True,
layernorm: str = 'RMSNorm',
layernorm_eps: float = 1e-5,
layernorm_elementwise_affine: bool = True,
conv_bias: bool = True,
layer_scale_init_value: float = 1e-6,
weight_init_value: float = 1e-2,
# encoder specific
encoder_n_filters: int = 32,
encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2],
encoder_depths: str = "3-3-3-3-3-3-8",
**kwargs
):
super().__init__(**kwargs)
self.channels = channels
self.corpus_normalize = corpus_normalize
self.causal = causal
self.vae_dim = vae_dim
self.fix_std = fix_std
self.std_dist_type = std_dist_type
# common parameters
self.conv_norm = conv_norm
self.pad_mode = pad_mode
self.layernorm_eps = layernorm_eps
self.disable_last_norm = disable_last_norm
self.layernorm = layernorm
self.layernorm_elementwise_affine = layernorm_elementwise_affine
self.conv_bias = conv_bias
self.layer_scale_init_value = layer_scale_init_value
self.weight_init_value = weight_init_value
self.mixer_layer = mixer_layer
# encoder specific parameters
self.encoder_n_filters = encoder_n_filters
self.encoder_ratios = encoder_ratios
self.encoder_depths = encoder_depths
class VibeVoiceDiffusionHeadConfig(PretrainedConfig):
model_type = "vibevoice_diffusion_head"
def __init__(
self,
hidden_size=768,
head_layers=4,
head_ffn_ratio=3.0,
rms_norm_eps=1e-5,
latent_size=64,
speech_vae_dim=None,
prediction_type="v_prediction",
diffusion_type="ddpm",
ddpm_num_steps=1000,
ddpm_num_inference_steps=20,
ddpm_beta_schedule="cosine",
ddpm_batch_mul=4,
**kwargs
):
self.hidden_size = hidden_size
self.head_layers = head_layers
self.head_ffn_ratio = head_ffn_ratio
self.rms_norm_eps = rms_norm_eps
self.latent_size = latent_size
self.speech_vae_dim = speech_vae_dim
self.prediction_type = prediction_type
self.diffusion_type = diffusion_type
self.ddpm_num_steps = ddpm_num_steps
self.ddpm_num_inference_steps = ddpm_num_inference_steps
self.ddpm_beta_schedule = ddpm_beta_schedule
self.ddpm_batch_mul = ddpm_batch_mul
super().__init__(**kwargs)
class VibeVoiceConfig(PretrainedConfig):
model_type = "vibevoice"
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
"semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
}
# keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
semantic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
**kwargs
):
# kwargs["_attn_implementation"] = "flash_attention_2"
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
# If an instance of the config class is provided
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if semantic_tokenizer_config is None:
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]()
elif isinstance(semantic_tokenizer_config, dict):
semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer"
self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config)
elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig):
# If an instance of the config class is provided
self.semantic_tokenizer_config = semantic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
# If a dictionary is provided, instantiate the config class with it
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
if decoder_config.get("model_type", '') == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
elif isinstance(decoder_config, (Qwen2Config,)):
# If an instance of the config class is provided
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
# If an instance of the config class is provided
self.diffusion_head_config = diffusion_head_config
# other parameters
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128)
super().__init__(**kwargs)
__all__ = [
"VibeVoiceAcousticTokenizerConfig",
"VibeVoiceSemanticTokenizerConfig",
"VibeVoiceDiffusionHeadConfig",
"VibeVoiceConfig"
]
@@ -0,0 +1,85 @@
""" VibeVoice Streaming model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig
logger = logging.get_logger(__name__)
class VibeVoiceStreamingConfig(PretrainedConfig):
model_type = "vibevoice_streaming"
is_composition = True
sub_configs = {
"acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig,
"decoder_config": Qwen2Config,
"diffusion_head_config": VibeVoiceDiffusionHeadConfig,
}
# keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen2`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,
acoustic_tokenizer_config=None,
decoder_config=None,
diffusion_head_config=None,
tts_backbone_num_hidden_layers=20,
**kwargs
):
# kwargs["_attn_implementation"] = "flash_attention_2"
kwargs["_attn_implementation_autoset"] = False
if acoustic_tokenizer_config is None:
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]()
elif isinstance(acoustic_tokenizer_config, dict):
acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer"
self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config)
elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig):
# If an instance of the config class is provided
self.acoustic_tokenizer_config = acoustic_tokenizer_config
if decoder_config is None:
self.decoder_config = self.sub_configs["decoder_config"]()
elif isinstance(decoder_config, dict):
# If a dictionary is provided, instantiate the config class with it
# self.decoder_config = self.sub_configs["decoder_config"](**decoder_config)
if decoder_config.get("model_type", '') == "qwen2":
self.decoder_config = Qwen2Config(**decoder_config)
else:
raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}")
elif isinstance(decoder_config, (Qwen2Config,)):
# If an instance of the config class is provided
self.decoder_config = decoder_config
if diffusion_head_config is None:
self.diffusion_head_config = self.sub_configs["diffusion_head_config"]()
elif isinstance(diffusion_head_config, dict):
diffusion_head_config["model_type"] = "vibevoice_diffusion_head"
self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config)
elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig):
# If an instance of the config class is provided
self.diffusion_head_config = diffusion_head_config
# other parameters
self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64)
# The decoder of the model is divided into two components. The lower Transformer layers are only used for encoding text, while the upper Transformer layers are used for encoding text and generating speech. `tts_backbone_num_hidden_layers` indicates the number of upper layers used for TTS.
self.tts_backbone_num_hidden_layers = tts_backbone_num_hidden_layers
super().__init__(**kwargs)
__all__ = [
"VibeVoiceStreamingConfig"
]
@@ -0,0 +1,190 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice_streaming import VibeVoiceStreamingConfig
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
class BinaryClassifier(nn.Module):
def __init__(self, hidden_size):
super(BinaryClassifier, self).__init__()
self.fc1 = nn.Linear(hidden_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class SpeechConnector(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
self.fc2 = nn.Linear(output_dim, output_dim)
def forward(self, features, **kwargs):
x = self.fc1(features)
x = self.norm(x)
x = self.fc2(x)
return x
# @auto_docstring
class VibeVoiceStreamingPreTrainedModel(PreTrainedModel):
config_class = VibeVoiceStreamingConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if isinstance(module, VibeVoiceDiffusionHead):
module.initialize_weights()
return
# Use the language model's initializer_range if available
if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'):
std = self.config.language_model_config.initializer_range
elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'):
std = self.config.decoder_config.initializer_range
else:
std = 0.02 # Default value
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
# @auto_docstring
class VibeVoiceStreamingModel(VibeVoiceStreamingPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if hasattr(config, 'torch_dtype') and config.torch_dtype is not None:
if isinstance(config.torch_dtype, str):
dtype = getattr(torch, config.torch_dtype)
else:
dtype = config.torch_dtype
else:
dtype = torch.float32
# Initialize Qwen2 model for language modeling.
# The lower Transformer layers are only used for encoding text, while the upper Transformer layers are used for encoding text and generating speech.
# To keep the code clean, we constructs two language models.
# The final norm layer of the first language_model is set to identity and will not be used in inference.
lm_config = copy.deepcopy(config.decoder_config)
lm_backbone_num_hidden_layers = getattr(lm_config, 'num_hidden_layers', 24) - config.tts_backbone_num_hidden_layers
lm_config.num_hidden_layers = lm_backbone_num_hidden_layers
self.language_model = AutoModel.from_config(lm_config)
self.language_model.norm = nn.Identity()
# We only need the Transformer layers here. Note that embed_tokens in tts_language_model is unused
tts_lm_config = copy.deepcopy(lm_config)
tts_lm_config.num_hidden_layers = config.tts_backbone_num_hidden_layers
self.tts_language_model = AutoModel.from_config(tts_lm_config)
# Marks the text that needs to be spoken by the TTS model.
self.tts_input_types = nn.Embedding(num_embeddings=2, embedding_dim=config.decoder_config.hidden_size)
# Initialize speech components if needed
self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype)
self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype)
# Register scaling factors as buffers - use 1D tensors for FSDP compatibility
self.register_buffer('speech_scaling_factor', torch.tensor(float('nan')))
self.register_buffer('speech_bias_factor', torch.tensor(float('nan')))
# Initialize prediction head for speech generation
self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype)
# Initialize noise scheduler
self.noise_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=config.diffusion_head_config.ddpm_num_steps,
beta_schedule=config.diffusion_head_config.ddpm_beta_schedule,
prediction_type=config.diffusion_head_config.prediction_type
)
def get_input_embeddings(self):
if hasattr(self.language_model, 'embed_tokens'):
# If the language model has an embed_tokens attribute, return it
return self.language_model.embed_tokens
for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed
if attr.orig_name == 'embed_tokens.weight':
return getattr(self.language_model, name)
assert False, 'should not arrive here'
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_speech_tokenizers(self, acoustic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.acoustic_tokenizer = acoustic_tokenizer
# Reset the encoder to evaluation mode
if self.acoustic_tokenizer is not None:
self.acoustic_tokenizer.eval()
def forward(self, *args, **kwargs):
"""
Intentionally not implemented.
This streaming model is split into two explicit submodules:
- `language_model` for plain text processing (lower layers).
- `tts_language_model` for TTS-related upper layers.
We deliberately avoid a unified `forward` to prevent accidental calls
that mix responsibilities.
To use the model:
- Call `self.language_model(...)` for text embeddings / hidden states.
- Call `self.tts_language_model(...)` for the TTS portion.
- Use the dedicated inference class for combined generation logic.
"""
raise RuntimeError(
"VibeVoiceStreamingModel.forward is intentionally disabled. "
"Use `model.language_model(...)` or `model.tts_language_model(...)` instead."
)
AutoModel.register(VibeVoiceStreamingConfig, VibeVoiceStreamingModel)
__all__ = [
"VibeVoiceStreamingPreTrainedModel",
"VibeVoiceStreamingModel",
]
@@ -0,0 +1,726 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import logging
from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache
from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead
from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler
from .configuration_vibevoice_streaming import VibeVoiceStreamingConfig
from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast
from .modeling_vibevoice_streaming import VibeVoiceStreamingPreTrainedModel, VibeVoiceStreamingModel, BinaryClassifier
from .streamer import AudioStreamer, AsyncAudioStreamer
logger = logging.get_logger(__name__)
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"]
TTS_TEXT_WINDOW_SIZE = 5
TTS_SPEECH_WINDOW_SIZE = 6
def _update_model_kwargs_for_generation(
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
num_new_tokens: int = 1,
) -> Dict[str, Any]:
"""
Update model_kwargs after adding new tokens.
Mainly for the case num_new_tokens > 1 (e.g. a whole text window):
- past_key_values: take from current outputs
- attention_mask: append num_new_tokens ones
- cache_position: advance by creating a range for all new positions
"""
# update past_key_values keeping its naming used in model code
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
)
model_kwargs["cache_position"] = torch.arange(model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1).to(model_kwargs["cache_position"].device)
return model_kwargs
@dataclass
class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast):
logits: Optional[torch.FloatTensor] = None
@dataclass
class VibeVoiceGenerationOutput(ModelOutput):
"""
Output type for VibeVoice generation.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences.
speech_outputs (`List[torch.FloatTensor]`, *optional*):
List of generated speech waveforms or latents for each speech segment.
"""
sequences: torch.LongTensor = None
speech_outputs: Optional[List[torch.FloatTensor]] = None
reach_max_step_sample: Optional[torch.BoolTensor] = None
class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
# Initialize the base model
self.model = VibeVoiceStreamingModel(config)
# TTS generation EOS classifier
self.tts_eos_classifier = BinaryClassifier(config.decoder_config.hidden_size)
# inference configuration
self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps
# Initialize weights and apply final processing
self.post_init()
@property
def noise_scheduler(self):
return self.model.noise_scheduler
@property
def prediction_head(self):
return self.model.prediction_head
@property
def speech_scaling_factor(self):
return self.model.speech_scaling_factor
@property
def speech_bias_factor(self):
return self.model.speech_bias_factor
@property
def acoustic_tokenizer(self):
return self.model.acoustic_tokenizer
@property
def acoustic_connector(self):
return self.model.acoustic_connector
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
"""
# Tie lm_head.weight to language_model.embed_tokens.weight
if not getattr(self.config, 'tie_word_embeddings', False):
return
if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'):
self.lm_head.weight = self.model.language_model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
"""
This model does not define an `lm_head` (vocabulary projection).
"""
return None
def set_output_embeddings(self, new_embeddings):
"""
No-op because there is no `lm_head`. Provided only to satisfy optional API calls.
To enable, first create `self.lm_head` then allow assignment.
"""
raise RuntimeError("Output embeddings (lm_head) are not defined for this model. "
"Create one before calling set_output_embeddings if needed.")
def set_speech_tokenizers(self, acoustic_tokenizer=None):
"""Set the speech tokenizers used for encoding and decoding speech."""
self.model.set_speech_tokenizers(acoustic_tokenizer)
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
# @can_return_tuple
def forward_lm(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Single pass of the base text LM.
- Builds embeddings if `inputs_embeds` not provided.
- Uses (and returns) `past_key_values` when `use_cache=True`.
- No loss / no lm_head / no speech logic.
Args:
input_ids: (B, S) token ids.
attention_mask: (B, S) mask.
past_key_values: cache from previous steps.
cache_position: positions for cached tokens.
labels: unsupported (will raise).
Returns:
BaseModelOutputWithPast with `last_hidden_state` and `past_key_values`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get embeddings
if inputs_embeds is None:
inputs_embeds = self.model.get_input_embeddings()(input_ids)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
if labels is not None:
raise NotImplementedError("Loss computation is not implemented in this version.")
return BaseModelOutputWithPast(
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
# @can_return_tuple
def forward_tts_lm(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
lm_last_hidden_state: Optional[torch.FloatTensor] = None,
tts_text_masks: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]:
"""
Single pass of the TTS LM.
- Overwrites tail embeddings with `lm_last_hidden_state`.
- Adds type embedding via `tts_text_masks` (1=text, 0=speech).
- Predicts EOS from last hidden state (binary classifier).
- No loss / no full acoustic decoding here.
Args:
input_ids: (B, S) token ids.
attention_mask: (B, S) mask.
lm_last_hidden_state: (B, K, H) hidden states to splice into the tail.
tts_text_masks: (B, 1) mask marking current position as text(1)/speech(0).
past_key_values: cache from previous TTS steps.
cache_position: positions for cached tokens.
labels: unsupported (will raise).
Returns:
VibeVoiceCausalLMOutputWithPast with `logits` (EOS), `last_hidden_state`, `past_key_values`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Get embeddings
if inputs_embeds is None:
# Will be replaced with lm_last_hidden_state
inputs_embeds = self.model.get_input_embeddings()(input_ids)
# Replace the last part of inputs_embeds with lm_last_hidden_state
start_idx = inputs_embeds.shape[1] - lm_last_hidden_state.shape[1]
inputs_embeds[:, start_idx:, :] = lm_last_hidden_state
# Adds type embedding via `tts_text_masks`.
inputs_embeds = inputs_embeds + self.model.tts_input_types(tts_text_masks.long())
outputs = self.model.tts_language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
logits = self.tts_eos_classifier(hidden_states[:, -1, :])
if labels is not None:
raise NotImplementedError("Loss computation is not implemented in this version.")
return VibeVoiceCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
last_hidden_state=hidden_states,
attentions=outputs.attentions,
)
def forward(self, *args, **kwargs):
"""
Unified forward is intentionally disabled.
Reasons:
1. The inference pipeline is staged: base text LM, then TTS LM, plus streaming & diffusion handled in `generate`.
2. A monolithic call would hide required sequencing (prefill, window stepping, speech diffusion sampling).
Use instead:
- self.forward_lm(...) for a base text LM step (prefill or incremental).
- self.forward_tts_lm(...) for a single TTS LM step (needs LM hidden states).
- self.generate(...) for full streaming (text + speech + diffusion + audio assembly).
Raises:
RuntimeError: Always (by design).
"""
raise RuntimeError(
"Unified forward is disabled. Use `forward_lm`, `forward_tts_lm`, or `generate` instead."
)
def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs):
if generation_config is None:
generation_config = GenerationConfig(
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id
)
else:
generation_config = GenerationConfig(
**generation_config,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id
)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config,
True,
speech_start_id=tokenizer.speech_start_id,
speech_end_id=tokenizer.speech_end_id,
speech_diffusion_id=tokenizer.speech_diffusion_id,
**kwargs
)
generation_config.speech_start_id = tokenizer.speech_start_id
generation_config.speech_end_id = tokenizer.speech_end_id
generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
device = self.device
self._prepare_special_tokens(generation_config, True, device=device)
generation_config.use_cache = True
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = inputs_tensor.to(self.device)
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):
model_kwargs[k] = v.to(device=device)
if return_processors:
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=LogitsProcessorList(),
device=inputs_tensor.device,
model_kwargs=model_kwargs,
)
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList())
return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria
else:
return generation_config, model_kwargs, input_ids
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
speech_tensors: Optional[torch.FloatTensor] = None,
speech_masks: Optional[torch.BoolTensor] = None,
speech_input_mask: Optional[torch.BoolTensor] = None,
tts_text_ids: Optional[torch.LongTensor] = None,
return_speech: bool = True,
cfg_scale: float = 1.0,
stop_check_fn: Optional[Callable[[], bool]] = None,
**kwargs,
) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]:
"""
Text is fed in small windows (dynamic slicing of `tts_text_ids`), which enables streaming text input: you dont need the full text upfront. After each text window, a loop samples several speech latents (diffusion). The interleaved text encoding + speech generation enables streaming text input and realtime speech output.
The function only supports batch size = 1 currently.
- Windowed text prefill → incremental LM + TTS LM updates.
- Interleave speech token diffusion sampling (`sample_speech_tokens`).
- Stops on EOS (binary classifier) or max length / external `stop_check_fn`.
- Returns final token `sequences` and (optionally) concatenated speech audio.
Args (selected):
tts_text_ids: Full text tokens to stream in windows.
audio_streamer: If provided, emits audio chunks during generation.
cfg_scale: Classifier-free guidance scale for speech diffusion.
return_speech: If False, skips audio decode concatenation.
stop_check_fn: External early-stop hook (returns True to halt).
Returns:
VibeVoiceGenerationOutput with:
- sequences: final token ids
- speech_outputs: list of concatenated audio tensors (or None)
- reach_max_step_sample: flags for samples stopped by max length
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None)
neg_text_input_id = tokenizer.convert_tokens_to_ids("<|image_pad|>")
tts_lm_input_ids = kwargs.pop("tts_lm_input_ids", None)
tts_lm_attention_mask = kwargs.pop("tts_lm_attention_mask", None)
# all_prefilled_outputs: cached prefilled prompt outputs for lm, tts_lm, neg_lm, neg_tts_lm
all_prefilled_outputs = kwargs.pop("all_prefilled_outputs", None)
tts_text_ids = tts_text_ids.to(self.device)
if kwargs.get('max_new_tokens', None) is None:
kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - tts_lm_input_ids.shape[-1]
generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs(
generation_config, inputs, tokenizer, return_processors=True, **kwargs
)
negative_kwargs = {
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), neg_text_input_id, dtype=torch.long, device=kwargs['input_ids'].device),
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
'max_new_tokens': kwargs.get('max_new_tokens', 100)
}
negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **negative_kwargs
)
tts_lm_kwargs = {
'input_ids': tts_lm_input_ids,
'attention_mask': tts_lm_attention_mask,
'max_new_tokens': kwargs.get('max_new_tokens', 100)
}
tts_lm_generation_config, tts_lm_model_kwargs, tts_lm_input_ids = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_kwargs
)
tts_lm_negative_kwargs = {
'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), neg_text_input_id, dtype=torch.long, device=kwargs['input_ids'].device),
'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device),
'max_new_tokens': kwargs.get('max_new_tokens', 100)
}
tts_lm_negative_generation_config, tts_lm_negative_model_kwargs, tts_lm_negative_input_ids = self._build_generate_config_model_kwargs(
None, None, tokenizer, return_processors=False, **tts_lm_negative_kwargs
)
acoustic_cache = VibeVoiceTokenizerStreamingCache()
batch_size = input_ids.shape[0]
assert batch_size == 1, "Currently only supports batch size == 1"
device = input_ids.device
finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device)
verbose = kwargs.get("verbose", False)
# Initialize audio chunks storage for each sample
audio_chunks = [[] for _ in range(batch_size)]
tts_text_window_index = 0
reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device)
first_text_window_size = TTS_TEXT_WINDOW_SIZE if tts_text_ids.shape[1] >= TTS_TEXT_WINDOW_SIZE else tts_text_ids.shape[1]
outputs = all_prefilled_outputs["lm"]
tts_lm_outputs = all_prefilled_outputs["tts_lm"]
negative_outputs = all_prefilled_outputs["neg_lm"]
tts_lm_negative_outputs = all_prefilled_outputs["neg_tts_lm"]
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=first_text_window_size,
)
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=first_text_window_size,
)
negative_model_kwargs = self._update_model_kwargs_for_generation(
negative_outputs, negative_model_kwargs, is_encoder_decoder=False,
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False,
)
step = tts_lm_input_ids.shape[1]
total_generated_speech_tokens = 0
total_prefilled_text_tokens = 0
if kwargs.get("show_progress_bar", True):
progress_bar = tqdm(
total=tts_lm_generation_config.max_length,
desc=f"Prefilled {step} tokens, current step ({step} / {tts_lm_generation_config.max_length})",
initial=step,
leave=False
)
else:
progress_bar = None
while True:
# Check for external stop signal
if stop_check_fn is not None and stop_check_fn():
if verbose:
print(f"Generation stopped externally at step {step + 1}")
# End the audio streamer if it exists
if audio_streamer is not None:
audio_streamer.end()
break
# # Check if audio_streamer has been ended (stopped externally)
# if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'):
# if any(audio_streamer.finished_flags):
# if verbose:
# print(f"Audio generation stopped externally at step {step + 1}")
# break
if finished_tags.all():
if hasattr(progress_bar, 'set_description'):
progress_bar.set_description("Generation complete")
break
cur_input_tts_text_ids = tts_text_ids[:, tts_text_window_index*TTS_TEXT_WINDOW_SIZE:(tts_text_window_index+1)*TTS_TEXT_WINDOW_SIZE]
next_text_window_size = tts_text_ids[:, (tts_text_window_index+1)*TTS_TEXT_WINDOW_SIZE:(tts_text_window_index+2)*TTS_TEXT_WINDOW_SIZE].shape[1]
tts_text_window_index += 1
if cur_input_tts_text_ids.shape[1] > 0:
input_ids = torch.cat([input_ids, cur_input_tts_text_ids], dim=-1)
tts_lm_input_ids = torch.cat([tts_lm_input_ids, cur_input_tts_text_ids], dim=-1)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(f"Reached maximum generation length {generation_config.max_length}, stopped it.")
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
step += cur_input_tts_text_ids.shape[1]
total_prefilled_text_tokens += cur_input_tts_text_ids.shape[1]
if progress_bar is not None:
progress_bar.update(cur_input_tts_text_ids.shape[1])
progress_bar.set_description(f"Prefilled {total_prefilled_text_tokens} text tokens, generated {total_generated_speech_tokens} speech tokens, current step ({step} / {tts_lm_generation_config.max_length})")
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# Forward pass through the model
outputs = self.forward_lm(
**model_inputs, return_dict=True, output_attentions=False, output_hidden_states=False,
)
model_kwargs = _update_model_kwargs_for_generation(
outputs, model_kwargs, num_new_tokens=next_text_window_size,
)
tts_lm_model_inputs = self.prepare_inputs_for_generation(tts_lm_input_ids, **tts_lm_model_kwargs)
tts_lm_additional_inputs = {
"tts_text_masks": torch.ones_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": outputs.last_hidden_state,
}
# Forward pass through the model
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs, **tts_lm_additional_inputs, return_dict=True, output_attentions=False, output_hidden_states=False,
)
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False,
)
diffusion_indices = torch.LongTensor([0])
for cur_speech_index in range(TTS_SPEECH_WINDOW_SIZE):
positive_condition = tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]
negative_condition = tts_lm_negative_outputs.last_hidden_state[diffusion_indices, -1, :]
speech_latent = self.sample_speech_tokens(
positive_condition,
negative_condition,
cfg_scale=cfg_scale,
).unsqueeze(1)
# Decode acoustic latent to audio using acoustic streaming cache
scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device)
audio_chunk = self.model.acoustic_tokenizer.decode(
scaled_latent.to(self.model.acoustic_tokenizer.device),
cache=acoustic_cache, # Use acoustic-specific cache
sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device),
use_cache=True,
debug=False
)
# Store audio chunks for each sample
for i, sample_idx in enumerate(diffusion_indices):
idx = sample_idx.item()
# Only append audio chunk if the sample is not finished
if not finished_tags[idx]:
audio_chunks[idx].append(audio_chunk[i])
# Add streaming support here
if audio_streamer is not None:
# Stream the audio chunks immediately
audio_streamer.put(audio_chunk, diffusion_indices)
acoustic_embed = self.model.acoustic_connector(speech_latent)
tts_lm_input_ids = torch.cat([tts_lm_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
break
step += 1
total_generated_speech_tokens += 1
if progress_bar is not None:
progress_bar.update(1)
progress_bar.set_description(f"Prefilled {total_prefilled_text_tokens} text tokens, generated {total_generated_speech_tokens} speech tokens, current step ({step} / {tts_lm_generation_config.max_length})")
tts_lm_model_inputs = self.prepare_inputs_for_generation(tts_lm_input_ids, **tts_lm_model_kwargs)
tts_lm_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
# Forward pass through the model
tts_lm_outputs = self.forward_tts_lm(
**tts_lm_model_inputs, **tts_lm_additional_inputs, return_dict=True, output_attentions=False, output_hidden_states=False,
)
if cur_speech_index == TTS_SPEECH_WINDOW_SIZE - 1 and next_text_window_size > 0:
tts_lm_model_kwargs = _update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, num_new_tokens=next_text_window_size,
)
else:
tts_lm_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_outputs, tts_lm_model_kwargs, is_encoder_decoder=False,
)
tts_lm_negative_input_ids = torch.cat([tts_lm_negative_input_ids, torch.ones_like(tts_lm_input_ids[:, -1:])], dim=-1)
tts_lm_negative_model_inputs = self.prepare_inputs_for_generation(tts_lm_negative_input_ids, **tts_lm_negative_model_kwargs)
# Forward negative pass through the model
tts_lm_negative_additional_inputs = {
"tts_text_masks": torch.zeros_like(tts_lm_negative_input_ids[:, -1:]),
"lm_last_hidden_state": acoustic_embed,
}
tts_lm_negative_outputs = self.forward_tts_lm(
**tts_lm_negative_model_inputs, **tts_lm_negative_additional_inputs, return_dict=True, output_attentions=False, output_hidden_states=False,
)
tts_lm_negative_model_kwargs = self._update_model_kwargs_for_generation(
tts_lm_negative_outputs, tts_lm_negative_model_kwargs, is_encoder_decoder=False,
)
tts_eos_logits = torch.sigmoid(self.tts_eos_classifier(tts_lm_outputs.last_hidden_state[diffusion_indices, -1, :]))
if tts_eos_logits[0].item() > 0.5:
# If EOS token is predicted, we can stop generation for this sample
finished_tags[diffusion_indices] = True
if audio_streamer is not None:
audio_streamer.end(diffusion_indices)
if tts_lm_input_ids.shape[1] > tts_lm_generation_config.max_length:
if verbose:
print(f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it.")
reached_samples = torch.arange(batch_size, device=device)[~finished_tags]
if reached_samples.numel() > 0:
reach_max_step_sample[reached_samples] = True
break
if audio_streamer is not None:
audio_streamer.end()
# Concatenate audio chunks for each sample
final_audio_outputs = []
for sample_chunks in audio_chunks:
if sample_chunks:
# Concatenate all chunks along the time dimension (assumed to be the last dimension)
concatenated_audio = torch.cat(sample_chunks, dim=-1)
final_audio_outputs.append(concatenated_audio)
else:
# If no audio was generated for this sample, append None
final_audio_outputs.append(None)
if reach_max_step_sample is not None and reach_max_step_sample.any():
print(f"Reached maximum generation length {tts_lm_generation_config.max_length}, stopped it.")
return VibeVoiceGenerationOutput(
sequences=tts_lm_input_ids,
speech_outputs=final_audio_outputs if return_speech else None,
reach_max_step_sample=reach_max_step_sample,
)
@torch.no_grad()
def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0):
self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps)
condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device)
speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition)
for t in self.model.noise_scheduler.timesteps:
half = speech[: len(speech) // 2]
combined = torch.cat([half, half], dim=0)
eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample
return speech[: len(speech) // 2]
AutoModelForCausalLM.register(VibeVoiceStreamingConfig, VibeVoiceStreamingForConditionalGenerationInference)
__all__ = [
"VibeVoiceStreamingForConditionalGenerationInference",
]
@@ -0,0 +1,287 @@
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
# from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.activations import ACT2FN
from transformers.utils import logging
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
def modulate(x, shift, scale):
"""Apply modulation to input tensor."""
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
Args:
hidden_size (`int`): Size of the output embedding
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(hidden_size, hidden_size, bias=False),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim (`int`): The dimension of the output.
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
Returns:
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(t.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForwardNetwork(nn.Module):
"""
Standard feed-forward network with SwiGLU activation.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
"""
def __init__(
self,
embed_dim,
ffn_dim,
):
super().__init__()
self.embed_dim = embed_dim
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
# SwiGLU activation
# gate = F.silu(gate)
gate = self.act_fn(gate)
return self.down_proj(gate * up)
class HeadLayer(nn.Module):
"""
A layer in the diffusion head.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
cond_dim (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(
self,
embed_dim,
ffn_dim,
cond_dim,
norm_eps=1e-5,
):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
self.ffn_dim = ffn_dim
self.ffn = FeedForwardNetwork(
self.embed_dim,
self.ffn_dim,
)
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
)
def forward(self, x, c):
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
return x
class FinalLayer(nn.Module):
"""
Final layer in the diffusion head.
Args:
hidden_size (`int`): Input dimension
output_size (`int`): Output dimension
cond_size (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
super().__init__()
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
self.linear = nn.Linear(hidden_size, output_size, bias=False)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_size, 2 * hidden_size, bias=False)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class VibeVoiceDiffusionHead(PreTrainedModel):
"""
Diffusion head model for vibevoice.
Args:
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
"""
config_class = VibeVoiceDiffusionHeadConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
config,
):
super().__init__(config)
self.config = config
self.cond_dim = config.hidden_size
latent_size = config.latent_size
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
self.t_embedder = TimestepEmbedder(self.cond_dim)
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
# Create the intermediate layers
self.layers = nn.ModuleList([
HeadLayer(
embed_dim=config.hidden_size,
ffn_dim=ffn_dim,
cond_dim=self.cond_dim,
norm_eps=config.rms_norm_eps
)
for _ in range(config.head_layers)
])
# Final layer for output
self.final_layer = FinalLayer(
hidden_size=config.hidden_size,
output_size=latent_size,
cond_size=self.cond_dim,
norm_eps=config.rms_norm_eps
)
self.initialize_weights()
def initialize_weights(self):
"""Initialize the weights of the model."""
# Initialize timestep embedder
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers
for layer in self.layers:
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
def forward(
self,
noisy_images,
timesteps,
condition,
):
"""
Forward pass of the prediction head.
Args:
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
timesteps (`torch.Tensor`): Timesteps for diffusion
condition (`torch.Tensor`): Conditioning information
Returns:
`torch.Tensor`: The predicted noise/velocity
"""
x = self.noisy_images_proj(noisy_images)
t = self.t_embedder(timesteps)
condition = self.cond_proj(condition)
c = condition + t
for layer in self.layers:
x = layer(x, c)
x = self.final_layer(x, c)
return x
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
__all__ = [
"VibeVoiceDiffusionHead",
]
@@ -0,0 +1,214 @@
"""Tokenization classes for vibevoice."""
from typing import List, Optional, Union
from transformers.utils import logging
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
logger = logging.get_logger(__name__)
class VibeVoiceTextTokenizer(Qwen2Tokenizer):
"""
Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
add_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to add special tokens when encoding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
errors="replace",
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
add_special_tokens=True,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_special_tokens=add_special_tokens,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>", # Speech start (reusing vision tokens)
"<|vision_end|>", # Speech end
"<|vision_pad|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
"""Id used for padding (returns -100 for loss masking)."""
return -100
class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast):
"""
Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library).
Based on the Qwen2 tokenizer with additional special tokens for speech.
Args:
vocab_file (`str`, *optional*):
Path to the vocabulary file.
merges_file (`str`, *optional*):
Path to the merges file.
tokenizer_file (`str`, *optional*):
Path to [tokenizers](https://github.com/huggingface/tokenizers) file.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token.
bos_token (`str`, *optional*):
The beginning of sequence token. Not used for vibevoice.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The token used for padding.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
merges_file=None,
tokenizer_file=None,
unk_token="<|endoftext|>",
bos_token=None,
eos_token="<|endoftext|>",
pad_token="<|endoftext|>",
add_prefix_space=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
# Add VibeVoice-specific special tokens
self._add_vibevoice_special_tokens()
def _add_vibevoice_special_tokens(self):
"""Add VibeVoice-specific special tokens."""
special_tokens = {
"additional_special_tokens": [
"<|vision_start|>", # Speech start (reusing vision tokens)
"<|vision_end|>", # Speech end
"<|vision_pad|>", # Speech diffusion pad
]
}
num_added = self.add_special_tokens(special_tokens)
# Cache special token IDs
self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>")
self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>")
self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>")
# self._eos_id = self.convert_tokens_to_ids('<|endoftext|>')
self._eos_id = self.eos_token_id # qwen2 / qwen3
self._pad_id = self.convert_tokens_to_ids('<|image_pad|>')
return num_added
@property
def eos_id(self) -> int:
"""Id of the end of sequence token."""
return self._eos_id
@property
def speech_start_id(self) -> int:
"""Id of the speech start token."""
return self._speech_start_id
@property
def speech_end_id(self) -> int:
"""Id of the speech end token."""
return self._speech_end_id
@property
def speech_diffusion_id(self) -> int:
"""Id of the speech diffusion token."""
return self._speech_diffusion_id
@property
def pad_id(self) -> int:
"""Id used for padding (returns -100 for loss masking)."""
return self._pad_id
__all__ = [
"VibeVoiceTextTokenizer",
"VibeVoiceTextTokenizerFast",
]
@@ -0,0 +1,904 @@
import math
import typing as tp
from functools import partial
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.modeling_utils import PreTrainedModel
from transformers.activations import ACT2FN
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig
logger = logging.get_logger(__name__)
import os
# Try to import APEX FusedRMSNorm
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
APEX_AVAILABLE = True
logger.info("APEX FusedRMSNorm is available and will be used for optimization")
if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
APEX_AVAILABLE = False
logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
except ImportError:
APEX_AVAILABLE = False
logger.warning("APEX FusedRMSNorm not available, using native implementation")
# APEX_AVAILABLE=False
# Normalization modules
class ConvLayerNorm(nn.LayerNorm):
"""
Convolution-friendly LayerNorm that moves channels to last dimensions
before running the normalization and moves them back to original position right after.
"""
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
super().__init__(normalized_shape, **kwargs)
def forward(self, x):
x = x.transpose(1, 2) # b ... t -> b t ...
x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
x = x.transpose(1, 2) # b t ... -> b ... t
return x
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
weight_shape = (dim,) if weight_shape is None else weight_shape
self.weight = nn.Parameter(torch.ones(weight_shape))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
class ConvRMSNorm(RMSNorm):
def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
super().__init__(dim, eps, elementwise_affine, weight_shape)
def forward(self, x):
x = x.transpose(1, 2) # b ... t -> b t ...
if (not APEX_AVAILABLE) or (not self.elementwise_affine):
# Fallback to native implementation
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
else:
output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
output = output.transpose(1, 2) # b t ... -> b ... t
return output
# Convolutional layers and utilities
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_layer_norm', 'layer_norm', 'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return nn.utils.weight_norm(module)
elif norm == 'spectral_norm':
return nn.utils.spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == 'layer_norm':
assert isinstance(module, nn.modules.conv._ConvNd)
return ConvLayerNorm(module.out_channels, **norm_kwargs)
elif norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""Calculate extra padding needed for convolution to have the same output length"""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
"""Pad 1D input with handling for small inputs in reflect mode"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class VibeVoiceTokenizerStreamingCache:
"""Cache for streaming convolution, similar to KV cache in attention"""
def __init__(self):
self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor
def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
"""Get cached states for given layer and sample indices"""
states = []
max_length = 0
# First pass: collect states and find max length
for idx in sample_indices.tolist():
key = (layer_id, idx)
if key not in self.cache:
return None # If any sample is missing, return None
state = self.cache[key]
states.append(state)
max_length = max(max_length, state.shape[-1])
# Second pass: pad states to max length if needed
if len(states) > 0 and states[0].dim() >= 2:
padded_states = []
for state in states:
if state.shape[-1] < max_length:
# Pad on the time dimension (last dimension)
pad_size = max_length - state.shape[-1]
# Pad with zeros on the LEFT to align the most recent samples
padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
padded_states.append(padded_state)
else:
padded_states.append(state)
return torch.stack(padded_states, dim=0)
else:
return torch.stack(states, dim=0)
def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
"""Set cached states for given layer and sample indices"""
for i, idx in enumerate(sample_indices.tolist()):
key = (layer_id, idx)
self.cache[key] = states[i].detach()
def set_to_zero(self, sample_indices: torch.Tensor):
"""Set all cached states to zero for given sample indices"""
for key in list(self.cache.keys()):
layer_id, sample_idx = key
if sample_idx in sample_indices.tolist():
# Create zero tensor with same shape and dtype as cached tensor
cached_tensor = self.cache[key]
self.cache[key] = torch.zeros_like(cached_tensor)
def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
"""Clear cache for specific layer/samples or everything"""
if layer_id is None and sample_indices is None:
self.cache.clear()
elif layer_id is not None and sample_indices is None:
# Clear all samples for a specific layer
keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
for k in keys_to_remove:
del self.cache[k]
elif layer_id is not None and sample_indices is not None:
# Clear specific samples for a specific layer
for idx in sample_indices.tolist():
key = (layer_id, idx)
self.cache.pop(key, None)
class SConv1d(nn.Module):
"""Conv1d with built-in handling of asymmetric or causal padding and normalization."""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect'):
super().__init__()
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
# Store configuration
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
# For causal convolution, we need to maintain kernel_size - 1 samples as context
# need to check use which context_size is more suitable
# self.context_size = (kernel_size - 1) * dilation
self.context_size = (kernel_size - 1) * dilation - (stride - 1)
# For non-streaming mode, calculate padding
self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
# Create a unique layer ID for cache management
self._layer_id = None
@property
def layer_id(self):
if self._layer_id is None:
self._layer_id = f"sconv1d_{id(self)}"
return self._layer_id
def forward(self, x: torch.Tensor,
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False) -> torch.Tensor:
"""
Forward pass with optional streaming support via cache.
Args:
x: Input tensor [batch_size, channels, time]
cache: VibeVoiceTokenizerStreamingCache object for maintaining states
sample_indices: Indices identifying each sample for cache management
use_cache: Whether to use cached states for streaming
debug: Whether to print debug information
Returns:
Output tensor
"""
B, C, T = x.shape
# Non-streaming mode
if not use_cache or cache is None:
return self._forward_non_streaming(x, debug=debug)
# Streaming mode
assert self.causal, "Streaming mode is only supported for causal convolutions"
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
def _forward_streaming(self, x: torch.Tensor,
cache: VibeVoiceTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False) -> torch.Tensor:
"""Streaming forward pass with cache operations kept separate from compiled code"""
B, C, T = x.shape
# Cache operations (not compiled)
cached_states = cache.get(self.layer_id, sample_indices)
if cached_states is None:
# First chunk - initialize with zeros for context
if self.context_size > 0:
cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
else:
cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] No context needed (kernel_size=stride)")
# Concatenate cached states with input
if cached_states.shape[2] > 0:
input_with_context = torch.cat([cached_states, x], dim=2)
else:
input_with_context = x
if debug:
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
# Apply convolution directly - no extra padding in streaming mode
# The conv layer will handle its own padding internally
output = self.conv(input_with_context)
if debug:
print(f"[DEBUG] Output shape: {output.shape}")
# Update cache for next chunk
if self.context_size > 0:
# Calculate how many samples to keep
total_input_length = input_with_context.shape[2]
# Keep the last context_size samples
if total_input_length >= self.context_size:
new_cache_start = total_input_length - self.context_size
new_cache = input_with_context[:, :, new_cache_start:]
else:
# If we have less than context_size samples, keep everything
new_cache = input_with_context
if debug:
print(f"[DEBUG] New cache shape: {new_cache.shape}")
cache.set(self.layer_id, sample_indices, new_cache)
return output
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
"""Standard forward pass without streaming"""
B, C, T = x.shape
kernel_size = self.kernel_size
stride = self.stride
dilation = self.dilation
padding_total = self.padding_total
# Compute extra padding for stride alignment
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if debug:
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
if self.causal:
# Left padding for causal
if self.pad_mode == 'constant':
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
else:
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Symmetric padding for non-causal
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
if debug:
print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
output = self.conv(x)
if debug:
print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
return output
class SConvTranspose1d(nn.Module):
"""ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert self.causal or self.trim_right_ratio == 1., \
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
# Store configuration
self.kernel_size = kernel_size
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
# For transposed convolution, padding calculation is different
self.padding_total = kernel_size - stride
# For streaming, we need to keep track of input history
# Transposed conv needs to see multiple input samples to produce correct output
self.context_size = kernel_size - 1
# Create a unique layer ID for cache management
self._layer_id = None
@property
def layer_id(self):
if self._layer_id is None:
self._layer_id = f"sconvtr1d_{id(self)}"
return self._layer_id
def forward(self, x: torch.Tensor,
cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
sample_indices: Optional[torch.Tensor] = None,
use_cache: bool = False,
debug: bool = False) -> torch.Tensor:
"""
Forward pass with optional streaming support via cache.
"""
B, C, T = x.shape
# Non-streaming mode
if not use_cache or cache is None:
return self._forward_non_streaming(x, debug=debug)
# Streaming mode
assert sample_indices is not None, "sample_indices must be provided for streaming mode"
assert len(sample_indices) == B, "sample_indices must match batch size"
return self._forward_streaming(x, cache, sample_indices, debug)
def _forward_streaming(self, x: torch.Tensor,
cache: VibeVoiceTokenizerStreamingCache,
sample_indices: torch.Tensor,
debug: bool = False) -> torch.Tensor:
"""Streaming forward pass with cache operations kept separate from compiled code"""
B, C, T = x.shape
# Cache operations (not compiled)
cached_input = cache.get(self.layer_id, sample_indices)
if cached_input is None:
# First chunk - no history yet
cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
if debug:
print(f"[DEBUG] Initialized empty cache for transposed conv")
# Concatenate cached input with new input
full_input = torch.cat([cached_input, x], dim=2)
if debug:
print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
# First chunk or debug mode - use uncompiled version
full_output = self.convtr(full_input)
if debug:
print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
# Calculate padding to remove
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
padding_left = self.padding_total - padding_right
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
# Remove padding
if padding_left + padding_right > 0:
full_output = unpad1d(full_output, (padding_left, padding_right))
if debug:
print(f"[DEBUG] After unpadding: {full_output.shape}")
# Determine which part of the output corresponds to the new input
if cached_input.shape[2] == 0:
# First chunk - return all output
output = full_output
else:
# Subsequent chunks - return only the new output
expected_new_output = T * self.stride
# Take the last expected_new_output samples
if full_output.shape[2] >= expected_new_output:
output = full_output[:, :, -expected_new_output:]
else:
output = full_output
if debug:
print(f"[DEBUG] Final streaming output shape: {output.shape}")
# Update cache
if full_input.shape[2] > self.context_size:
new_cache = full_input[:, :, -self.context_size:]
else:
new_cache = full_input
if debug:
print(f"[DEBUG] New cache shape: {new_cache.shape}")
cache.set(self.layer_id, sample_indices, new_cache)
return output
def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
"""Standard forward pass without streaming"""
if debug:
print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
# Apply transposed convolution
y = self.convtr(x)
if debug:
print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
# Calculate and remove padding
if self.causal:
padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
padding_left = self.padding_total - padding_right
else:
padding_right = self.padding_total // 2
padding_left = self.padding_total - padding_right
if padding_left + padding_right > 0:
y = unpad1d(y, (padding_left, padding_right))
if debug:
print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
return y
# FFN
class FFN(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
bias=False,
):
super().__init__()
self.embed_dim = embed_dim
self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
self.gelu = ACT2FN["gelu"]
self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
def forward(self, x):
x = self.linear1(x)
x = self.gelu(x)
x = self.linear2(x)
return x
class Convlayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
pad_mode='zeros',
norm='weight_norm',
causal=True,
):
super().__init__()
self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
def forward(self, x):
return self.conv(x)
class Block1D(nn.Module):
def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
layer_scale_init_value=1e-6, **kwargs):
super().__init__()
if kwargs.get('layernorm', 'LN') == 'LN':
self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
if mixer_layer == 'conv':
self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
kernel_size=kernel_size,
pad_mode=kwargs.get('pad_mode', 'reflect'),
norm=kwargs.get('norm', 'none'),
causal=kwargs.get('causal', True),
bias=kwargs.get('bias', True),
)
elif mixer_layer == 'depthwise_conv':
self.mixer = Convlayer(dim, dim, groups=dim,
kernel_size=kernel_size,
pad_mode=kwargs.get('pad_mode', 'reflect'),
norm=kwargs.get('norm', 'none'),
causal=kwargs.get('causal', True),
bias=kwargs.get('bias', True),
)
else:
raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
self.ffn = FFN(
dim,
kwargs.get('ffn_expansion', 4) * dim,
bias=kwargs.get('bias', False),
)
self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
if layer_scale_init_value > 0:
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
else:
self.gamma = None
self.ffn_gamma = None
def forward(self, x):
# mixer
residual = x
x = self.norm(x)
x = self.mixer(x)
if self.gamma is not None:
x = x * self.gamma.unsqueeze(-1)
x = residual + self.drop_path(x)
# ffn
residual = x
x = self.ffn_norm(x)
x = x.permute(0, 2, 1)
x = self.ffn(x)
x = x.permute(0, 2, 1)
if self.ffn_gamma is not None:
x = x * self.ffn_gamma.unsqueeze(-1)
x = residual + self.drop_path(x)
return x
class TokenizerDecoder(nn.Module):
"""
Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
Args:
config: Configuration object with model parameters
"""
def __init__(self, config):
super().__init__()
# Extract parameters from config
self.dimension = config.dimension
self.channels = config.channels
self.n_filters = config.n_filters
self.ratios = config.ratios
# IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel
self.depths = config.depths # Changed from list(reversed(config.depths))
self.n_residual_layers = getattr(config, "n_residual_layers", 1)
self.hop_length = np.prod(self.ratios)
self.causal = config.causal
# Additional config parameters with defaults
kernel_size = getattr(config, "kernel_size", 7)
last_kernel_size = getattr(config, "last_kernel_size", 7)
norm = getattr(config, "norm", "none")
norm_params = getattr(config, "norm_params", {})
pad_mode = getattr(config, "pad_mode", "reflect")
bias = getattr(config, "bias", True)
layernorm = getattr(config, "layernorm", "LN")
layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
drop_path_rate = getattr(config, "drop_path_rate", 0.0)
mixer_layer = getattr(config, "mixer_layer", "conv")
layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
disable_last_norm = getattr(config, "disable_last_norm", False)
# determine the norm type based on layernorm
if layernorm == 'LN':
norm_type = ConvLayerNorm
elif layernorm == 'RMSNorm':
norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
else:
raise ValueError(f"Unsupported norm type: {layernorm}")
# stem and upsampling layers
stem = nn.Sequential(
SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
)
self.upsample_layers = nn.ModuleList()
self.upsample_layers.append(stem)
for i in range(len(self.ratios)):
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
upsample_layer = nn.Sequential(
SConvTranspose1d(in_ch, out_ch,
kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
norm=norm, norm_kwargs=norm_params, bias=bias,
causal=self.causal, trim_right_ratio=trim_right_ratio),
)
self.upsample_layers.append(upsample_layer)
# configure transformer blocks
layer_type = partial(
Block1D,
mixer_layer=mixer_layer,
layernorm=layernorm,
eps=layernorm_eps,
causal=self.causal,
pad_mode=pad_mode,
norm=norm,
bias=bias,
layer_scale_init_value=layer_scale_init_value,
)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
# Create stages in the same order as the original model
for i in range(len(self.depths)):
in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
stage = nn.Sequential(
*[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
)
self.stages.append(stage)
cur += self.depths[i]
if not disable_last_norm:
self.norm = norm_type(in_ch, eps=layernorm_eps)
else:
self.norm = nn.Identity()
self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
for i in range(len(self.depths)):
# Apply upsampling
for layer in self.upsample_layers[i]:
if isinstance(layer, (SConv1d, SConvTranspose1d)):
x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
else:
x = layer(x)
# Apply stage (Block1D contains Convlayer which contains SConv1d)
for block in self.stages[i]:
if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
# Block1D forward with cache support
residual = x
x = block.norm(x)
x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
if block.gamma is not None:
x = x * block.gamma.unsqueeze(-1)
x = residual + x
# FFN part
residual = x
x = block.ffn_norm(x)
x = x.permute(0, 2, 1)
x = block.ffn(x)
x = x.permute(0, 2, 1)
if block.ffn_gamma is not None:
x = x * block.ffn_gamma.unsqueeze(-1)
x = residual + x
else:
x = block(x)
return self.norm(x)
def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return x
class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
"""VibeVoice speech tokenizer model (only decoder) for acoustic tokens"""
config_class = VibeVoiceAcousticTokenizerConfig
base_model_prefix = "vibevoice_acoustic_tokenizer"
_supports_flash_attn_2 = True
_supports_sdpa = True
_no_split_modules = ["TokenizerDecoder"]
def __init__(self, config):
super().__init__(config)
self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
self.std_dist_type = getattr(config, "std_dist_type", "fix")
# Parse encoder depths
if isinstance(config.encoder_depths, str):
encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
else:
encoder_depths = config.encoder_depths
# Parse decoder depths if provided
if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
else:
# Default: use reversed encoder depths if decoder_depths is None
decoder_depths = list(reversed(encoder_depths))
# Create decoder config
decoder_config = copy.deepcopy(config)
decoder_config.dimension = config.vae_dim
decoder_config.n_filters = config.decoder_n_filters
decoder_config.ratios = config.decoder_ratios
decoder_config.depths = decoder_depths
decoder_config.norm = config.conv_norm
decoder_config.pad_mode = config.pad_mode
decoder_config.bias = config.conv_bias
decoder_config.layernorm_eps = config.layernorm_eps
decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
decoder_config.mixer_layer = config.mixer_layer
decoder_config.layer_scale_init_value = config.layer_scale_init_value
decoder_config.disable_last_norm = config.disable_last_norm
self.decoder = TokenizerDecoder(decoder_config)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights for the model"""
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, std=self.config.weight_init_value)
if module.bias is not None:
nn.init.zeros_(module.bias)
@torch.no_grad()
def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
"""Convert latent representations back to audio"""
if latents.shape[1] == self.config.vae_dim:
pass
else:
latents = latents.permute(0, 2, 1)
audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
return audio
AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
__all__ = [
"VibeVoiceTokenizerStreamingCache",
"VibeVoiceAcousticTokenizerModel",
]
+264
View File
@@ -0,0 +1,264 @@
from __future__ import annotations
import torch
import asyncio
from queue import Queue
from typing import TYPE_CHECKING, Optional
from transformers.generation import BaseStreamer
class AudioStreamer(BaseStreamer):
"""
Audio streamer that stores audio chunks in queues for each sample in the batch.
This allows streaming audio generation for multiple samples simultaneously.
Parameters:
batch_size (`int`):
The batch size for generation
stop_signal (`any`, *optional*):
The signal to put in the queue when generation ends. Defaults to None.
timeout (`float`, *optional*):
The timeout for the audio queue. If `None`, the queue will block indefinitely.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
self.batch_size = batch_size
self.stop_signal = stop_signal
self.timeout = timeout
# Create a queue for each sample in the batch
self.audio_queues = [Queue() for _ in range(batch_size)]
self.finished_flags = [False for _ in range(batch_size)]
self.sample_indices_map = {} # Maps from sample index to queue index
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""
Receives audio chunks and puts them in the appropriate queues.
Args:
audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
sample_indices: Tensor indicating which samples these chunks belong to
"""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
# Convert to numpy or keep as tensor based on preference
audio_chunk = audio_chunks[i].detach().cpu()
self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""
Signals the end of generation for specified samples or all samples.
Args:
sample_indices: Optional tensor of sample indices to end. If None, ends all.
"""
if sample_indices is None:
# End all samples
for idx in range(self.batch_size):
if not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
else:
# End specific samples
for sample_idx in sample_indices:
idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
if idx < self.batch_size and not self.finished_flags[idx]:
self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
self.finished_flags[idx] = True
def __iter__(self):
"""Returns an iterator over the batch of audio streams."""
return AudioBatchIterator(self)
def get_stream(self, sample_idx: int):
"""Get the audio stream for a specific sample."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
return AudioSampleIterator(self, sample_idx)
class AudioSampleIterator:
"""Iterator for a single audio stream from the batch."""
def __init__(self, streamer: AudioStreamer, sample_idx: int):
self.streamer = streamer
self.sample_idx = sample_idx
def __iter__(self):
return self
def __next__(self):
value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
if value == self.streamer.stop_signal:
raise StopIteration()
return value
class AudioBatchIterator:
"""Iterator that yields audio chunks for all samples in the batch."""
def __init__(self, streamer: AudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __iter__(self):
return self
def __next__(self):
if not self.active_samples:
raise StopIteration()
batch_chunks = {}
samples_to_remove = set()
# Try to get chunks from all active samples
for idx in self.active_samples:
try:
value = self.streamer.audio_queues[idx].get(block=False)
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except:
# Queue is empty for this sample, skip it this iteration
pass
# Remove finished samples
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# If no chunks were ready but we still have active samples,
# wait a bit and try again
import time
time.sleep(0.01)
return self.__next__()
else:
raise StopIteration()
class AsyncAudioStreamer(AudioStreamer):
"""
Async version of AudioStreamer for use in async contexts.
"""
def __init__(
self,
batch_size: int,
stop_signal: Optional[any] = None,
timeout: Optional[float] = None,
):
super().__init__(batch_size, stop_signal, timeout)
# Replace regular queues with async queues
self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
self.loop = asyncio.get_running_loop()
def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
"""Put audio chunks in the appropriate async queues."""
for i, sample_idx in enumerate(sample_indices):
idx = sample_idx.item()
if idx < self.batch_size and not self.finished_flags[idx]:
audio_chunk = audio_chunks[i].detach().cpu()
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, audio_chunk
)
def end(self, sample_indices: Optional[torch.Tensor] = None):
"""Signal the end of generation for specified samples."""
if sample_indices is None:
indices_to_end = range(self.batch_size)
else:
indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
for idx in indices_to_end:
if idx < self.batch_size and not self.finished_flags[idx]:
self.loop.call_soon_threadsafe(
self.audio_queues[idx].put_nowait, self.stop_signal
)
self.finished_flags[idx] = True
async def get_stream(self, sample_idx: int):
"""Get async iterator for a specific sample's audio stream."""
if sample_idx >= self.batch_size:
raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
while True:
value = await self.audio_queues[sample_idx].get()
if value == self.stop_signal:
break
yield value
def __aiter__(self):
"""Returns an async iterator over all audio streams."""
return AsyncAudioBatchIterator(self)
class AsyncAudioBatchIterator:
"""Async iterator for batch audio streaming."""
def __init__(self, streamer: AsyncAudioStreamer):
self.streamer = streamer
self.active_samples = set(range(streamer.batch_size))
def __aiter__(self):
return self
async def __anext__(self):
if not self.active_samples:
raise StopAsyncIteration()
batch_chunks = {}
samples_to_remove = set()
# Create tasks for all active samples
tasks = {
idx: asyncio.create_task(self._get_chunk(idx))
for idx in self.active_samples
}
# Wait for at least one chunk to be ready
done, pending = await asyncio.wait(
tasks.values(),
return_when=asyncio.FIRST_COMPLETED,
timeout=self.streamer.timeout
)
# Cancel pending tasks
for task in pending:
task.cancel()
# Process completed tasks
for idx, task in tasks.items():
if task in done:
try:
value = await task
if value == self.streamer.stop_signal:
samples_to_remove.add(idx)
else:
batch_chunks[idx] = value
except asyncio.CancelledError:
pass
self.active_samples -= samples_to_remove
if batch_chunks:
return batch_chunks
elif self.active_samples:
# Try again if we still have active samples
return await self.__anext__()
else:
raise StopAsyncIteration()
async def _get_chunk(self, idx):
"""Helper to get a chunk from a specific queue."""
return await self.streamer.audio_queues[idx].get()
View File
+692
View File
@@ -0,0 +1,692 @@
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType, logging
from .vibevoice_tokenizer_processor import AudioNormalizer
logger = logging.get_logger(__name__)
class VibeVoiceProcessor:
r"""
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
Args:
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
The tokenizer for text processing.
audio_processor (`VibeVoiceTokenizerProcessor`):
The audio processor for speech processing.
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
The compression ratio for speech tokenization.
db_normalize (`bool`, *optional*, defaults to True):
Whether to apply decibel normalization to audio inputs.
"""
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model
- a path to a *directory* containing processor config
Returns:
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
"""
import os
import json
from transformers.utils import cached_file
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
VibeVoiceTextTokenizer,
VibeVoiceTextTokenizerFast
)
# Try to load from local path first, then from HF hub
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
config = None
if os.path.exists(config_path):
# Local path exists
with open(config_path, 'r') as f:
config = json.load(f)
else:
# Try to load from HF hub
try:
config_file = cached_file(
pretrained_model_name_or_path,
"preprocessor_config.json",
**kwargs
)
with open(config_file, 'r') as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Could not load preprocessor_config.json from {pretrained_model_name_or_path}: {e}")
logger.warning("Using default configuration")
config = {
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
}
# Extract main processor parameters
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
# Load tokenizer - try from model path first, then fallback to Qwen
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
if 'qwen' in language_model_pretrained_name.lower():
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
language_model_pretrained_name,
**kwargs
)
else:
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
# Load audio processor
if "audio_processor" in config:
# Create audio processor from config
audio_config = config["audio_processor"]
audio_processor = VibeVoiceTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 24000),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-6),
)
else:
# Create default audio processor
audio_processor = VibeVoiceTokenizerProcessor()
# Create and return the processor
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save a processor to a directory, so that it can be re-loaded using the
[`~VibeVoiceProcessor.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the processor will be saved.
"""
import os
import json
os.makedirs(save_directory, exist_ok=True)
# Save processor configuration
processor_config = {
"processor_class": "VibeVoiceProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
"eps": getattr(self.audio_processor, 'eps', 1e-6),
}
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path}")
def __call__(
self,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Main method to process one or more podcast scripts with optional voice samples.
Args:
text (`str`, `List[str]`):
The input text(s) to process. Can be:
- A single script string
- A list of script strings for batch processing
- A path to a .json or .txt file
- A list of paths
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
Voice samples for each script. Can be:
- A list of samples for a single script
- A list of lists for batch processing
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
Whether to pad sequences to the same length
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
Whether to truncate sequences
max_length (`int`, *optional*):
Maximum length of the returned sequences
return_tensors (`str` or `TensorType`, *optional*):
If set, will return tensors of a particular framework
return_attention_mask (`bool`, defaults to `True`):
Whether to return the attention mask
Returns:
`BatchEncoding`: A BatchEncoding with the following fields:
- **input_ids** -- List of token id sequences or tensor
- **attention_mask** -- List of attention masks or tensor
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
- **speech_masks** -- Speech masks (if voice_samples provided)
- **speech_input_mask** -- Boolean masks indicating speech token positions
"""
# Handle single vs batch input
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
# Single input
texts = [text]
is_batched = False
else:
# Batch input
texts = text
is_batched = True
# Handle voice samples
if voice_samples is not None:
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
# Single set of voice samples
voice_samples_list = [voice_samples]
else:
# Batch of voice samples
voice_samples_list = voice_samples
else:
voice_samples_list = [None] * len(texts)
# Process each input
all_encodings = []
for text_input, voice_input in zip(texts, voice_samples_list):
encoding = self._process_single(text_input, voice_input)
all_encodings.append(encoding)
# Combine batch
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _process_single(
self,
text: Union[str, TextInput],
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
) -> Dict[str, Any]:
"""Process a single podcast script."""
# Determine if text is a file path or direct script
script = None
if isinstance(text, str):
# Check if it's a file path
if text.endswith('.json') and os.path.exists(text):
script = self._convert_json_to_script(text)
elif text.endswith('.txt') and os.path.exists(text):
script = self._convert_text_to_script(text)
else:
# Assume it's the script content directly
script = text
if script is None:
raise ValueError(f"Could not process input text: {text}")
# Parse the script
parsed_lines = self._parse_script(script)
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
# Create system prompt
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
system_tokens = self.tokenizer.encode(self.system_prompt)
# Process voice samples if provided
if voice_samples:
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
else:
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
# Build full token sequence
full_tokens = system_tokens + voice_tokens
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
# Add text input section
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
for speaker_id, speaker_text in parsed_lines:
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
full_tokens += speaker_text_tokens
speech_input_mask += [False] * len(speaker_text_tokens)
# Add speech output section
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
return {
"input_ids": full_tokens,
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
"speech_input_mask": speech_input_mask,
"parsed_script": parsed_lines,
"all_speakers": all_speakers,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
"""Combine multiple encodings into a batch with padding."""
# Extract input_ids and create attention_mask
input_ids_list = [enc["input_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
# Determine padding strategy
if isinstance(padding, bool):
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
elif isinstance(padding, str):
padding_strategy = PaddingStrategy(padding)
else:
padding_strategy = padding
# Apply padding to input_ids
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
if padding_strategy == PaddingStrategy.LONGEST:
max_len = max(len(ids) for ids in input_ids_list)
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
max_len = max_length
else:
max_len = max(len(ids) for ids in input_ids_list)
# Pad sequences
padded_input_ids = []
attention_masks = []
padded_speech_input_masks = []
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
# Truncate if needed
if truncation and len(input_ids) > max_len:
input_ids = input_ids[:max_len]
speech_mask = speech_mask[:max_len]
# Pad
padding_length = max_len - len(input_ids)
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_speech_mask = [False] * padding_length + speech_mask
padded_input_ids.append(padded_ids)
attention_masks.append(attention_mask)
padded_speech_input_masks.append(padded_speech_mask)
input_ids_list = padded_input_ids
speech_input_masks_list = padded_speech_input_masks
else:
# No padding, just create attention masks
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
# Process speech inputs
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
# Prepare batch encoding
batch_encoding = BatchEncoding()
# Handle tensor conversion
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
else:
batch_encoding["input_ids"] = input_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
# Process speech tensors if present
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs,
return_tensors=return_tensors,
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
# Add metadata
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
return batch_encoding
def _create_voice_prompt(
self,
speaker_samples: List[Union[str, np.ndarray]]
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
"""
Create voice prompt tokens and process audio samples.
Returns:
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
"""
vae_token_id = self.tokenizer.speech_diffusion_id
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
voice_speech_inputs = []
voice_speech_masks = [False] * len(voice_full_tokens)
for speaker_id, speaker_audio in enumerate(speaker_samples):
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
# Process audio
if isinstance(speaker_audio, str):
# Load audio from file
wav = self.audio_processor._load_audio_from_path(speaker_audio)
else:
wav = np.array(speaker_audio, dtype=np.float32)
# Apply normalization if needed
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
# Calculate token length based on compression ratio
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
# vae_tok_len = wav.shape[0]
# else:
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
# Build tokens and masks
speaker_tokens = (prefix_tokens +
[self.tokenizer.speech_start_id] +
[vae_token_id] * vae_tok_len +
[self.tokenizer.speech_end_id] +
self.tokenizer.encode('\n', add_special_tokens=False))
vae_input_mask = ([False] * len(prefix_tokens) +
[False] +
[True] * vae_tok_len +
[False] +
[False])
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
voice_speech_inputs.append(wav)
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
"""
Prepare speech inputs for model consumption.
Args:
speech_inputs: List of speech arrays
return_tensors: Output tensor type
device: Device to place tensors on
dtype: Data type for tensors
Returns:
Dictionary with padded_speeches and speech_masks
"""
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
# Calculate sequence lengths
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
max_speech_length = max(s.shape[0] for s in speech_inputs)
# Pad speeches
if speech_inputs[0].ndim == 1:
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
else:
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
padded_speeches[i, :len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {
"padded_speeches": padded_speeches,
"speech_masks": speech_masks,
}
# Convert to tensors if requested
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
return result
def _convert_json_to_script(self, json_file: str) -> str:
"""
Convert JSON format to script format.
Expected JSON format:
[
{"speaker": "1", "text": "Hello everyone..."},
{"speaker": "2", "text": "Great to be here..."}
]
"""
import json
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON file must contain a list of speaker entries")
script_lines = []
for item in data:
if not isinstance(item, dict):
logger.warning(f"Skipping non-dict entry: {item}")
continue
speaker = item.get('speaker')
text = item.get('text')
if speaker is None or text is None:
logger.warning(f"Skipping entry missing speaker or text: {item}")
continue
# Ensure speaker ID is valid
try:
speaker_id = int(speaker)
except (ValueError, TypeError):
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
continue
# Clean up text
text = text.strip()
if text:
script_lines.append(f"Speaker {speaker_id}: {text}")
if not script_lines:
raise ValueError("No valid entries found in JSON file")
return "\n".join(script_lines)
def _convert_text_to_script(self, text_file: str) -> str:
"""
Convert text file to script format.
Handles multiple formats:
1. Already formatted as "Speaker X: text"
2. Plain text (assigns to Speaker 1)
Handles edge cases like multiple colons in a line.
"""
with open(text_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
script_lines = []
current_speaker = 1
for line in lines:
line = line.strip()
if not line:
continue
# Try to parse as "Speaker X: text" format
# Use regex to be more robust
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
if speaker_match:
speaker_id = int(speaker_match.group(1))
text = speaker_match.group(2).strip()
if text:
script_lines.append(f"Speaker {speaker_id}: {text}")
else:
# Treat as plain text - assign to current speaker
script_lines.append(f"Speaker {current_speaker}: {line}")
if not script_lines:
raise ValueError("No valid content found in text file")
return "\n".join(script_lines)
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
"""Parse script into list of (speaker_id, text) tuples."""
lines = script.strip().split("\n")
parsed_lines = []
speaker_ids = []
# First pass: parse all lines and collect speaker IDs
for line in lines:
if not line.strip():
continue
# Use regex to handle edge cases like multiple colons
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
if match:
speaker_id = int(match.group(1))
text = ' ' + match.group(2).strip()
parsed_lines.append((speaker_id, text))
speaker_ids.append(speaker_id)
else:
logger.warning(f"Could not parse line: '{line}'")
if not parsed_lines:
raise ValueError("No valid speaker lines found in script")
# Check if we need to normalize speaker IDs (only if all are > 0)
min_speaker_id = min(speaker_ids)
if min_speaker_id > 0:
# Normalize to start from 0
normalized_lines = []
for speaker_id, text in parsed_lines:
normalized_lines.append((speaker_id - 1, text))
return normalized_lines
else:
# Keep original IDs
return parsed_lines
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
"""Merge text and audio inputs into a single BatchEncoding."""
# Start with text inputs
merged = BatchEncoding(text_inputs)
# Add audio-specific fields
if "audio" in audio_inputs:
merged["speech_inputs"] = audio_inputs["audio"]
if "streaming" in audio_inputs:
merged["streaming"] = audio_inputs["streaming"]
return merged
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
"""
Return the list of inputs accepted by the model.
"""
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
def save_audio(self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
"""
Save audio data to a file.
Args:
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
The audio data to save. Can be a single tensor/array or a list of them.
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
Returns:
str: The path to the saved audio file.
"""
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
__all__ = [
"VibeVoiceProcessor",
]
@@ -0,0 +1,409 @@
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType, logging
from .vibevoice_tokenizer_processor import AudioNormalizer
logger = logging.get_logger(__name__)
class VibeVoiceStreamingProcessor:
r"""
Constructs a VibeVoice Streaming processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
Args:
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
The tokenizer for text processing.
audio_processor (`VibeVoiceTokenizerProcessor`):
The audio processor for speech processing.
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
The compression ratio for speech tokenization.
db_normalize (`bool`, *optional*, defaults to True):
Whether to apply decibel normalization to audio inputs.
"""
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Instantiate a VibeVoiceStreamingProcessor from a pretrained VibeVoice Streaming processor.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model
- a path to a *directory* containing processor config
Returns:
[`VibeVoiceStreamingProcessor`]: The processor object instantiated from pretrained model.
"""
import os
import json
from transformers.utils import cached_file
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
VibeVoiceTextTokenizer,
VibeVoiceTextTokenizerFast
)
# Try to load from local path first, then from HF hub
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
config = None
if os.path.exists(config_path):
# Local path exists
with open(config_path, 'r') as f:
config = json.load(f)
else:
# Try to load from HF hub
try:
config_file = cached_file(
pretrained_model_name_or_path,
"preprocessor_config.json",
**kwargs
)
with open(config_file, 'r') as f:
config = json.load(f)
except Exception as e:
logger.warning(f"Could not load preprocessor_config.json from {pretrained_model_name_or_path}: {e}")
logger.warning("Using default configuration")
config = {
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
}
# Extract main processor parameters
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
# Load tokenizer - try from model path first, then fallback to Qwen
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
if 'qwen' in language_model_pretrained_name.lower():
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
language_model_pretrained_name,
**kwargs
)
else:
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
# Load audio processor
if "audio_processor" in config:
# Create audio processor from config
audio_config = config["audio_processor"]
audio_processor = VibeVoiceTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 24000),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-6),
)
else:
# Create default audio processor
audio_processor = VibeVoiceTokenizerProcessor()
# Create and return the processor
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save a processor to a directory, so that it can be re-loaded using the
[`~VibeVoiceStreamingProcessor.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the processor will be saved.
"""
import os
import json
os.makedirs(save_directory, exist_ok=True)
# Save processor configuration
processor_config = {
"processor_class": "VibeVoiceStreamingProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
"eps": getattr(self.audio_processor, 'eps', 1e-6),
}
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path}")
def __call__(self) -> BatchEncoding:
"""
Note:
This method is intentionally not implemented in the streaming processor.
Use `process_input_with_cached_prompt` for streaming use cases.
"""
raise NotImplementedError(
"VibeVoiceStreamingProcessor.__call__ is not implemented. "
"Use process_input_with_cached_prompt for streaming inputs."
)
def process_input_with_cached_prompt(
self,
text: Optional[str] = None,
cached_prompt: Optional[Dict[str, Any]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Main method to process one text script based on cached prompt. The function currently only supports single examples.
Args:
text (`str`):
The input text to process.
cached_prompt (`Dict[str, Any]`, *optional*):
The cached prompt to use for processing. It contains the kv cache of the voice prompt.
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
Whether to pad sequences to the same length
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
Whether to truncate sequences
max_length (`int`, *optional*):
Maximum length of the returned sequences
return_tensors (`str` or `TensorType`, *optional*):
If set, will return tensors of a particular framework
return_attention_mask (`bool`, defaults to `True`):
Whether to return the attention mask
Returns:
`BatchEncoding`: A BatchEncoding with the following fields:
- **input_ids** -- List of token id sequences or tensor
- **attention_mask** -- List of attention masks or tensor
- **tts_lm_input_ids** -- List of token id sequences or tensor used for TTS LM
- **tts_lm_attention_mask** -- List of attention masks or tensor used for TTS LM
- **tts_text_ids** -- List of token id sequences or tensor for TTS text input
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
- **speech_masks** -- Speech masks (if voice_samples provided)
- **speech_input_mask** -- Boolean masks indicating speech token positions
"""
# Only support single example
texts = [text]
cached_prompts = [cached_prompt]
is_batched = False
# Process each input
all_encodings = []
for text_input, cached_prompt_input in zip(texts, cached_prompts):
script_tokens = self.tokenizer.encode(text_input.strip() + "\n", add_special_tokens=False)
input_id_length = cached_prompt_input['lm']['last_hidden_state'].size(1)
tts_lm_input_id_length = cached_prompt_input['tts_lm']['last_hidden_state'].size(1)
# psudo input ids and masks
input_ids = [self.tokenizer.pad_id] * input_id_length
tts_lm_input_ids = [self.tokenizer.pad_id] * tts_lm_input_id_length
speech_input_mask = [False] * tts_lm_input_id_length
encoding = {
"input_ids": input_ids,
"tts_lm_input_ids": tts_lm_input_ids,
"tts_text_ids": script_tokens,
"speech_inputs": None,
"speech_input_mask": speech_input_mask,
}
all_encodings.append(encoding)
# Combine batch
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
"""Combine multiple encodings into a batch with padding."""
# Extract input_ids and create attention_mask
input_ids_list = [enc["input_ids"] for enc in encodings]
tts_lm_input_ids_list = [enc["tts_lm_input_ids"] for enc in encodings]
tts_text_ids_list = [enc["tts_text_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
tts_lm_attention_masks = [[1] * len(ids) for ids in tts_lm_input_ids_list] if return_attention_mask else None
# Process speech inputs
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
# Prepare batch encoding
batch_encoding = BatchEncoding()
# Handle tensor conversion
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
batch_encoding["tts_lm_input_ids"] = torch.tensor(tts_lm_input_ids_list, dtype=torch.long)
batch_encoding["tts_text_ids"] = torch.tensor(tts_text_ids_list, dtype=torch.long)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
batch_encoding["tts_lm_attention_mask"] = torch.tensor(tts_lm_attention_masks, dtype=torch.long)
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
else:
batch_encoding["input_ids"] = input_ids_list
batch_encoding["tts_lm_input_ids"] = tts_lm_input_ids_list
batch_encoding["tts_text_ids"] = tts_text_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["tts_lm_attention_mask"] = tts_lm_attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
# Process speech tensors if present
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs,
return_tensors=return_tensors,
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
return batch_encoding
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
"""
Prepare speech inputs for model consumption.
Args:
speech_inputs: List of speech arrays
return_tensors: Output tensor type
device: Device to place tensors on
dtype: Data type for tensors
Returns:
Dictionary with padded_speeches and speech_masks
"""
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
# Calculate sequence lengths
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
max_speech_length = max(s.shape[0] for s in speech_inputs)
# Pad speeches
if speech_inputs[0].ndim == 1:
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
else:
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
padded_speeches[i, :len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {
"padded_speeches": padded_speeches,
"speech_masks": speech_masks,
}
# Convert to tensors if requested
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
return result
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
"""
Return the list of inputs accepted by the model.
"""
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
def save_audio(self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
"""
Save audio data to a file.
Args:
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
The audio data to save. Can be a single tensor/array or a list of them.
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
Returns:
str: The path to the saved audio file.
"""
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
__all__ = [
"VibeVoiceStreamingProcessor",
]
@@ -0,0 +1,483 @@
"""
Processor class for VibeVoice models.
"""
import os
import json
import warnings
from typing import List, Optional, Union, Dict, Any
import numpy as np
import torch
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class AudioNormalizer:
"""
Audio normalization class for VibeVoice tokenizer.
This class provides audio normalization to ensure consistent input levels
for the VibeVoice tokenizer while maintaining audio quality.
"""
def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6):
"""
Initialize the audio normalizer.
Args:
target_dB_FS (float): Target dB FS level for the audio. Default: -25
eps (float): Small value to avoid division by zero. Default: 1e-6
"""
self.target_dB_FS = target_dB_FS
self.eps = eps
def tailor_dB_FS(self, audio: np.ndarray) -> tuple:
"""
Adjust the audio to the target dB FS level.
Args:
audio (np.ndarray): Input audio signal
Returns:
tuple: (normalized_audio, rms, scalar)
"""
rms = np.sqrt(np.mean(audio**2))
scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
normalized_audio = audio * scalar
return normalized_audio, rms, scalar
def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple:
"""
Avoid clipping by scaling down if necessary.
Args:
audio (np.ndarray): Input audio signal
scalar (float, optional): Explicit scaling factor
Returns:
tuple: (normalized_audio, scalar)
"""
if scalar is None:
max_val = np.max(np.abs(audio))
if max_val > 1.0:
scalar = max_val + self.eps
else:
scalar = 1.0
return audio / scalar, scalar
def __call__(self, audio: np.ndarray) -> np.ndarray:
"""
Normalize the audio by adjusting to target dB FS and avoiding clipping.
Args:
audio (np.ndarray): Input audio signal
Returns:
np.ndarray: Normalized audio signal
"""
# First adjust to target dB FS
audio, _, _ = self.tailor_dB_FS(audio)
# Then avoid clipping
audio, _ = self.avoid_clipping(audio)
return audio
# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components
class VibeVoiceTokenizerProcessor(FeatureExtractionMixin):
"""
Processor for VibeVoice acoustic tokenizer models.
This processor handles audio preprocessing for VibeVoice models, including:
- Audio format conversion (stereo to mono)
- Optional audio normalization
- Streaming support for infinite-length audio
Args:
sampling_rate (int, optional): Expected sampling rate. Defaults to 24000.
normalize_audio (bool, optional): Whether to normalize audio. Defaults to True.
target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25.
eps (float, optional): Small value for numerical stability. Defaults to 1e-6.
"""
model_input_names = ["input_features"]
def __init__(
self,
sampling_rate: int = 24000,
normalize_audio: bool = True,
target_dB_FS: float = -25,
eps: float = 1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.sampling_rate = sampling_rate
self.normalize_audio = normalize_audio
# Initialize audio normalizer if needed
if self.normalize_audio:
self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps)
else:
self.normalizer = None
# Save config
self.feature_extractor_dict = {
"sampling_rate": sampling_rate,
"normalize_audio": normalize_audio,
"target_dB_FS": target_dB_FS,
"eps": eps,
}
def _ensure_mono(self, audio: np.ndarray) -> np.ndarray:
"""
Convert stereo audio to mono if needed.
Args:
audio (np.ndarray): Input audio array
Returns:
np.ndarray: Mono audio array
"""
if len(audio.shape) == 1:
return audio
elif len(audio.shape) == 2:
if audio.shape[0] == 2: # (2, time)
return np.mean(audio, axis=0)
elif audio.shape[1] == 2: # (time, 2)
return np.mean(audio, axis=1)
else:
# If one dimension is 1, squeeze it
if audio.shape[0] == 1:
return audio.squeeze(0)
elif audio.shape[1] == 1:
return audio.squeeze(1)
else:
raise ValueError(f"Unexpected audio shape: {audio.shape}")
else:
raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}")
def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray:
"""
Process a single audio array.
Args:
audio: Single audio input
Returns:
np.ndarray: Processed audio
"""
# Convert to numpy array
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
else:
audio = audio.astype(np.float32)
# Ensure mono
audio = self._ensure_mono(audio)
# Normalize if requested
if self.normalize_audio and self.normalizer is not None:
audio = self.normalizer(audio)
return audio
def __call__(
self,
audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[str] = None,
**kwargs,
):
"""
Process audio for VibeVoice models.
Args:
audio: Audio input(s) to process. Can be:
- str: Path to audio file
- np.ndarray: Audio array
- List[float]: Audio as list of floats
- List[np.ndarray]: Batch of audio arrays
- List[str]: Batch of audio file paths
sampling_rate (int, optional): Sampling rate of the input audio
return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy)
Returns:
dict: Processed audio inputs with keys:
- input_features: Audio tensor(s) ready for the model
"""
if audio is None:
raise ValueError("Audio input is required")
# Validate sampling rate
if sampling_rate is not None and sampling_rate != self.sampling_rate:
logger.warning(
f"Input sampling rate ({sampling_rate}) differs from expected "
f"sampling rate ({self.sampling_rate}). Please resample your audio."
)
# Handle different input types
if isinstance(audio, str):
# Single audio file path
audio = self._load_audio_from_path(audio)
is_batched = False
elif isinstance(audio, list):
if len(audio) == 0:
raise ValueError("Empty audio list provided")
# Check if it's a list of file paths
if all(isinstance(item, str) for item in audio):
# Batch of audio file paths
audio = [self._load_audio_from_path(path) for path in audio]
is_batched = True
else:
# Check if it's batched audio arrays
is_batched = isinstance(audio[0], (np.ndarray, list))
else:
# Single audio array or list
is_batched = False
# Process audio
if is_batched:
processed_audio = [self._process_single_audio(a) for a in audio]
else:
processed_audio = [self._process_single_audio(audio)]
# Convert to tensors if requested
if return_tensors == "pt":
if len(processed_audio) == 1:
# Create a proper batch dimension (B, T)
input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1)
else:
# For batched input with different lengths, create a batch properly
input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1)
elif return_tensors == "np":
if len(processed_audio) == 1:
input_features = processed_audio[0][np.newaxis, np.newaxis, :]
else:
input_features = np.stack(processed_audio)[:, np.newaxis, :]
else:
input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio
outputs = {
"audio": input_features, # Use "audio" instead of "input_features"
}
return outputs
def _load_audio_from_path(self, audio_path: str) -> np.ndarray:
"""
Load audio from file path.
Args:
audio_path (str): Path to audio file
Returns:
np.ndarray: Loaded audio array
"""
# Get file extension to determine loading method
file_ext = os.path.splitext(audio_path)[1].lower()
if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']:
# Audio file - use librosa
import librosa
audio_array, sr = librosa.load(
audio_path,
sr=self.sampling_rate,
mono=True
)
return audio_array
elif file_ext == '.pt':
# PyTorch tensor file
audio_tensor = torch.load(audio_path, map_location='cpu').squeeze()
if isinstance(audio_tensor, torch.Tensor):
audio_array = audio_tensor.numpy()
else:
audio_array = np.array(audio_tensor)
return audio_array.astype(np.float32)
elif file_ext == '.npy':
# NumPy file
audio_array = np.load(audio_path)
return audio_array.astype(np.float32)
else:
raise ValueError(
f"Unsupported file format: {file_ext}. "
f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz"
)
def preprocess_audio(
self,
audio_path_or_array: Union[str, np.ndarray],
normalize: Optional[bool] = None,
) -> np.ndarray:
"""
Convenience method to preprocess audio from file path or array.
This method is kept for backward compatibility but __call__ is recommended.
Args:
audio_path_or_array: Path to audio file or numpy array
normalize: Whether to normalize (overrides default setting)
Returns:
np.ndarray: Preprocessed audio array
"""
if isinstance(audio_path_or_array, str):
audio_array = self._load_audio_from_path(audio_path_or_array)
else:
audio_array = np.array(audio_path_or_array, dtype=np.float32)
# Override normalization setting if specified
original_normalize = self.normalize_audio
if normalize is not None:
self.normalize_audio = normalize
try:
processed = self._process_single_audio(audio_array)
finally:
# Restore original setting
self.normalize_audio = original_normalize
return processed
# Override to_dict method for configuration saving
def to_dict(self) -> Dict[str, Any]:
"""
Convert the object to a dict containing all attributes needed for serialization.
"""
return self.feature_extractor_dict
def save_audio(
self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
):
"""
Save audio data to WAV file(s).
Args:
audio: Audio data to save. Can be:
- torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T)
- np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T)
- List of tensors or arrays
output_path: Path where to save the audio. If saving multiple files,
this is treated as a directory and individual files will be saved inside.
sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate.
normalize: Whether to normalize audio before saving.
batch_prefix: Prefix for batch files when saving multiple audios.
Returns:
List[str]: Paths to the saved audio files.
"""
if sampling_rate is None:
sampling_rate = self.sampling_rate
try:
import soundfile as sf
except ImportError:
raise ImportError(
"soundfile is required to save audio files. "
"Install it with: pip install soundfile"
)
# Ensure audio is in the right format
if isinstance(audio, torch.Tensor):
# Convert PyTorch tensor to numpy
audio_np = audio.float().detach().cpu().numpy()
elif isinstance(audio, np.ndarray):
audio_np = audio
elif isinstance(audio, list):
# Handle list of tensors or arrays
if all(isinstance(a, torch.Tensor) for a in audio):
audio_np = [a.float().detach().cpu().numpy() for a in audio]
else:
audio_np = audio
else:
raise ValueError(f"Unsupported audio type: {type(audio)}")
saved_paths = []
# Handle based on shape or type
if isinstance(audio_np, list):
# Multiple separate audios to save
output_dir = output_path
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Save each audio
for i, audio_item in enumerate(audio_np):
audio_item = self._prepare_audio_for_save(audio_item, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
sf.write(file_path, audio_item, sampling_rate)
saved_paths.append(file_path)
else:
# Handle different dimensions
if len(audio_np.shape) >= 3: # (B, C, T) or similar
# Get batch size
batch_size = audio_np.shape[0]
if batch_size > 1:
# Multiple audios in a batch
output_dir = output_path
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Save each audio in the batch
for i in range(batch_size):
# Extract single audio and remove channel dim if present
single_audio = audio_np[i]
if len(single_audio.shape) > 1:
if single_audio.shape[0] == 1: # (1, T)
single_audio = single_audio.squeeze(0)
single_audio = self._prepare_audio_for_save(single_audio, normalize)
file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav")
sf.write(file_path, single_audio, sampling_rate)
saved_paths.append(file_path)
else:
# Single audio with batch and channel dims
audio_item = audio_np.squeeze() # Remove batch and channel dimensions
audio_item = self._prepare_audio_for_save(audio_item, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
else:
# Single audio without batch dimension
audio_item = self._prepare_audio_for_save(audio_np, normalize)
sf.write(output_path, audio_item, sampling_rate)
saved_paths.append(output_path)
return saved_paths
def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray:
"""
Prepare audio for saving by ensuring it's the right shape and optionally normalizing.
Args:
audio: Audio data as numpy array
normalize: Whether to normalize audio
Returns:
np.ndarray: Processed audio ready for saving
"""
# Ensure right dimensionality
if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T)
audio = audio.squeeze(0)
# Normalize if requested
if normalize:
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio
__all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"]
View File
File diff suppressed because it is too large Load Diff
+19
View File
@@ -0,0 +1,19 @@
import math
import torch
class UniformSampler:
def __init__(self, timesteps = 1000):
self.timesteps = timesteps
def sample(self, batch_size, device):
return torch.randint(0, self.timesteps, (batch_size,), device=device)
class LogitNormalSampler:
def __init__(self, timesteps = 1000, m = 0, s = 1):
self.timesteps = timesteps
timesteps = torch.linspace(0, 1, timesteps)
logit = torch.log(timesteps / (1 - timesteps))
self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi))
def sample(self, batch_size, device):
return torch.multinomial(self.prob, batch_size, replacement=True).to(device)
View File
@@ -0,0 +1,166 @@
#!/usr/bin/env python
# coding=utf-8
import argparse
import json
import os
from pathlib import Path
import re
import torch
from typing import Dict, List, Tuple
from vibevoice.modular.configuration_vibevoice import (
VibeVoiceConfig
)
from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
from transformers.utils import logging
logger = logging.get_logger(__name__)
def convert_vibevoice_nnscaler_checkpoint_to_hf(
checkpoint_path: str,
pytorch_dump_folder_path: str,
config_path: str = None,
):
"""
Convert a nnscaler VibeVoice checkpoint to HuggingFace format.
Supports both regular checkpoints and tensor parallel checkpoints.
"""
# Load regular checkpoint
logger.info(f"Loading regular checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader']
# config = checkpoint['train_args']
init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path']
pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path']
init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1]
if init_config_path.exists():
logger.info(f"Loading initial config from {init_config_path}")
with open(init_config_path, 'r') as f:
init_config = json.load(f)
else:
raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.")
tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True)
logger.info(f"Tie word embeddings: {tie_word_embeddings}")
init_config['decoder_config']['use_cache'] = True
config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings)
# # Extract the model state dict
model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')}
if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys():
# If not tying weights, we need to add the lm_head weight separately
model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight']
# Override with provided config if available
if config_path:
logger.info(f"Loading config from {config_path}")
with open(config_path, 'r') as f:
config_dict = json.load(f)
config = VibeVoiceConfig.from_dict(config_dict)
# Set the default dtype to bfloat16 before creating the model
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
# Create the HuggingFace model
logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model")
model = VibeVoiceForConditionalGeneration(config)
# Restore original dtype
torch.set_default_dtype(original_dtype)
# Load the state dict
logger.info("Loading weights into model")
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
if missing_keys:
logger.warning(f"Missing keys: {missing_keys}")
if unexpected_keys:
logger.warning(f"Unexpected keys: {unexpected_keys}")
# Create output directory
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
# Save the model and config
logger.info(f"Saving model to {pytorch_dump_folder_path}")
# Save config
config.save_pretrained(pytorch_dump_folder_path)
# Save VibeVoiceProcessor configuration
logger.info("Saving VibeVoiceProcessor configuration")
processor_config = {
"processor_class": "VibeVoiceProcessor",
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
# Audio processor configuration
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": 24000,
"normalize_audio": True,
"target_dB_FS": -25,
"eps": 1e-6,
},
"language_model_pretrained_name": pretrained_name,
}
processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json")
with open(processor_config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Saved processor config to {processor_config_path}")
# Save model with sharding
# save_pretrained handles tied weights automatically
logger.info("Saving model weights with sharding...")
model.save_pretrained(
pytorch_dump_folder_path,
max_shard_size="2GB", # Set maximum size for each shard
safe_serialization=True # Ensure saving in .safetensors format
)
logger.info(f"Model weights saved to {pytorch_dump_folder_path}")
logger.info("Conversion complete!")
# Verify the saved model can be loaded
logger.info("Verifying saved model...")
loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path)
logger.info("Model successfully loaded from saved checkpoint!")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--nnscaler_checkpoint_path",
type=str,
required=True,
help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, "
"provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), "
"and the script will automatically detect and merge all parts.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
type=str,
required=True,
help="Path to the output PyTorch model directory",
)
parser.add_argument(
"--config_path",
type=str,
default=None,
help="Optional path to a config JSON file to override extracted config",
)
args = parser.parse_args()
convert_vibevoice_nnscaler_checkpoint_to_hf(
args.nnscaler_checkpoint_path,
args.pytorch_dump_folder_path,
args.config_path,
)
if __name__ == "__main__":
main()