Files
VibeVoice/demo/vibevoice_asr_inference_from_file.py
T
2026-01-26 14:00:31 +08:00

580 lines
20 KiB
Python

#!/usr/bin/env python
"""
VibeVoice ASR Batch Inference Demo Script
This script supports batch inference for ASR model and compares results
between batch processing and single-sample processing.
"""
import os
import sys
import torch
import numpy as np
from pathlib import Path
import argparse
import time
import json
import re
from typing import List, Dict, Any, Optional
from functools import wraps
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
class VibeVoiceASRBatchInference:
"""Batch inference wrapper for VibeVoice ASR model."""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
attn_implementation: str = "sdpa"
):
"""
Initialize the ASR batch inference pipeline.
Args:
model_path: Path to the pretrained model
device: Device to run inference on (cuda, mps, xpu, cpu, auto)
dtype: Data type for model weights
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
"""
print(f"Loading VibeVoice ASR model from {model_path}")
# Load processor
self.processor = VibeVoiceASRProcessor.from_pretrained(
model_path,
language_model_pretrained_name="Qwen/Qwen2.5-7B"
)
# Load model with specified attention implementation
print(f"Using attention implementation: {attn_implementation}")
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
model_path,
dtype=dtype,
device_map=device if device == "auto" else None,
attn_implementation=attn_implementation,
trust_remote_code=True
)
if device != "auto":
self.model = self.model.to(device)
self.device = device if device != "auto" else next(self.model.parameters()).device
self.dtype = dtype
self.model.eval()
print(f"Model loaded successfully on {self.device}")
def _prepare_generation_config(
self,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 0.9,
do_sample: bool = True,
num_beams: int = 1,
) -> dict:
"""Prepare generation configuration."""
config = {
"max_new_tokens": max_new_tokens,
"pad_token_id": self.processor.pad_id,
"eos_token_id": self.processor.tokenizer.eos_token_id,
}
# Beam search vs sampling
if num_beams > 1:
config["num_beams"] = num_beams
config["do_sample"] = False # Beam search doesn't use sampling
else:
config["do_sample"] = do_sample
# Only set temperature and top_p when sampling is enabled
if do_sample:
config["temperature"] = temperature
config["top_p"] = top_p
return config
def transcribe_batch(
self,
audio_inputs: List,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 1.0,
do_sample: bool = True,
num_beams: int = 1,
) -> List[Dict[str, Any]]:
"""
Transcribe multiple audio files/arrays in a single batch.
Args:
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
max_new_tokens: Maximum tokens to generate
temperature: Temperature for sampling
top_p: Top-p for nucleus sampling
do_sample: Whether to use sampling
Returns:
List of transcription results
"""
if len(audio_inputs) == 0:
return []
batch_size = len(audio_inputs)
print(f"\nProcessing batch of {batch_size} audio(s)...")
# Process all audio together
inputs = self.processor(
audio=audio_inputs,
sampling_rate=None,
return_tensors="pt",
padding=True,
add_generation_prompt=True
)
# Move to device
inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Print batch info
print(f" Input IDs shape: {inputs['input_ids'].shape}")
print(f" Speech tensors shape: {inputs['speech_tensors'].shape}")
print(f" Attention mask shape: {inputs['attention_mask'].shape}")
# Generate
generation_config = self._prepare_generation_config(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
num_beams=num_beams,
)
start_time = time.time()
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
**generation_config
)
generation_time = time.time() - start_time
# Decode outputs for each sample in the batch
results = []
input_length = inputs['input_ids'].shape[1]
for i, audio_input in enumerate(audio_inputs):
# Get generated tokens for this sample (excluding input tokens)
generated_ids = output_ids[i, input_length:]
# Remove padding tokens from the end
# Find the first eos_token or pad_token
eos_positions = (generated_ids == self.processor.tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
generated_ids = generated_ids[:eos_positions[0] + 1]
generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)
# Parse structured output
try:
transcription_segments = self.processor.post_process_transcription(generated_text)
except Exception as e:
print(f"Warning: Failed to parse structured output: {e}")
transcription_segments = []
# Get file name based on input type
if isinstance(audio_input, str):
file_name = audio_input
elif isinstance(audio_input, dict) and 'id' in audio_input:
file_name = audio_input['id']
else:
file_name = f"audio_{i}"
results.append({
"file": file_name,
"raw_text": generated_text,
"segments": transcription_segments,
"generation_time": generation_time / batch_size,
})
print(f" Total generation time: {generation_time:.2f}s")
print(f" Average time per sample: {generation_time/batch_size:.2f}s")
return results
def transcribe_with_batching(
self,
audio_inputs: List,
batch_size: int = 4,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 1.0,
do_sample: bool = True,
num_beams: int = 1,
) -> List[Dict[str, Any]]:
"""
Transcribe multiple audio files/arrays with automatic batching.
Args:
audio_inputs: List of audio file paths or (array, sampling_rate) tuples
batch_size: Number of samples per batch
max_new_tokens: Maximum tokens to generate
temperature: Temperature for sampling
top_p: Top-p for nucleus sampling
do_sample: Whether to use sampling
Returns:
List of transcription results
"""
all_results = []
# Process in batches
for i in range(0, len(audio_inputs), batch_size):
batch_inputs = audio_inputs[i:i + batch_size]
print(f"\n{'='*60}")
print(f"Processing batch {i//batch_size + 1}/{(len(audio_inputs) + batch_size - 1)//batch_size}")
batch_results = self.transcribe_batch(
batch_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
num_beams=num_beams,
)
all_results.extend(batch_results)
return all_results
def print_result(result: Dict[str, Any]):
"""Pretty print a single transcription result."""
print(f"\nFile: {result['file']}")
print(f"Generation Time: {result['generation_time']:.2f}s")
print(f"\n--- Raw Output ---")
print(result['raw_text'][:500] + "..." if len(result['raw_text']) > 500 else result['raw_text'])
if result['segments']:
print(f"\n--- Structured Output ({len(result['segments'])} segments) ---")
for seg in result['segments'][:50]: # Show first 50 segments
print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] "
f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')}...")
if len(result['segments']) > 50:
print(f" ... and {len(result['segments']) - 50} more segments")
def load_dataset_and_concatenate(
dataset_name: str,
split: str,
max_duration: float,
num_audios: int,
target_sr: int = 24000
) -> Optional[List[np.ndarray]]:
"""
Load a HuggingFace dataset and concatenate audio samples into long audio chunks.
(Note, just for demo purpose, not for benchmark evaluation)
Args:
dataset_name: HuggingFace dataset name (e.g., 'openslr/librispeech_asr')
split: Dataset split to use (e.g., 'test', 'test.other')
max_duration: Maximum duration in seconds for each concatenated audio
num_audios: Number of concatenated audios to create
target_sr: Target sample rate (default: 24000)
Returns:
List of concatenated audio arrays, or None if loading fails
"""
try:
from datasets import load_dataset
import torchcodec # just for decode audio in datasets
except ImportError:
print("Please install it with: pip install datasets torchcodec")
return None
print(f"\nLoading dataset: {dataset_name} (split: {split})")
print(f"Will create {num_audios} concatenated audio(s), each up to {max_duration:.1f}s ({max_duration/3600:.2f} hours)")
try:
# Use streaming to avoid downloading the entire dataset
dataset = load_dataset(dataset_name, split=split, streaming=True)
print(f"Dataset loaded in streaming mode")
concatenated_audios = [] # List of concatenated audio metadata
# Create multiple concatenated audios based on num_audios
current_chunks = []
current_duration = 0.0
current_samples_used = 0
sample_idx = 0
for sample in dataset:
if len(concatenated_audios) >= num_audios:
break
if 'audio' not in sample:
continue
audio_data = sample['audio']
audio_array = audio_data['array']
sr = audio_data['sampling_rate']
# Resample if needed
if sr != target_sr:
duration = len(audio_array) / sr
new_length = int(duration * target_sr)
audio_array = np.interp(
np.linspace(0, len(audio_array) - 1, new_length),
np.arange(len(audio_array)),
audio_array
)
chunk_duration = len(audio_array) / target_sr
# Check if adding this chunk exceeds max_duration
if current_duration + chunk_duration > max_duration:
remaining_duration = max_duration - current_duration
if remaining_duration > 0.5: # Only add if > 0.5s remaining
samples_to_take = int(remaining_duration * target_sr)
current_chunks.append(audio_array[:samples_to_take])
current_duration += remaining_duration
current_samples_used += 1
# Save current concatenated audio and start a new one
if current_chunks:
concatenated_audios.append({
'array': np.concatenate(current_chunks),
'duration': current_duration,
'samples_used': current_samples_used,
})
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
# Reset for next concatenated audio
current_chunks = []
current_duration = 0.0
current_samples_used = 0
if len(concatenated_audios) >= num_audios:
break
current_chunks.append(audio_array)
current_duration += chunk_duration
current_samples_used += 1
sample_idx += 1
if sample_idx % 100 == 0:
print(f" Processed {sample_idx} samples...")
# Don't forget the last batch if it has content
if current_chunks and len(concatenated_audios) < num_audios:
concatenated_audios.append({
'array': np.concatenate(current_chunks),
'duration': current_duration,
'samples_used': current_samples_used,
})
print(f" Created audio {len(concatenated_audios)}: {current_duration:.1f}s from {current_samples_used} samples")
if not concatenated_audios:
print("Warning: No audio samples found in dataset")
return None
# Extract arrays and print summary
result = [a['array'] for a in concatenated_audios]
total_duration = sum(a['duration'] for a in concatenated_audios)
total_samples = sum(a['samples_used'] for a in concatenated_audios)
print(f"\nCreated {len(result)} concatenated audio(s), total {total_duration:.1f}s ({total_duration/60:.1f} min) from {total_samples} samples")
return result
except Exception as e:
print(f"Error loading dataset: {e}")
import traceback
traceback.print_exc()
return None
def main():
parser = argparse.ArgumentParser(description="VibeVoice ASR Batch Inference Demo")
parser.add_argument(
"--model_path",
type=str,
default="",
help="Path to the model checkpoint"
)
parser.add_argument(
"--audio_files",
type=str,
nargs='+',
required=False,
help="Paths to audio files for transcription"
)
parser.add_argument(
"--audio_dir",
type=str,
required=False,
help="Directory containing audio files for batch transcription"
)
parser.add_argument(
"--dataset",
type=str,
required=False,
help="HuggingFace dataset name (e.g., 'openslr/librispeech_asr')"
)
parser.add_argument(
"--split",
type=str,
default="test",
help="Dataset split to use (e.g., 'test', 'test.other', 'test.clean')"
)
parser.add_argument(
"--max_duration",
type=float,
default=3600.0,
help="Maximum duration in seconds for concatenated dataset audio (default: 3600 = 1 hour)"
)
parser.add_argument(
"--batch_size",
type=int,
default=2,
help="Batch size for processing multiple files"
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else ("xpu" if torch.backends.xpu.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ),
choices=["cuda", "cpu", "mps","xpu", "auto"],
help="Device to run inference on"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=32768,
help="Maximum number of tokens to generate"
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Temperature for sampling (0 = greedy decoding)"
)
parser.add_argument(
"--top_p",
type=float,
default=1.0,
help="Top-p for nucleus sampling"
)
parser.add_argument(
"--num_beams",
type=int,
default=1,
help="Number of beams for beam search. Use 1 for greedy/sampling"
)
parser.add_argument(
"--attn_implementation",
type=str,
default="auto",
choices=["flash_attention_2", "sdpa", "eager", "auto"],
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU/XPU)"
)
args = parser.parse_args()
# Auto-detect best attention implementation based on device
if args.attn_implementation == "auto":
if args.device == "cuda" and torch.cuda.is_available():
try:
import flash_attn
args.attn_implementation = "flash_attention_2"
except ImportError:
print("flash_attn not installed, falling back to sdpa")
args.attn_implementation = "sdpa"
else:
# MPS/XPU/CPU don't support flash_attention_2
args.attn_implementation = "sdpa"
print(f"Auto-detected attention implementation: {args.attn_implementation}")
# Collect audio files
audio_files = []
concatenated_audio = None # For storing concatenated dataset audio
if args.audio_files:
audio_files.extend(args.audio_files)
if args.audio_dir:
import glob
for ext in ["*.wav", "*.mp3", "*.flac", "*.mp4", "*.m4a", "*.webm"]:
audio_files.extend(glob.glob(os.path.join(args.audio_dir, ext)))
if args.dataset:
concatenated_audio = load_dataset_and_concatenate(
dataset_name=args.dataset,
split=args.split,
max_duration=args.max_duration,
num_audios=args.batch_size,
)
if concatenated_audio is None:
return
if len(audio_files) == 0 and concatenated_audio is None:
print("No audio files provided. Please specify --audio_files, --audio_dir, or --dataset.")
return
if audio_files:
print(f"\nAudio files to process ({len(audio_files)}):")
for f in audio_files:
print(f" - {f}")
if concatenated_audio:
print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
# Initialize model
# Handle MPS device and dtype
if args.device == "mps":
model_dtype = torch.float32 # MPS works better with float32
elif args.device == "xpu":
model_dtype = torch.float32
elif args.device == "cpu":
model_dtype = torch.float32
else:
model_dtype = torch.bfloat16
asr = VibeVoiceASRBatchInference(
model_path=args.model_path,
device=args.device,
dtype=model_dtype,
attn_implementation=args.attn_implementation
)
# If temperature is 0, use greedy decoding (no sampling)
do_sample = args.temperature > 0
# Combine all audio inputs
all_audio_inputs = audio_files + (concatenated_audio or [])
print("\n" + "="*80)
print(f"Processing {len(all_audio_inputs)} audio(s)")
print("="*80)
all_results = asr.transcribe_with_batching(
all_audio_inputs,
batch_size=args.batch_size,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=do_sample,
num_beams=args.num_beams,
)
# Print results
print("\n" + "="*80)
print("Results")
print("="*80)
for result in all_results:
print("\n" + "-"*60)
print_result(result)
if __name__ == "__main__":
main()