diff --git a/docs/vibevoice-vllm-asr.md b/docs/vibevoice-vllm-asr.md new file mode 100644 index 0000000..f32cab4 --- /dev/null +++ b/docs/vibevoice-vllm-asr.md @@ -0,0 +1,112 @@ +# VibeVoice vLLM ASR Deployment + +Huggingface + +Deploy VibeVoice ASR model as a high-performance API service using [vLLM](https://github.com/vllm-project/vllm). This plugin provides OpenAI-compatible API endpoints for speech-to-text transcription with streaming support. + +## šŸ”„ Key Features + +- **šŸš€ High-Performance Serving**: Optimized for high-throughput ASR inference with vLLM's continuous batching +- **šŸ“” OpenAI-Compatible API**: Standard `/v1/chat/completions` endpoint with streaming support +- **šŸŽµ Long Audio Support**: Process up to 60+ minutes of audio in a single request +- **šŸ”Œ Plugin Architecture**: No vLLM source code modification required - just install and run + +## šŸ› ļø Installation + +Using Official vLLM Docker Image (Recommended) + +```bash +# 1. Pull the official vLLM image +docker pull vllm/vllm-openai:latest + +# 2. Start an interactive container +docker run -it --gpus all --name vibevoice-vllm \ + --ipc=host \ + -p 8000:8000 \ + -e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \ + -e PYTORCH_ALLOC_CONF=expandable_segments:True \ + -v /path/to/models:/models \ + -v /path/to/VibeVoice:/app \ + -w /app \ + --entrypoint bash \ + vllm/vllm-openai:latest + +# 3. Inside container: Install system dependencies +bash vllm_plugin/scripts/install_deps.sh + +# 4. Inside container: Install VibeVoice with vLLM support +pip install -e .[vllm] + +# 5. Inside container: (Optional) Generate tokenizer files if needed +python3 -m vllm_plugin.tools.generate_tokenizer_files --output /models/your_model + +# 6. Inside container: Start vLLM server +vllm serve /models/your_model \ + --served-model-name vibevoice \ + --trust-remote-code \ + --dtype bfloat16 \ + --max-num-seqs 64 \ + --max-model-len 65536 \ + --max-num-batched-tokens 32768 \ + --gpu-memory-utilization 0.8 \ + --enforce-eager \ + --no-enable-prefix-caching \ + --enable-chunked-prefill \ + --chat-template-content-format openai \ + --tensor-parallel-size 1 \ + --allowed-local-media-path /app \ + --port 8000 +``` + +> **Note**: This approach allows you to switch models, adjust parameters, and debug issues without rebuilding the container. + + +## šŸš€ Quick Start + +### Test the API + +Once the vLLM server is running, test it with the provided script: + +```bash +# Run the test script (inside container) +python3 vllm_plugin/tests/test_api.py /path/to/audio.wav +``` + + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `VIBEVOICE_FFMPEG_MAX_CONCURRENCY` | Maximum FFmpeg processes for audio decoding | `64` | +| `PYTORCH_CUDA_ALLOC_CONF` | CUDA memory allocator config | `expandable_segments:True` | + + + +## šŸ“Š Performance Tips + +1. **GPU Memory**: Use `--gpu-memory-utilization 0.9` for maximum throughput if you have dedicated GPU +2. **Batch Size**: Increase `--max-num-seqs` for higher concurrency (requires more GPU memory) +3. **FFmpeg Concurrency**: Tune `VIBEVOICE_FFMPEG_MAX_CONCURRENCY` based on CPU cores + +## 🚨 Troubleshooting + +### Common Issues + +1. **"CUDA out of memory"** + - Reduce `--gpu-memory-utilization` + - Reduce `--max-num-seqs` + - Use smaller `--max-model-len` + +2. **"Audio decoding failed"** + - Ensure FFmpeg is installed: `ffmpeg -version` + - Check audio file format is supported + +3. **"Model not found"** + - Ensure model path contains `config.json` and model weights + - Generate tokenizer files if missing + +4. **"Plugin not loaded"** + - Verify installation: `pip show vibevoice` + - Check entry point: `pip show -f vibevoice | grep entry` + + diff --git a/pyproject.toml b/pyproject.toml index 6dc69b3..0cc390a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,20 @@ asr = [ "pydub" # for visualization ] +vllm = [ + "transformers>=4.51.3", + "fastapi", + "uvicorn[standard]", + "requests", +] + +[project.entry-points."vllm.general_plugins"] +vibevoice = "vllm_plugin:register_vibevoice" + [project.urls] "Homepage" = "https://github.com/microsoft/VibeVoice" "Bug Tracker" = "https://github.com/microsoft/VibeVoice/issues" [tool.setuptools.packages.find] where = ["."] +include = ["vibevoice*", "vllm_plugin*"] diff --git a/vibevoice/modular/configuration_vibevoice.py b/vibevoice/modular/configuration_vibevoice.py index 02b2751..d5e3149 100644 --- a/vibevoice/modular/configuration_vibevoice.py +++ b/vibevoice/modular/configuration_vibevoice.py @@ -240,6 +240,26 @@ class VibeVoiceConfig(PretrainedConfig): super().__init__(**kwargs) + def get_text_config(self, decoder=False): + """ + Returns the text config for this model. + + vLLM uses this method to get the text configuration from multimodal models. + This allows vLLM to correctly determine hidden_size, num_attention_heads, + and other properties needed for memory profiling and model execution. + + For VibeVoice, the "text config" is the decoder_config (Qwen2Config). + + Args: + decoder: If True, return the decoder config (for encoder-decoder models). + For VibeVoice, this is always the decoder_config. + + Returns: + The decoder configuration (Qwen2Config) which contains hidden_size, etc. + """ + return self.decoder_config + + class VibeVoiceASRConfig(PretrainedConfig): model_type = "vibevoice" is_composition = True diff --git a/vibevoice/processor/audio_utils.py b/vibevoice/processor/audio_utils.py index ad4d10d..3f9d112 100644 --- a/vibevoice/processor/audio_utils.py +++ b/vibevoice/processor/audio_utils.py @@ -1,3 +1,6 @@ +import os +import threading + import numpy as np from subprocess import run from typing import List, Optional, Union, Dict, Any @@ -57,6 +60,7 @@ def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24 cmd = [ "ffmpeg", + "-loglevel", "error", "-nostdin", "-threads", "0", "-i", file, @@ -64,14 +68,84 @@ def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24 "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr_to_use), - "-" + "-", ] - out = run(cmd, capture_output=True, check=True).stdout + out = _run_ffmpeg(cmd).stdout audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 return audio_data, sr_to_use + +def _get_ffmpeg_max_concurrency() -> int: + """Get the maximum FFmpeg concurrency from environment variable.""" + v = os.getenv("VIBEVOICE_FFMPEG_MAX_CONCURRENCY", "") + try: + n = int(v) if v.strip() else 0 + except Exception: + n = 0 + # 0/negative means no explicit limit. + return n + + +_FFMPEG_MAX_CONCURRENCY = _get_ffmpeg_max_concurrency() +_FFMPEG_SEM = threading.Semaphore(_FFMPEG_MAX_CONCURRENCY) if _FFMPEG_MAX_CONCURRENCY > 0 else None + + +def _run_ffmpeg(cmd: list, *, stdin_bytes: bytes = None): + """Run ffmpeg with optional global concurrency limiting. + + This is important for vLLM multi-request concurrency: spawning too many + ffmpeg processes can saturate CPU/IO and cause request failures/timeouts. + """ + if _FFMPEG_SEM is None: + return run(cmd, capture_output=True, check=True, input=stdin_bytes) + with _FFMPEG_SEM: + return run(cmd, capture_output=True, check=True, input=stdin_bytes) + + +def load_audio_bytes_use_ffmpeg(data: bytes, *, resample: bool = False, target_sr: int = 24000): + """Decode audio bytes via ffmpeg stdin pipe. + + Compared to writing bytes to a temp file, this avoids filesystem IO and + reduces contention under high request concurrency. + + Parameters + ---------- + data: bytes + The audio data bytes + resample: bool + Whether to resample the audio (must be True) + target_sr: int + The target sample rate if resampling is requested + + Returns + ------- + A tuple containing: + - A NumPy array with the audio waveform in float32 dtype + - The sample rate + """ + if not resample: + # For stdin bytes, we don't have a cheap/robust way to probe original sr. + # Keep behavior explicit. + raise ValueError("load_audio_bytes_use_ffmpeg requires resample=True") + + cmd = [ + "ffmpeg", + "-loglevel", "error", + "-threads", "0", + "-i", "pipe:0", + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(target_sr), + "-", + ] + out = _run_ffmpeg(cmd, stdin_bytes=data).stdout + audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + return audio_data, target_sr + + class AudioNormalizer: """ Audio normalization class for VibeVoice tokenizer. diff --git a/vllm_plugin/__init__.py b/vllm_plugin/__init__.py new file mode 100644 index 0000000..b696a45 --- /dev/null +++ b/vllm_plugin/__init__.py @@ -0,0 +1,63 @@ +"""VibeVoice vLLM Plugin - Registers VibeVoice model for vLLM inference. + +This plugin enables VibeVoice ASR models to be loaded and served through vLLM. +It registers the model architecture, configuration, tokenizer, and processor +with their respective registries. + +The plugin is automatically loaded by vLLM via the 'vllm.general_plugins' +entry point defined in pyproject.toml. +""" + +from vllm.model_executor.models import ModelRegistry +from transformers import AutoConfig, AutoTokenizer, Qwen2Tokenizer, AutoProcessor, Qwen2AudioProcessor + +from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig +from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast + +from .model import VibeVoiceForCausalLM +from .inputs import vibevoice_audio_input_mapper + + +def register_vibevoice(): + """Register VibeVoice model with vLLM and transformers. + + This function is called automatically by vLLM through the entry point + mechanism. It registers: + - VibeVoiceConfig with AutoConfig + - VibeVoiceASRTextTokenizerFast with AutoTokenizer (for ASR) + - Qwen2AudioProcessor with AutoProcessor + - VibeVoiceForCausalLM with vLLM ModelRegistry + """ + # Register the configuration class with transformers + AutoConfig.register("vibevoice", VibeVoiceConfig) + + # Register the tokenizer with transformers. + # IMPORTANT (ASR): Align with the PyTorch ASR path. + # VibeVoiceASRTextTokenizerFast maps: + # speech_start_id -> <|object_ref_start|> + # speech_pad_id -> <|box_start|> + # speech_end_id -> <|object_ref_end|> + # This significantly affects ASR quality even when requests succeed. + try: + AutoTokenizer.register( + VibeVoiceConfig, + slow_tokenizer_class=Qwen2Tokenizer, + fast_tokenizer_class=VibeVoiceASRTextTokenizerFast, + ) + except Exception: + pass # May already be registered + + # Register the processor with transformers + try: + AutoProcessor.register(VibeVoiceConfig, processor_class=Qwen2AudioProcessor) + except Exception: + pass # May already be registered + + # Register the model class with the architecture name "VibeVoice" + # This name must match the "architectures" list in config.json + ModelRegistry.register_model("VibeVoice", VibeVoiceForCausalLM) + ModelRegistry.register_model("VibeVoiceForASRTraining", VibeVoiceForCausalLM) + + +# Note: This function is called via vllm.general_plugins entry point +# defined in pyproject.toml, ensuring it runs in all vLLM processes diff --git a/vllm_plugin/inputs.py b/vllm_plugin/inputs.py new file mode 100644 index 0000000..e801484 --- /dev/null +++ b/vllm_plugin/inputs.py @@ -0,0 +1,82 @@ +"""Audio input mapper for vLLM multimodal pipeline. + +This module handles audio data loading and preprocessing for VibeVoice ASR inference. +It converts various audio input formats (path, bytes, numpy array) into tensors +that can be processed by the VibeVoice model. +""" +import torch +import numpy as np +from typing import Union, List +from vllm.multimodal.inputs import MultiModalInputs +from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer + + +def load_audio(audio_path: str, target_sr: int = 24000) -> np.ndarray: + """Load and normalize audio from file path. + + Args: + audio_path: Path to audio file + target_sr: Target sample rate (default 24kHz for VibeVoice) + + Returns: + Normalized audio waveform as numpy array + """ + # Load with FFmpeg (handles various formats) + audio, sr = load_audio_use_ffmpeg(audio_path, resample=True, target_sr=target_sr) + + # Normalize audio + normalizer = AudioNormalizer() + audio = normalizer(audio) + + return audio + + +def vibevoice_audio_input_mapper(ctx, data: Union[str, bytes, np.ndarray, List[str]]) -> MultiModalInputs: + """Map audio input data to vLLM MultiModalInputs format. + + This function is registered as the input mapper for VibeVoice audio processing. + It handles multiple input formats and converts them to normalized tensors. + + Args: + ctx: vLLM context (unused) + data: Audio data in one of these formats: + - str: Path to audio file + - bytes: Raw audio bytes (any format FFmpeg supports) + - np.ndarray: Pre-loaded audio waveform + - List[str]: List of audio paths (only first is used) + + Returns: + MultiModalInputs containing: + - audio: Audio tensor (float32) + - audio_length: Length of audio in samples + """ + # Handle list input (take first item) + if isinstance(data, list): + data = data[0] + + audio_waveform = None + + if isinstance(data, str): + # Load from file path + audio_waveform = load_audio(data) + + elif isinstance(data, bytes): + # Decode bytes directly via ffmpeg stdin pipe to avoid temp-file IO + audio_waveform, _sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000) + normalizer = AudioNormalizer() + audio_waveform = normalizer(audio_waveform) + + elif isinstance(data, np.ndarray): + # Already loaded numpy array + audio_waveform = data + else: + raise ValueError(f"Unsupported audio data type: {type(data)}") + + # Convert to tensor + audio_tensor = torch.from_numpy(audio_waveform).float() + audio_length = audio_tensor.shape[0] + + return MultiModalInputs({ + "audio": audio_tensor, + "audio_length": audio_length + }) diff --git a/vllm_plugin/model.py b/vllm_plugin/model.py new file mode 100644 index 0000000..bcb4ca3 --- /dev/null +++ b/vllm_plugin/model.py @@ -0,0 +1,1329 @@ +""" +VibeVoice vLLM Plugin Model - Native Multimodal Integration + +This module implements the VibeVoice ASR model with full vLLM multimodal registry +integration for speech-to-text inference. +""" + +from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal +import json +import math +import os +import sys +from pathlib import Path +import torch +import torch.nn as nn +import numpy as np +from io import BytesIO +import tempfile +import base64 + + +# ============================================================================ +# Audio Loading: FFmpeg-based AudioMediaIO +# ============================================================================ +# VibeVoice uses FFmpeg for audio decoding to ensure consistent behavior +# across different audio formats (MP3, WAV, FLAC, etc.). + + +from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer + + +def _suffix_from_media_type(media_type: str | None) -> str: + if not media_type: + return ".bin" + mt = media_type.lower().strip() + if mt in ("audio/wav", "audio/x-wav", "audio/wave"): + return ".wav" + if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"): + return ".mp3" + if mt in ("audio/flac",): + return ".flac" + if mt in ("audio/ogg", "audio/opus"): + return ".ogg" + if mt in ("audio/mp4", "audio/m4a"): + return ".m4a" + if mt in ("video/mp4",): + return ".mp4" + return ".bin" + + +def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]: + """Load audio bytes using FFmpeg. + + Returns: + Tuple of (audio_waveform, sample_rate). Sample rate is always 24000. + """ + # Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency. + audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000) + normalizer = AudioNormalizer() + audio = normalizer(audio) + return audio, sr + +def _ffmpeg_load_file(filepath) -> tuple[np.ndarray, int]: + """Load audio file using FFmpeg. + + Returns: + Tuple of (audio_waveform, sample_rate). Sample rate is always 24000. + """ + audio, sr = load_audio_use_ffmpeg(str(filepath), resample=True, target_sr=24000) + normalizer = AudioNormalizer() + audio = normalizer(audio) + return audio, sr + +# Register FFmpeg-based audio loader +import vllm.multimodal.audio as _vllm_audio_module +_OriginalAudioMediaIO = _vllm_audio_module.AudioMediaIO + +class _PatchedAudioMediaIO(_OriginalAudioMediaIO): + """AudioMediaIO implementation using FFmpeg for audio decoding.""" + + def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]: + return _ffmpeg_load_bytes(data, media_type=None) + + def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]: + return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type) + + def load_file(self, filepath) -> tuple[np.ndarray, int]: + return _ffmpeg_load_file(filepath) + +# Replace globally +_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO + +# Also patch in utils module where it's imported +import vllm.multimodal.utils as _vllm_utils_module +_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO + +# ============================================================================ + +from transformers import Qwen2Config, BatchFeature +from transformers.models.whisper import WhisperFeatureExtractor +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.config import VllmConfig, ModelConfig +from vllm.config.speech_to_text import SpeechToTextConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.parse import MultiModalDataParser +from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings +from vllm.inputs import PromptType +from vllm.model_executor.models.utils import ( + init_vllm_registered_model, + maybe_prefix, + AutoWeightsLoader, + WeightsMapper, +) +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs + +# Import VibeVoice components +from vibevoice.modular.modular_vibevoice_tokenizer import ( + VibeVoiceAcousticTokenizerModel, + VibeVoiceSemanticTokenizerModel, + VibeVoiceTokenizerStreamingCache, + VibeVoiceTokenizerEncoderOutput, +) +from vibevoice.modular.configuration_vibevoice import ( + VibeVoiceAcousticTokenizerConfig, + VibeVoiceSemanticTokenizerConfig, +) + + +class SpeechConnector(nn.Module): + """Projects speech features to language model hidden dimension. + + Architecture: fc1 -> RMSNorm -> fc2 (no activation function) + """ + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, output_dim) + self.norm = LlamaRMSNorm(output_dim, eps=1e-6) + self.fc2 = nn.Linear(output_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.norm(x) + x = self.fc2(x) + return x + + +class LlamaRMSNorm(nn.Module): + """RMSNorm layer used in SpeechConnector.""" + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class VibeVoiceAudioEncoder(nn.Module): + """ + VibeVoice Audio Encoder module. + + Encapsulates Acoustic/Semantic VAE Tokenizers and projection Connectors. + Converts raw audio waveforms into embeddings compatible with the language model. + + Features: + - Streaming support for long audio (>60s by default) + - Configurable dtype for numerical precision + - Supports both sampling and deterministic (mean) modes + """ + def __init__(self, config): + super().__init__() + self.config = config + + import sys + + def get_cfg(obj, key, default=None): + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + self.acoustic_vae_dim = get_cfg(config, "acoustic_vae_dim", 64) + self.semantic_vae_dim = get_cfg(config, "semantic_vae_dim", 128) + + decoder_config = get_cfg(config, "decoder_config") + text_config = get_cfg(config, "text_config") + + target_hidden_size = None + + if decoder_config is not None: + target_hidden_size = get_cfg(decoder_config, "hidden_size") + + if target_hidden_size is None and text_config is not None: + target_hidden_size = get_cfg(text_config, "hidden_size") + + if target_hidden_size is None: + target_hidden_size = get_cfg(config, "hidden_size") + + if target_hidden_size is None: + print("[VibeVoice] WARN: Could not find hidden_size in config! Defaulting to 3584 (7B).", file=sys.stderr) + self.hidden_size = 3584 + else: + self.hidden_size = target_hidden_size + + ac_cfg = get_cfg(config, "acoustic_tokenizer_config") + sc_cfg = get_cfg(config, "semantic_tokenizer_config") + + if ac_cfg is None or sc_cfg is None: + raise ValueError("Missing acoustic/semantic tokenizer config in model config") + + # Handle both dict and already-constructed config objects + if isinstance(ac_cfg, VibeVoiceAcousticTokenizerConfig): + acoustic_config = ac_cfg + elif isinstance(ac_cfg, dict): + acoustic_config = VibeVoiceAcousticTokenizerConfig(**ac_cfg) + else: + raise TypeError(f"acoustic_tokenizer_config has unexpected type: {type(ac_cfg)}") + + if isinstance(sc_cfg, VibeVoiceSemanticTokenizerConfig): + semantic_config = sc_cfg + elif isinstance(sc_cfg, dict): + semantic_config = VibeVoiceSemanticTokenizerConfig(**sc_cfg) + else: + raise TypeError(f"semantic_tokenizer_config has unexpected type: {type(sc_cfg)}") + + # Tokenizers use float32 for numerical precision + self.acoustic_tokenizer = VibeVoiceAcousticTokenizerModel(acoustic_config) + self.semantic_tokenizer = VibeVoiceSemanticTokenizerModel(semantic_config) + + # Get audio encoder dtype from config (defaults to float32 for precision) + root_torch_dtype = get_cfg(config, "torch_dtype", None) + if root_torch_dtype is not None: + if isinstance(root_torch_dtype, str): + self._audio_encoder_dtype = getattr(torch, root_torch_dtype) + else: + self._audio_encoder_dtype = root_torch_dtype + else: + self._audio_encoder_dtype = torch.float32 + + self.acoustic_connector = SpeechConnector(self.acoustic_vae_dim, self.hidden_size) + self.semantic_connector = SpeechConnector(self.semantic_vae_dim, self.hidden_size) + + self.compress_ratio = get_cfg(config, "speech_tok_compress_ratio", 3200) + + # Streaming controls + self.sample_rate = get_cfg(config, "target_sample_rate", 24000) + + # Default to True (per requirement): segment + cache inside one forward call. + self.enable_streaming = get_cfg(config, "enable_streaming", True) + self.streaming_segment_duration = get_cfg(config, "streaming_segment_duration", 60.0) + + # Control whether to use sample() or .mean for acoustic tokens + # Default: use sample() for training-consistent behavior + # Set VIBEVOICE_USE_MEAN=1 for deterministic output + use_mean_env = os.getenv("VIBEVOICE_USE_MEAN", "").strip().lower() + self.use_sample = use_mean_env not in ("1", "true", "yes") + + # Language model dtype (set by VibeVoiceForCausalLM.__init__) + # This is the dtype that audio embeddings will be converted to before + # being passed to the language model. Defaults to bfloat16. + self._lm_dtype: torch.dtype = torch.bfloat16 + + def _ensure_audio_encoder_dtype(self): + """Ensure all audio encoder components use the correct dtype from config. + + vLLM may convert weights to a different dtype (e.g., bfloat16) during loading. + This method converts audio encoder components back to the config-specified dtype + (typically float32) for numerical precision during audio encoding. + """ + import sys + target_dtype = self._audio_encoder_dtype + + # Check and convert acoustic_tokenizer + try: + acoustic_dtype = next(self.acoustic_tokenizer.parameters()).dtype + if acoustic_dtype != target_dtype: + self.acoustic_tokenizer = self.acoustic_tokenizer.to(dtype=target_dtype) + print(f"[VibeVoice] Converted acoustic_tokenizer to {target_dtype} (was {acoustic_dtype})", file=sys.stderr) + except StopIteration: + pass + + # Check and convert semantic_tokenizer + try: + semantic_dtype = next(self.semantic_tokenizer.parameters()).dtype + if semantic_dtype != target_dtype: + self.semantic_tokenizer = self.semantic_tokenizer.to(dtype=target_dtype) + print(f"[VibeVoice] Converted semantic_tokenizer to {target_dtype} (was {semantic_dtype})", file=sys.stderr) + except StopIteration: + pass + + # Check and convert acoustic_connector + try: + ac_conn_dtype = next(self.acoustic_connector.parameters()).dtype + if ac_conn_dtype != target_dtype: + self.acoustic_connector = self.acoustic_connector.to(dtype=target_dtype) + print(f"[VibeVoice] Converted acoustic_connector to {target_dtype} (was {ac_conn_dtype})", file=sys.stderr) + except StopIteration: + pass + + # Check and convert semantic_connector + try: + sc_conn_dtype = next(self.semantic_connector.parameters()).dtype + if sc_conn_dtype != target_dtype: + self.semantic_connector = self.semantic_connector.to(dtype=target_dtype) + print(f"[VibeVoice] Converted semantic_connector to {target_dtype} (was {sc_conn_dtype})", file=sys.stderr) + except StopIteration: + pass + + def forward( + self, + audio: torch.Tensor, + *, + use_streaming: bool = True, + segment_duration_s: Optional[float] = None, + use_sample: Optional[bool] = None, + ) -> torch.Tensor: + """Encode audio with optional streaming for long clips. + + Args: + audio: Input audio tensor [B, T] or [T] + use_streaming: Whether to enable segmented encoding for long audio + segment_duration_s: Segment length in seconds (defaults to 60s) + use_sample: If True, use sampling for acoustic tokens; if False, use mean + Defaults to self.use_sample (controlled by VIBEVOICE_USE_MEAN env var) + + Returns: + Audio embeddings tensor compatible with the language model + """ + # Ensure audio encoder components use correct dtype + self._ensure_audio_encoder_dtype() + + # Audio input should match the audio encoder dtype + audio = audio.to(dtype=self._audio_encoder_dtype) + + if audio.ndim == 1: + audio = audio.unsqueeze(0) + + # Resolve streaming options + segment_duration = segment_duration_s or self.streaming_segment_duration + sample_rate = self.sample_rate + total_samples = audio.shape[-1] + segment_samples = int(segment_duration * sample_rate) + + use_streaming = use_streaming and self.enable_streaming and total_samples > segment_samples + + # Resolve use_sample flag + if use_sample is None: + use_sample = self.use_sample + + # Keep encoding in inference mode to avoid autograd build-up + with torch.no_grad(): + if not use_streaming: + acoustic_input = audio.unsqueeze(1) + acoustic_out = self.acoustic_tokenizer.encode(acoustic_input) + # Use sample() or .mean based on use_sample flag + if use_sample: + acoustic_tokens = acoustic_out.sample( + dist_type=self.acoustic_tokenizer.std_dist_type + )[0] + else: + acoustic_tokens = acoustic_out.mean + + # Connector is now float32, no conversion needed + acoustic_embeds = self.acoustic_connector(acoustic_tokens) + + semantic_out = self.semantic_tokenizer.encode(acoustic_input) + # Semantic always uses .mean for consistency + semantic_tokens = semantic_out.mean + # Connector is now float32, no conversion needed + semantic_embeds = self.semantic_connector(semantic_tokens) + else: + # ========================================== + # Streaming path (Retained for future use) + # ========================================== + acoustic_cache = VibeVoiceTokenizerStreamingCache() + semantic_cache = VibeVoiceTokenizerStreamingCache() + acoustic_mean_segments = [] + semantic_mean_segments = [] + batch_size = audio.shape[0] + sample_indices = torch.arange(batch_size, device=audio.device) + + def _iter_segments(total_length: int, segment_length: int): + for start in range(0, total_length, segment_length): + end = min(start + segment_length, total_length) + if end > start: + yield start, end + + segments = list(_iter_segments(total_samples, segment_samples)) + num_segments = len(segments) + for seg_idx, (start, end) in enumerate(segments): + chunk = audio[:, start:end].contiguous() + if chunk.numel() == 0: + continue + + # Check if this is the final segment + is_final = (seg_idx == num_segments - 1) + + # --- Acoustic Encode --- + acoustic_enc_out = self.acoustic_tokenizer.encode( + chunk.unsqueeze(1), + cache=acoustic_cache, + sample_indices=sample_indices, + use_cache=True, + is_final_chunk=is_final, + ) + acoustic_mean_segments.append(acoustic_enc_out.mean) + + # --- Semantic Encode --- + semantic_enc_out = self.semantic_tokenizer.encode( + chunk.unsqueeze(1), + cache=semantic_cache, + sample_indices=sample_indices, + use_cache=True, + is_final_chunk=is_final, + ) + semantic_mean_segments.append(semantic_enc_out.mean) + + # Concatenate sequence outputs (Acoustic) + if len(acoustic_mean_segments) == 0: + acoustic_mean_full = torch.zeros( + (batch_size, 0, self.acoustic_vae_dim), + device=audio.device, + dtype=self._audio_encoder_dtype # Use config dtype + ) + else: + acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous() + + # Get acoustic tokens based on use_sample flag + acoustic_enc_full = VibeVoiceTokenizerEncoderOutput( + mean=acoustic_mean_full, + std=self.acoustic_tokenizer.fix_std, + ) + if use_sample: + acoustic_tokens = acoustic_enc_full.sample( + dist_type=self.acoustic_tokenizer.std_dist_type + )[0] + else: + acoustic_tokens = acoustic_enc_full.mean + # Connector uses same dtype as tokenizer + acoustic_embeds = self.acoustic_connector(acoustic_tokens) + + # Concatenate sequence outputs (Semantic) + if len(semantic_mean_segments) == 0: + semantic_tokens = torch.zeros( + (batch_size, 0, self.semantic_vae_dim), + device=audio.device, + dtype=self._audio_encoder_dtype # Use config dtype + ) + else: + semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous() + # Connector uses same dtype as tokenizer + semantic_embeds = self.semantic_connector(semantic_tokens) + + # Combine acoustic and semantic embeddings + combined_embeds = acoustic_embeds + semantic_embeds + + # Convert to language model dtype for compatibility + # Audio encoder uses config.torch_dtype (typically float32) for numerical precision, + # but LM expects the dtype specified by vLLM's --dtype flag (e.g., bfloat16, float16) + combined_embeds = combined_embeds.to(dtype=self._lm_dtype) + + return combined_embeds + +# ============================================================================ +# vLLM Multimodal Processing Infrastructure +# ============================================================================ + +class VibeVoiceProcessingInfo(BaseProcessingInfo): + """Processing info for VibeVoice multimodal model.""" + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_feature_extractor(self, **kwargs) -> WhisperFeatureExtractor: + """ + Get a WhisperFeatureExtractor for vLLM profiling compatibility. + + IMPORTANT: This is NOT used in actual inference! + VibeVoice uses its own acoustic/semantic VAE tokenizers operating on + raw 24kHz waveforms, NOT Whisper mel spectrograms. + + This feature extractor exists only to satisfy vLLM's multimodal + profiling infrastructure which may query audio parameters like + sampling_rate and chunk_length for memory estimation. + """ + # Read config from preprocessor_config.json if available + import json + import os + model_path = self.ctx.model_config.model + preprocessor_path = os.path.join(model_path, "preprocessor_config.json") + + # Default values: keep a coherent (sr, hop) pair. + # VibeVoice runs at 24kHz in this repo (see demo/asr_transcribe_file.py). + config = { + "sampling_rate": 24000, + "feature_size": 128, + # 10ms hop at 24kHz + "hop_length": 240, + "chunk_length": 30, + "n_fft": 400, + "padding_value": 0.0, + } + + # Try to load from config file + if os.path.exists(preprocessor_path): + try: + with open(preprocessor_path, "r") as f: + file_config = json.load(f) + config.update({k: file_config[k] for k in config.keys() if k in file_config}) + except Exception: + pass # Use defaults + + return WhisperFeatureExtractor( + feature_size=config["feature_size"], + sampling_rate=config["sampling_rate"], + hop_length=config["hop_length"], + chunk_length=config["chunk_length"], + n_fft=config["n_fft"], + padding_value=config["padding_value"], + ) + + def get_audio_token_info(self) -> dict: + """ + Get audio special tokens and their IDs. + + Returns dict with: + audio_token, audio_bos_token, audio_eos_token, + audio_token_id, audio_bos_id, audio_eos_id + """ + tokenizer = self.get_tokenizer() + vocab = tokenizer.get_vocab() + + # VibeVoice special tokens + tokens = { + "audio_token": "<|AUDIO|>", + "audio_bos_token": "<|audio_bos|>", + "audio_eos_token": "<|audio_eos|>", + } + + # Get IDs + tokens["audio_token_id"] = vocab.get(tokens["audio_token"]) + tokens["audio_bos_id"] = vocab.get(tokens["audio_bos_token"]) + tokens["audio_eos_id"] = vocab.get(tokens["audio_eos_token"]) + + return tokens + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None} + + +class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]): + """ + Build dummy inputs for multimodal profiling. + + Dummy text uses the raw <|AUDIO|> token(s). vLLM's processing pipeline will + expand each <|AUDIO|> via `VibeVoiceMultiModalProcessor._get_prompt_updates` + into the full ASR format: + [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id] + where N is derived from audio length / compress_ratio. + """ + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + if num_audios <= 0: + return "" + + # Get the audio token from our token info helper + token_info = self.info.get_audio_token_info() + audio_token = token_info["audio_token"] + + # Return ONLY the audio tokens - the HF processor adds bos/eos + return audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, Any] | None = None, + ) -> Dict[str, Any]: + """Generate dummy audio data for profiling.""" + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + # Generate dummy audio as numpy arrays (what the HF processor expects) + return { + "audio": [np.zeros(audio_len, dtype=np.float32) for _ in range(num_audios)] + } + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, Any] | None = None, + ) -> ProcessorInputs: + """Build ProcessorInputs for dummy profiling.""" + return ProcessorInputs( + prompt=self.get_dummy_text(mm_counts), + mm_data=self.get_dummy_mm_data(seq_len, mm_counts, mm_options), + ) + + +def _vibevoice_field_config(hf_inputs: Mapping[str, torch.Tensor]): + """Map HF processor output keys to audio modality. + + Returns a config dict that tells vLLM how to batch multimodal data. + """ + # Always define the config for all fields we use + # Even if the field isn't in hf_inputs, vLLM needs to know how to batch it + config = { + # These are our custom fields for VibeVoice + "raw_audio": MultiModalFieldConfig.batched("audio"), + "raw_audio_lengths": MultiModalFieldConfig.batched("audio"), + "salt": MultiModalFieldConfig.batched("audio"), + } + + # Add optional Whisper features if present + if "input_features" in hf_inputs: + config["input_features"] = MultiModalFieldConfig.batched("audio") + if "feature_attention_mask" in hf_inputs: + config["feature_attention_mask"] = MultiModalFieldConfig.batched("audio") + + return config + + +class VibeVoiceMultiModalProcessor(BaseMultiModalProcessor[VibeVoiceProcessingInfo]): + """ + Multimodal processor for VibeVoice. + + Handles the conversion of raw audio inputs to model-ready features, + and manages the prompt token replacement for audio placeholders. + """ + + def _get_data_parser(self) -> MultiModalDataParser: + """Create a data parser with the correct target sample rate (24kHz).""" + # VibeVoice requires 24kHz, not 16kHz (Whisper default) + target_sr = 24000 + return MultiModalDataParser(target_sr=target_sr) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + """ + Process prompt and audio for vLLM multimodal pipeline. + + We intentionally do NOT run a HF processor that would pre-expand + `<|AUDIO|>` inside this method. Instead we: + 1) Tokenize the prompt as-is (so `<|AUDIO|>` stays a single token) + 2) Store raw audio tensors for `embed_multimodal` to encode later + 3) Let vLLM call `_get_prompt_updates` to expand the single `<|AUDIO|>` + into the full ASR format: [speech_start] + N*[speech_pad] + [speech_end] + [\\n] + """ + # Handle the case where 'audios' key is used (transformers deprecation) + mm_data = dict(mm_data) # Make a mutable copy + audios = mm_data.pop("audios", None) + if audios is not None and "audio" not in mm_data: + mm_data["audio"] = audios + + # Text-only input handling + if not mm_data.get("audio"): + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + # Get raw audio data + raw_audio_list = mm_data.get("audio") + if isinstance(raw_audio_list, np.ndarray): + raw_audio_list = [raw_audio_list] + elif not isinstance(raw_audio_list, list): + raw_audio_list = list(raw_audio_list) + + # Tokenize prompt directly to preserve <|AUDIO|> as a single token + # vLLM will expand it via _get_prompt_updates + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + + # Create result with input_ids + result = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + # Add raw audio tensors for VibeVoice encoder + # Stack into a single tensor for vLLM's batched field config + max_len = max(len(a) for a in raw_audio_list) + raw_audio_tensors = [] + audio_lengths = [] + for audio in raw_audio_list: + audio_len = len(audio) + audio_lengths.append(audio_len) + if audio_len < max_len: + audio = np.pad(audio, (0, max_len - audio_len), mode='constant') + raw_audio_tensors.append(torch.from_numpy(audio).float()) + + # Stack into [num_audios, max_len] tensor + stacked_audio = torch.stack(raw_audio_tensors, dim=0) # Shape: [num_audios, max_len] + result["raw_audio"] = stacked_audio + # Convert lengths to tensor as well + result["raw_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long) + + # Add a random salt to ensure unique hash and bypass cache + import uuid + # Use a random integer for salt + salt_val = hash(str(uuid.uuid4())) % 100000 + result["salt"] = torch.tensor([salt_val], dtype=torch.long).expand(len(raw_audio_list)) + + return result + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + """Return whether the HF processor applies prompt updates. + + Returns False because we handle token expansion via _get_prompt_updates. + """ + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + """Configure which HF output fields map to which modality.""" + return _vibevoice_field_config(hf_inputs) + + def _get_prompt_updates( + self, + mm_items, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + """ + Define how to replace the audio placeholder in the prompt. + + vLLM's OpenAI multimodal parsing inserts the model placeholder string + from `get_placeholder_str` (here: `<|AUDIO|>`) into the conversation. + We expand that single token into N repeated `<|AUDIO|>` tokens, where N + is derived from waveform length and `speech_tok_compress_ratio`. + """ + token_info = self.info.get_audio_token_info() + audio_token = token_info["audio_token"] + audio_token_id = token_info["audio_token_id"] + audio_bos_id = token_info.get("audio_bos_id") + audio_eos_id = token_info.get("audio_eos_id") + + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + def _tok_id(name: str) -> int | None: + return vocab.get(name) + + # Look up speech token IDs from vocabulary + # These tokens mark the start/end of audio embeddings in the prompt + speech_start_id = ( + _tok_id("<|object_ref_start|>") + or getattr(tokenizer, "speech_start_id", None) + or _tok_id("<|speech_start|>") + ) + speech_end_id = ( + _tok_id("<|object_ref_end|>") + or getattr(tokenizer, "speech_end_id", None) + or _tok_id("<|speech_end|>") + ) + speech_pad_id = ( + _tok_id("<|box_start|>") + or getattr(tokenizer, "speech_pad_id", None) + or _tok_id("<|speech_pad|>") + ) + + if audio_token_id is None: + return [] + + # Get raw audio lengths (in samples, after resampling to 24kHz) from our stored data + out_mm_data = out_mm_kwargs.get_data() + raw_audio_lengths = out_mm_data.get("raw_audio_lengths", []) + + # Fetch defaults from model config when available (falls back to 3200) + hf_config = self.info.get_hf_config() + if isinstance(hf_config, dict): + compress_ratio = int(hf_config.get("speech_tok_compress_ratio", 3200)) + else: + compress_ratio = int(getattr(hf_config, "speech_tok_compress_ratio", 3200)) + + def _to_int_len(x) -> int: + if x is None: + return 0 + if isinstance(x, torch.Tensor): + # Accept 0-dim or 1-dim scalar-like tensors + if x.numel() == 1: + return int(x.item()) + # If a full tensor is passed accidentally, fall back to its length + return int(x.shape[0]) + return int(x) + + def get_replacement(item_idx: int): + if raw_audio_lengths and item_idx < len(raw_audio_lengths): + audio_len = _to_int_len(raw_audio_lengths[item_idx]) + num_features = max(1, int(np.ceil(audio_len / compress_ratio))) + else: + # Fallback: estimate for 30 second audio at 24kHz + num_features = int(np.ceil(30 * 24000 / compress_ratio)) + + if num_features == 0: + raise ValueError( + f"Audio at index {item_idx} is too short to be represented" + ) + + # Build replacement token sequence: + # <|speech_start|> + N * <|speech_pad|> + <|speech_end|> + \n + # The newline is important for correct prompt structure. + newline_id = 198 # '\n' token + if speech_start_id is not None and speech_pad_id is not None and speech_end_id is not None: + embed_id = int(speech_pad_id) + replacement_ids = [int(speech_start_id)] + [embed_id] * num_features + [int(speech_end_id), newline_id] + # Fallback: add audio BOS/EOS boundaries around repeated <|AUDIO|>. + elif audio_bos_id is not None and audio_eos_id is not None: + embed_id = int(audio_token_id) + replacement_ids = [int(audio_bos_id)] + [embed_id] * num_features + [int(audio_eos_id)] + else: + embed_id = int(audio_token_id) + replacement_ids = [embed_id] * num_features + + return PromptUpdateDetails.select_token_id( + replacement_ids, + embed_token_id=int(embed_id), + ) + + return [ + PromptReplacement( + modality="audio", + # Keep string placeholder matching for maximum vLLM compatibility. + target=audio_token, + replacement=get_replacement, + ) + ] + + +# ============================================================================ +# Main Model Class +# ============================================================================ + +@MULTIMODAL_REGISTRY.register_processor( + VibeVoiceMultiModalProcessor, + info=VibeVoiceProcessingInfo, + dummy_inputs=VibeVoiceDummyInputsBuilder, +) +class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + """ + VibeVoice ASR model with native vLLM multimodal integration. + + This model combines VibeVoice acoustic/semantic tokenizers for audio encoding + with a causal language model for text generation. + """ + + # SupportsTranscription interface + supports_transcription: ClassVar[Literal[True]] = True + supports_transcription_only: ClassVar[bool] = False + supports_segment_timestamp: ClassVar[bool] = False + + # Supported languages (Chinese as primary target) + supported_languages: ClassVar[Mapping[str, str]] = { + "zh": "Chinese", + "en": "English", + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + """Return the placeholder string format for a given modality. + + Returns "<|AUDIO|>" which vLLM inserts into the conversation prompt. + This single placeholder is later expanded by `_get_prompt_updates` into: + [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id] + where N = ceil(audio_samples / compress_ratio). + """ + if modality.startswith("audio"): + return "<|AUDIO|>" + raise ValueError(f"Unsupported modality: {modality}") + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + """Get the prompt for the ASR model. + + Generates a chat-formatted prompt for speech-to-text transcription + with JSON output format. + """ + # If user provides custom prompt, use it + if request_prompt: + return request_prompt + + # Calculate audio duration for the prompt + # Audio should be at 24kHz, so duration = len(audio) / 24000 + duration = len(audio) / 24000 if audio is not None else 10.0 + + system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format." + show_keys = ["Start time", "End time", "Speaker ID", "Content"] + user_suffix = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + + # IMPORTANT: keep <|AUDIO|> as the only placeholder token here. + # `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders. + user_content = "<|AUDIO|>\n" + user_suffix + + prompt = ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + return prompt + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"] + ) -> SpeechToTextConfig: + """Get the speech to text config for the ASR model.""" + return SpeechToTextConfig( + language=None, # Auto-detect or use request language + task_type=task_type, + ) + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> int | None: + """Estimate number of audio tokens from duration. + + Returns the number of audio EMBEDDING positions (speech_pad_id tokens). + Note: _get_prompt_updates actually generates: + [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id] + So total prompt tokens = N + 3, but this returns N (the embedding count). + """ + sampling_rate = 24000 + compress_ratio = 3200 + samples = int(audio_duration_s * sampling_rate) + num_tokens = int(np.ceil(samples / compress_ratio)) + return num_tokens + + @classmethod + def get_other_languages(cls) -> Mapping[str, str]: + """Get languages from Whisper map not natively supported.""" + # Import LANGUAGES from vllm + try: + from vllm.transformers_utils.tokenizer import LANGUAGES + except ImportError: + # Fallback to empty dict if import fails + return {} + return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages} + + @classmethod + def validate_language(cls, language: str | None) -> str | None: + """Validate the language code.""" + if language is None or language in cls.supported_languages: + return language + elif language in cls.get_other_languages(): + print(f"Warning: Language {language!r} is not natively supported") + return language + else: + raise ValueError( + f"Unsupported language: {language!r}. " + f"Supported: {list(cls.supported_languages.keys())}" + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + + # Keep a copy of the resolved model path for any custom weight-loading + # logic (e.g., loading audio encoder weights in fp32 directly from + # safetensors shards). + self._model_path = vllm_config.model_config.model + + self.audio_encoder = VibeVoiceAudioEncoder(config) + + # Pass decoder_config to the language model initialization + decoder_config = getattr(config, "decoder_config", config) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=decoder_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + # Set the language model dtype for audio encoder output conversion + # This should match vLLM's --dtype flag (e.g., bfloat16, float16, float32) + # Audio encoder internal computation stays in fp32 for numerical precision, + # but output is converted to LM dtype for compatibility + lm_dtype = vllm_config.model_config.dtype + if lm_dtype is not None: + self.audio_encoder._lm_dtype = lm_dtype + + # Ensure audio encoder uses correct dtype (typically fp32 for precision) + try: + self.audio_encoder._ensure_audio_encoder_dtype() + except Exception: + pass + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + """ + Extract audio embeddings using VibeVoice's acoustic/semantic tokenizers. + + Called by vLLM to get audio embeddings that replace audio placeholder tokens. + + Returns: + Tuple of embedding tensors, one per audio input. + """ + # Get raw audio data (stored by our processor) + raw_audio = kwargs.get("raw_audio") + raw_audio_lengths = kwargs.get("raw_audio_lengths") + + # Handle no audio input - this happens during memory profiling + if raw_audio is None: + return [] + + # Handle empty audio list + if isinstance(raw_audio, (list, tuple)) and len(raw_audio) == 0: + return [] + + # Flatten raw_audio_lengths if it's nested + def flatten_lengths(lengths): + """Flatten nested lists/tensors of lengths to a single list.""" + if lengths is None: + return [] + + result = [] + if isinstance(lengths, torch.Tensor): + lengths = lengths.tolist() + + if isinstance(lengths, (list, tuple)): + for item in lengths: + if isinstance(item, (list, tuple)): + result.extend(flatten_lengths(item)) + elif isinstance(item, torch.Tensor): + if item.dim() == 0: + result.append(item.item()) + else: + result.extend(item.tolist()) + else: + result.append(item) + else: + result.append(lengths) + return result + + raw_audio_lengths = flatten_lengths(raw_audio_lengths) + + # Streaming controls. Enabled by default; can be overridden per-call. + use_streaming_flag = bool( + kwargs.get( + "use_streaming", + getattr(self.audio_encoder, "enable_streaming", True), + ) + ) + streaming_segment_duration = kwargs.get( + "streaming_segment_duration", + getattr(self.audio_encoder, "streaming_segment_duration", 60.0), + ) + + # Process each audio through the VibeVoice encoder + embeddings = [] + + # Get model device and dtype for alignment + try: + device = next(self.audio_encoder.parameters()).device + dtype = next(self.audio_encoder.parameters()).dtype + except StopIteration: + # Fallback if no parameters (shouldn't happen) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.bfloat16 + + # Handle both stacked tensor and list of tensors + # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] + if isinstance(raw_audio, torch.Tensor): + if raw_audio.dim() == 3: + # Shape: [batch_size, 1, seq_len] - squeeze the middle dimension + num_audios = raw_audio.shape[0] + audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)] + elif raw_audio.dim() == 2: + # Shape: [batch_size, seq_len] + num_audios = raw_audio.shape[0] + audio_list = [raw_audio[i] for i in range(num_audios)] + else: + # Single 1D tensor + audio_list = [raw_audio] + elif isinstance(raw_audio, (list, tuple)): + audio_list = list(raw_audio) + else: + # Single tensor + audio_list = [raw_audio] + + for i, audio_tensor in enumerate(audio_list): + try: + if isinstance(audio_tensor, list): + audio_tensor = torch.stack(audio_tensor) + + # Ensure tensor + if not isinstance(audio_tensor, torch.Tensor): + audio_tensor = torch.tensor(audio_tensor) + + # Let vLLM handle dtype (bfloat16 by default) + audio_tensor = audio_tensor.to(device=device) + + # Get actual length if available, otherwise use full length + if raw_audio_lengths and i < len(raw_audio_lengths): + actual_len = int(raw_audio_lengths[i]) + if actual_len > 0 and actual_len <= audio_tensor.shape[-1]: + # Truncate from the last dimension (sequence length) + audio_tensor = audio_tensor[..., :actual_len] + + # Skip if audio is too short (< 1 frame) + if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz + continue + + # Encode audio through VibeVoice encoder + audio_embeds = self.audio_encoder( + audio_tensor, + use_streaming=use_streaming_flag, + segment_duration_s=streaming_segment_duration, + ) + + # audio_embeds shape: [1, seq_len, hidden_size] + # We need to return it as a single embedding tensor per audio + final_embed = audio_embeds.squeeze(0) + embeddings.append(final_embed) + + except Exception as e: + # Log error but don't crash - this helps debug profiling issues + print(f"[VibeVoice] Error encoding audio {i}: {e}") + import traceback + traceback.print_exc() + # Return empty embedding to avoid crash + continue + + return tuple(embeddings) + + def get_input_embeddings(self) -> torch.nn.Module: + """Return the text embedding layer (embed_tokens). + + vLLM uses this to get the embedding module for converting token IDs + to embeddings during decode phase. + + Returns: + The embed_tokens module from the language model + """ + # Get embed_tokens from the language model + if hasattr(self.language_model, 'model') and hasattr(self.language_model.model, 'embed_tokens'): + return self.language_model.model.embed_tokens + elif hasattr(self.language_model, 'embed_tokens'): + return self.language_model.embed_tokens + else: + # Try to get from inner model + inner = self.language_model + if hasattr(inner, 'language_model'): + inner = inner.language_model + if hasattr(inner, 'model') and hasattr(inner.model, 'embed_tokens'): + return inner.model.embed_tokens + + raise AttributeError("Cannot find embed_tokens layer") + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, + is_multimodal: Optional[torch.Tensor] = None, + **kwargs, # Accept any additional kwargs for compatibility + ) -> torch.Tensor: + """Apply token embeddings to input_ids and merge with multimodal embeddings. + + This is the preferred method in vLLM V1 for converting token IDs + to embeddings and merging multimodal (audio) embeddings. + + Args: + input_ids: Tensor of token IDs to embed + multimodal_embeddings: Pre-computed multimodal embeddings (audio). + Can be a Tensor or a List of Tensors (vLLM standard). + is_multimodal: Boolean mask indicating which positions are multimodal + **kwargs: Additional arguments for compatibility + + Returns: + Tensor of embeddings with multimodal content merged in + """ + from vllm.model_executor.models.utils import _merge_multimodal_embeddings + + # Get text embeddings + embed_tokens = self.get_input_embeddings() + inputs_embeds = embed_tokens(input_ids) + + # Merge multimodal embeddings if provided + if multimodal_embeddings is not None and is_multimodal is not None: + # Use vLLM's standard merge function which handles List[Tensor] correctly + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds, + multimodal_embeddings, + is_multimodal, + ) + + return inputs_embeds + + def get_language_model(self) -> torch.nn.Module: + """Return the language model backbone.""" + return self.language_model + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]: + """Load model weights from checkpoint. + + The checkpoint has weights named like: + - lm_head.weight -> language_model.lm_head.weight + - model.language_model.layers.X... -> language_model.model.layers.X... + - model.acoustic_tokenizer... -> audio_encoder.acoustic_tokenizer... + - model.semantic_tokenizer... -> audio_encoder.semantic_tokenizer... + - model.acoustic_connector... -> audio_encoder.acoustic_connector... + - model.semantic_connector... -> audio_encoder.semantic_connector... + + Let vLLM handle all dtype conversions according to --dtype flag. + """ + # Map weight prefixes for VibeVoice + # The checkpoint uses "model." prefix, we need to remap it + mapper = WeightsMapper( + orig_to_new_prefix={ + # Audio encoder components: model.X -> audio_encoder.X + "model.acoustic_tokenizer.": "audio_encoder.acoustic_tokenizer.", + "model.semantic_tokenizer.": "audio_encoder.semantic_tokenizer.", + "model.acoustic_connector.": "audio_encoder.acoustic_connector.", + "model.semantic_connector.": "audio_encoder.semantic_connector.", + # Language model: model.language_model.X -> language_model.model.X + "model.language_model.": "language_model.model.", + # LM head: lm_head.X -> language_model.lm_head.X + "lm_head.": "language_model.lm_head.", + } + ) + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=mapper) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + Forward pass for VibeVoice ASR model. + + Handles embedding computation and language model forward pass. + Uses inputs_embeds if provided (from vLLM multimodal merge), + otherwise computes embeddings from input_ids. + + Args: + input_ids: Token IDs. May be None when inputs_embeds is provided. + positions: Position indices for the input tokens. + intermediate_tensors: Intermediate tensors for pipeline parallelism. + inputs_embeds: Pre-computed embeddings (from multimodal merge or decode). + """ + try: + # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode) + # Only compute from input_ids if inputs_embeds is not available + if inputs_embeds is None and input_ids is not None: + # Compute embeddings from input_ids + inputs_embeds = self.get_input_embeddings()(input_ids) + + # If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds + if intermediate_tensors is not None: + inputs_embeds = None + + # Get the inner model - handle both wrapped and direct language models + language_model = self.language_model + if hasattr(language_model, "language_model"): + language_model = language_model.language_model + + # Call the language model's model (Qwen2Model) + # vLLM V1 passes kv_caches and attn_metadata via context, not arguments + # IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding + hidden_states = language_model.model( + input_ids=None, # Always None when we have inputs_embeds + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds + ) + return hidden_states + + except Exception as e: + raise + + +# Alias for training checkpoint compatibility +VibeVoiceForASRTraining = VibeVoiceForCausalLM diff --git a/vllm_plugin/scripts/install_deps.sh b/vllm_plugin/scripts/install_deps.sh new file mode 100644 index 0000000..1e62f45 --- /dev/null +++ b/vllm_plugin/scripts/install_deps.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Install system dependencies for VibeVoice vLLM plugin +# Run this script inside the vLLM container before using the plugin + +set -e + +echo "Installing system dependencies for VibeVoice vLLM plugin..." + +# Update package list +apt-get update + +# Install FFmpeg and audio processing libraries +apt-get install -y \ + ffmpeg \ + libsndfile1 \ + git + +echo "āœ… System dependencies installed successfully!" +echo "" +echo "Next steps:" +echo " 1. Install VibeVoice: pip install -e .[vllm]" +echo " 2. Generate tokenizer files (if needed): python -m vllm_plugin.tools.generate_tokenizer_files -o /path/to/model" +echo " 3. Start vLLM server: vllm serve --trust-remote-code --enforce-eager --no-enable-prefix-caching" diff --git a/vllm_plugin/tests/test_api.py b/vllm_plugin/tests/test_api.py new file mode 100644 index 0000000..af128bc --- /dev/null +++ b/vllm_plugin/tests/test_api.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" +Test VibeVoice vLLM API with Streaming (Real-time output). + +Usage: + python test_api.py [audio_path] [--url URL] + +Examples: + python test_api.py # Use default audio + python test_api.py /path/to/audio.wav # Specify audio file + python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL +""" +import requests +import json +import base64 +import time +import sys +import os +import subprocess +import argparse + + +def _guess_mime_type(path: str) -> str: + ext = os.path.splitext(path)[1].lower() + if ext == ".wav": + return "audio/wav" + if ext in (".mp3",): + return "audio/mpeg" + if ext in (".m4a",): + return "audio/mp4" + if ext in (".mp4", ".m4v", ".mov", ".webm"): + return "video/mp4" + if ext in (".flac",): + return "audio/flac" + if ext in (".ogg", ".opus"): + return "audio/ogg" + return "application/octet-stream" + + +def _get_duration_seconds_ffprobe(path: str) -> float: + """Get audio duration using ffprobe.""" + cmd = [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + path, + ] + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() + return float(out) + + +def _extract_audio_from_video(video_path: str) -> str: + """ + Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. + Returns the path to the extracted audio file. + """ + import tempfile + # Create temp file with .mp3 extension + fd, audio_path = tempfile.mkstemp(suffix=".mp3") + os.close(fd) + + cmd = [ + "ffmpeg", "-y", "-i", video_path, + "-vn", # No video + "-acodec", "libmp3lame", + "-q:a", "2", # High quality + audio_path + ] + subprocess.run(cmd, check=True, capture_output=True) + return audio_path + + +def _is_video_file(path: str) -> bool: + """Check if the file is a video file that needs audio extraction.""" + ext = os.path.splitext(path)[1].lower() + return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") + + +def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"): + """Test ASR transcription with streaming output.""" + + print(f"Loading audio from: {audio_path}") + + # Handle video files: extract audio first + temp_audio_path = None + actual_audio_path = audio_path + if _is_video_file(audio_path): + print(f"Detected video file, extracting audio...") + temp_audio_path = _extract_audio_from_video(audio_path) + actual_audio_path = temp_audio_path + print(f"Audio extracted to: {temp_audio_path}") + + try: + duration = _get_duration_seconds_ffprobe(actual_audio_path) + print(f"Audio duration: {duration:.2f} seconds") + + with open(actual_audio_path, "rb") as f: + audio_bytes = f.read() + + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + print(f"Audio size: {len(audio_bytes)} bytes") + + except Exception as e: + print(f"Error preparing audio: {e}") + return + + # Build the request + url = f"{base_url}/v1/chat/completions" + + show_keys = ["Start time", "End time", "Speaker ID", "Content"] + prompt_text = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + + mime = _guess_mime_type(actual_audio_path) + data_url = f"data:{mime};base64,{audio_b64}" + + payload = { + "model": "vibevoice", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant that transcribes audio input into text output in JSON format." + }, + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": data_url}}, + {"type": "text", "text": prompt_text} + ] + } + ], + "max_tokens": 32768, + "temperature": 0.0, + "stream": True, + "top_p": 1.0, + "repetition_penalty": 1.0, + } + + print(f"\nSending request to {url} (Streaming Mode)...") + print(f"Prompt: {prompt_text}") + print("-" * 60) + + t0 = time.time() + try: + + response = requests.post(url, json=payload, stream=True, timeout=12000) + + if response.status_code == 200: + print("Response received. Streaming content:\n") + + printed = "" + for line in response.iter_lines(): + if line: + decoded_line = line.decode('utf-8') + + if decoded_line.startswith("data: "): + json_str = decoded_line[6:] + if json_str.strip() == "[DONE]": + print("\n\n[Finished]") + break + try: + data = json.loads(json_str) + + delta = data['choices'][0]['delta'] + content = delta.get('content', '') + if content: + + # vLLM/OpenAI-compatible streams may emit either + # incremental deltas OR the full accumulated text. + # Only print the newly-added part to avoid repeats. + if content.startswith(printed): + to_print = content[len(printed):] + else: + to_print = content + + if to_print: + print(to_print, end='', flush=True) + printed += to_print + except json.JSONDecodeError: + pass + else: + print(f"Error: {response.status_code}") + print(response.text) + + except requests.exceptions.Timeout: + print("\nRequest timed out!") + except Exception as e: + print(f"\nError: {e}") + + print(f"\n{'-'*60}") + print(f"Total time elapsed: {time.time() - t0:.2f}s") + + # Cleanup temp audio file if created + if temp_audio_path and os.path.exists(temp_audio_path): + os.remove(temp_audio_path) + print(f"Cleaned up temp file: {temp_audio_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Test VibeVoice vLLM API with streaming output" + ) + parser.add_argument( + "audio_path", + nargs="?", + default=None, + help="Path to audio file (wav, mp3, flac, etc.) or video file" + ) + parser.add_argument( + "--url", + default="http://localhost:8000", + help="vLLM server base URL (default: http://localhost:8000)" + ) + + args = parser.parse_args() + + # Find default audio if not specified + audio_path = args.audio_path + if audio_path is None: + # Try to find a sample audio in common locations + possible_paths = [ + # In VibeVoice demo folder + os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"), + os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"), + # Relative to current directory + "demo/voices/en-Carter_man.wav", + "demo/voices/zh-Anchen_man_bgm.wav", + ] + + for path in possible_paths: + if os.path.exists(path): + audio_path = path + break + + if audio_path is None: + print("Error: No audio file specified and no default audio found.") + print("Usage: python test_api.py ") + sys.exit(1) + + if not os.path.exists(audio_path): + print(f"Error: Audio file not found: {audio_path}") + sys.exit(1) + + test_transcription(audio_path, args.url) + + +if __name__ == "__main__": + main() diff --git a/vllm_plugin/tools/generate_tokenizer_files.py b/vllm_plugin/tools/generate_tokenizer_files.py new file mode 100644 index 0000000..62de55f --- /dev/null +++ b/vllm_plugin/tools/generate_tokenizer_files.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python3 +""" +Standalone tool to generate VibeVoice tokenizer files from Qwen2 base. + +Downloads base tokenizer from Qwen2 and patches it with VibeVoice-specific +audio tokens and chat template modifications. + +Usage: + python generate_tokenizer_files.py --output /path/to/output [--compare /path/to/reference] +""" + +import argparse +import json +import os +import shutil +import tempfile +from typing import Optional, Dict, Any + + +# Qwen2.5 extended tokens (151646-151664) +# These are NOT in base Qwen2-7B but ARE in Qwen2.5 and Qwen2-VL +# VibeVoice uses some of these for speech: object_ref_start/end, box_start +QWEN25_EXTENDED_TOKENS = { + "<|object_ref_start|>": 151646, # Used as speech_start_id + "<|object_ref_end|>": 151647, # Used as speech_end_id + "<|box_start|>": 151648, # Used as speech_pad_id + "<|box_end|>": 151649, + "<|quad_start|>": 151650, + "<|quad_end|>": 151651, + "<|vision_start|>": 151652, + "<|vision_end|>": 151653, + "<|vision_pad|>": 151654, + "<|image_pad|>": 151655, + "<|video_pad|>": 151656, + "": 151657, + "": 151658, + "<|fim_prefix|>": 151659, + "<|fim_middle|>": 151660, + "<|fim_suffix|>": 151661, + "<|fim_pad|>": 151662, + "<|repo_name|>": 151663, + "<|file_sep|>": 151664, +} + +# VibeVoice-specific audio tokens (IDs follow Qwen2.5's last token 151664) +VIBEVOICE_AUDIO_TOKENS = { + "<|AUDIO|>": 151665, + "<|audio_bos|>": 151666, + "<|audio_eos|>": 151667, +} + +# All extended tokens (Qwen2.5 + VibeVoice) +ALL_EXTENDED_TOKENS = {**QWEN25_EXTENDED_TOKENS, **VIBEVOICE_AUDIO_TOKENS} + +# Chat template with audio support +# Key modification: handles part['type'] == 'audio' or 'audio_url' -> '<|AUDIO|>' +VIBEVOICE_CHAT_TEMPLATE = """{%- if tools %} + {{- '<|im_start|>system\\n' }} + {%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {{- messages[0]['content'] }} + {%- else %} + {%- for part in messages[0]['content'] %} + {%- if part['type'] == 'text' %} + {{- part['text'] }} + {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %} + {{- '<|AUDIO|>' }} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- else %} + {{- 'You are a helpful assistant.' }} + {%- endif %} + {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }} + {%- for tool in tools %} + {{- "\\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\"name\\": , \\"arguments\\": }\\n<|im_end|>\\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\\n' }} + {%- if messages[0]['content'] is string %} + {{- messages[0]['content'] }} + {%- else %} + {%- for part in messages[0]['content'] %} + {%- if part['type'] == 'text' %} + {{- part['text'] }} + {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %} + {{- '<|AUDIO|>' }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\\n' }} + {%- else %} + {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\\n' }} + {%- if message['content'] is string %} + {{- message['content'] }} + {%- else %} + {%- for part in message['content'] %} + {%- if part['type'] == 'text' %} + {{- part['text'] }} + {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %} + {{- '<|AUDIO|>' }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\\n\\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\\n' }} + {%- endfor %} + {{- '<|im_end|>\\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\\n\\n' }} + {{- message.content }} + {{- '\\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\\n' }} +{%- endif %}""" + + +# Default to Qwen2.5-7B which has all the extended tokens (151646-151664) +DEFAULT_QWEN_MODEL = "Qwen/Qwen2.5-7B" + + +def download_qwen_tokenizer_files(output_dir: str, qwen_model: str = DEFAULT_QWEN_MODEL) -> None: + """Download base tokenizer files from Qwen2.5 (which includes extended tokens).""" + try: + from huggingface_hub import hf_hub_download + except ImportError: + raise ImportError("Please install huggingface_hub: pip install huggingface_hub") + + files_to_download = [ + "vocab.json", + "merges.txt", + "tokenizer.json", + "tokenizer_config.json", + ] + + os.makedirs(output_dir, exist_ok=True) + + for filename in files_to_download: + print(f"Downloading {filename} from {qwen_model}...") + hf_hub_download( + repo_id=qwen_model, + filename=filename, + local_dir=output_dir, + local_dir_use_symlinks=False, + ) + + +def patch_tokenizer_config(output_dir: str) -> None: + """ + Patch tokenizer_config.json with VibeVoice audio tokens and chat template. + """ + config_path = os.path.join(output_dir, "tokenizer_config.json") + + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + + # 1. Add ALL extended tokens to added_tokens_decoder (Qwen2.5 + VibeVoice audio) + if "added_tokens_decoder" not in config: + config["added_tokens_decoder"] = {} + + for token, token_id in ALL_EXTENDED_TOKENS.items(): + if str(token_id) not in config["added_tokens_decoder"]: + # Determine if token should be marked as "special" + # tool_call tokens are NOT special in Qwen2.5 + is_special = token not in ("", "", "<|fim_prefix|>", + "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>", + "<|repo_name|>", "<|file_sep|>") + config["added_tokens_decoder"][str(token_id)] = { + "content": token, + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": is_special, + } + + # 2. Add audio tokens to additional_special_tokens + if "additional_special_tokens" not in config: + config["additional_special_tokens"] = [] + + for token in VIBEVOICE_AUDIO_TOKENS.keys(): + if token not in config["additional_special_tokens"]: + config["additional_special_tokens"].append(token) + + # 3. Modify chat_template to support audio + # Instead of replacing entirely, we patch the existing template to handle audio + chat_template = config.get("chat_template", "") + if chat_template and "<|AUDIO|>" not in chat_template: + # Insert audio handling into the template + # Find patterns like: {%- if part['type'] == 'text' %} + # Add after: {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}\n {{- '<|AUDIO|>' }} + audio_handler = """{%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %} + {{- '<|AUDIO|>' }}""" + + # Pattern to find: after handling 'text' type, before endif + import re + # Look for the pattern where we handle text type and add audio handling + pattern = r"(\{\%- if part\['type'\] == 'text' \%\}\s*\n\s*\{\{- part\['text'\] \}\})" + replacement = r"\1\n " + audio_handler.replace("\n", r"\n") + + modified_template = re.sub(pattern, replacement, chat_template) + + if modified_template != chat_template: + config["chat_template"] = modified_template + print(" - Added audio support to existing chat_template") + else: + # Fallback: use our predefined template + print(" - Warning: Could not patch existing template, using predefined template") + config["chat_template"] = VIBEVOICE_CHAT_TEMPLATE + + # 4. Update model_max_length for long audio support + config["model_max_length"] = 131072 + + # 5. Add add_bos_token if not present + if "add_bos_token" not in config: + config["add_bos_token"] = False + + # Write back + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + + print(f"Patched {config_path}") + + +def patch_tokenizer_json(output_dir: str) -> None: + """ + Patch tokenizer.json with VibeVoice audio tokens. + """ + tokenizer_path = os.path.join(output_dir, "tokenizer.json") + + with open(tokenizer_path, "r", encoding="utf-8") as f: + tokenizer = json.load(f) + + # Find existing token IDs to avoid duplicates + existing_ids = set() + if "added_tokens" in tokenizer: + for token_entry in tokenizer["added_tokens"]: + existing_ids.add(token_entry.get("id")) + + # Add ALL extended tokens (Qwen2.5 + VibeVoice audio) + for token, token_id in ALL_EXTENDED_TOKENS.items(): + if token_id not in existing_ids: + # Determine if token should be marked as "special" + is_special = token not in ("", "", "<|fim_prefix|>", + "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>", + "<|repo_name|>", "<|file_sep|>") + tokenizer["added_tokens"].append({ + "id": token_id, + "content": token, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": is_special, + }) + + # Write back + with open(tokenizer_path, "w", encoding="utf-8") as f: + json.dump(tokenizer, f, indent=2, ensure_ascii=False) + + print(f"Patched {tokenizer_path}") + + +def generate_added_tokens_json(output_dir: str) -> None: + """ + Generate added_tokens.json from tokenizer_config.json. + """ + config_path = os.path.join(output_dir, "tokenizer_config.json") + + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + + added_tokens = {} + for token_id, token_info in config.get("added_tokens_decoder", {}).items(): + content = token_info.get("content") + if content: + added_tokens[content] = int(token_id) + + output_path = os.path.join(output_dir, "added_tokens.json") + with open(output_path, "w", encoding="utf-8") as f: + json.dump(added_tokens, f, indent=2, ensure_ascii=False) + + print(f"Generated {output_path}") + + +def generate_special_tokens_map_json(output_dir: str) -> None: + """ + Generate special_tokens_map.json with VibeVoice special tokens. + """ + # Build the special tokens map + special_tokens_map = { + "additional_special_tokens": [], + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + "unk_token": "<|endoftext|>", + } + + # Add audio tokens as additional_special_tokens + for token in VIBEVOICE_AUDIO_TOKENS.keys(): + special_tokens_map["additional_special_tokens"].append({ + "content": token, + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + }) + + # Add some commonly used special tokens + common_special = ["<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>"] + for token in common_special: + special_tokens_map["additional_special_tokens"].append({ + "content": token, + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + }) + + output_path = os.path.join(output_dir, "special_tokens_map.json") + with open(output_path, "w", encoding="utf-8") as f: + json.dump(special_tokens_map, f, indent=2, ensure_ascii=False) + + print(f"Generated {output_path}") + + +def generate_vibevoice_tokenizer_files(output_dir: str, qwen_model: str = DEFAULT_QWEN_MODEL) -> None: + """ + Generate all 6 VibeVoice tokenizer files. + + Files generated: + 1. vocab.json - from Qwen2.5 (unchanged) + 2. merges.txt - from Qwen2.5 (unchanged) + 3. tokenizer.json - from Qwen2.5 + audio tokens + 4. tokenizer_config.json - from Qwen2.5 + audio tokens + chat_template + 5. added_tokens.json - generated from tokenizer_config.json + 6. special_tokens_map.json - generated with VibeVoice tokens + """ + print(f"=== Generating VibeVoice tokenizer files to {output_dir} ===\n") + + # Step 1: Download base files from Qwen2 + download_qwen_tokenizer_files(output_dir, qwen_model) + + # Step 2: Patch tokenizer_config.json + patch_tokenizer_config(output_dir) + + # Step 3: Patch tokenizer.json + patch_tokenizer_json(output_dir) + + # Step 4: Generate added_tokens.json + generate_added_tokens_json(output_dir) + + # Step 5: Generate special_tokens_map.json + generate_special_tokens_map_json(output_dir) + + print(f"\nāœ… All 6 tokenizer files generated in {output_dir}") + + +def compare_json_files(file1: str, file2: str, name: str) -> Dict[str, Any]: + """Compare two JSON files and return differences.""" + result = { + "name": name, + "identical": False, + "differences": [], + } + + if not os.path.exists(file1): + result["differences"].append(f"File 1 not found: {file1}") + return result + + if not os.path.exists(file2): + result["differences"].append(f"File 2 not found: {file2}") + return result + + with open(file1, "r", encoding="utf-8") as f: + data1 = json.load(f) + + with open(file2, "r", encoding="utf-8") as f: + data2 = json.load(f) + + if data1 == data2: + result["identical"] = True + return result + + # Find specific differences + def find_diff(d1, d2, path=""): + diffs = [] + if isinstance(d1, dict) and isinstance(d2, dict): + all_keys = set(d1.keys()) | set(d2.keys()) + for k in all_keys: + new_path = f"{path}.{k}" if path else k + if k not in d1: + diffs.append(f"Missing in generated: {new_path}") + elif k not in d2: + diffs.append(f"Extra in generated: {new_path}") + else: + diffs.extend(find_diff(d1[k], d2[k], new_path)) + elif isinstance(d1, list) and isinstance(d2, list): + if len(d1) != len(d2): + diffs.append(f"{path}: list length differs ({len(d1)} vs {len(d2)})") + # For lists, just check if they're equal (detailed diff is complex) + if d1 != d2: + diffs.append(f"{path}: list content differs") + elif d1 != d2: + # Truncate long values for readability + v1 = str(d1)[:100] + "..." if len(str(d1)) > 100 else str(d1) + v2 = str(d2)[:100] + "..." if len(str(d2)) > 100 else str(d2) + diffs.append(f"{path}: '{v1}' vs '{v2}'") + return diffs + + result["differences"] = find_diff(data1, data2) + return result + + +def compare_text_files(file1: str, file2: str, name: str) -> Dict[str, Any]: + """Compare two text files.""" + result = { + "name": name, + "identical": False, + "differences": [], + } + + if not os.path.exists(file1): + result["differences"].append(f"File 1 not found: {file1}") + return result + + if not os.path.exists(file2): + result["differences"].append(f"File 2 not found: {file2}") + return result + + with open(file1, "r", encoding="utf-8") as f: + content1 = f.read() + + with open(file2, "r", encoding="utf-8") as f: + content2 = f.read() + + if content1 == content2: + result["identical"] = True + else: + lines1 = content1.splitlines() + lines2 = content2.splitlines() + result["differences"].append(f"Line count: {len(lines1)} vs {len(lines2)}") + + # Find first difference + for i, (l1, l2) in enumerate(zip(lines1, lines2)): + if l1 != l2: + result["differences"].append(f"First diff at line {i+1}") + break + + return result + + +def compare_with_reference(generated_dir: str, reference_dir: str) -> None: + """Compare generated files with reference files.""" + print(f"\n=== Comparing generated files with reference ===") + print(f"Generated: {generated_dir}") + print(f"Reference: {reference_dir}\n") + + files_to_compare = [ + ("vocab.json", "json"), + ("merges.txt", "text"), + ("tokenizer.json", "json"), + ("tokenizer_config.json", "json"), + ("added_tokens.json", "json"), + ("special_tokens_map.json", "json"), + ] + + all_identical = True + + for filename, file_type in files_to_compare: + gen_file = os.path.join(generated_dir, filename) + ref_file = os.path.join(reference_dir, filename) + + if file_type == "json": + result = compare_json_files(gen_file, ref_file, filename) + else: + result = compare_text_files(gen_file, ref_file, filename) + + if result["identical"]: + print(f"āœ… {filename}: IDENTICAL") + else: + print(f"āŒ {filename}: DIFFERENT") + for diff in result["differences"][:5]: # Show first 5 differences + print(f" - {diff}") + if len(result["differences"]) > 5: + print(f" ... and {len(result['differences']) - 5} more differences") + all_identical = False + + print() + if all_identical: + print("šŸŽ‰ All files are identical!") + else: + print("āš ļø Some files have differences. See details above.") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate VibeVoice tokenizer files from Qwen2 base" + ) + parser.add_argument( + "--output", "-o", + type=str, + default=None, + help="Output directory for generated files (default: temp directory)" + ) + parser.add_argument( + "--compare", "-c", + type=str, + default=None, + help="Reference directory to compare generated files against" + ) + parser.add_argument( + "--qwen-model", + type=str, + default=DEFAULT_QWEN_MODEL, + help=f"Qwen model to download base tokenizer from (default: {DEFAULT_QWEN_MODEL})" + ) + + args = parser.parse_args() + + # Determine output directory + if args.output: + output_dir = args.output + cleanup = False + else: + output_dir = tempfile.mkdtemp(prefix="vibevoice_tokenizer_") + cleanup = not args.compare # Only cleanup if not comparing + + try: + # Generate files + generate_vibevoice_tokenizer_files(output_dir, args.qwen_model) + + # Compare if requested + if args.compare: + compare_with_reference(output_dir, args.compare) + + if not args.output: + print(f"\nGenerated files are in: {output_dir}") + + finally: + if cleanup and not args.output: + print(f"\nCleaning up temporary directory: {output_dir}") + shutil.rmtree(output_dir, ignore_errors=True) + + +if __name__ == "__main__": + main()