Add vLLM plugin support for high-performance ASR serving
This commit is contained in:
@@ -0,0 +1,112 @@
|
||||
# VibeVoice vLLM ASR Deployment
|
||||
|
||||
<a href="https://huggingface.co/microsoft/VibeVoice-ASR"><img alt="Huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-VibeVoice--ASR-blue"></a>
|
||||
|
||||
Deploy VibeVoice ASR model as a high-performance API service using [vLLM](https://github.com/vllm-project/vllm). This plugin provides OpenAI-compatible API endpoints for speech-to-text transcription with streaming support.
|
||||
|
||||
## 🔥 Key Features
|
||||
|
||||
- **🚀 High-Performance Serving**: Optimized for high-throughput ASR inference with vLLM's continuous batching
|
||||
- **📡 OpenAI-Compatible API**: Standard `/v1/chat/completions` endpoint with streaming support
|
||||
- **🎵 Long Audio Support**: Process up to 60+ minutes of audio in a single request
|
||||
- **🔌 Plugin Architecture**: No vLLM source code modification required - just install and run
|
||||
|
||||
## 🛠️ Installation
|
||||
|
||||
Using Official vLLM Docker Image (Recommended)
|
||||
|
||||
```bash
|
||||
# 1. Pull the official vLLM image
|
||||
docker pull vllm/vllm-openai:latest
|
||||
|
||||
# 2. Start an interactive container
|
||||
docker run -it --gpus all --name vibevoice-vllm \
|
||||
--ipc=host \
|
||||
-p 8000:8000 \
|
||||
-e VIBEVOICE_FFMPEG_MAX_CONCURRENCY=64 \
|
||||
-e PYTORCH_ALLOC_CONF=expandable_segments:True \
|
||||
-v /path/to/models:/models \
|
||||
-v /path/to/VibeVoice:/app \
|
||||
-w /app \
|
||||
--entrypoint bash \
|
||||
vllm/vllm-openai:latest
|
||||
|
||||
# 3. Inside container: Install system dependencies
|
||||
bash vllm_plugin/scripts/install_deps.sh
|
||||
|
||||
# 4. Inside container: Install VibeVoice with vLLM support
|
||||
pip install -e .[vllm]
|
||||
|
||||
# 5. Inside container: (Optional) Generate tokenizer files if needed
|
||||
python3 -m vllm_plugin.tools.generate_tokenizer_files --output /models/your_model
|
||||
|
||||
# 6. Inside container: Start vLLM server
|
||||
vllm serve /models/your_model \
|
||||
--served-model-name vibevoice \
|
||||
--trust-remote-code \
|
||||
--dtype bfloat16 \
|
||||
--max-num-seqs 64 \
|
||||
--max-model-len 65536 \
|
||||
--max-num-batched-tokens 32768 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--enforce-eager \
|
||||
--no-enable-prefix-caching \
|
||||
--enable-chunked-prefill \
|
||||
--chat-template-content-format openai \
|
||||
--tensor-parallel-size 1 \
|
||||
--allowed-local-media-path /app \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
> **Note**: This approach allows you to switch models, adjust parameters, and debug issues without rebuilding the container.
|
||||
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Test the API
|
||||
|
||||
Once the vLLM server is running, test it with the provided script:
|
||||
|
||||
```bash
|
||||
# Run the test script (inside container)
|
||||
python3 vllm_plugin/tests/test_api.py /path/to/audio.wav
|
||||
```
|
||||
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Description | Default |
|
||||
|----------|-------------|---------|
|
||||
| `VIBEVOICE_FFMPEG_MAX_CONCURRENCY` | Maximum FFmpeg processes for audio decoding | `64` |
|
||||
| `PYTORCH_CUDA_ALLOC_CONF` | CUDA memory allocator config | `expandable_segments:True` |
|
||||
|
||||
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
1. **GPU Memory**: Use `--gpu-memory-utilization 0.9` for maximum throughput if you have dedicated GPU
|
||||
2. **Batch Size**: Increase `--max-num-seqs` for higher concurrency (requires more GPU memory)
|
||||
3. **FFmpeg Concurrency**: Tune `VIBEVOICE_FFMPEG_MAX_CONCURRENCY` based on CPU cores
|
||||
|
||||
## 🚨 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **"CUDA out of memory"**
|
||||
- Reduce `--gpu-memory-utilization`
|
||||
- Reduce `--max-num-seqs`
|
||||
- Use smaller `--max-model-len`
|
||||
|
||||
2. **"Audio decoding failed"**
|
||||
- Ensure FFmpeg is installed: `ffmpeg -version`
|
||||
- Check audio file format is supported
|
||||
|
||||
3. **"Model not found"**
|
||||
- Ensure model path contains `config.json` and model weights
|
||||
- Generate tokenizer files if missing
|
||||
|
||||
4. **"Plugin not loaded"**
|
||||
- Verify installation: `pip show vibevoice`
|
||||
- Check entry point: `pip show -f vibevoice | grep entry`
|
||||
|
||||
|
||||
@@ -45,9 +45,20 @@ asr = [
|
||||
"pydub" # for visualization
|
||||
]
|
||||
|
||||
vllm = [
|
||||
"transformers>=4.51.3",
|
||||
"fastapi",
|
||||
"uvicorn[standard]",
|
||||
"requests",
|
||||
]
|
||||
|
||||
[project.entry-points."vllm.general_plugins"]
|
||||
vibevoice = "vllm_plugin:register_vibevoice"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/microsoft/VibeVoice"
|
||||
"Bug Tracker" = "https://github.com/microsoft/VibeVoice/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["vibevoice*", "vllm_plugin*"]
|
||||
|
||||
@@ -240,6 +240,26 @@ class VibeVoiceConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
"""
|
||||
Returns the text config for this model.
|
||||
|
||||
vLLM uses this method to get the text configuration from multimodal models.
|
||||
This allows vLLM to correctly determine hidden_size, num_attention_heads,
|
||||
and other properties needed for memory profiling and model execution.
|
||||
|
||||
For VibeVoice, the "text config" is the decoder_config (Qwen2Config).
|
||||
|
||||
Args:
|
||||
decoder: If True, return the decoder config (for encoder-decoder models).
|
||||
For VibeVoice, this is always the decoder_config.
|
||||
|
||||
Returns:
|
||||
The decoder configuration (Qwen2Config) which contains hidden_size, etc.
|
||||
"""
|
||||
return self.decoder_config
|
||||
|
||||
|
||||
class VibeVoiceASRConfig(PretrainedConfig):
|
||||
model_type = "vibevoice"
|
||||
is_composition = True
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
from subprocess import run
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
@@ -57,6 +60,7 @@ def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-loglevel", "error",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
@@ -64,14 +68,84 @@ def load_audio_use_ffmpeg(file: str, resample: bool = False, target_sr: int = 24
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr_to_use),
|
||||
"-"
|
||||
"-",
|
||||
]
|
||||
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
out = _run_ffmpeg(cmd).stdout
|
||||
audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
return audio_data, sr_to_use
|
||||
|
||||
|
||||
def _get_ffmpeg_max_concurrency() -> int:
|
||||
"""Get the maximum FFmpeg concurrency from environment variable."""
|
||||
v = os.getenv("VIBEVOICE_FFMPEG_MAX_CONCURRENCY", "")
|
||||
try:
|
||||
n = int(v) if v.strip() else 0
|
||||
except Exception:
|
||||
n = 0
|
||||
# 0/negative means no explicit limit.
|
||||
return n
|
||||
|
||||
|
||||
_FFMPEG_MAX_CONCURRENCY = _get_ffmpeg_max_concurrency()
|
||||
_FFMPEG_SEM = threading.Semaphore(_FFMPEG_MAX_CONCURRENCY) if _FFMPEG_MAX_CONCURRENCY > 0 else None
|
||||
|
||||
|
||||
def _run_ffmpeg(cmd: list, *, stdin_bytes: bytes = None):
|
||||
"""Run ffmpeg with optional global concurrency limiting.
|
||||
|
||||
This is important for vLLM multi-request concurrency: spawning too many
|
||||
ffmpeg processes can saturate CPU/IO and cause request failures/timeouts.
|
||||
"""
|
||||
if _FFMPEG_SEM is None:
|
||||
return run(cmd, capture_output=True, check=True, input=stdin_bytes)
|
||||
with _FFMPEG_SEM:
|
||||
return run(cmd, capture_output=True, check=True, input=stdin_bytes)
|
||||
|
||||
|
||||
def load_audio_bytes_use_ffmpeg(data: bytes, *, resample: bool = False, target_sr: int = 24000):
|
||||
"""Decode audio bytes via ffmpeg stdin pipe.
|
||||
|
||||
Compared to writing bytes to a temp file, this avoids filesystem IO and
|
||||
reduces contention under high request concurrency.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: bytes
|
||||
The audio data bytes
|
||||
resample: bool
|
||||
Whether to resample the audio (must be True)
|
||||
target_sr: int
|
||||
The target sample rate if resampling is requested
|
||||
|
||||
Returns
|
||||
-------
|
||||
A tuple containing:
|
||||
- A NumPy array with the audio waveform in float32 dtype
|
||||
- The sample rate
|
||||
"""
|
||||
if not resample:
|
||||
# For stdin bytes, we don't have a cheap/robust way to probe original sr.
|
||||
# Keep behavior explicit.
|
||||
raise ValueError("load_audio_bytes_use_ffmpeg requires resample=True")
|
||||
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-loglevel", "error",
|
||||
"-threads", "0",
|
||||
"-i", "pipe:0",
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(target_sr),
|
||||
"-",
|
||||
]
|
||||
out = _run_ffmpeg(cmd, stdin_bytes=data).stdout
|
||||
audio_data = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
return audio_data, target_sr
|
||||
|
||||
|
||||
class AudioNormalizer:
|
||||
"""
|
||||
Audio normalization class for VibeVoice tokenizer.
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""VibeVoice vLLM Plugin - Registers VibeVoice model for vLLM inference.
|
||||
|
||||
This plugin enables VibeVoice ASR models to be loaded and served through vLLM.
|
||||
It registers the model architecture, configuration, tokenizer, and processor
|
||||
with their respective registries.
|
||||
|
||||
The plugin is automatically loaded by vLLM via the 'vllm.general_plugins'
|
||||
entry point defined in pyproject.toml.
|
||||
"""
|
||||
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from transformers import AutoConfig, AutoTokenizer, Qwen2Tokenizer, AutoProcessor, Qwen2AudioProcessor
|
||||
|
||||
from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
||||
|
||||
from .model import VibeVoiceForCausalLM
|
||||
from .inputs import vibevoice_audio_input_mapper
|
||||
|
||||
|
||||
def register_vibevoice():
|
||||
"""Register VibeVoice model with vLLM and transformers.
|
||||
|
||||
This function is called automatically by vLLM through the entry point
|
||||
mechanism. It registers:
|
||||
- VibeVoiceConfig with AutoConfig
|
||||
- VibeVoiceASRTextTokenizerFast with AutoTokenizer (for ASR)
|
||||
- Qwen2AudioProcessor with AutoProcessor
|
||||
- VibeVoiceForCausalLM with vLLM ModelRegistry
|
||||
"""
|
||||
# Register the configuration class with transformers
|
||||
AutoConfig.register("vibevoice", VibeVoiceConfig)
|
||||
|
||||
# Register the tokenizer with transformers.
|
||||
# IMPORTANT (ASR): Align with the PyTorch ASR path.
|
||||
# VibeVoiceASRTextTokenizerFast maps:
|
||||
# speech_start_id -> <|object_ref_start|>
|
||||
# speech_pad_id -> <|box_start|>
|
||||
# speech_end_id -> <|object_ref_end|>
|
||||
# This significantly affects ASR quality even when requests succeed.
|
||||
try:
|
||||
AutoTokenizer.register(
|
||||
VibeVoiceConfig,
|
||||
slow_tokenizer_class=Qwen2Tokenizer,
|
||||
fast_tokenizer_class=VibeVoiceASRTextTokenizerFast,
|
||||
)
|
||||
except Exception:
|
||||
pass # May already be registered
|
||||
|
||||
# Register the processor with transformers
|
||||
try:
|
||||
AutoProcessor.register(VibeVoiceConfig, processor_class=Qwen2AudioProcessor)
|
||||
except Exception:
|
||||
pass # May already be registered
|
||||
|
||||
# Register the model class with the architecture name "VibeVoice"
|
||||
# This name must match the "architectures" list in config.json
|
||||
ModelRegistry.register_model("VibeVoice", VibeVoiceForCausalLM)
|
||||
ModelRegistry.register_model("VibeVoiceForASRTraining", VibeVoiceForCausalLM)
|
||||
|
||||
|
||||
# Note: This function is called via vllm.general_plugins entry point
|
||||
# defined in pyproject.toml, ensuring it runs in all vLLM processes
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Audio input mapper for vLLM multimodal pipeline.
|
||||
|
||||
This module handles audio data loading and preprocessing for VibeVoice ASR inference.
|
||||
It converts various audio input formats (path, bytes, numpy array) into tensors
|
||||
that can be processed by the VibeVoice model.
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union, List
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
|
||||
|
||||
|
||||
def load_audio(audio_path: str, target_sr: int = 24000) -> np.ndarray:
|
||||
"""Load and normalize audio from file path.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
target_sr: Target sample rate (default 24kHz for VibeVoice)
|
||||
|
||||
Returns:
|
||||
Normalized audio waveform as numpy array
|
||||
"""
|
||||
# Load with FFmpeg (handles various formats)
|
||||
audio, sr = load_audio_use_ffmpeg(audio_path, resample=True, target_sr=target_sr)
|
||||
|
||||
# Normalize audio
|
||||
normalizer = AudioNormalizer()
|
||||
audio = normalizer(audio)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def vibevoice_audio_input_mapper(ctx, data: Union[str, bytes, np.ndarray, List[str]]) -> MultiModalInputs:
|
||||
"""Map audio input data to vLLM MultiModalInputs format.
|
||||
|
||||
This function is registered as the input mapper for VibeVoice audio processing.
|
||||
It handles multiple input formats and converts them to normalized tensors.
|
||||
|
||||
Args:
|
||||
ctx: vLLM context (unused)
|
||||
data: Audio data in one of these formats:
|
||||
- str: Path to audio file
|
||||
- bytes: Raw audio bytes (any format FFmpeg supports)
|
||||
- np.ndarray: Pre-loaded audio waveform
|
||||
- List[str]: List of audio paths (only first is used)
|
||||
|
||||
Returns:
|
||||
MultiModalInputs containing:
|
||||
- audio: Audio tensor (float32)
|
||||
- audio_length: Length of audio in samples
|
||||
"""
|
||||
# Handle list input (take first item)
|
||||
if isinstance(data, list):
|
||||
data = data[0]
|
||||
|
||||
audio_waveform = None
|
||||
|
||||
if isinstance(data, str):
|
||||
# Load from file path
|
||||
audio_waveform = load_audio(data)
|
||||
|
||||
elif isinstance(data, bytes):
|
||||
# Decode bytes directly via ffmpeg stdin pipe to avoid temp-file IO
|
||||
audio_waveform, _sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
|
||||
normalizer = AudioNormalizer()
|
||||
audio_waveform = normalizer(audio_waveform)
|
||||
|
||||
elif isinstance(data, np.ndarray):
|
||||
# Already loaded numpy array
|
||||
audio_waveform = data
|
||||
else:
|
||||
raise ValueError(f"Unsupported audio data type: {type(data)}")
|
||||
|
||||
# Convert to tensor
|
||||
audio_tensor = torch.from_numpy(audio_waveform).float()
|
||||
audio_length = audio_tensor.shape[0]
|
||||
|
||||
return MultiModalInputs({
|
||||
"audio": audio_tensor,
|
||||
"audio_length": audio_length
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
# Install system dependencies for VibeVoice vLLM plugin
|
||||
# Run this script inside the vLLM container before using the plugin
|
||||
|
||||
set -e
|
||||
|
||||
echo "Installing system dependencies for VibeVoice vLLM plugin..."
|
||||
|
||||
# Update package list
|
||||
apt-get update
|
||||
|
||||
# Install FFmpeg and audio processing libraries
|
||||
apt-get install -y \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
git
|
||||
|
||||
echo "✅ System dependencies installed successfully!"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo " 1. Install VibeVoice: pip install -e .[vllm]"
|
||||
echo " 2. Generate tokenizer files (if needed): python -m vllm_plugin.tools.generate_tokenizer_files -o /path/to/model"
|
||||
echo " 3. Start vLLM server: vllm serve <model_path> --trust-remote-code --enforce-eager --no-enable-prefix-caching"
|
||||
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test VibeVoice vLLM API with Streaming (Real-time output).
|
||||
|
||||
Usage:
|
||||
python test_api.py [audio_path] [--url URL]
|
||||
|
||||
Examples:
|
||||
python test_api.py # Use default audio
|
||||
python test_api.py /path/to/audio.wav # Specify audio file
|
||||
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
|
||||
def _guess_mime_type(path: str) -> str:
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext == ".wav":
|
||||
return "audio/wav"
|
||||
if ext in (".mp3",):
|
||||
return "audio/mpeg"
|
||||
if ext in (".m4a",):
|
||||
return "audio/mp4"
|
||||
if ext in (".mp4", ".m4v", ".mov", ".webm"):
|
||||
return "video/mp4"
|
||||
if ext in (".flac",):
|
||||
return "audio/flac"
|
||||
if ext in (".ogg", ".opus"):
|
||||
return "audio/ogg"
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def _get_duration_seconds_ffprobe(path: str) -> float:
|
||||
"""Get audio duration using ffprobe."""
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
path,
|
||||
]
|
||||
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
|
||||
return float(out)
|
||||
|
||||
|
||||
def _extract_audio_from_video(video_path: str) -> str:
|
||||
"""
|
||||
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
|
||||
Returns the path to the extracted audio file.
|
||||
"""
|
||||
import tempfile
|
||||
# Create temp file with .mp3 extension
|
||||
fd, audio_path = tempfile.mkstemp(suffix=".mp3")
|
||||
os.close(fd)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-i", video_path,
|
||||
"-vn", # No video
|
||||
"-acodec", "libmp3lame",
|
||||
"-q:a", "2", # High quality
|
||||
audio_path
|
||||
]
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
return audio_path
|
||||
|
||||
|
||||
def _is_video_file(path: str) -> bool:
|
||||
"""Check if the file is a video file that needs audio extraction."""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
||||
|
||||
|
||||
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"):
|
||||
"""Test ASR transcription with streaming output."""
|
||||
|
||||
print(f"Loading audio from: {audio_path}")
|
||||
|
||||
# Handle video files: extract audio first
|
||||
temp_audio_path = None
|
||||
actual_audio_path = audio_path
|
||||
if _is_video_file(audio_path):
|
||||
print(f"Detected video file, extracting audio...")
|
||||
temp_audio_path = _extract_audio_from_video(audio_path)
|
||||
actual_audio_path = temp_audio_path
|
||||
print(f"Audio extracted to: {temp_audio_path}")
|
||||
|
||||
try:
|
||||
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
||||
print(f"Audio duration: {duration:.2f} seconds")
|
||||
|
||||
with open(actual_audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
print(f"Audio size: {len(audio_bytes)} bytes")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error preparing audio: {e}")
|
||||
return
|
||||
|
||||
# Build the request
|
||||
url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
||||
prompt_text = (
|
||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||
+ ", ".join(show_keys)
|
||||
)
|
||||
|
||||
mime = _guess_mime_type(actual_audio_path)
|
||||
data_url = f"data:{mime};base64,{audio_b64}"
|
||||
|
||||
payload = {
|
||||
"model": "vibevoice",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": data_url}},
|
||||
{"type": "text", "text": prompt_text}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.0,
|
||||
"stream": True,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1.0,
|
||||
}
|
||||
|
||||
print(f"\nSending request to {url} (Streaming Mode)...")
|
||||
print(f"Prompt: {prompt_text}")
|
||||
print("-" * 60)
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
|
||||
response = requests.post(url, json=payload, stream=True, timeout=12000)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("Response received. Streaming content:\n")
|
||||
|
||||
printed = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
|
||||
if decoded_line.startswith("data: "):
|
||||
json_str = decoded_line[6:]
|
||||
if json_str.strip() == "[DONE]":
|
||||
print("\n\n[Finished]")
|
||||
break
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
|
||||
delta = data['choices'][0]['delta']
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
|
||||
# vLLM/OpenAI-compatible streams may emit either
|
||||
# incremental deltas OR the full accumulated text.
|
||||
# Only print the newly-added part to avoid repeats.
|
||||
if content.startswith(printed):
|
||||
to_print = content[len(printed):]
|
||||
else:
|
||||
to_print = content
|
||||
|
||||
if to_print:
|
||||
print(to_print, end='', flush=True)
|
||||
printed += to_print
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("\nRequest timed out!")
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
|
||||
print(f"\n{'-'*60}")
|
||||
print(f"Total time elapsed: {time.time() - t0:.2f}s")
|
||||
|
||||
# Cleanup temp audio file if created
|
||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||
os.remove(temp_audio_path)
|
||||
print(f"Cleaned up temp file: {temp_audio_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test VibeVoice vLLM API with streaming output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Path to audio file (wav, mp3, flac, etc.) or video file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
default="http://localhost:8000",
|
||||
help="vLLM server base URL (default: http://localhost:8000)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find default audio if not specified
|
||||
audio_path = args.audio_path
|
||||
if audio_path is None:
|
||||
# Try to find a sample audio in common locations
|
||||
possible_paths = [
|
||||
# In VibeVoice demo folder
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
|
||||
# Relative to current directory
|
||||
"demo/voices/en-Carter_man.wav",
|
||||
"demo/voices/zh-Anchen_man_bgm.wav",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
audio_path = path
|
||||
break
|
||||
|
||||
if audio_path is None:
|
||||
print("Error: No audio file specified and no default audio found.")
|
||||
print("Usage: python test_api.py <audio_path>")
|
||||
sys.exit(1)
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
print(f"Error: Audio file not found: {audio_path}")
|
||||
sys.exit(1)
|
||||
|
||||
test_transcription(audio_path, args.url)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,575 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone tool to generate VibeVoice tokenizer files from Qwen2 base.
|
||||
|
||||
Downloads base tokenizer from Qwen2 and patches it with VibeVoice-specific
|
||||
audio tokens and chat template modifications.
|
||||
|
||||
Usage:
|
||||
python generate_tokenizer_files.py --output /path/to/output [--compare /path/to/reference]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
# Qwen2.5 extended tokens (151646-151664)
|
||||
# These are NOT in base Qwen2-7B but ARE in Qwen2.5 and Qwen2-VL
|
||||
# VibeVoice uses some of these for speech: object_ref_start/end, box_start
|
||||
QWEN25_EXTENDED_TOKENS = {
|
||||
"<|object_ref_start|>": 151646, # Used as speech_start_id
|
||||
"<|object_ref_end|>": 151647, # Used as speech_end_id
|
||||
"<|box_start|>": 151648, # Used as speech_pad_id
|
||||
"<|box_end|>": 151649,
|
||||
"<|quad_start|>": 151650,
|
||||
"<|quad_end|>": 151651,
|
||||
"<|vision_start|>": 151652,
|
||||
"<|vision_end|>": 151653,
|
||||
"<|vision_pad|>": 151654,
|
||||
"<|image_pad|>": 151655,
|
||||
"<|video_pad|>": 151656,
|
||||
"<tool_call>": 151657,
|
||||
"</tool_call>": 151658,
|
||||
"<|fim_prefix|>": 151659,
|
||||
"<|fim_middle|>": 151660,
|
||||
"<|fim_suffix|>": 151661,
|
||||
"<|fim_pad|>": 151662,
|
||||
"<|repo_name|>": 151663,
|
||||
"<|file_sep|>": 151664,
|
||||
}
|
||||
|
||||
# VibeVoice-specific audio tokens (IDs follow Qwen2.5's last token 151664)
|
||||
VIBEVOICE_AUDIO_TOKENS = {
|
||||
"<|AUDIO|>": 151665,
|
||||
"<|audio_bos|>": 151666,
|
||||
"<|audio_eos|>": 151667,
|
||||
}
|
||||
|
||||
# All extended tokens (Qwen2.5 + VibeVoice)
|
||||
ALL_EXTENDED_TOKENS = {**QWEN25_EXTENDED_TOKENS, **VIBEVOICE_AUDIO_TOKENS}
|
||||
|
||||
# Chat template with audio support
|
||||
# Key modification: handles part['type'] == 'audio' or 'audio_url' -> '<|AUDIO|>'
|
||||
VIBEVOICE_CHAT_TEMPLATE = """{%- if tools %}
|
||||
{{- '<|im_start|>system\\n' }}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{{- messages[0]['content'] }}
|
||||
{%- else %}
|
||||
{%- for part in messages[0]['content'] %}
|
||||
{%- if part['type'] == 'text' %}
|
||||
{{- part['text'] }}
|
||||
{%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
|
||||
{{- '<|AUDIO|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- 'You are a helpful assistant.' }}
|
||||
{%- endif %}
|
||||
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- '<|im_start|>system\\n' }}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{{- messages[0]['content'] }}
|
||||
{%- else %}
|
||||
{%- for part in messages[0]['content'] %}
|
||||
{%- if part['type'] == 'text' %}
|
||||
{{- part['text'] }}
|
||||
{%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
|
||||
{{- '<|AUDIO|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\\n' }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
||||
{{- '<|im_start|>' + message.role + '\\n' }}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
{%- for part in message['content'] %}
|
||||
{%- if part['type'] == 'text' %}
|
||||
{{- part['text'] }}
|
||||
{%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
|
||||
{{- '<|AUDIO|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- if message.content %}
|
||||
{{- '\\n' + message.content }}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\\n<tool_call>\\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{{- '}\\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\\n<tool_response>\\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\\n' }}
|
||||
{%- endif %}"""
|
||||
|
||||
|
||||
# Default to Qwen2.5-7B which has all the extended tokens (151646-151664)
|
||||
DEFAULT_QWEN_MODEL = "Qwen/Qwen2.5-7B"
|
||||
|
||||
|
||||
def download_qwen_tokenizer_files(output_dir: str, qwen_model: str = DEFAULT_QWEN_MODEL) -> None:
|
||||
"""Download base tokenizer files from Qwen2.5 (which includes extended tokens)."""
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError:
|
||||
raise ImportError("Please install huggingface_hub: pip install huggingface_hub")
|
||||
|
||||
files_to_download = [
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
]
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for filename in files_to_download:
|
||||
print(f"Downloading {filename} from {qwen_model}...")
|
||||
hf_hub_download(
|
||||
repo_id=qwen_model,
|
||||
filename=filename,
|
||||
local_dir=output_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
|
||||
def patch_tokenizer_config(output_dir: str) -> None:
|
||||
"""
|
||||
Patch tokenizer_config.json with VibeVoice audio tokens and chat template.
|
||||
"""
|
||||
config_path = os.path.join(output_dir, "tokenizer_config.json")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 1. Add ALL extended tokens to added_tokens_decoder (Qwen2.5 + VibeVoice audio)
|
||||
if "added_tokens_decoder" not in config:
|
||||
config["added_tokens_decoder"] = {}
|
||||
|
||||
for token, token_id in ALL_EXTENDED_TOKENS.items():
|
||||
if str(token_id) not in config["added_tokens_decoder"]:
|
||||
# Determine if token should be marked as "special"
|
||||
# tool_call tokens are NOT special in Qwen2.5
|
||||
is_special = token not in ("<tool_call>", "</tool_call>", "<|fim_prefix|>",
|
||||
"<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>",
|
||||
"<|repo_name|>", "<|file_sep|>")
|
||||
config["added_tokens_decoder"][str(token_id)] = {
|
||||
"content": token,
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
"special": is_special,
|
||||
}
|
||||
|
||||
# 2. Add audio tokens to additional_special_tokens
|
||||
if "additional_special_tokens" not in config:
|
||||
config["additional_special_tokens"] = []
|
||||
|
||||
for token in VIBEVOICE_AUDIO_TOKENS.keys():
|
||||
if token not in config["additional_special_tokens"]:
|
||||
config["additional_special_tokens"].append(token)
|
||||
|
||||
# 3. Modify chat_template to support audio
|
||||
# Instead of replacing entirely, we patch the existing template to handle audio
|
||||
chat_template = config.get("chat_template", "")
|
||||
if chat_template and "<|AUDIO|>" not in chat_template:
|
||||
# Insert audio handling into the template
|
||||
# Find patterns like: {%- if part['type'] == 'text' %}
|
||||
# Add after: {%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}\n {{- '<|AUDIO|>' }}
|
||||
audio_handler = """{%- elif part['type'] == 'audio' or part['type'] == 'audio_url' %}
|
||||
{{- '<|AUDIO|>' }}"""
|
||||
|
||||
# Pattern to find: after handling 'text' type, before endif
|
||||
import re
|
||||
# Look for the pattern where we handle text type and add audio handling
|
||||
pattern = r"(\{\%- if part\['type'\] == 'text' \%\}\s*\n\s*\{\{- part\['text'\] \}\})"
|
||||
replacement = r"\1\n " + audio_handler.replace("\n", r"\n")
|
||||
|
||||
modified_template = re.sub(pattern, replacement, chat_template)
|
||||
|
||||
if modified_template != chat_template:
|
||||
config["chat_template"] = modified_template
|
||||
print(" - Added audio support to existing chat_template")
|
||||
else:
|
||||
# Fallback: use our predefined template
|
||||
print(" - Warning: Could not patch existing template, using predefined template")
|
||||
config["chat_template"] = VIBEVOICE_CHAT_TEMPLATE
|
||||
|
||||
# 4. Update model_max_length for long audio support
|
||||
config["model_max_length"] = 131072
|
||||
|
||||
# 5. Add add_bos_token if not present
|
||||
if "add_bos_token" not in config:
|
||||
config["add_bos_token"] = False
|
||||
|
||||
# Write back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Patched {config_path}")
|
||||
|
||||
|
||||
def patch_tokenizer_json(output_dir: str) -> None:
|
||||
"""
|
||||
Patch tokenizer.json with VibeVoice audio tokens.
|
||||
"""
|
||||
tokenizer_path = os.path.join(output_dir, "tokenizer.json")
|
||||
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
tokenizer = json.load(f)
|
||||
|
||||
# Find existing token IDs to avoid duplicates
|
||||
existing_ids = set()
|
||||
if "added_tokens" in tokenizer:
|
||||
for token_entry in tokenizer["added_tokens"]:
|
||||
existing_ids.add(token_entry.get("id"))
|
||||
|
||||
# Add ALL extended tokens (Qwen2.5 + VibeVoice audio)
|
||||
for token, token_id in ALL_EXTENDED_TOKENS.items():
|
||||
if token_id not in existing_ids:
|
||||
# Determine if token should be marked as "special"
|
||||
is_special = token not in ("<tool_call>", "</tool_call>", "<|fim_prefix|>",
|
||||
"<|fim_middle|>", "<|fim_suffix|>", "<|fim_pad|>",
|
||||
"<|repo_name|>", "<|file_sep|>")
|
||||
tokenizer["added_tokens"].append({
|
||||
"id": token_id,
|
||||
"content": token,
|
||||
"single_word": False,
|
||||
"lstrip": False,
|
||||
"rstrip": False,
|
||||
"normalized": False,
|
||||
"special": is_special,
|
||||
})
|
||||
|
||||
# Write back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Patched {tokenizer_path}")
|
||||
|
||||
|
||||
def generate_added_tokens_json(output_dir: str) -> None:
|
||||
"""
|
||||
Generate added_tokens.json from tokenizer_config.json.
|
||||
"""
|
||||
config_path = os.path.join(output_dir, "tokenizer_config.json")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
added_tokens = {}
|
||||
for token_id, token_info in config.get("added_tokens_decoder", {}).items():
|
||||
content = token_info.get("content")
|
||||
if content:
|
||||
added_tokens[content] = int(token_id)
|
||||
|
||||
output_path = os.path.join(output_dir, "added_tokens.json")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(added_tokens, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Generated {output_path}")
|
||||
|
||||
|
||||
def generate_special_tokens_map_json(output_dir: str) -> None:
|
||||
"""
|
||||
Generate special_tokens_map.json with VibeVoice special tokens.
|
||||
"""
|
||||
# Build the special tokens map
|
||||
special_tokens_map = {
|
||||
"additional_special_tokens": [],
|
||||
"eos_token": "<|endoftext|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"unk_token": "<|endoftext|>",
|
||||
}
|
||||
|
||||
# Add audio tokens as additional_special_tokens
|
||||
for token in VIBEVOICE_AUDIO_TOKENS.keys():
|
||||
special_tokens_map["additional_special_tokens"].append({
|
||||
"content": token,
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
})
|
||||
|
||||
# Add some commonly used special tokens
|
||||
common_special = ["<|object_ref_start|>", "<|object_ref_end|>", "<|box_start|>"]
|
||||
for token in common_special:
|
||||
special_tokens_map["additional_special_tokens"].append({
|
||||
"content": token,
|
||||
"lstrip": False,
|
||||
"normalized": False,
|
||||
"rstrip": False,
|
||||
"single_word": False,
|
||||
})
|
||||
|
||||
output_path = os.path.join(output_dir, "special_tokens_map.json")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Generated {output_path}")
|
||||
|
||||
|
||||
def generate_vibevoice_tokenizer_files(output_dir: str, qwen_model: str = DEFAULT_QWEN_MODEL) -> None:
|
||||
"""
|
||||
Generate all 6 VibeVoice tokenizer files.
|
||||
|
||||
Files generated:
|
||||
1. vocab.json - from Qwen2.5 (unchanged)
|
||||
2. merges.txt - from Qwen2.5 (unchanged)
|
||||
3. tokenizer.json - from Qwen2.5 + audio tokens
|
||||
4. tokenizer_config.json - from Qwen2.5 + audio tokens + chat_template
|
||||
5. added_tokens.json - generated from tokenizer_config.json
|
||||
6. special_tokens_map.json - generated with VibeVoice tokens
|
||||
"""
|
||||
print(f"=== Generating VibeVoice tokenizer files to {output_dir} ===\n")
|
||||
|
||||
# Step 1: Download base files from Qwen2
|
||||
download_qwen_tokenizer_files(output_dir, qwen_model)
|
||||
|
||||
# Step 2: Patch tokenizer_config.json
|
||||
patch_tokenizer_config(output_dir)
|
||||
|
||||
# Step 3: Patch tokenizer.json
|
||||
patch_tokenizer_json(output_dir)
|
||||
|
||||
# Step 4: Generate added_tokens.json
|
||||
generate_added_tokens_json(output_dir)
|
||||
|
||||
# Step 5: Generate special_tokens_map.json
|
||||
generate_special_tokens_map_json(output_dir)
|
||||
|
||||
print(f"\n✅ All 6 tokenizer files generated in {output_dir}")
|
||||
|
||||
|
||||
def compare_json_files(file1: str, file2: str, name: str) -> Dict[str, Any]:
|
||||
"""Compare two JSON files and return differences."""
|
||||
result = {
|
||||
"name": name,
|
||||
"identical": False,
|
||||
"differences": [],
|
||||
}
|
||||
|
||||
if not os.path.exists(file1):
|
||||
result["differences"].append(f"File 1 not found: {file1}")
|
||||
return result
|
||||
|
||||
if not os.path.exists(file2):
|
||||
result["differences"].append(f"File 2 not found: {file2}")
|
||||
return result
|
||||
|
||||
with open(file1, "r", encoding="utf-8") as f:
|
||||
data1 = json.load(f)
|
||||
|
||||
with open(file2, "r", encoding="utf-8") as f:
|
||||
data2 = json.load(f)
|
||||
|
||||
if data1 == data2:
|
||||
result["identical"] = True
|
||||
return result
|
||||
|
||||
# Find specific differences
|
||||
def find_diff(d1, d2, path=""):
|
||||
diffs = []
|
||||
if isinstance(d1, dict) and isinstance(d2, dict):
|
||||
all_keys = set(d1.keys()) | set(d2.keys())
|
||||
for k in all_keys:
|
||||
new_path = f"{path}.{k}" if path else k
|
||||
if k not in d1:
|
||||
diffs.append(f"Missing in generated: {new_path}")
|
||||
elif k not in d2:
|
||||
diffs.append(f"Extra in generated: {new_path}")
|
||||
else:
|
||||
diffs.extend(find_diff(d1[k], d2[k], new_path))
|
||||
elif isinstance(d1, list) and isinstance(d2, list):
|
||||
if len(d1) != len(d2):
|
||||
diffs.append(f"{path}: list length differs ({len(d1)} vs {len(d2)})")
|
||||
# For lists, just check if they're equal (detailed diff is complex)
|
||||
if d1 != d2:
|
||||
diffs.append(f"{path}: list content differs")
|
||||
elif d1 != d2:
|
||||
# Truncate long values for readability
|
||||
v1 = str(d1)[:100] + "..." if len(str(d1)) > 100 else str(d1)
|
||||
v2 = str(d2)[:100] + "..." if len(str(d2)) > 100 else str(d2)
|
||||
diffs.append(f"{path}: '{v1}' vs '{v2}'")
|
||||
return diffs
|
||||
|
||||
result["differences"] = find_diff(data1, data2)
|
||||
return result
|
||||
|
||||
|
||||
def compare_text_files(file1: str, file2: str, name: str) -> Dict[str, Any]:
|
||||
"""Compare two text files."""
|
||||
result = {
|
||||
"name": name,
|
||||
"identical": False,
|
||||
"differences": [],
|
||||
}
|
||||
|
||||
if not os.path.exists(file1):
|
||||
result["differences"].append(f"File 1 not found: {file1}")
|
||||
return result
|
||||
|
||||
if not os.path.exists(file2):
|
||||
result["differences"].append(f"File 2 not found: {file2}")
|
||||
return result
|
||||
|
||||
with open(file1, "r", encoding="utf-8") as f:
|
||||
content1 = f.read()
|
||||
|
||||
with open(file2, "r", encoding="utf-8") as f:
|
||||
content2 = f.read()
|
||||
|
||||
if content1 == content2:
|
||||
result["identical"] = True
|
||||
else:
|
||||
lines1 = content1.splitlines()
|
||||
lines2 = content2.splitlines()
|
||||
result["differences"].append(f"Line count: {len(lines1)} vs {len(lines2)}")
|
||||
|
||||
# Find first difference
|
||||
for i, (l1, l2) in enumerate(zip(lines1, lines2)):
|
||||
if l1 != l2:
|
||||
result["differences"].append(f"First diff at line {i+1}")
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compare_with_reference(generated_dir: str, reference_dir: str) -> None:
|
||||
"""Compare generated files with reference files."""
|
||||
print(f"\n=== Comparing generated files with reference ===")
|
||||
print(f"Generated: {generated_dir}")
|
||||
print(f"Reference: {reference_dir}\n")
|
||||
|
||||
files_to_compare = [
|
||||
("vocab.json", "json"),
|
||||
("merges.txt", "text"),
|
||||
("tokenizer.json", "json"),
|
||||
("tokenizer_config.json", "json"),
|
||||
("added_tokens.json", "json"),
|
||||
("special_tokens_map.json", "json"),
|
||||
]
|
||||
|
||||
all_identical = True
|
||||
|
||||
for filename, file_type in files_to_compare:
|
||||
gen_file = os.path.join(generated_dir, filename)
|
||||
ref_file = os.path.join(reference_dir, filename)
|
||||
|
||||
if file_type == "json":
|
||||
result = compare_json_files(gen_file, ref_file, filename)
|
||||
else:
|
||||
result = compare_text_files(gen_file, ref_file, filename)
|
||||
|
||||
if result["identical"]:
|
||||
print(f"✅ {filename}: IDENTICAL")
|
||||
else:
|
||||
print(f"❌ {filename}: DIFFERENT")
|
||||
for diff in result["differences"][:5]: # Show first 5 differences
|
||||
print(f" - {diff}")
|
||||
if len(result["differences"]) > 5:
|
||||
print(f" ... and {len(result['differences']) - 5} more differences")
|
||||
all_identical = False
|
||||
|
||||
print()
|
||||
if all_identical:
|
||||
print("🎉 All files are identical!")
|
||||
else:
|
||||
print("⚠️ Some files have differences. See details above.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate VibeVoice tokenizer files from Qwen2 base"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for generated files (default: temp directory)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare", "-c",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Reference directory to compare generated files against"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen-model",
|
||||
type=str,
|
||||
default=DEFAULT_QWEN_MODEL,
|
||||
help=f"Qwen model to download base tokenizer from (default: {DEFAULT_QWEN_MODEL})"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine output directory
|
||||
if args.output:
|
||||
output_dir = args.output
|
||||
cleanup = False
|
||||
else:
|
||||
output_dir = tempfile.mkdtemp(prefix="vibevoice_tokenizer_")
|
||||
cleanup = not args.compare # Only cleanup if not comparing
|
||||
|
||||
try:
|
||||
# Generate files
|
||||
generate_vibevoice_tokenizer_files(output_dir, args.qwen_model)
|
||||
|
||||
# Compare if requested
|
||||
if args.compare:
|
||||
compare_with_reference(output_dir, args.compare)
|
||||
|
||||
if not args.output:
|
||||
print(f"\nGenerated files are in: {output_dir}")
|
||||
|
||||
finally:
|
||||
if cleanup and not args.output:
|
||||
print(f"\nCleaning up temporary directory: {output_dir}")
|
||||
shutil.rmtree(output_dir, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user