diff --git a/docs/vibevoice-vllm-asr.md b/docs/vibevoice-vllm-asr.md index f712d3b..489389e 100644 --- a/docs/vibevoice-vllm-asr.md +++ b/docs/vibevoice-vllm-asr.md @@ -52,15 +52,25 @@ docker logs -f vibevoice-vllm Once the vLLM server is running, test it with the provided script: ```bash -# Run the test (use container path /app/...) +# Basic transcription docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav + +# With hotwords for better recognition of specific terms +docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice" + ``` ```bash -# Run the recover_test (use container path /app/...) +# With auto-recovery from repetition loops (for long audio) docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav + +# Auto-recover with hotwords +docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice" ``` -> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing. + +> **Note**: +> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing. +> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names. ### Environment Variables diff --git a/vllm_plugin/__init__.py b/vllm_plugin/__init__.py index b696a45..989acb4 100644 --- a/vllm_plugin/__init__.py +++ b/vllm_plugin/__init__.py @@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast from .model import VibeVoiceForCausalLM -from .inputs import vibevoice_audio_input_mapper def register_vibevoice(): diff --git a/vllm_plugin/model.py b/vllm_plugin/model.py index bcb4ca3..2ca3cd2 100644 --- a/vllm_plugin/model.py +++ b/vllm_plugin/model.py @@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr integration for speech-to-text inference. """ -from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal -import json -import math +from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence import os -import sys -from pathlib import Path import torch import torch.nn as nn import numpy as np -from io import BytesIO -import tempfile import base64 @@ -29,32 +23,12 @@ import base64 from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer -def _suffix_from_media_type(media_type: str | None) -> str: - if not media_type: - return ".bin" - mt = media_type.lower().strip() - if mt in ("audio/wav", "audio/x-wav", "audio/wave"): - return ".wav" - if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"): - return ".mp3" - if mt in ("audio/flac",): - return ".flac" - if mt in ("audio/ogg", "audio/opus"): - return ".ogg" - if mt in ("audio/mp4", "audio/m4a"): - return ".m4a" - if mt in ("video/mp4",): - return ".mp4" - return ".bin" - - -def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]: - """Load audio bytes using FFmpeg. +def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]: + """Load audio bytes using FFmpeg via stdin-pipe decoding. Returns: Tuple of (audio_waveform, sample_rate). Sample rate is always 24000. """ - # Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency. audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000) normalizer = AudioNormalizer() audio = normalizer(audio) @@ -79,10 +53,10 @@ class _PatchedAudioMediaIO(_OriginalAudioMediaIO): """AudioMediaIO implementation using FFmpeg for audio decoding.""" def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]: - return _ffmpeg_load_bytes(data, media_type=None) + return _ffmpeg_load_bytes(data) def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]: - return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type) + return _ffmpeg_load_bytes(base64.b64decode(data)) def load_file(self, filepath) -> tuple[np.ndarray, int]: return _ffmpeg_load_file(filepath) @@ -96,17 +70,13 @@ _vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO # ============================================================================ -from transformers import Qwen2Config, BatchFeature +from transformers import BatchFeature from transformers.models.whisper import WhisperFeatureExtractor -from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.config import VllmConfig, ModelConfig -from vllm.config.speech_to_text import SpeechToTextConfig +from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import MultiModalDataParser from vllm.sequence import IntermediateTensors from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings -from vllm.inputs import PromptType from vllm.model_executor.models.utils import ( init_vllm_registered_model, maybe_prefix, @@ -558,7 +528,7 @@ class VibeVoiceProcessingInfo(BaseProcessingInfo): return tokens def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"audio": None} + return {"audio": 1} class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]): @@ -873,17 +843,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): with a causal language model for text generation. """ - # SupportsTranscription interface - supports_transcription: ClassVar[Literal[True]] = True - supports_transcription_only: ClassVar[bool] = False - supports_segment_timestamp: ClassVar[bool] = False - - # Supported languages (Chinese as primary target) - supported_languages: ClassVar[Mapping[str, str]] = { - "zh": "Chinese", - "en": "English", - } - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: """Return the placeholder string format for a given modality. @@ -897,112 +856,10 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return "<|AUDIO|>" raise ValueError(f"Unsupported modality: {modality}") - @classmethod - def get_generation_prompt( - cls, - audio: np.ndarray, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - language: str | None, - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: str | None, - ) -> PromptType: - """Get the prompt for the ASR model. - - Generates a chat-formatted prompt for speech-to-text transcription - with JSON output format. - """ - # If user provides custom prompt, use it - if request_prompt: - return request_prompt - - # Calculate audio duration for the prompt - # Audio should be at 24kHz, so duration = len(audio) / 24000 - duration = len(audio) / 24000 if audio is not None else 10.0 - - system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format." - show_keys = ["Start time", "End time", "Speaker ID", "Content"] - user_suffix = ( - f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " - + ", ".join(show_keys) - ) - - # IMPORTANT: keep <|AUDIO|> as the only placeholder token here. - # `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders. - user_content = "<|AUDIO|>\n" + user_suffix - - prompt = ( - f"<|im_start|>system\n{system_prompt}<|im_end|>\n" - f"<|im_start|>user\n{user_content}<|im_end|>\n" - "<|im_start|>assistant\n" - ) - return prompt - - @classmethod - def get_speech_to_text_config( - cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"] - ) -> SpeechToTextConfig: - """Get the speech to text config for the ASR model.""" - return SpeechToTextConfig( - language=None, # Auto-detect or use request language - task_type=task_type, - ) - - @classmethod - def get_num_audio_tokens( - cls, - audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - ) -> int | None: - """Estimate number of audio tokens from duration. - - Returns the number of audio EMBEDDING positions (speech_pad_id tokens). - Note: _get_prompt_updates actually generates: - [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id] - So total prompt tokens = N + 3, but this returns N (the embedding count). - """ - sampling_rate = 24000 - compress_ratio = 3200 - samples = int(audio_duration_s * sampling_rate) - num_tokens = int(np.ceil(samples / compress_ratio)) - return num_tokens - - @classmethod - def get_other_languages(cls) -> Mapping[str, str]: - """Get languages from Whisper map not natively supported.""" - # Import LANGUAGES from vllm - try: - from vllm.transformers_utils.tokenizer import LANGUAGES - except ImportError: - # Fallback to empty dict if import fails - return {} - return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages} - - @classmethod - def validate_language(cls, language: str | None) -> str | None: - """Validate the language code.""" - if language is None or language in cls.supported_languages: - return language - elif language in cls.get_other_languages(): - print(f"Warning: Language {language!r} is not natively supported") - return language - else: - raise ValueError( - f"Unsupported language: {language!r}. " - f"Supported: {list(cls.supported_languages.keys())}" - ) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config - - # Keep a copy of the resolved model path for any custom weight-loading - # logic (e.g., loading audio encoder weights in fp32 directly from - # safetensors shards). - self._model_path = vllm_config.model_config.model self.audio_encoder = VibeVoiceAudioEncoder(config) @@ -1100,14 +957,14 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # Process each audio through the VibeVoice encoder embeddings = [] - # Get model device and dtype for alignment + # Get model device for tensor placement. + # dtype is NOT set here — audio_encoder.forward() handles it internally: + # input: converted to fp32 (self._audio_encoder_dtype) + # output: converted to bfloat16 (self._lm_dtype) try: device = next(self.audio_encoder.parameters()).device - dtype = next(self.audio_encoder.parameters()).dtype except StopIteration: - # Fallback if no parameters (shouldn't happen) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dtype = torch.bfloat16 # Handle both stacked tensor and list of tensors # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] @@ -1138,7 +995,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if not isinstance(audio_tensor, torch.Tensor): audio_tensor = torch.tensor(audio_tensor) - # Let vLLM handle dtype (bfloat16 by default) + # Only place on correct device; audio_encoder.forward() handles dtype audio_tensor = audio_tensor.to(device=device) # Get actual length if available, otherwise use full length @@ -1294,35 +1151,31 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): intermediate_tensors: Intermediate tensors for pipeline parallelism. inputs_embeds: Pre-computed embeddings (from multimodal merge or decode). """ - try: - # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode) - # Only compute from input_ids if inputs_embeds is not available - if inputs_embeds is None and input_ids is not None: - # Compute embeddings from input_ids - inputs_embeds = self.get_input_embeddings()(input_ids) - - # If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds - if intermediate_tensors is not None: - inputs_embeds = None - - # Get the inner model - handle both wrapped and direct language models - language_model = self.language_model - if hasattr(language_model, "language_model"): - language_model = language_model.language_model - - # Call the language model's model (Qwen2Model) - # vLLM V1 passes kv_caches and attn_metadata via context, not arguments - # IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding - hidden_states = language_model.model( - input_ids=None, # Always None when we have inputs_embeds - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds - ) - return hidden_states - - except Exception as e: - raise + # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode) + # Only compute from input_ids if inputs_embeds is not available + if inputs_embeds is None and input_ids is not None: + # Compute embeddings from input_ids + inputs_embeds = self.get_input_embeddings()(input_ids) + + # If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds + if intermediate_tensors is not None: + inputs_embeds = None + + # Get the inner model - handle both wrapped and direct language models + language_model = self.language_model + if hasattr(language_model, "language_model"): + language_model = language_model.language_model + + # Call the language model's model (Qwen2Model) + # vLLM V1 passes kv_caches and attn_metadata via context, not arguments + # IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding + hidden_states = language_model.model( + input_ids=None, # Always None when we have inputs_embeds + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds + ) + return hidden_states # Alias for training checkpoint compatibility diff --git a/vllm_plugin/tests/52min.mp3 b/vllm_plugin/tests/52min.mp3 deleted file mode 100644 index 0ed89ef..0000000 Binary files a/vllm_plugin/tests/52min.mp3 and /dev/null differ diff --git a/vllm_plugin/tests/test_api.py b/vllm_plugin/tests/test_api.py index af128bc..4076c20 100644 --- a/vllm_plugin/tests/test_api.py +++ b/vllm_plugin/tests/test_api.py @@ -1,14 +1,23 @@ #!/usr/bin/env python3 """ -Test VibeVoice vLLM API with Streaming (Real-time output). +Test VibeVoice vLLM API with Streaming and Optional Hotwords Support. + +This script tests ASR transcription via the vLLM OpenAI-compatible API. +By default, it runs standard transcription without hotwords. + +Optionally, you can provide hotwords (context_info) to improve recognition +of domain-specific content like proper nouns, technical terms, and speaker names. +Hotwords are embedded in the prompt as "with extra info: {hotwords}". Usage: - python test_api.py [audio_path] [--url URL] + python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"] Examples: - python test_api.py # Use default audio - python test_api.py /path/to/audio.wav # Specify audio file - python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL + # Standard transcription (no hotwords) + python3 test_api.py audio.wav + + # With hotwords for better recognition of specific terms + python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice" """ import requests import json @@ -21,38 +30,38 @@ import argparse def _guess_mime_type(path: str) -> str: + """Guess MIME type from file extension.""" ext = os.path.splitext(path)[1].lower() - if ext == ".wav": - return "audio/wav" - if ext in (".mp3",): - return "audio/mpeg" - if ext in (".m4a",): - return "audio/mp4" - if ext in (".mp4", ".m4v", ".mov", ".webm"): - return "video/mp4" - if ext in (".flac",): - return "audio/flac" - if ext in (".ogg", ".opus"): - return "audio/ogg" - return "application/octet-stream" + mime_map = { + ".wav": "audio/wav", + ".mp3": "audio/mpeg", + ".m4a": "audio/mp4", + ".mp4": "video/mp4", + ".flac": "audio/flac", + ".ogg": "audio/ogg", + ".opus": "audio/ogg", + } + return mime_map.get(ext, "application/octet-stream") def _get_duration_seconds_ffprobe(path: str) -> float: """Get audio duration using ffprobe.""" cmd = [ - "ffprobe", - "-v", - "error", - "-show_entries", - "format=duration", - "-of", - "default=noprint_wrappers=1:nokey=1", + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "default=noprint_wrappers=1:nokey=1", path, ] out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() return float(out) +def _is_video_file(path: str) -> bool: + """Check if the file is a video file that needs audio extraction.""" + ext = os.path.splitext(path)[1].lower() + return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") + + def _extract_audio_from_video(video_path: str) -> str: """ Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. @@ -74,26 +83,40 @@ def _extract_audio_from_video(video_path: str) -> str: return audio_path -def _is_video_file(path: str) -> bool: - """Check if the file is a video file that needs audio extraction.""" - ext = os.path.splitext(path)[1].lower() - return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") - - -def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"): - """Test ASR transcription with streaming output.""" +def test_transcription_with_hotwords( + audio_path: str, + context_info: str = None, + base_url: str = "http://localhost:8000", +): + """ + Test ASR transcription with customized hotwords. - print(f"Loading audio from: {audio_path}") + Hotwords are embedded in the prompt text as "with extra info: {hotwords}". + This helps the model recognize domain-specific terms more accurately. + + Args: + audio_path: Path to the audio file + context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice") + base_url: vLLM server URL + """ + + print(f"=" * 70) + print(f"Testing Customized Hotwords Support") + print(f"=" * 70) + print(f"Input file: {audio_path}") + print(f"Hotwords: {context_info or '(none)'}") + print() # Handle video files: extract audio first temp_audio_path = None actual_audio_path = audio_path if _is_video_file(audio_path): - print(f"Detected video file, extracting audio...") + print(f"šŸŽ¬ Detected video file, extracting audio...") temp_audio_path = _extract_audio_from_video(audio_path) actual_audio_path = temp_audio_path - print(f"Audio extracted to: {temp_audio_path}") + print(f"āœ… Audio extracted to: {temp_audio_path}") + # Load audio try: duration = _get_duration_seconds_ffprobe(actual_audio_path) print(f"Audio duration: {duration:.2f} seconds") @@ -106,16 +129,30 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000") except Exception as e: print(f"Error preparing audio: {e}") + # Cleanup temp file if created + if temp_audio_path and os.path.exists(temp_audio_path): + os.remove(temp_audio_path) return # Build the request url = f"{base_url}/v1/chat/completions" show_keys = ["Start time", "End time", "Speaker ID", "Content"] - prompt_text = ( - f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " - + ", ".join(show_keys) - ) + + # Build prompt with optional hotwords + # Hotwords are embedded as "with extra info: {hotwords}" in the prompt + if context_info and context_info.strip(): + prompt_text = ( + f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n" + f"Please transcribe it with these keys: " + ", ".join(show_keys) + ) + print(f"\nšŸ“ Hotwords embedded in prompt: '{context_info}'") + else: + prompt_text = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + print(f"\nšŸ“ No hotwords provided") mime = _guess_mime_type(actual_audio_path) data_url = f"data:{mime};base64,{audio_b64}" @@ -139,20 +176,19 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000") "temperature": 0.0, "stream": True, "top_p": 1.0, - "repetition_penalty": 1.0, } - print(f"\nSending request to {url} (Streaming Mode)...") - print(f"Prompt: {prompt_text}") - print("-" * 60) + print(f"\n{'=' * 70}") + print(f"Sending request to {url}") + print(f"{'=' * 70}") t0 = time.time() try: - response = requests.post(url, json=payload, stream=True, timeout=12000) if response.status_code == 200: - print("Response received. Streaming content:\n") + print("\nāœ… Response received. Streaming content:\n") + print("-" * 50) printed = "" for line in response.iter_lines(): @@ -162,92 +198,72 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000") if decoded_line.startswith("data: "): json_str = decoded_line[6:] if json_str.strip() == "[DONE]": - print("\n\n[Finished]") + print("\n" + "-" * 50) + print("āœ… [Finished]") break try: data = json.loads(json_str) - delta = data['choices'][0]['delta'] content = delta.get('content', '') if content: - - # vLLM/OpenAI-compatible streams may emit either - # incremental deltas OR the full accumulated text. - # Only print the newly-added part to avoid repeats. if content.startswith(printed): to_print = content[len(printed):] else: to_print = content - if to_print: print(to_print, end='', flush=True) printed += to_print except json.JSONDecodeError: pass else: - print(f"Error: {response.status_code}") + print(f"āŒ Error: {response.status_code}") print(response.text) except requests.exceptions.Timeout: - print("\nRequest timed out!") + print("āŒ Request timed out!") except Exception as e: - print(f"\nError: {e}") + print(f"āŒ Error: {e}") - print(f"\n{'-'*60}") - print(f"Total time elapsed: {time.time() - t0:.2f}s") + elapsed = time.time() - t0 + print(f"\n{'=' * 70}") + print(f"ā±ļø Total time elapsed: {elapsed:.2f}s") + print(f"šŸ“Š RTF (Real-Time Factor): {elapsed / duration:.2f}x") + print(f"{'=' * 70}") # Cleanup temp audio file if created if temp_audio_path and os.path.exists(temp_audio_path): os.remove(temp_audio_path) - print(f"Cleaned up temp file: {temp_audio_path}") + print(f"šŸ—‘ļø Cleaned up temp file: {temp_audio_path}") def main(): parser = argparse.ArgumentParser( - description="Test VibeVoice vLLM API with streaming output" + description="Test VibeVoice vLLM API with Customized Hotwords" ) parser.add_argument( "audio_path", - nargs="?", - default=None, help="Path to audio file (wav, mp3, flac, etc.) or video file" ) parser.add_argument( "--url", default="http://localhost:8000", - help="vLLM server base URL (default: http://localhost:8000)" + help="vLLM server URL (default: http://localhost:8000)" + ) + parser.add_argument( + "--hotwords", + type=str, + default=None, + help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')" ) args = parser.parse_args() - # Find default audio if not specified - audio_path = args.audio_path - if audio_path is None: - # Try to find a sample audio in common locations - possible_paths = [ - # In VibeVoice demo folder - os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"), - os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"), - # Relative to current directory - "demo/voices/en-Carter_man.wav", - "demo/voices/zh-Anchen_man_bgm.wav", - ] - - for path in possible_paths: - if os.path.exists(path): - audio_path = path - break - - if audio_path is None: - print("Error: No audio file specified and no default audio found.") - print("Usage: python test_api.py ") - sys.exit(1) - - if not os.path.exists(audio_path): - print(f"Error: Audio file not found: {audio_path}") - sys.exit(1) - - test_transcription(audio_path, args.url) + # Run test + test_transcription_with_hotwords( + audio_path=args.audio_path, + context_info=args.hotwords, + base_url=args.url, + ) if __name__ == "__main__": diff --git a/vllm_plugin/tests/test_api_auto_recover.py b/vllm_plugin/tests/test_api_auto_recover.py index 4fa053e..482e258 100644 --- a/vllm_plugin/tests/test_api_auto_recover.py +++ b/vllm_plugin/tests/test_api_auto_recover.py @@ -1,18 +1,36 @@ #!/usr/bin/env python3 """ -VibeVoice vLLM API with Auto-Recovery from Repetition Loops. +Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery. -Strategy: -1. Start with greedy decoding (temperature=0, top_p=1.0) -2. Stream and detect repetition patterns in real-time -3. Only output content up to (current_length - window_size) at segment boundaries -4. When loop detected: - - Truncate to last complete segment boundary (},) - - Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95 -5. Max 3 retries, if all fail output error message +This script tests ASR transcription with automatic recovery from repetition loops. +Supports optional hotwords to improve recognition of domain-specific terms. -User sees: clean streaming transcription output (only complete segments) -Internal: automatic recovery from repetition loops (silent) +Features: +- Streaming output with real-time repetition detection +- Auto-recovery when model enters repetition loops +- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}") +- Video file support (auto-extracts audio) + +Recovery Strategy: +1. First attempt: greedy decoding (temperature=0, top_p=1.0) +2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95 +3. Max 3 retries, truncate to last complete segment boundary + +Usage: + python test_api_auto_recover.py [output_path] [--url URL] [--hotwords "word1,word2"] [--debug] + +Examples: + # Basic usage + python3 test_api_auto_recover.py audio.wav + + # With hotwords + python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice" + + # Save result to file + python3 test_api_auto_recover.py audio.wav result.txt + + # Debug mode (show recovery info) + python3 test_api_auto_recover.py audio.wav --debug """ import requests import json @@ -22,6 +40,7 @@ import sys import os import subprocess import re +import argparse from collections import Counter @@ -441,30 +460,41 @@ def stream_with_recovery( return None -def test_transcription_with_recovery(): - """Main test function with auto-recovery.""" +def test_transcription_with_recovery( + audio_path: str, + output_path: str = None, + base_url: str = "http://localhost:8000", + hotwords: str = None, + debug: bool = False, +): + """ + Test ASR transcription with auto-recovery from repetition loops. - # Parse arguments - debug = "--debug" in sys.argv or "-debug" in sys.argv - args = [a for a in sys.argv[1:] if not a.startswith("-")] + Args: + audio_path: Path to the audio file + output_path: Optional path to save transcription result + base_url: vLLM server URL + hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice") + debug: Show recovery debug info + """ - audio_path = ( - args[0] - ) - - output_path = args[1] if len(args) > 1 else None - - print(f"Loading audio from: {audio_path}") + print(f"=" * 70) + print(f"Testing with Auto-Recovery") + print(f"=" * 70) + print(f"Input file: {audio_path}") + print(f"Hotwords: {hotwords or '(none)'}") + print() # Handle video files: extract audio first temp_audio_path = None actual_audio_path = audio_path if _is_video_file(audio_path): - print(f"Detected video file, extracting audio...") + print(f"šŸŽ¬ Detected video file, extracting audio...") temp_audio_path = _extract_audio_from_video(audio_path) actual_audio_path = temp_audio_path - print(f"Audio extracted to: {temp_audio_path}") + print(f"āœ… Audio extracted to: {temp_audio_path}") + # Load audio try: duration = _get_duration_seconds_ffprobe(actual_audio_path) print(f"Audio duration: {duration:.2f} seconds") @@ -476,16 +506,29 @@ def test_transcription_with_recovery(): print(f"Audio size: {len(audio_bytes)} bytes") except Exception as e: - print(f"Error preparing audio: {e}") + print(f"āŒ Error preparing audio: {e}") + # Cleanup temp file if created + if temp_audio_path and os.path.exists(temp_audio_path): + os.remove(temp_audio_path) return - url = "http://localhost:8000/v1/chat/completions" + url = f"{base_url}/v1/chat/completions" show_keys = ["Start time", "End time", "Speaker ID", "Content"] - prompt_text = ( - f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " - + ", ".join(show_keys) - ) + + # Build prompt with optional hotwords + if hotwords and hotwords.strip(): + prompt_text = ( + f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n" + f"Please transcribe it with these keys: " + ", ".join(show_keys) + ) + print(f"\nšŸ“ Hotwords embedded in prompt: '{hotwords}'") + else: + prompt_text = ( + f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " + + ", ".join(show_keys) + ) + print(f"\nšŸ“ No hotwords provided") mime = _guess_mime_type(actual_audio_path) data_url = f"data:{mime};base64,{audio_b64}" @@ -505,12 +548,13 @@ def test_transcription_with_recovery(): } ] - print(f"\nSending request to {url} (Streaming Mode)...") - print(f"Prompt: {prompt_text}") - print("-" * 60) - print("Response received. Streaming content:\n") + print(f"\n{'=' * 70}") + print(f"Sending request to {url}") + print(f"{'=' * 70}") t0 = time.time() + print("\nāœ… Response received. Streaming content:\n") + print("-" * 50) result = stream_with_recovery( url=url, @@ -522,27 +566,73 @@ def test_transcription_with_recovery(): debug=debug, ) - print("\n[Finished]") - print("-" * 60) - print(f"Total time elapsed: {time.time() - t0:.2f}s") + elapsed = time.time() - t0 + print("-" * 50) + print("āœ… [Finished]") + print(f"\n{'=' * 70}") + print(f"ā±ļø Total time elapsed: {elapsed:.2f}s") + print(f"{'=' * 70}") if result is None: - print("Transcription failed") + print("āŒ Transcription failed") return - print(f"Final output length: {len(result)} chars") + print(f"šŸ“„ Final output length: {len(result)} chars") # Optionally save result if output_path: with open(output_path, "w", encoding="utf-8") as f: f.write(result) - print(f"Result saved to: {output_path}") + print(f"šŸ’¾ Result saved to: {output_path}") # Cleanup temp audio file if created if temp_audio_path and os.path.exists(temp_audio_path): os.remove(temp_audio_path) - print(f"Cleaned up temp file: {temp_audio_path}") + print(f"šŸ—‘ļø Cleaned up temp file: {temp_audio_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Test VibeVoice vLLM API with auto-recovery from repetition loops" + ) + parser.add_argument( + "audio_path", + help="Path to audio file (wav, mp3, flac, etc.) or video file" + ) + parser.add_argument( + "output_path", + nargs="?", + default=None, + help="Optional path to save transcription result" + ) + parser.add_argument( + "--url", + default="http://localhost:8000", + help="vLLM server URL (default: http://localhost:8000)" + ) + parser.add_argument( + "--hotwords", + type=str, + default=None, + help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')" + ) + parser.add_argument( + "--debug", + action="store_true", + help="Show recovery debug info" + ) + + args = parser.parse_args() + + # Run test + test_transcription_with_recovery( + audio_path=args.audio_path, + output_path=args.output_path, + base_url=args.url, + hotwords=args.hotwords, + debug=args.debug, + ) if __name__ == "__main__": - test_transcription_with_recovery() + main() diff --git a/vllm_plugin/tests/zeo.mp3 b/vllm_plugin/tests/zeo.mp3 deleted file mode 100644 index e149f94..0000000 Binary files a/vllm_plugin/tests/zeo.mp3 and /dev/null differ