diff --git a/docs/vibevoice-vllm-asr.md b/docs/vibevoice-vllm-asr.md
new file mode 100644
index 0000000..f32cab4
--- /dev/null
+++ b/docs/vibevoice-vllm-asr.md
@@ -0,0 +1,112 @@
+# VibeVoice vLLM ASR Deployment
+
+
+
+Deploy VibeVoice ASR model as a high-performance API service using [vLLM](https://github.com/vllm-project/vllm). This plugin provides OpenAI-compatible API endpoints for speech-to-text transcription with streaming support.
+
+## š„ Key Features
+
+- **š High-Performance Serving**: Optimized for high-throughput ASR inference with vLLM's continuous batching
+- **š” OpenAI-Compatible API**: Standard `/v1/chat/completions` endpoint with streaming support
+- **šµ Long Audio Support**: Process up to 60+ minutes of audio in a single request
+- **š Plugin Architecture**: No vLLM source code modification required - just install and run
+
+## š ļø Installation
+
+Using Official vLLM Docker Image (Recommended)
+
+```bash
+# 1. Pull the official vLLM image
+docker pull vllm/vllm-openai:latest
+
+# 2. Start an interactive container
+docker run -it --gpus all --name vibevoice-vllm \
+ --ipc=host \
+ -p 8000:8000 \
+ -e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
+ -e PYTORCH_ALLOC_CONF=expandable_segments:True \
+ -v /path/to/models:/models \
+ -v /path/to/VibeVoice:/app \
+ -w /app \
+ --entrypoint bash \
+ vllm/vllm-openai:latest
+
+# 3. Inside container: Install system dependencies
+bash vllm_plugin/scripts/install_deps.sh
+
+# 4. Inside container: Install VibeVoice with vLLM support
+pip install -e .[vllm]
+
+# 5. Inside container: (Optional) Generate tokenizer files if needed
+python3 -m vllm_plugin.tools.generate_tokenizer_files --output /models/your_model
+
+# 6. Inside container: Start vLLM server
+vllm serve /models/your_model \
+ --served-model-name vibevoice \
+ --trust-remote-code \
+ --dtype bfloat16 \
+ --max-num-seqs 64 \
+ --max-model-len 65536 \
+ --max-num-batched-tokens 32768 \
+ --gpu-memory-utilization 0.8 \
+ --enforce-eager \
+ --no-enable-prefix-caching \
+ --enable-chunked-prefill \
+ --chat-template-content-format openai \
+ --tensor-parallel-size 1 \
+ --allowed-local-media-path /app \
+ --port 8000
+```
+
+> **Note**: This approach allows you to switch models, adjust parameters, and debug issues without rebuilding the container.
+
+
+## š Quick Start
+
+### Test the API
+
+Once the vLLM server is running, test it with the provided script:
+
+```bash
+# Run the test script (inside container)
+python3 vllm_plugin/tests/test_api.py /path/to/audio.wav
+```
+
+
+### Environment Variables
+
+| Variable | Description | Default |
+|----------|-------------|---------|
+| `VIBEVOICE_FFMPEG_MAX_CONCURRENCY` | Maximum FFmpeg processes for audio decoding | `64` |
+| `PYTORCH_CUDA_ALLOC_CONF` | CUDA memory allocator config | `expandable_segments:True` |
+
+
+
+## š Performance Tips
+
+1. **GPU Memory**: Use `--gpu-memory-utilization 0.9` for maximum throughput if you have dedicated GPU
+2. **Batch Size**: Increase `--max-num-seqs` for higher concurrency (requires more GPU memory)
+3. **FFmpeg Concurrency**: Tune `VIBEVOICE_FFMPEG_MAX_CONCURRENCY` based on CPU cores
+
+## šØ Troubleshooting
+
+### Common Issues
+
+1. **"CUDA out of memory"**
+ - Reduce `--gpu-memory-utilization`
+ - Reduce `--max-num-seqs`
+ - Use smaller `--max-model-len`
+
+2. **"Audio decoding failed"**
+ - Ensure FFmpeg is installed: `ffmpeg -version`
+ - Check audio file format is supported
+
+3. **"Model not found"**
+ - Ensure model path contains `config.json` and model weights
+ - Generate tokenizer files if missing
+
+4. **"Plugin not loaded"**
+ - Verify installation: `pip show vibevoice`
+ - Check entry point: `pip show -f vibevoice | grep entry`
+
+
diff --git a/pyproject.toml b/pyproject.toml
index 6dc69b3..0cc390a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,9 +45,20 @@ asr = [
"pydub" # for visualization
]
+vllm = [
+ "transformers>=4.51.3",
+ "fastapi",
+ "uvicorn[standard]",
+ "requests",
+]
+
+[project.entry-points."vllm.general_plugins"]
+vibevoice = "vllm_plugin:register_vibevoice"
+
[project.urls]
"Homepage" = "https://github.com/microsoft/VibeVoice"
"Bug Tracker" = "https://github.com/microsoft/VibeVoice/issues"
[tool.setuptools.packages.find]
where = ["."]
+include = ["vibevoice*", "vllm_plugin*"]
diff --git a/vibevoice/modular/configuration_vibevoice.py b/vibevoice/modular/configuration_vibevoice.py
index 02b2751..d5e3149 100644
--- a/vibevoice/modular/configuration_vibevoice.py
+++ b/vibevoice/modular/configuration_vibevoice.py
@@ -240,6 +240,26 @@ class VibeVoiceConfig(PretrainedConfig):
super().__init__(**kwargs)
+ def get_text_config(self, decoder=False):
+ """
+ Returns the text config for this model.
+
+ vLLM uses this method to get the text configuration from multimodal models.
+ This allows vLLM to correctly determine hidden_size, num_attention_heads,
+ and other properties needed for memory profiling and model execution.
+
+ For VibeVoice, the "text config" is the decoder_config (Qwen2Config).
+
+ Args:
+ decoder: If True, return the decoder config (for encoder-decoder models).
+ For VibeVoice, this is always the decoder_config.
+
+ Returns:
+ The decoder configuration (Qwen2Config) which contains hidden_size, etc.
+ """
+ return self.decoder_config
+
+
class VibeVoiceASRConfig(PretrainedConfig):
model_type = "vibevoice"
is_composition = True
diff --git a/vibevoice/processor/audio_utils.py b/vibevoice/processor/audio_utils.py
index ad4d10d..3f9d112 100644
--- a/vibevoice/processor/audio_utils.py
+++ b/vibevoice/processor/audio_utils.py
@@ -1,3 +1,6 @@
+import os
+import threading
+
import numpy as np
from subprocess import run
from typing import List, Optional, Union, Dict, Any
@@ -57,6 +60,7 @@ def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24
cmd = [
"ffmpeg",
+ "-loglevel", "error",
"-nostdin",
"-threads", "0",
"-i", file,
@@ -64,14 +68,84 @@ def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr_to_use),
- "-"
+ "-",
]
- out = run(cmd, capture_output=True, check=True).stdout
+ out = _run_ffmpeg(cmd).stdout
audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
return audio_data, sr_to_use
+
+def _get_ffmpeg_max_concurrency() -> int:
+ """Get the maximum FFmpeg concurrency from environment variable."""
+ v = os.getenv("VIBEVOICE_FFMPEG_MAX_CONCURRENCY", "")
+ try:
+ n = int(v) if v.strip() else 0
+ except Exception:
+ n = 0
+ # 0/negative means no explicit limit.
+ return n
+
+
+_FFMPEG_MAX_CONCURRENCY = _get_ffmpeg_max_concurrency()
+_FFMPEG_SEM = threading.Semaphore(_FFMPEG_MAX_CONCURRENCY) if _FFMPEG_MAX_CONCURRENCY > 0 else None
+
+
+def _run_ffmpeg(cmd: list, *, stdin_bytes: bytes = None):
+ """Run ffmpeg with optional global concurrency limiting.
+
+ This is important for vLLM multi-request concurrency: spawning too many
+ ffmpeg processes can saturate CPU/IO and cause request failures/timeouts.
+ """
+ if _FFMPEG_SEM is None:
+ return run(cmd, capture_output=True, check=True, input=stdin_bytes)
+ with _FFMPEG_SEM:
+ return run(cmd, capture_output=True, check=True, input=stdin_bytes)
+
+
+def load_audio_bytes_use_ffmpeg(data: bytes, *, resample: bool = False, target_sr: int = 24000):
+ """Decode audio bytes via ffmpeg stdin pipe.
+
+ Compared to writing bytes to a temp file, this avoids filesystem IO and
+ reduces contention under high request concurrency.
+
+ Parameters
+ ----------
+ data: bytes
+ The audio data bytes
+ resample: bool
+ Whether to resample the audio (must be True)
+ target_sr: int
+ The target sample rate if resampling is requested
+
+ Returns
+ -------
+ A tuple containing:
+ - A NumPy array with the audio waveform in float32 dtype
+ - The sample rate
+ """
+ if not resample:
+ # For stdin bytes, we don't have a cheap/robust way to probe original sr.
+ # Keep behavior explicit.
+ raise ValueError("load_audio_bytes_use_ffmpeg requires resample=True")
+
+ cmd = [
+ "ffmpeg",
+ "-loglevel", "error",
+ "-threads", "0",
+ "-i", "pipe:0",
+ "-f", "s16le",
+ "-ac", "1",
+ "-acodec", "pcm_s16le",
+ "-ar", str(target_sr),
+ "-",
+ ]
+ out = _run_ffmpeg(cmd, stdin_bytes=data).stdout
+ audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+ return audio_data, target_sr
+
+
class AudioNormalizer:
"""
Audio normalization class for VibeVoice tokenizer.
diff --git a/vllm_plugin/__init__.py b/vllm_plugin/__init__.py
new file mode 100644
index 0000000..b696a45
--- /dev/null
+++ b/vllm_plugin/__init__.py
@@ -0,0 +1,63 @@
+"""VibeVoice vLLM Plugin - Registers VibeVoice model for vLLM inference.
+
+This plugin enables VibeVoice ASR models to be loaded and served through vLLM.
+It registers the model architecture, configuration, tokenizer, and processor
+with their respective registries.
+
+The plugin is automatically loaded by vLLM via the 'vllm.general_plugins'
+entry point defined in pyproject.toml.
+"""
+
+from vllm.model_executor.models import ModelRegistry
+from transformers import AutoConfig, AutoTokenizer, Qwen2Tokenizer, AutoProcessor, Qwen2AudioProcessor
+
+from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
+from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
+
+from .model import VibeVoiceForCausalLM
+from .inputs import vibevoice_audio_input_mapper
+
+
+def register_vibevoice():
+ """Register VibeVoice model with vLLM and transformers.
+
+ This function is called automatically by vLLM through the entry point
+ mechanism. It registers:
+ - VibeVoiceConfig with AutoConfig
+ - VibeVoiceASRTextTokenizerFast with AutoTokenizer (for ASR)
+ - Qwen2AudioProcessor with AutoProcessor
+ - VibeVoiceForCausalLM with vLLM ModelRegistry
+ """
+ # Register the configuration class with transformers
+ AutoConfig.register("vibevoice", VibeVoiceConfig)
+
+ # Register the tokenizer with transformers.
+ # IMPORTANT (ASR): Align with the PyTorch ASR path.
+ # VibeVoiceASRTextTokenizerFast maps:
+ # speech_start_id -> <|object_ref_start|>
+ # speech_pad_id -> <|box_start|>
+ # speech_end_id -> <|object_ref_end|>
+ # This significantly affects ASR quality even when requests succeed.
+ try:
+ AutoTokenizer.register(
+ VibeVoiceConfig,
+ slow_tokenizer_class=Qwen2Tokenizer,
+ fast_tokenizer_class=VibeVoiceASRTextTokenizerFast,
+ )
+ except Exception:
+ pass # May already be registered
+
+ # Register the processor with transformers
+ try:
+ AutoProcessor.register(VibeVoiceConfig, processor_class=Qwen2AudioProcessor)
+ except Exception:
+ pass # May already be registered
+
+ # Register the model class with the architecture name "VibeVoice"
+ # This name must match the "architectures" list in config.json
+ ModelRegistry.register_model("VibeVoice", VibeVoiceForCausalLM)
+ ModelRegistry.register_model("VibeVoiceForASRTraining", VibeVoiceForCausalLM)
+
+
+# Note: This function is called via vllm.general_plugins entry point
+# defined in pyproject.toml, ensuring it runs in all vLLM processes
diff --git a/vllm_plugin/inputs.py b/vllm_plugin/inputs.py
new file mode 100644
index 0000000..e801484
--- /dev/null
+++ b/vllm_plugin/inputs.py
@@ -0,0 +1,82 @@
+"""Audio input mapper for vLLM multimodal pipeline.
+
+This module handles audio data loading and preprocessing for VibeVoice ASR inference.
+It converts various audio input formats (path, bytes, numpy array) into tensors
+that can be processed by the VibeVoice model.
+"""
+import torch
+import numpy as np
+from typing import Union, List
+from vllm.multimodal.inputs import MultiModalInputs
+from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
+
+
+def load_audio(audio_path: str, target_sr: int = 24000) -> np.ndarray:
+ """Load and normalize audio from file path.
+
+ Args:
+ audio_path: Path to audio file
+ target_sr: Target sample rate (default 24kHz for VibeVoice)
+
+ Returns:
+ Normalized audio waveform as numpy array
+ """
+ # Load with FFmpeg (handles various formats)
+ audio, sr = load_audio_use_ffmpeg(audio_path, resample=True, target_sr=target_sr)
+
+ # Normalize audio
+ normalizer = AudioNormalizer()
+ audio = normalizer(audio)
+
+ return audio
+
+
+def vibevoice_audio_input_mapper(ctx, data: Union[str, bytes, np.ndarray, List[str]]) -> MultiModalInputs:
+ """Map audio input data to vLLM MultiModalInputs format.
+
+ This function is registered as the input mapper for VibeVoice audio processing.
+ It handles multiple input formats and converts them to normalized tensors.
+
+ Args:
+ ctx: vLLM context (unused)
+ data: Audio data in one of these formats:
+ - str: Path to audio file
+ - bytes: Raw audio bytes (any format FFmpeg supports)
+ - np.ndarray: Pre-loaded audio waveform
+ - List[str]: List of audio paths (only first is used)
+
+ Returns:
+ MultiModalInputs containing:
+ - audio: Audio tensor (float32)
+ - audio_length: Length of audio in samples
+ """
+ # Handle list input (take first item)
+ if isinstance(data, list):
+ data = data[0]
+
+ audio_waveform = None
+
+ if isinstance(data, str):
+ # Load from file path
+ audio_waveform = load_audio(data)
+
+ elif isinstance(data, bytes):
+ # Decode bytes directly via ffmpeg stdin pipe to avoid temp-file IO
+ audio_waveform, _sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
+ normalizer = AudioNormalizer()
+ audio_waveform = normalizer(audio_waveform)
+
+ elif isinstance(data, np.ndarray):
+ # Already loaded numpy array
+ audio_waveform = data
+ else:
+ raise ValueError(f"Unsupported audio data type: {type(data)}")
+
+ # Convert to tensor
+ audio_tensor = torch.from_numpy(audio_waveform).float()
+ audio_length = audio_tensor.shape[0]
+
+ return MultiModalInputs({
+ "audio": audio_tensor,
+ "audio_length": audio_length
+ })
diff --git a/vllm_plugin/model.py b/vllm_plugin/model.py
new file mode 100644
index 0000000..bcb4ca3
--- /dev/null
+++ b/vllm_plugin/model.py
@@ -0,0 +1,1329 @@
+"""
+VibeVoice vLLM Plugin Model - Native Multimodal Integration
+
+This module implements the VibeVoice ASR model with full vLLM multimodal registry
+integration for speech-to-text inference.
+"""
+
+from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
+import json
+import math
+import os
+import sys
+from pathlib import Path
+import torch
+import torch.nn as nn
+import numpy as np
+from io import BytesIO
+import tempfile
+import base64
+
+
+# ============================================================================
+# Audio Loading: FFmpeg-based AudioMediaIO
+# ============================================================================
+# VibeVoice uses FFmpeg for audio decoding to ensure consistent behavior
+# across different audio formats (MP3, WAV, FLAC, etc.).
+
+
+from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
+
+
+def _suffix_from_media_type(media_type: str | None) -> str:
+ if not media_type:
+ return ".bin"
+ mt = media_type.lower().strip()
+ if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
+ return ".wav"
+ if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
+ return ".mp3"
+ if mt in ("audio/flac",):
+ return ".flac"
+ if mt in ("audio/ogg", "audio/opus"):
+ return ".ogg"
+ if mt in ("audio/mp4", "audio/m4a"):
+ return ".m4a"
+ if mt in ("video/mp4",):
+ return ".mp4"
+ return ".bin"
+
+
+def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
+ """Load audio bytes using FFmpeg.
+
+ Returns:
+ Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
+ """
+ # Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
+ audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
+ normalizer = AudioNormalizer()
+ audio = normalizer(audio)
+ return audio, sr
+
+def _ffmpeg_load_file(filepath) -> tuple[np.ndarray, int]:
+ """Load audio file using FFmpeg.
+
+ Returns:
+ Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
+ """
+ audio, sr = load_audio_use_ffmpeg(str(filepath), resample=True, target_sr=24000)
+ normalizer = AudioNormalizer()
+ audio = normalizer(audio)
+ return audio, sr
+
+# Register FFmpeg-based audio loader
+import vllm.multimodal.audio as _vllm_audio_module
+_OriginalAudioMediaIO = _vllm_audio_module.AudioMediaIO
+
+class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
+ """AudioMediaIO implementation using FFmpeg for audio decoding."""
+
+ def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
+ return _ffmpeg_load_bytes(data, media_type=None)
+
+ def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
+ return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
+
+ def load_file(self, filepath) -> tuple[np.ndarray, int]:
+ return _ffmpeg_load_file(filepath)
+
+# Replace globally
+_vllm_audio_module.AudioMediaIO = _PatchedAudioMediaIO
+
+# Also patch in utils module where it's imported
+import vllm.multimodal.utils as _vllm_utils_module
+_vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
+
+# ============================================================================
+
+from transformers import Qwen2Config, BatchFeature
+from transformers.models.whisper import WhisperFeatureExtractor
+from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.config import VllmConfig, ModelConfig
+from vllm.config.speech_to_text import SpeechToTextConfig
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.parse import MultiModalDataParser
+from vllm.sequence import IntermediateTensors
+from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
+from vllm.inputs import PromptType
+from vllm.model_executor.models.utils import (
+ init_vllm_registered_model,
+ maybe_prefix,
+ AutoWeightsLoader,
+ WeightsMapper,
+)
+from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
+from vllm.multimodal.processing import (
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+ PromptUpdateDetails,
+)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
+
+# Import VibeVoice components
+from vibevoice.modular.modular_vibevoice_tokenizer import (
+ VibeVoiceAcousticTokenizerModel,
+ VibeVoiceSemanticTokenizerModel,
+ VibeVoiceTokenizerStreamingCache,
+ VibeVoiceTokenizerEncoderOutput,
+)
+from vibevoice.modular.configuration_vibevoice import (
+ VibeVoiceAcousticTokenizerConfig,
+ VibeVoiceSemanticTokenizerConfig,
+)
+
+
+class SpeechConnector(nn.Module):
+ """Projects speech features to language model hidden dimension.
+
+ Architecture: fc1 -> RMSNorm -> fc2 (no activation function)
+ """
+ def __init__(self, input_dim: int, output_dim: int):
+ super().__init__()
+ self.fc1 = nn.Linear(input_dim, output_dim)
+ self.norm = LlamaRMSNorm(output_dim, eps=1e-6)
+ self.fc2 = nn.Linear(output_dim, output_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.fc1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ return x
+
+
+class LlamaRMSNorm(nn.Module):
+ """RMSNorm layer used in SpeechConnector."""
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class VibeVoiceAudioEncoder(nn.Module):
+ """
+ VibeVoice Audio Encoder module.
+
+ Encapsulates Acoustic/Semantic VAE Tokenizers and projection Connectors.
+ Converts raw audio waveforms into embeddings compatible with the language model.
+
+ Features:
+ - Streaming support for long audio (>60s by default)
+ - Configurable dtype for numerical precision
+ - Supports both sampling and deterministic (mean) modes
+ """
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ import sys
+
+ def get_cfg(obj, key, default=None):
+ if isinstance(obj, dict):
+ return obj.get(key, default)
+ return getattr(obj, key, default)
+
+ self.acoustic_vae_dim = get_cfg(config, "acoustic_vae_dim", 64)
+ self.semantic_vae_dim = get_cfg(config, "semantic_vae_dim", 128)
+
+ decoder_config = get_cfg(config, "decoder_config")
+ text_config = get_cfg(config, "text_config")
+
+ target_hidden_size = None
+
+ if decoder_config is not None:
+ target_hidden_size = get_cfg(decoder_config, "hidden_size")
+
+ if target_hidden_size is None and text_config is not None:
+ target_hidden_size = get_cfg(text_config, "hidden_size")
+
+ if target_hidden_size is None:
+ target_hidden_size = get_cfg(config, "hidden_size")
+
+ if target_hidden_size is None:
+ print("[VibeVoice] WARN: Could not find hidden_size in config! Defaulting to 3584 (7B).", file=sys.stderr)
+ self.hidden_size = 3584
+ else:
+ self.hidden_size = target_hidden_size
+
+ ac_cfg = get_cfg(config, "acoustic_tokenizer_config")
+ sc_cfg = get_cfg(config, "semantic_tokenizer_config")
+
+ if ac_cfg is None or sc_cfg is None:
+ raise ValueError("Missing acoustic/semantic tokenizer config in model config")
+
+ # Handle both dict and already-constructed config objects
+ if isinstance(ac_cfg, VibeVoiceAcousticTokenizerConfig):
+ acoustic_config = ac_cfg
+ elif isinstance(ac_cfg, dict):
+ acoustic_config = VibeVoiceAcousticTokenizerConfig(**ac_cfg)
+ else:
+ raise TypeError(f"acoustic_tokenizer_config has unexpected type: {type(ac_cfg)}")
+
+ if isinstance(sc_cfg, VibeVoiceSemanticTokenizerConfig):
+ semantic_config = sc_cfg
+ elif isinstance(sc_cfg, dict):
+ semantic_config = VibeVoiceSemanticTokenizerConfig(**sc_cfg)
+ else:
+ raise TypeError(f"semantic_tokenizer_config has unexpected type: {type(sc_cfg)}")
+
+ # Tokenizers use float32 for numerical precision
+ self.acoustic_tokenizer = VibeVoiceAcousticTokenizerModel(acoustic_config)
+ self.semantic_tokenizer = VibeVoiceSemanticTokenizerModel(semantic_config)
+
+ # Get audio encoder dtype from config (defaults to float32 for precision)
+ root_torch_dtype = get_cfg(config, "torch_dtype", None)
+ if root_torch_dtype is not None:
+ if isinstance(root_torch_dtype, str):
+ self._audio_encoder_dtype = getattr(torch, root_torch_dtype)
+ else:
+ self._audio_encoder_dtype = root_torch_dtype
+ else:
+ self._audio_encoder_dtype = torch.float32
+
+ self.acoustic_connector = SpeechConnector(self.acoustic_vae_dim, self.hidden_size)
+ self.semantic_connector = SpeechConnector(self.semantic_vae_dim, self.hidden_size)
+
+ self.compress_ratio = get_cfg(config, "speech_tok_compress_ratio", 3200)
+
+ # Streaming controls
+ self.sample_rate = get_cfg(config, "target_sample_rate", 24000)
+
+ # Default to True (per requirement): segment + cache inside one forward call.
+ self.enable_streaming = get_cfg(config, "enable_streaming", True)
+ self.streaming_segment_duration = get_cfg(config, "streaming_segment_duration", 60.0)
+
+ # Control whether to use sample() or .mean for acoustic tokens
+ # Default: use sample() for training-consistent behavior
+ # Set VIBEVOICE_USE_MEAN=1 for deterministic output
+ use_mean_env = os.getenv("VIBEVOICE_USE_MEAN", "").strip().lower()
+ self.use_sample = use_mean_env not in ("1", "true", "yes")
+
+ # Language model dtype (set by VibeVoiceForCausalLM.__init__)
+ # This is the dtype that audio embeddings will be converted to before
+ # being passed to the language model. Defaults to bfloat16.
+ self._lm_dtype: torch.dtype = torch.bfloat16
+
+ def _ensure_audio_encoder_dtype(self):
+ """Ensure all audio encoder components use the correct dtype from config.
+
+ vLLM may convert weights to a different dtype (e.g., bfloat16) during loading.
+ This method converts audio encoder components back to the config-specified dtype
+ (typically float32) for numerical precision during audio encoding.
+ """
+ import sys
+ target_dtype = self._audio_encoder_dtype
+
+ # Check and convert acoustic_tokenizer
+ try:
+ acoustic_dtype = next(self.acoustic_tokenizer.parameters()).dtype
+ if acoustic_dtype != target_dtype:
+ self.acoustic_tokenizer = self.acoustic_tokenizer.to(dtype=target_dtype)
+ print(f"[VibeVoice] Converted acoustic_tokenizer to {target_dtype} (was {acoustic_dtype})", file=sys.stderr)
+ except StopIteration:
+ pass
+
+ # Check and convert semantic_tokenizer
+ try:
+ semantic_dtype = next(self.semantic_tokenizer.parameters()).dtype
+ if semantic_dtype != target_dtype:
+ self.semantic_tokenizer = self.semantic_tokenizer.to(dtype=target_dtype)
+ print(f"[VibeVoice] Converted semantic_tokenizer to {target_dtype} (was {semantic_dtype})", file=sys.stderr)
+ except StopIteration:
+ pass
+
+ # Check and convert acoustic_connector
+ try:
+ ac_conn_dtype = next(self.acoustic_connector.parameters()).dtype
+ if ac_conn_dtype != target_dtype:
+ self.acoustic_connector = self.acoustic_connector.to(dtype=target_dtype)
+ print(f"[VibeVoice] Converted acoustic_connector to {target_dtype} (was {ac_conn_dtype})", file=sys.stderr)
+ except StopIteration:
+ pass
+
+ # Check and convert semantic_connector
+ try:
+ sc_conn_dtype = next(self.semantic_connector.parameters()).dtype
+ if sc_conn_dtype != target_dtype:
+ self.semantic_connector = self.semantic_connector.to(dtype=target_dtype)
+ print(f"[VibeVoice] Converted semantic_connector to {target_dtype} (was {sc_conn_dtype})", file=sys.stderr)
+ except StopIteration:
+ pass
+
+ def forward(
+ self,
+ audio: torch.Tensor,
+ *,
+ use_streaming: bool = True,
+ segment_duration_s: Optional[float] = None,
+ use_sample: Optional[bool] = None,
+ ) -> torch.Tensor:
+ """Encode audio with optional streaming for long clips.
+
+ Args:
+ audio: Input audio tensor [B, T] or [T]
+ use_streaming: Whether to enable segmented encoding for long audio
+ segment_duration_s: Segment length in seconds (defaults to 60s)
+ use_sample: If True, use sampling for acoustic tokens; if False, use mean
+ Defaults to self.use_sample (controlled by VIBEVOICE_USE_MEAN env var)
+
+ Returns:
+ Audio embeddings tensor compatible with the language model
+ """
+ # Ensure audio encoder components use correct dtype
+ self._ensure_audio_encoder_dtype()
+
+ # Audio input should match the audio encoder dtype
+ audio = audio.to(dtype=self._audio_encoder_dtype)
+
+ if audio.ndim == 1:
+ audio = audio.unsqueeze(0)
+
+ # Resolve streaming options
+ segment_duration = segment_duration_s or self.streaming_segment_duration
+ sample_rate = self.sample_rate
+ total_samples = audio.shape[-1]
+ segment_samples = int(segment_duration * sample_rate)
+
+ use_streaming = use_streaming and self.enable_streaming and total_samples > segment_samples
+
+ # Resolve use_sample flag
+ if use_sample is None:
+ use_sample = self.use_sample
+
+ # Keep encoding in inference mode to avoid autograd build-up
+ with torch.no_grad():
+ if not use_streaming:
+ acoustic_input = audio.unsqueeze(1)
+ acoustic_out = self.acoustic_tokenizer.encode(acoustic_input)
+ # Use sample() or .mean based on use_sample flag
+ if use_sample:
+ acoustic_tokens = acoustic_out.sample(
+ dist_type=self.acoustic_tokenizer.std_dist_type
+ )[0]
+ else:
+ acoustic_tokens = acoustic_out.mean
+
+ # Connector is now float32, no conversion needed
+ acoustic_embeds = self.acoustic_connector(acoustic_tokens)
+
+ semantic_out = self.semantic_tokenizer.encode(acoustic_input)
+ # Semantic always uses .mean for consistency
+ semantic_tokens = semantic_out.mean
+ # Connector is now float32, no conversion needed
+ semantic_embeds = self.semantic_connector(semantic_tokens)
+ else:
+ # ==========================================
+ # Streaming path (Retained for future use)
+ # ==========================================
+ acoustic_cache = VibeVoiceTokenizerStreamingCache()
+ semantic_cache = VibeVoiceTokenizerStreamingCache()
+ acoustic_mean_segments = []
+ semantic_mean_segments = []
+ batch_size = audio.shape[0]
+ sample_indices = torch.arange(batch_size, device=audio.device)
+
+ def _iter_segments(total_length: int, segment_length: int):
+ for start in range(0, total_length, segment_length):
+ end = min(start + segment_length, total_length)
+ if end > start:
+ yield start, end
+
+ segments = list(_iter_segments(total_samples, segment_samples))
+ num_segments = len(segments)
+ for seg_idx, (start, end) in enumerate(segments):
+ chunk = audio[:, start:end].contiguous()
+ if chunk.numel() == 0:
+ continue
+
+ # Check if this is the final segment
+ is_final = (seg_idx == num_segments - 1)
+
+ # --- Acoustic Encode ---
+ acoustic_enc_out = self.acoustic_tokenizer.encode(
+ chunk.unsqueeze(1),
+ cache=acoustic_cache,
+ sample_indices=sample_indices,
+ use_cache=True,
+ is_final_chunk=is_final,
+ )
+ acoustic_mean_segments.append(acoustic_enc_out.mean)
+
+ # --- Semantic Encode ---
+ semantic_enc_out = self.semantic_tokenizer.encode(
+ chunk.unsqueeze(1),
+ cache=semantic_cache,
+ sample_indices=sample_indices,
+ use_cache=True,
+ is_final_chunk=is_final,
+ )
+ semantic_mean_segments.append(semantic_enc_out.mean)
+
+ # Concatenate sequence outputs (Acoustic)
+ if len(acoustic_mean_segments) == 0:
+ acoustic_mean_full = torch.zeros(
+ (batch_size, 0, self.acoustic_vae_dim),
+ device=audio.device,
+ dtype=self._audio_encoder_dtype # Use config dtype
+ )
+ else:
+ acoustic_mean_full = torch.cat(acoustic_mean_segments, dim=1).contiguous()
+
+ # Get acoustic tokens based on use_sample flag
+ acoustic_enc_full = VibeVoiceTokenizerEncoderOutput(
+ mean=acoustic_mean_full,
+ std=self.acoustic_tokenizer.fix_std,
+ )
+ if use_sample:
+ acoustic_tokens = acoustic_enc_full.sample(
+ dist_type=self.acoustic_tokenizer.std_dist_type
+ )[0]
+ else:
+ acoustic_tokens = acoustic_enc_full.mean
+ # Connector uses same dtype as tokenizer
+ acoustic_embeds = self.acoustic_connector(acoustic_tokens)
+
+ # Concatenate sequence outputs (Semantic)
+ if len(semantic_mean_segments) == 0:
+ semantic_tokens = torch.zeros(
+ (batch_size, 0, self.semantic_vae_dim),
+ device=audio.device,
+ dtype=self._audio_encoder_dtype # Use config dtype
+ )
+ else:
+ semantic_tokens = torch.cat(semantic_mean_segments, dim=1).contiguous()
+ # Connector uses same dtype as tokenizer
+ semantic_embeds = self.semantic_connector(semantic_tokens)
+
+ # Combine acoustic and semantic embeddings
+ combined_embeds = acoustic_embeds + semantic_embeds
+
+ # Convert to language model dtype for compatibility
+ # Audio encoder uses config.torch_dtype (typically float32) for numerical precision,
+ # but LM expects the dtype specified by vLLM's --dtype flag (e.g., bfloat16, float16)
+ combined_embeds = combined_embeds.to(dtype=self._lm_dtype)
+
+ return combined_embeds
+
+# ============================================================================
+# vLLM Multimodal Processing Infrastructure
+# ============================================================================
+
+class VibeVoiceProcessingInfo(BaseProcessingInfo):
+ """Processing info for VibeVoice multimodal model."""
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config()
+
+ def get_feature_extractor(self, **kwargs) -> WhisperFeatureExtractor:
+ """
+ Get a WhisperFeatureExtractor for vLLM profiling compatibility.
+
+ IMPORTANT: This is NOT used in actual inference!
+ VibeVoice uses its own acoustic/semantic VAE tokenizers operating on
+ raw 24kHz waveforms, NOT Whisper mel spectrograms.
+
+ This feature extractor exists only to satisfy vLLM's multimodal
+ profiling infrastructure which may query audio parameters like
+ sampling_rate and chunk_length for memory estimation.
+ """
+ # Read config from preprocessor_config.json if available
+ import json
+ import os
+ model_path = self.ctx.model_config.model
+ preprocessor_path = os.path.join(model_path, "preprocessor_config.json")
+
+ # Default values: keep a coherent (sr, hop) pair.
+ # VibeVoice runs at 24kHz in this repo (see demo/asr_transcribe_file.py).
+ config = {
+ "sampling_rate": 24000,
+ "feature_size": 128,
+ # 10ms hop at 24kHz
+ "hop_length": 240,
+ "chunk_length": 30,
+ "n_fft": 400,
+ "padding_value": 0.0,
+ }
+
+ # Try to load from config file
+ if os.path.exists(preprocessor_path):
+ try:
+ with open(preprocessor_path, "r") as f:
+ file_config = json.load(f)
+ config.update({k: file_config[k] for k in config.keys() if k in file_config})
+ except Exception:
+ pass # Use defaults
+
+ return WhisperFeatureExtractor(
+ feature_size=config["feature_size"],
+ sampling_rate=config["sampling_rate"],
+ hop_length=config["hop_length"],
+ chunk_length=config["chunk_length"],
+ n_fft=config["n_fft"],
+ padding_value=config["padding_value"],
+ )
+
+ def get_audio_token_info(self) -> dict:
+ """
+ Get audio special tokens and their IDs.
+
+ Returns dict with:
+ audio_token, audio_bos_token, audio_eos_token,
+ audio_token_id, audio_bos_id, audio_eos_id
+ """
+ tokenizer = self.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+
+ # VibeVoice special tokens
+ tokens = {
+ "audio_token": "<|AUDIO|>",
+ "audio_bos_token": "<|audio_bos|>",
+ "audio_eos_token": "<|audio_eos|>",
+ }
+
+ # Get IDs
+ tokens["audio_token_id"] = vocab.get(tokens["audio_token"])
+ tokens["audio_bos_id"] = vocab.get(tokens["audio_bos_token"])
+ tokens["audio_eos_id"] = vocab.get(tokens["audio_eos_token"])
+
+ return tokens
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"audio": None}
+
+
+class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
+ """
+ Build dummy inputs for multimodal profiling.
+
+ Dummy text uses the raw <|AUDIO|> token(s). vLLM's processing pipeline will
+ expand each <|AUDIO|> via `VibeVoiceMultiModalProcessor._get_prompt_updates`
+ into the full ASR format:
+ [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
+ where N is derived from audio length / compress_ratio.
+ """
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_audios = mm_counts.get("audio", 0)
+ if num_audios <= 0:
+ return ""
+
+ # Get the audio token from our token info helper
+ token_info = self.info.get_audio_token_info()
+ audio_token = token_info["audio_token"]
+
+ # Return ONLY the audio tokens - the HF processor adds bos/eos
+ return audio_token * num_audios
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, Any] | None = None,
+ ) -> Dict[str, Any]:
+ """Generate dummy audio data for profiling."""
+ feature_extractor = self.info.get_feature_extractor()
+
+ sampling_rate = feature_extractor.sampling_rate
+ audio_len = feature_extractor.chunk_length * sampling_rate
+ num_audios = mm_counts.get("audio", 0)
+
+ # Generate dummy audio as numpy arrays (what the HF processor expects)
+ return {
+ "audio": [np.zeros(audio_len, dtype=np.float32) for _ in range(num_audios)]
+ }
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, Any] | None = None,
+ ) -> ProcessorInputs:
+ """Build ProcessorInputs for dummy profiling."""
+ return ProcessorInputs(
+ prompt=self.get_dummy_text(mm_counts),
+ mm_data=self.get_dummy_mm_data(seq_len, mm_counts, mm_options),
+ )
+
+
+def _vibevoice_field_config(hf_inputs: Mapping[str, torch.Tensor]):
+ """Map HF processor output keys to audio modality.
+
+ Returns a config dict that tells vLLM how to batch multimodal data.
+ """
+ # Always define the config for all fields we use
+ # Even if the field isn't in hf_inputs, vLLM needs to know how to batch it
+ config = {
+ # These are our custom fields for VibeVoice
+ "raw_audio": MultiModalFieldConfig.batched("audio"),
+ "raw_audio_lengths": MultiModalFieldConfig.batched("audio"),
+ "salt": MultiModalFieldConfig.batched("audio"),
+ }
+
+ # Add optional Whisper features if present
+ if "input_features" in hf_inputs:
+ config["input_features"] = MultiModalFieldConfig.batched("audio")
+ if "feature_attention_mask" in hf_inputs:
+ config["feature_attention_mask"] = MultiModalFieldConfig.batched("audio")
+
+ return config
+
+
+class VibeVoiceMultiModalProcessor(BaseMultiModalProcessor[VibeVoiceProcessingInfo]):
+ """
+ Multimodal processor for VibeVoice.
+
+ Handles the conversion of raw audio inputs to model-ready features,
+ and manages the prompt token replacement for audio placeholders.
+ """
+
+ def _get_data_parser(self) -> MultiModalDataParser:
+ """Create a data parser with the correct target sample rate (24kHz)."""
+ # VibeVoice requires 24kHz, not 16kHz (Whisper default)
+ target_sr = 24000
+ return MultiModalDataParser(target_sr=target_sr)
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ """
+ Process prompt and audio for vLLM multimodal pipeline.
+
+ We intentionally do NOT run a HF processor that would pre-expand
+ `<|AUDIO|>` inside this method. Instead we:
+ 1) Tokenize the prompt as-is (so `<|AUDIO|>` stays a single token)
+ 2) Store raw audio tensors for `embed_multimodal` to encode later
+ 3) Let vLLM call `_get_prompt_updates` to expand the single `<|AUDIO|>`
+ into the full ASR format: [speech_start] + N*[speech_pad] + [speech_end] + [\\n]
+ """
+ # Handle the case where 'audios' key is used (transformers deprecation)
+ mm_data = dict(mm_data) # Make a mutable copy
+ audios = mm_data.pop("audios", None)
+ if audios is not None and "audio" not in mm_data:
+ mm_data["audio"] = audios
+
+ # Text-only input handling
+ if not mm_data.get("audio"):
+ prompt_ids = self.info.get_tokenizer().encode(prompt)
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ # Get raw audio data
+ raw_audio_list = mm_data.get("audio")
+ if isinstance(raw_audio_list, np.ndarray):
+ raw_audio_list = [raw_audio_list]
+ elif not isinstance(raw_audio_list, list):
+ raw_audio_list = list(raw_audio_list)
+
+ # Tokenize prompt directly to preserve <|AUDIO|> as a single token
+ # vLLM will expand it via _get_prompt_updates
+ tokenizer = self.info.get_tokenizer()
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
+ prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
+
+ # Create result with input_ids
+ result = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ # Add raw audio tensors for VibeVoice encoder
+ # Stack into a single tensor for vLLM's batched field config
+ max_len = max(len(a) for a in raw_audio_list)
+ raw_audio_tensors = []
+ audio_lengths = []
+ for audio in raw_audio_list:
+ audio_len = len(audio)
+ audio_lengths.append(audio_len)
+ if audio_len < max_len:
+ audio = np.pad(audio, (0, max_len - audio_len), mode='constant')
+ raw_audio_tensors.append(torch.from_numpy(audio).float())
+
+ # Stack into [num_audios, max_len] tensor
+ stacked_audio = torch.stack(raw_audio_tensors, dim=0) # Shape: [num_audios, max_len]
+ result["raw_audio"] = stacked_audio
+ # Convert lengths to tensor as well
+ result["raw_audio_lengths"] = torch.tensor(audio_lengths, dtype=torch.long)
+
+ # Add a random salt to ensure unique hash and bypass cache
+ import uuid
+ # Use a random integer for salt
+ salt_val = hash(str(uuid.uuid4())) % 100000
+ result["salt"] = torch.tensor([salt_val], dtype=torch.long).expand(len(raw_audio_list))
+
+ return result
+
+ def _hf_processor_applies_updates(
+ self,
+ prompt_text: str,
+ mm_items,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ tokenization_kwargs: Mapping[str, object],
+ ) -> bool:
+ """Return whether the HF processor applies prompt updates.
+
+ Returns False because we handle token expansion via _get_prompt_updates.
+ """
+ return False
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ """Configure which HF output fields map to which modality."""
+ return _vibevoice_field_config(hf_inputs)
+
+ def _get_prompt_updates(
+ self,
+ mm_items,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ """
+ Define how to replace the audio placeholder in the prompt.
+
+ vLLM's OpenAI multimodal parsing inserts the model placeholder string
+ from `get_placeholder_str` (here: `<|AUDIO|>`) into the conversation.
+ We expand that single token into N repeated `<|AUDIO|>` tokens, where N
+ is derived from waveform length and `speech_tok_compress_ratio`.
+ """
+ token_info = self.info.get_audio_token_info()
+ audio_token = token_info["audio_token"]
+ audio_token_id = token_info["audio_token_id"]
+ audio_bos_id = token_info.get("audio_bos_id")
+ audio_eos_id = token_info.get("audio_eos_id")
+
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+
+ def _tok_id(name: str) -> int | None:
+ return vocab.get(name)
+
+ # Look up speech token IDs from vocabulary
+ # These tokens mark the start/end of audio embeddings in the prompt
+ speech_start_id = (
+ _tok_id("<|object_ref_start|>")
+ or getattr(tokenizer, "speech_start_id", None)
+ or _tok_id("<|speech_start|>")
+ )
+ speech_end_id = (
+ _tok_id("<|object_ref_end|>")
+ or getattr(tokenizer, "speech_end_id", None)
+ or _tok_id("<|speech_end|>")
+ )
+ speech_pad_id = (
+ _tok_id("<|box_start|>")
+ or getattr(tokenizer, "speech_pad_id", None)
+ or _tok_id("<|speech_pad|>")
+ )
+
+ if audio_token_id is None:
+ return []
+
+ # Get raw audio lengths (in samples, after resampling to 24kHz) from our stored data
+ out_mm_data = out_mm_kwargs.get_data()
+ raw_audio_lengths = out_mm_data.get("raw_audio_lengths", [])
+
+ # Fetch defaults from model config when available (falls back to 3200)
+ hf_config = self.info.get_hf_config()
+ if isinstance(hf_config, dict):
+ compress_ratio = int(hf_config.get("speech_tok_compress_ratio", 3200))
+ else:
+ compress_ratio = int(getattr(hf_config, "speech_tok_compress_ratio", 3200))
+
+ def _to_int_len(x) -> int:
+ if x is None:
+ return 0
+ if isinstance(x, torch.Tensor):
+ # Accept 0-dim or 1-dim scalar-like tensors
+ if x.numel() == 1:
+ return int(x.item())
+ # If a full tensor is passed accidentally, fall back to its length
+ return int(x.shape[0])
+ return int(x)
+
+ def get_replacement(item_idx: int):
+ if raw_audio_lengths and item_idx < len(raw_audio_lengths):
+ audio_len = _to_int_len(raw_audio_lengths[item_idx])
+ num_features = max(1, int(np.ceil(audio_len / compress_ratio)))
+ else:
+ # Fallback: estimate for 30 second audio at 24kHz
+ num_features = int(np.ceil(30 * 24000 / compress_ratio))
+
+ if num_features == 0:
+ raise ValueError(
+ f"Audio at index {item_idx} is too short to be represented"
+ )
+
+ # Build replacement token sequence:
+ # <|speech_start|> + N * <|speech_pad|> + <|speech_end|> + \n
+ # The newline is important for correct prompt structure.
+ newline_id = 198 # '\n' token
+ if speech_start_id is not None and speech_pad_id is not None and speech_end_id is not None:
+ embed_id = int(speech_pad_id)
+ replacement_ids = [int(speech_start_id)] + [embed_id] * num_features + [int(speech_end_id), newline_id]
+ # Fallback: add audio BOS/EOS boundaries around repeated <|AUDIO|>.
+ elif audio_bos_id is not None and audio_eos_id is not None:
+ embed_id = int(audio_token_id)
+ replacement_ids = [int(audio_bos_id)] + [embed_id] * num_features + [int(audio_eos_id)]
+ else:
+ embed_id = int(audio_token_id)
+ replacement_ids = [embed_id] * num_features
+
+ return PromptUpdateDetails.select_token_id(
+ replacement_ids,
+ embed_token_id=int(embed_id),
+ )
+
+ return [
+ PromptReplacement(
+ modality="audio",
+ # Keep string placeholder matching for maximum vLLM compatibility.
+ target=audio_token,
+ replacement=get_replacement,
+ )
+ ]
+
+
+# ============================================================================
+# Main Model Class
+# ============================================================================
+
+@MULTIMODAL_REGISTRY.register_processor(
+ VibeVoiceMultiModalProcessor,
+ info=VibeVoiceProcessingInfo,
+ dummy_inputs=VibeVoiceDummyInputsBuilder,
+)
+class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
+ """
+ VibeVoice ASR model with native vLLM multimodal integration.
+
+ This model combines VibeVoice acoustic/semantic tokenizers for audio encoding
+ with a causal language model for text generation.
+ """
+
+ # SupportsTranscription interface
+ supports_transcription: ClassVar[Literal[True]] = True
+ supports_transcription_only: ClassVar[bool] = False
+ supports_segment_timestamp: ClassVar[bool] = False
+
+ # Supported languages (Chinese as primary target)
+ supported_languages: ClassVar[Mapping[str, str]] = {
+ "zh": "Chinese",
+ "en": "English",
+ }
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ """Return the placeholder string format for a given modality.
+
+ Returns "<|AUDIO|>" which vLLM inserts into the conversation prompt.
+ This single placeholder is later expanded by `_get_prompt_updates` into:
+ [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
+ where N = ceil(audio_samples / compress_ratio).
+ """
+ if modality.startswith("audio"):
+ return "<|AUDIO|>"
+ raise ValueError(f"Unsupported modality: {modality}")
+
+ @classmethod
+ def get_generation_prompt(
+ cls,
+ audio: np.ndarray,
+ stt_config: SpeechToTextConfig,
+ model_config: ModelConfig,
+ language: str | None,
+ task_type: Literal["transcribe", "translate"],
+ request_prompt: str,
+ to_language: str | None,
+ ) -> PromptType:
+ """Get the prompt for the ASR model.
+
+ Generates a chat-formatted prompt for speech-to-text transcription
+ with JSON output format.
+ """
+ # If user provides custom prompt, use it
+ if request_prompt:
+ return request_prompt
+
+ # Calculate audio duration for the prompt
+ # Audio should be at 24kHz, so duration = len(audio) / 24000
+ duration = len(audio) / 24000 if audio is not None else 10.0
+
+ system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
+ show_keys = ["Start time", "End time", "Speaker ID", "Content"]
+ user_suffix = (
+ f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ + ", ".join(show_keys)
+ )
+
+ # IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
+ # `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
+ user_content = "<|AUDIO|>\n" + user_suffix
+
+ prompt = (
+ f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
+ f"<|im_start|>user\n{user_content}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+ return prompt
+
+ @classmethod
+ def get_speech_to_text_config(
+ cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
+ ) -> SpeechToTextConfig:
+ """Get the speech to text config for the ASR model."""
+ return SpeechToTextConfig(
+ language=None, # Auto-detect or use request language
+ task_type=task_type,
+ )
+
+ @classmethod
+ def get_num_audio_tokens(
+ cls,
+ audio_duration_s: float,
+ stt_config: SpeechToTextConfig,
+ model_config: ModelConfig,
+ ) -> int | None:
+ """Estimate number of audio tokens from duration.
+
+ Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
+ Note: _get_prompt_updates actually generates:
+ [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
+ So total prompt tokens = N + 3, but this returns N (the embedding count).
+ """
+ sampling_rate = 24000
+ compress_ratio = 3200
+ samples = int(audio_duration_s * sampling_rate)
+ num_tokens = int(np.ceil(samples / compress_ratio))
+ return num_tokens
+
+ @classmethod
+ def get_other_languages(cls) -> Mapping[str, str]:
+ """Get languages from Whisper map not natively supported."""
+ # Import LANGUAGES from vllm
+ try:
+ from vllm.transformers_utils.tokenizer import LANGUAGES
+ except ImportError:
+ # Fallback to empty dict if import fails
+ return {}
+ return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
+
+ @classmethod
+ def validate_language(cls, language: str | None) -> str | None:
+ """Validate the language code."""
+ if language is None or language in cls.supported_languages:
+ return language
+ elif language in cls.get_other_languages():
+ print(f"Warning: Language {language!r} is not natively supported")
+ return language
+ else:
+ raise ValueError(
+ f"Unsupported language: {language!r}. "
+ f"Supported: {list(cls.supported_languages.keys())}"
+ )
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ self.config = config
+
+ # Keep a copy of the resolved model path for any custom weight-loading
+ # logic (e.g., loading audio encoder weights in fp32 directly from
+ # safetensors shards).
+ self._model_path = vllm_config.model_config.model
+
+ self.audio_encoder = VibeVoiceAudioEncoder(config)
+
+ # Pass decoder_config to the language model initialization
+ decoder_config = getattr(config, "decoder_config", config)
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=decoder_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Qwen2ForCausalLM"],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors
+ )
+
+ # Set the language model dtype for audio encoder output conversion
+ # This should match vLLM's --dtype flag (e.g., bfloat16, float16, float32)
+ # Audio encoder internal computation stays in fp32 for numerical precision,
+ # but output is converted to LM dtype for compatibility
+ lm_dtype = vllm_config.model_config.dtype
+ if lm_dtype is not None:
+ self.audio_encoder._lm_dtype = lm_dtype
+
+ # Ensure audio encoder uses correct dtype (typically fp32 for precision)
+ try:
+ self.audio_encoder._ensure_audio_encoder_dtype()
+ except Exception:
+ pass
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states)
+
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ """
+ Extract audio embeddings using VibeVoice's acoustic/semantic tokenizers.
+
+ Called by vLLM to get audio embeddings that replace audio placeholder tokens.
+
+ Returns:
+ Tuple of embedding tensors, one per audio input.
+ """
+ # Get raw audio data (stored by our processor)
+ raw_audio = kwargs.get("raw_audio")
+ raw_audio_lengths = kwargs.get("raw_audio_lengths")
+
+ # Handle no audio input - this happens during memory profiling
+ if raw_audio is None:
+ return []
+
+ # Handle empty audio list
+ if isinstance(raw_audio, (list, tuple)) and len(raw_audio) == 0:
+ return []
+
+ # Flatten raw_audio_lengths if it's nested
+ def flatten_lengths(lengths):
+ """Flatten nested lists/tensors of lengths to a single list."""
+ if lengths is None:
+ return []
+
+ result = []
+ if isinstance(lengths, torch.Tensor):
+ lengths = lengths.tolist()
+
+ if isinstance(lengths, (list, tuple)):
+ for item in lengths:
+ if isinstance(item, (list, tuple)):
+ result.extend(flatten_lengths(item))
+ elif isinstance(item, torch.Tensor):
+ if item.dim() == 0:
+ result.append(item.item())
+ else:
+ result.extend(item.tolist())
+ else:
+ result.append(item)
+ else:
+ result.append(lengths)
+ return result
+
+ raw_audio_lengths = flatten_lengths(raw_audio_lengths)
+
+ # Streaming controls. Enabled by default; can be overridden per-call.
+ use_streaming_flag = bool(
+ kwargs.get(
+ "use_streaming",
+ getattr(self.audio_encoder, "enable_streaming", True),
+ )
+ )
+ streaming_segment_duration = kwargs.get(
+ "streaming_segment_duration",
+ getattr(self.audio_encoder, "streaming_segment_duration", 60.0),
+ )
+
+ # Process each audio through the VibeVoice encoder
+ embeddings = []
+
+ # Get model device and dtype for alignment
+ try:
+ device = next(self.audio_encoder.parameters()).device
+ dtype = next(self.audio_encoder.parameters()).dtype
+ except StopIteration:
+ # Fallback if no parameters (shouldn't happen)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dtype = torch.bfloat16
+
+ # Handle both stacked tensor and list of tensors
+ # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
+ if isinstance(raw_audio, torch.Tensor):
+ if raw_audio.dim() == 3:
+ # Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
+ num_audios = raw_audio.shape[0]
+ audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
+ elif raw_audio.dim() == 2:
+ # Shape: [batch_size, seq_len]
+ num_audios = raw_audio.shape[0]
+ audio_list = [raw_audio[i] for i in range(num_audios)]
+ else:
+ # Single 1D tensor
+ audio_list = [raw_audio]
+ elif isinstance(raw_audio, (list, tuple)):
+ audio_list = list(raw_audio)
+ else:
+ # Single tensor
+ audio_list = [raw_audio]
+
+ for i, audio_tensor in enumerate(audio_list):
+ try:
+ if isinstance(audio_tensor, list):
+ audio_tensor = torch.stack(audio_tensor)
+
+ # Ensure tensor
+ if not isinstance(audio_tensor, torch.Tensor):
+ audio_tensor = torch.tensor(audio_tensor)
+
+ # Let vLLM handle dtype (bfloat16 by default)
+ audio_tensor = audio_tensor.to(device=device)
+
+ # Get actual length if available, otherwise use full length
+ if raw_audio_lengths and i < len(raw_audio_lengths):
+ actual_len = int(raw_audio_lengths[i])
+ if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
+ # Truncate from the last dimension (sequence length)
+ audio_tensor = audio_tensor[..., :actual_len]
+
+ # Skip if audio is too short (< 1 frame)
+ if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
+ continue
+
+ # Encode audio through VibeVoice encoder
+ audio_embeds = self.audio_encoder(
+ audio_tensor,
+ use_streaming=use_streaming_flag,
+ segment_duration_s=streaming_segment_duration,
+ )
+
+ # audio_embeds shape: [1, seq_len, hidden_size]
+ # We need to return it as a single embedding tensor per audio
+ final_embed = audio_embeds.squeeze(0)
+ embeddings.append(final_embed)
+
+ except Exception as e:
+ # Log error but don't crash - this helps debug profiling issues
+ print(f"[VibeVoice] Error encoding audio {i}: {e}")
+ import traceback
+ traceback.print_exc()
+ # Return empty embedding to avoid crash
+ continue
+
+ return tuple(embeddings)
+
+ def get_input_embeddings(self) -> torch.nn.Module:
+ """Return the text embedding layer (embed_tokens).
+
+ vLLM uses this to get the embedding module for converting token IDs
+ to embeddings during decode phase.
+
+ Returns:
+ The embed_tokens module from the language model
+ """
+ # Get embed_tokens from the language model
+ if hasattr(self.language_model, 'model') and hasattr(self.language_model.model, 'embed_tokens'):
+ return self.language_model.model.embed_tokens
+ elif hasattr(self.language_model, 'embed_tokens'):
+ return self.language_model.embed_tokens
+ else:
+ # Try to get from inner model
+ inner = self.language_model
+ if hasattr(inner, 'language_model'):
+ inner = inner.language_model
+ if hasattr(inner, 'model') and hasattr(inner.model, 'embed_tokens'):
+ return inner.model.embed_tokens
+
+ raise AttributeError("Cannot find embed_tokens layer")
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
+ is_multimodal: Optional[torch.Tensor] = None,
+ **kwargs, # Accept any additional kwargs for compatibility
+ ) -> torch.Tensor:
+ """Apply token embeddings to input_ids and merge with multimodal embeddings.
+
+ This is the preferred method in vLLM V1 for converting token IDs
+ to embeddings and merging multimodal (audio) embeddings.
+
+ Args:
+ input_ids: Tensor of token IDs to embed
+ multimodal_embeddings: Pre-computed multimodal embeddings (audio).
+ Can be a Tensor or a List of Tensors (vLLM standard).
+ is_multimodal: Boolean mask indicating which positions are multimodal
+ **kwargs: Additional arguments for compatibility
+
+ Returns:
+ Tensor of embeddings with multimodal content merged in
+ """
+ from vllm.model_executor.models.utils import _merge_multimodal_embeddings
+
+ # Get text embeddings
+ embed_tokens = self.get_input_embeddings()
+ inputs_embeds = embed_tokens(input_ids)
+
+ # Merge multimodal embeddings if provided
+ if multimodal_embeddings is not None and is_multimodal is not None:
+ # Use vLLM's standard merge function which handles List[Tensor] correctly
+ inputs_embeds = _merge_multimodal_embeddings(
+ inputs_embeds,
+ multimodal_embeddings,
+ is_multimodal,
+ )
+
+ return inputs_embeds
+
+ def get_language_model(self) -> torch.nn.Module:
+ """Return the language model backbone."""
+ return self.language_model
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
+ """Load model weights from checkpoint.
+
+ The checkpoint has weights named like:
+ - lm_head.weight -> language_model.lm_head.weight
+ - model.language_model.layers.X... -> language_model.model.layers.X...
+ - model.acoustic_tokenizer... -> audio_encoder.acoustic_tokenizer...
+ - model.semantic_tokenizer... -> audio_encoder.semantic_tokenizer...
+ - model.acoustic_connector... -> audio_encoder.acoustic_connector...
+ - model.semantic_connector... -> audio_encoder.semantic_connector...
+
+ Let vLLM handle all dtype conversions according to --dtype flag.
+ """
+ # Map weight prefixes for VibeVoice
+ # The checkpoint uses "model." prefix, we need to remap it
+ mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # Audio encoder components: model.X -> audio_encoder.X
+ "model.acoustic_tokenizer.": "audio_encoder.acoustic_tokenizer.",
+ "model.semantic_tokenizer.": "audio_encoder.semantic_tokenizer.",
+ "model.acoustic_connector.": "audio_encoder.acoustic_connector.",
+ "model.semantic_connector.": "audio_encoder.semantic_connector.",
+ # Language model: model.language_model.X -> language_model.model.X
+ "model.language_model.": "language_model.model.",
+ # LM head: lm_head.X -> language_model.lm_head.X
+ "lm_head.": "language_model.lm_head.",
+ }
+ )
+
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=mapper)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor],
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ """
+ Forward pass for VibeVoice ASR model.
+
+ Handles embedding computation and language model forward pass.
+ Uses inputs_embeds if provided (from vLLM multimodal merge),
+ otherwise computes embeddings from input_ids.
+
+ Args:
+ input_ids: Token IDs. May be None when inputs_embeds is provided.
+ positions: Position indices for the input tokens.
+ intermediate_tensors: Intermediate tensors for pipeline parallelism.
+ inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
+ """
+ try:
+ # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
+ # Only compute from input_ids if inputs_embeds is not available
+ if inputs_embeds is None and input_ids is not None:
+ # Compute embeddings from input_ids
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # Get the inner model - handle both wrapped and direct language models
+ language_model = self.language_model
+ if hasattr(language_model, "language_model"):
+ language_model = language_model.language_model
+
+ # Call the language model's model (Qwen2Model)
+ # vLLM V1 passes kv_caches and attn_metadata via context, not arguments
+ # IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
+ hidden_states = language_model.model(
+ input_ids=None, # Always None when we have inputs_embeds
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds
+ )
+ return hidden_states
+
+ except Exception as e:
+ raise
+
+
+# Alias for training checkpoint compatibility
+VibeVoiceForASRTraining = VibeVoiceForCausalLM
diff --git a/vllm_plugin/scripts/install_deps.sh b/vllm_plugin/scripts/install_deps.sh
new file mode 100644
index 0000000..1e62f45
--- /dev/null
+++ b/vllm_plugin/scripts/install_deps.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+# Install system dependencies for VibeVoice vLLM plugin
+# Run this script inside the vLLM container before using the plugin
+
+set -e
+
+echo "Installing system dependencies for VibeVoice vLLM plugin..."
+
+# Update package list
+apt-get update
+
+# Install FFmpeg and audio processing libraries
+apt-get install -y \
+ ffmpeg \
+ libsndfile1 \
+ git
+
+echo "ā
System dependencies installed successfully!"
+echo ""
+echo "Next steps:"
+echo " 1. Install VibeVoice: pip install -e .[vllm]"
+echo " 2. Generate tokenizer files (if needed): python -m vllm_plugin.tools.generate_tokenizer_files -o /path/to/model"
+echo " 3. Start vLLM server: vllm serve --trust-remote-code --enforce-eager --no-enable-prefix-caching"
diff --git a/vllm_plugin/tests/test_api.py b/vllm_plugin/tests/test_api.py
new file mode 100644
index 0000000..af128bc
--- /dev/null
+++ b/vllm_plugin/tests/test_api.py
@@ -0,0 +1,254 @@
+#!/usr/bin/env python3
+"""
+Test VibeVoice vLLM API with Streaming (Real-time output).
+
+Usage:
+ python test_api.py [audio_path] [--url URL]
+
+Examples:
+ python test_api.py # Use default audio
+ python test_api.py /path/to/audio.wav # Specify audio file
+ python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
+"""
+import requests
+import json
+import base64
+import time
+import sys
+import os
+import subprocess
+import argparse
+
+
+def _guess_mime_type(path: str) -> str:
+ ext = os.path.splitext(path)[1].lower()
+ if ext == ".wav":
+ return "audio/wav"
+ if ext in (".mp3",):
+ return "audio/mpeg"
+ if ext in (".m4a",):
+ return "audio/mp4"
+ if ext in (".mp4", ".m4v", ".mov", ".webm"):
+ return "video/mp4"
+ if ext in (".flac",):
+ return "audio/flac"
+ if ext in (".ogg", ".opus"):
+ return "audio/ogg"
+ return "application/octet-stream"
+
+
+def _get_duration_seconds_ffprobe(path: str) -> float:
+ """Get audio duration using ffprobe."""
+ cmd = [
+ "ffprobe",
+ "-v",
+ "error",
+ "-show_entries",
+ "format=duration",
+ "-of",
+ "default=noprint_wrappers=1:nokey=1",
+ path,
+ ]
+ out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
+ return float(out)
+
+
+def _extract_audio_from_video(video_path: str) -> str:
+ """
+ Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
+ Returns the path to the extracted audio file.
+ """
+ import tempfile
+ # Create temp file with .mp3 extension
+ fd, audio_path = tempfile.mkstemp(suffix=".mp3")
+ os.close(fd)
+
+ cmd = [
+ "ffmpeg", "-y", "-i", video_path,
+ "-vn", # No video
+ "-acodec", "libmp3lame",
+ "-q:a", "2", # High quality
+ audio_path
+ ]
+ subprocess.run(cmd, check=True, capture_output=True)
+ return audio_path
+
+
+def _is_video_file(path: str) -> bool:
+ """Check if the file is a video file that needs audio extraction."""
+ ext = os.path.splitext(path)[1].lower()
+ return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
+
+
+def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"):
+ """Test ASR transcription with streaming output."""
+
+ print(f"Loading audio from: {audio_path}")
+
+ # Handle video files: extract audio first
+ temp_audio_path = None
+ actual_audio_path = audio_path
+ if _is_video_file(audio_path):
+ print(f"Detected video file, extracting audio...")
+ temp_audio_path = _extract_audio_from_video(audio_path)
+ actual_audio_path = temp_audio_path
+ print(f"Audio extracted to: {temp_audio_path}")
+
+ try:
+ duration = _get_duration_seconds_ffprobe(actual_audio_path)
+ print(f"Audio duration: {duration:.2f} seconds")
+
+ with open(actual_audio_path, "rb") as f:
+ audio_bytes = f.read()
+
+ audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
+ print(f"Audio size: {len(audio_bytes)} bytes")
+
+ except Exception as e:
+ print(f"Error preparing audio: {e}")
+ return
+
+ # Build the request
+ url = f"{base_url}/v1/chat/completions"
+
+ show_keys = ["Start time", "End time", "Speaker ID", "Content"]
+ prompt_text = (
+ f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ + ", ".join(show_keys)
+ )
+
+ mime = _guess_mime_type(actual_audio_path)
+ data_url = f"data:{mime};base64,{audio_b64}"
+
+ payload = {
+ "model": "vibevoice",
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "audio_url", "audio_url": {"url": data_url}},
+ {"type": "text", "text": prompt_text}
+ ]
+ }
+ ],
+ "max_tokens": 32768,
+ "temperature": 0.0,
+ "stream": True,
+ "top_p": 1.0,
+ "repetition_penalty": 1.0,
+ }
+
+ print(f"\nSending request to {url} (Streaming Mode)...")
+ print(f"Prompt: {prompt_text}")
+ print("-" * 60)
+
+ t0 = time.time()
+ try:
+
+ response = requests.post(url, json=payload, stream=True, timeout=12000)
+
+ if response.status_code == 200:
+ print("Response received. Streaming content:\n")
+
+ printed = ""
+ for line in response.iter_lines():
+ if line:
+ decoded_line = line.decode('utf-8')
+
+ if decoded_line.startswith("data: "):
+ json_str = decoded_line[6:]
+ if json_str.strip() == "[DONE]":
+ print("\n\n[Finished]")
+ break
+ try:
+ data = json.loads(json_str)
+
+ delta = data['choices'][0]['delta']
+ content = delta.get('content', '')
+ if content:
+
+ # vLLM/OpenAI-compatible streams may emit either
+ # incremental deltas OR the full accumulated text.
+ # Only print the newly-added part to avoid repeats.
+ if content.startswith(printed):
+ to_print = content[len(printed):]
+ else:
+ to_print = content
+
+ if to_print:
+ print(to_print, end='', flush=True)
+ printed += to_print
+ except json.JSONDecodeError:
+ pass
+ else:
+ print(f"Error: {response.status_code}")
+ print(response.text)
+
+ except requests.exceptions.Timeout:
+ print("\nRequest timed out!")
+ except Exception as e:
+ print(f"\nError: {e}")
+
+ print(f"\n{'-'*60}")
+ print(f"Total time elapsed: {time.time() - t0:.2f}s")
+
+ # Cleanup temp audio file if created
+ if temp_audio_path and os.path.exists(temp_audio_path):
+ os.remove(temp_audio_path)
+ print(f"Cleaned up temp file: {temp_audio_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Test VibeVoice vLLM API with streaming output"
+ )
+ parser.add_argument(
+ "audio_path",
+ nargs="?",
+ default=None,
+ help="Path to audio file (wav, mp3, flac, etc.) or video file"
+ )
+ parser.add_argument(
+ "--url",
+ default="http://localhost:8000",
+ help="vLLM server base URL (default: http://localhost:8000)"
+ )
+
+ args = parser.parse_args()
+
+ # Find default audio if not specified
+ audio_path = args.audio_path
+ if audio_path is None:
+ # Try to find a sample audio in common locations
+ possible_paths = [
+ # In VibeVoice demo folder
+ os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
+ os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
+ # Relative to current directory
+ "demo/voices/en-Carter_man.wav",
+ "demo/voices/zh-Anchen_man_bgm.wav",
+ ]
+
+ for path in possible_paths:
+ if os.path.exists(path):
+ audio_path = path
+ break
+
+ if audio_path is None:
+ print("Error: No audio file specified and no default audio found.")
+ print("Usage: python test_api.py ")
+ sys.exit(1)
+
+ if not os.path.exists(audio_path):
+ print(f"Error: Audio file not found: {audio_path}")
+ sys.exit(1)
+
+ test_transcription(audio_path, args.url)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vllm_plugin/tools/generate_tokenizer_files.py b/vllm_plugin/tools/generate_tokenizer_files.py
new file mode 100644
index 0000000..62de55f
--- /dev/null
+++ b/vllm_plugin/tools/generate_tokenizer_files.py
@@ -0,0 +1,575 @@
+#!/usr/bin/env python3
+"""
+Standalone tool to generate VibeVoice tokenizer files from Qwen2 base.
+
+Downloads base tokenizer from Qwen2 and patches it with VibeVoice-specific
+audio tokens and chat template modifications.
+
+Usage:
+ python generate_tokenizer_files.py --output /path/to/output [--compare /path/to/reference]
+"""
+
+import argparse
+import json
+import os
+import shutil
+import tempfile
+from typing import Optional, Dict, Any
+
+
+# Qwen2.5 extended tokens (151646-151664)
+# These are NOT in base Qwen2-7B but ARE in Qwen2.5 and Qwen2-VL
+# VibeVoice uses some of these for speech: object_ref_start/end, box_start
+QWEN25_EXTENDED_TOKENS = {
+ "<|object_ref_start|>": 151646, # Used as speech_start_id
+ "<|object_ref_end|>": 151647, # Used as speech_end_id
+ "<|box_start|>": 151648, # Used as speech_pad_id
+ "<|box_end|>": 151649,
+ "<|quad_start|>": 151650,
+ "<|quad_end|>": 151651,
+ "<|vision_start|>": 151652,
+ "<|vision_end|>": 151653,
+ "<|vision_pad|>": 151654,
+ "<|image_pad|>": 151655,
+ "<|video_pad|>": 151656,
+ "": 151657,
+ "": 151658,
+ "<|fim_prefix|>": 151659,
+ "<|fim_middle|>": 151660,
+ "<|fim_suffix|>": 151661,
+ "<|fim_pad|>": 151662,
+ "<|repo_name|>": 151663,
+ "<|file_sep|>": 151664,
+}
+
+# VibeVoice-specific audio tokens (IDs follow Qwen2.5's last token 151664)
+VIBEVOICE_AUDIO_TOKENS = {
+ "<|AUDIO|>": 151665,
+ "<|audio_bos|>": 151666,
+ "<|audio_eos|>": 151667,
+}
+
+# All extended tokens (Qwen2.5 + VibeVoice)
+ALL_EXTENDED_TOKENS = {**QWEN25_EXTENDED_TOKENS, **VIBEVOICE_AUDIO_TOKENS}
+
+# Chat template with audio support
+# Key modification: handles part['type'] == 'audio' or 'audio_url' -> '<|AUDIO|>'
+VIBEVOICE_CHAT_TEMPLATE = """{%- if tools %}
+ {{- '<|im_start|>system\\n' }}
+ {%- if messages[0]['role'] == 'system' %}
+ {%- if messages[0]['content'] is string %}
+ {{- messages[0]['content'] }}
+ {%- else %}
+ {%- for part in messages[0]['content'] %}
+ {%- if part['type'] == 'text' %}
+ {{- part['text'] }}
+ {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
+ {{- '<|AUDIO|>' }}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {%- else %}
+ {{- 'You are a helpful assistant.' }}
+ {%- endif %}
+ {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }}
+ {%- for tool in tools %}
+ {{- "\\n" }}
+ {{- tool | tojson }}
+ {%- endfor %}
+ {{- "\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\"name\\": , \\"arguments\\": }\\n<|im_end|>\\n" }}
+{%- else %}
+ {%- if messages[0]['role'] == 'system' %}
+ {{- '<|im_start|>system\\n' }}
+ {%- if messages[0]['content'] is string %}
+ {{- messages[0]['content'] }}
+ {%- else %}
+ {%- for part in messages[0]['content'] %}
+ {%- if part['type'] == 'text' %}
+ {{- part['text'] }}
+ {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
+ {{- '<|AUDIO|>' }}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\\n' }}
+ {%- else %}
+ {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
+ {%- endif %}
+{%- endif %}
+{%- for message in messages %}
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
+ {{- '<|im_start|>' + message.role + '\\n' }}
+ {%- if message['content'] is string %}
+ {{- message['content'] }}
+ {%- else %}
+ {%- for part in message['content'] %}
+ {%- if part['type'] == 'text' %}
+ {{- part['text'] }}
+ {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
+ {{- '<|AUDIO|>' }}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {{- '<|im_end|>\\n' }}
+ {%- elif message.role == "assistant" %}
+ {{- '<|im_start|>' + message.role }}
+ {%- if message.content %}
+ {{- '\\n' + message.content }}
+ {%- endif %}
+ {%- for tool_call in message.tool_calls %}
+ {%- if tool_call.function is defined %}
+ {%- set tool_call = tool_call.function %}
+ {%- endif %}
+ {{- '\\n\\n{"name": "' }}
+ {{- tool_call.name }}
+ {{- '", "arguments": ' }}
+ {{- tool_call.arguments | tojson }}
+ {{- '}\\n' }}
+ {%- endfor %}
+ {{- '<|im_end|>\\n' }}
+ {%- elif message.role == "tool" %}
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
+ {{- '<|im_start|>user' }}
+ {%- endif %}
+ {{- '\\n\\n' }}
+ {{- message.content }}
+ {{- '\\n' }}
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
+ {{- '<|im_end|>\\n' }}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {{- '<|im_start|>assistant\\n' }}
+{%- endif %}"""
+
+
+# Default to Qwen2.5-7B which has all the extended tokens (151646-151664)
+DEFAULT_QWEN_MODEL = "Qwen/Qwen2.5-7B"
+
+
+def download_qwen_tokenizer_files(output_dir: str, qwen_model: str = DEFAULT_QWEN_MODEL) -> None:
+ """Download base tokenizer files from Qwen2.5 (which includes extended tokens)."""
+ try:
+ from huggingface_hub import hf_hub_download
+ except ImportError:
+ raise ImportError("Please install huggingface_hub: pip install huggingface_hub")
+
+ files_to_download = [
+ "vocab.json",
+ "merges.txt",
+ "tokenizer.json",
+ "tokenizer_config.json",
+ ]
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ for filename in files_to_download:
+ print(f"Downloading {filename} from {qwen_model}...")
+ hf_hub_download(
+ repo_id=qwen_model,
+ filename=filename,
+ local_dir=output_dir,
+ local_dir_use_symlinks=False,
+ )
+
+
+def patch_tokenizer_config(output_dir: str) -> None:
+ """
+ Patch tokenizer_config.json with VibeVoice audio tokens and chat template.
+ """
+ config_path = os.path.join(output_dir, "tokenizer_config.json")
+
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ # 1. Add ALL extended tokens to added_tokens_decoder (Qwen2.5 + VibeVoice audio)
+ if "added_tokens_decoder" not in config:
+ config["added_tokens_decoder"] = {}
+
+ for token, token_id in ALL_EXTENDED_TOKENS.items():
+ if str(token_id) not in config["added_tokens_decoder"]:
+ # Determine if token should be marked as "special"
+ # tool_call tokens are NOT special in Qwen2.5
+ is_special = token not in ("", "", "<|fim_prefix|>",
+ "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>",
+ "<|repo_name|>", "<|file_sep|>")
+ config["added_tokens_decoder"][str(token_id)] = {
+ "content": token,
+ "lstrip": False,
+ "normalized": False,
+ "rstrip": False,
+ "single_word": False,
+ "special": is_special,
+ }
+
+ # 2. Add audio tokens to additional_special_tokens
+ if "additional_special_tokens" not in config:
+ config["additional_special_tokens"] = []
+
+ for token in VIBEVOICE_AUDIO_TOKENS.keys():
+ if token not in config["additional_special_tokens"]:
+ config["additional_special_tokens"].append(token)
+
+ # 3. Modify chat_template to support audio
+ # Instead of replacing entirely, we patch the existing template to handle audio
+ chat_template = config.get("chat_template", "")
+ if chat_template and "<|AUDIO|>" not in chat_template:
+ # Insert audio handling into the template
+ # Find patterns like: {%- if part['type'] == 'text' %}
+ # Add after: {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}\n {{- '<|AUDIO|>' }}
+ audio_handler = """{%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
+ {{- '<|AUDIO|>' }}"""
+
+ # Pattern to find: after handling 'text' type, before endif
+ import re
+ # Look for the pattern where we handle text type and add audio handling
+ pattern = r"(\{\%- if part\['type'\] == 'text' \%\}\s*\n\s*\{\{- part\['text'\] \}\})"
+ replacement = r"\1\n " + audio_handler.replace("\n", r"\n")
+
+ modified_template = re.sub(pattern, replacement, chat_template)
+
+ if modified_template != chat_template:
+ config["chat_template"] = modified_template
+ print(" - Added audio support to existing chat_template")
+ else:
+ # Fallback: use our predefined template
+ print(" - Warning: Could not patch existing template, using predefined template")
+ config["chat_template"] = VIBEVOICE_CHAT_TEMPLATE
+
+ # 4. Update model_max_length for long audio support
+ config["model_max_length"] = 131072
+
+ # 5. Add add_bos_token if not present
+ if "add_bos_token" not in config:
+ config["add_bos_token"] = False
+
+ # Write back
+ with open(config_path, "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2, ensure_ascii=False)
+
+ print(f"Patched {config_path}")
+
+
+def patch_tokenizer_json(output_dir: str) -> None:
+ """
+ Patch tokenizer.json with VibeVoice audio tokens.
+ """
+ tokenizer_path = os.path.join(output_dir, "tokenizer.json")
+
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
+ tokenizer = json.load(f)
+
+ # Find existing token IDs to avoid duplicates
+ existing_ids = set()
+ if "added_tokens" in tokenizer:
+ for token_entry in tokenizer["added_tokens"]:
+ existing_ids.add(token_entry.get("id"))
+
+ # Add ALL extended tokens (Qwen2.5 + VibeVoice audio)
+ for token, token_id in ALL_EXTENDED_TOKENS.items():
+ if token_id not in existing_ids:
+ # Determine if token should be marked as "special"
+ is_special = token not in ("", "", "<|fim_prefix|>",
+ "<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>",
+ "<|repo_name|>", "<|file_sep|>")
+ tokenizer["added_tokens"].append({
+ "id": token_id,
+ "content": token,
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": is_special,
+ })
+
+ # Write back
+ with open(tokenizer_path, "w", encoding="utf-8") as f:
+ json.dump(tokenizer, f, indent=2, ensure_ascii=False)
+
+ print(f"Patched {tokenizer_path}")
+
+
+def generate_added_tokens_json(output_dir: str) -> None:
+ """
+ Generate added_tokens.json from tokenizer_config.json.
+ """
+ config_path = os.path.join(output_dir, "tokenizer_config.json")
+
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ added_tokens = {}
+ for token_id, token_info in config.get("added_tokens_decoder", {}).items():
+ content = token_info.get("content")
+ if content:
+ added_tokens[content] = int(token_id)
+
+ output_path = os.path.join(output_dir, "added_tokens.json")
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(added_tokens, f, indent=2, ensure_ascii=False)
+
+ print(f"Generated {output_path}")
+
+
+def generate_special_tokens_map_json(output_dir: str) -> None:
+ """
+ Generate special_tokens_map.json with VibeVoice special tokens.
+ """
+ # Build the special tokens map
+ special_tokens_map = {
+ "additional_special_tokens": [],
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|endoftext|>",
+ "unk_token": "<|endoftext|>",
+ }
+
+ # Add audio tokens as additional_special_tokens
+ for token in VIBEVOICE_AUDIO_TOKENS.keys():
+ special_tokens_map["additional_special_tokens"].append({
+ "content": token,
+ "lstrip": False,
+ "normalized": False,
+ "rstrip": False,
+ "single_word": False,
+ })
+
+ # Add some commonly used special tokens
+ common_special = ["<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>"]
+ for token in common_special:
+ special_tokens_map["additional_special_tokens"].append({
+ "content": token,
+ "lstrip": False,
+ "normalized": False,
+ "rstrip": False,
+ "single_word": False,
+ })
+
+ output_path = os.path.join(output_dir, "special_tokens_map.json")
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
+
+ print(f"Generated {output_path}")
+
+
+def generate_vibevoice_tokenizer_files(output_dir: str, qwen_model: str = DEFAULT_QWEN_MODEL) -> None:
+ """
+ Generate all 6 VibeVoice tokenizer files.
+
+ Files generated:
+ 1. vocab.json - from Qwen2.5 (unchanged)
+ 2. merges.txt - from Qwen2.5 (unchanged)
+ 3. tokenizer.json - from Qwen2.5 + audio tokens
+ 4. tokenizer_config.json - from Qwen2.5 + audio tokens + chat_template
+ 5. added_tokens.json - generated from tokenizer_config.json
+ 6. special_tokens_map.json - generated with VibeVoice tokens
+ """
+ print(f"=== Generating VibeVoice tokenizer files to {output_dir} ===\n")
+
+ # Step 1: Download base files from Qwen2
+ download_qwen_tokenizer_files(output_dir, qwen_model)
+
+ # Step 2: Patch tokenizer_config.json
+ patch_tokenizer_config(output_dir)
+
+ # Step 3: Patch tokenizer.json
+ patch_tokenizer_json(output_dir)
+
+ # Step 4: Generate added_tokens.json
+ generate_added_tokens_json(output_dir)
+
+ # Step 5: Generate special_tokens_map.json
+ generate_special_tokens_map_json(output_dir)
+
+ print(f"\nā
All 6 tokenizer files generated in {output_dir}")
+
+
+def compare_json_files(file1: str, file2: str, name: str) -> Dict[str, Any]:
+ """Compare two JSON files and return differences."""
+ result = {
+ "name": name,
+ "identical": False,
+ "differences": [],
+ }
+
+ if not os.path.exists(file1):
+ result["differences"].append(f"File 1 not found: {file1}")
+ return result
+
+ if not os.path.exists(file2):
+ result["differences"].append(f"File 2 not found: {file2}")
+ return result
+
+ with open(file1, "r", encoding="utf-8") as f:
+ data1 = json.load(f)
+
+ with open(file2, "r", encoding="utf-8") as f:
+ data2 = json.load(f)
+
+ if data1 == data2:
+ result["identical"] = True
+ return result
+
+ # Find specific differences
+ def find_diff(d1, d2, path=""):
+ diffs = []
+ if isinstance(d1, dict) and isinstance(d2, dict):
+ all_keys = set(d1.keys()) | set(d2.keys())
+ for k in all_keys:
+ new_path = f"{path}.{k}" if path else k
+ if k not in d1:
+ diffs.append(f"Missing in generated: {new_path}")
+ elif k not in d2:
+ diffs.append(f"Extra in generated: {new_path}")
+ else:
+ diffs.extend(find_diff(d1[k], d2[k], new_path))
+ elif isinstance(d1, list) and isinstance(d2, list):
+ if len(d1) != len(d2):
+ diffs.append(f"{path}: list length differs ({len(d1)} vs {len(d2)})")
+ # For lists, just check if they're equal (detailed diff is complex)
+ if d1 != d2:
+ diffs.append(f"{path}: list content differs")
+ elif d1 != d2:
+ # Truncate long values for readability
+ v1 = str(d1)[:100] + "..." if len(str(d1)) > 100 else str(d1)
+ v2 = str(d2)[:100] + "..." if len(str(d2)) > 100 else str(d2)
+ diffs.append(f"{path}: '{v1}' vs '{v2}'")
+ return diffs
+
+ result["differences"] = find_diff(data1, data2)
+ return result
+
+
+def compare_text_files(file1: str, file2: str, name: str) -> Dict[str, Any]:
+ """Compare two text files."""
+ result = {
+ "name": name,
+ "identical": False,
+ "differences": [],
+ }
+
+ if not os.path.exists(file1):
+ result["differences"].append(f"File 1 not found: {file1}")
+ return result
+
+ if not os.path.exists(file2):
+ result["differences"].append(f"File 2 not found: {file2}")
+ return result
+
+ with open(file1, "r", encoding="utf-8") as f:
+ content1 = f.read()
+
+ with open(file2, "r", encoding="utf-8") as f:
+ content2 = f.read()
+
+ if content1 == content2:
+ result["identical"] = True
+ else:
+ lines1 = content1.splitlines()
+ lines2 = content2.splitlines()
+ result["differences"].append(f"Line count: {len(lines1)} vs {len(lines2)}")
+
+ # Find first difference
+ for i, (l1, l2) in enumerate(zip(lines1, lines2)):
+ if l1 != l2:
+ result["differences"].append(f"First diff at line {i+1}")
+ break
+
+ return result
+
+
+def compare_with_reference(generated_dir: str, reference_dir: str) -> None:
+ """Compare generated files with reference files."""
+ print(f"\n=== Comparing generated files with reference ===")
+ print(f"Generated: {generated_dir}")
+ print(f"Reference: {reference_dir}\n")
+
+ files_to_compare = [
+ ("vocab.json", "json"),
+ ("merges.txt", "text"),
+ ("tokenizer.json", "json"),
+ ("tokenizer_config.json", "json"),
+ ("added_tokens.json", "json"),
+ ("special_tokens_map.json", "json"),
+ ]
+
+ all_identical = True
+
+ for filename, file_type in files_to_compare:
+ gen_file = os.path.join(generated_dir, filename)
+ ref_file = os.path.join(reference_dir, filename)
+
+ if file_type == "json":
+ result = compare_json_files(gen_file, ref_file, filename)
+ else:
+ result = compare_text_files(gen_file, ref_file, filename)
+
+ if result["identical"]:
+ print(f"ā
{filename}: IDENTICAL")
+ else:
+ print(f"ā {filename}: DIFFERENT")
+ for diff in result["differences"][:5]: # Show first 5 differences
+ print(f" - {diff}")
+ if len(result["differences"]) > 5:
+ print(f" ... and {len(result['differences']) - 5} more differences")
+ all_identical = False
+
+ print()
+ if all_identical:
+ print("š All files are identical!")
+ else:
+ print("ā ļø Some files have differences. See details above.")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Generate VibeVoice tokenizer files from Qwen2 base"
+ )
+ parser.add_argument(
+ "--output", "-o",
+ type=str,
+ default=None,
+ help="Output directory for generated files (default: temp directory)"
+ )
+ parser.add_argument(
+ "--compare", "-c",
+ type=str,
+ default=None,
+ help="Reference directory to compare generated files against"
+ )
+ parser.add_argument(
+ "--qwen-model",
+ type=str,
+ default=DEFAULT_QWEN_MODEL,
+ help=f"Qwen model to download base tokenizer from (default: {DEFAULT_QWEN_MODEL})"
+ )
+
+ args = parser.parse_args()
+
+ # Determine output directory
+ if args.output:
+ output_dir = args.output
+ cleanup = False
+ else:
+ output_dir = tempfile.mkdtemp(prefix="vibevoice_tokenizer_")
+ cleanup = not args.compare # Only cleanup if not comparing
+
+ try:
+ # Generate files
+ generate_vibevoice_tokenizer_files(output_dir, args.qwen_model)
+
+ # Compare if requested
+ if args.compare:
+ compare_with_reference(output_dir, args.compare)
+
+ if not args.output:
+ print(f"\nGenerated files are in: {output_dir}")
+
+ finally:
+ if cleanup and not args.output:
+ print(f"\nCleaning up temporary directory: {output_dir}")
+ shutil.rmtree(output_dir, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ main()