Merge pull request #233 from Damon-Salvetore/main
Add hot words support
This commit is contained in:
@@ -52,15 +52,25 @@ docker logs -f vibevoice-vllm
|
|||||||
Once the vLLM server is running, test it with the provided script:
|
Once the vLLM server is running, test it with the provided script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run the test (use container path /app/...)
|
# Basic transcription
|
||||||
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
|
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav
|
||||||
|
|
||||||
|
# With hotwords for better recognition of specific terms
|
||||||
|
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run the recover_test (use container path /app/...)
|
# With auto-recovery from repetition loops (for long audio)
|
||||||
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
|
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav
|
||||||
|
|
||||||
|
# Auto-recover with hotwords
|
||||||
|
docker exec -it vibevoice-vllm python3 vllm_plugin/tests/test_api_auto_recover.py /app/audio.wav --hotwords "Microsoft,VibeVoice"
|
||||||
```
|
```
|
||||||
> **Note**: The audio file must be inside the mounted directory (`/app` in the container). Copy your audio to the VibeVoice folder before testing.
|
|
||||||
|
> **Note**:
|
||||||
|
> - The audio/video file must be inside the mounted directory (`/app` in the container). Copy your files to the VibeVoice folder before testing.
|
||||||
|
> - Hotwords help improve recognition of domain-specific terms like proper nouns, technical terms, and speaker names.
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
|||||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
||||||
|
|
||||||
from .model import VibeVoiceForCausalLM
|
from .model import VibeVoiceForCausalLM
|
||||||
from .inputs import vibevoice_audio_input_mapper
|
|
||||||
|
|
||||||
|
|
||||||
def register_vibevoice():
|
def register_vibevoice():
|
||||||
|
|||||||
+13
-160
@@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
|
|||||||
integration for speech-to-text inference.
|
integration for speech-to-text inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
|
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from io import BytesIO
|
|
||||||
import tempfile
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
@@ -29,32 +23,12 @@ import base64
|
|||||||
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
|
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
|
||||||
|
|
||||||
|
|
||||||
def _suffix_from_media_type(media_type: str | None) -> str:
|
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
|
||||||
if not media_type:
|
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
|
||||||
return ".bin"
|
|
||||||
mt = media_type.lower().strip()
|
|
||||||
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
|
|
||||||
return ".wav"
|
|
||||||
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
|
|
||||||
return ".mp3"
|
|
||||||
if mt in ("audio/flac",):
|
|
||||||
return ".flac"
|
|
||||||
if mt in ("audio/ogg", "audio/opus"):
|
|
||||||
return ".ogg"
|
|
||||||
if mt in ("audio/mp4", "audio/m4a"):
|
|
||||||
return ".m4a"
|
|
||||||
if mt in ("video/mp4",):
|
|
||||||
return ".mp4"
|
|
||||||
return ".bin"
|
|
||||||
|
|
||||||
|
|
||||||
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
|
|
||||||
"""Load audio bytes using FFmpeg.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
|
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
|
||||||
"""
|
"""
|
||||||
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
|
|
||||||
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
|
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
audio = normalizer(audio)
|
audio = normalizer(audio)
|
||||||
@@ -79,10 +53,10 @@ class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
|
|||||||
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
|
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
|
||||||
|
|
||||||
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
||||||
return _ffmpeg_load_bytes(data, media_type=None)
|
return _ffmpeg_load_bytes(data)
|
||||||
|
|
||||||
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
||||||
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
|
return _ffmpeg_load_bytes(base64.b64decode(data))
|
||||||
|
|
||||||
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
||||||
return _ffmpeg_load_file(filepath)
|
return _ffmpeg_load_file(filepath)
|
||||||
@@ -96,17 +70,13 @@ _vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
|
|||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
from transformers import Qwen2Config, BatchFeature
|
from transformers import BatchFeature
|
||||||
from transformers.models.whisper import WhisperFeatureExtractor
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
||||||
from vllm.config import VllmConfig, ModelConfig
|
|
||||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.parse import MultiModalDataParser
|
from vllm.multimodal.parse import MultiModalDataParser
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
|
||||||
from vllm.inputs import PromptType
|
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
@@ -558,7 +528,7 @@ class VibeVoiceProcessingInfo(BaseProcessingInfo):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||||
return {"audio": None}
|
return {"audio": 1}
|
||||||
|
|
||||||
|
|
||||||
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
|
class VibeVoiceDummyInputsBuilder(BaseDummyInputsBuilder[VibeVoiceProcessingInfo]):
|
||||||
@@ -873,17 +843,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
with a causal language model for text generation.
|
with a causal language model for text generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# SupportsTranscription interface
|
|
||||||
supports_transcription: ClassVar[Literal[True]] = True
|
|
||||||
supports_transcription_only: ClassVar[bool] = False
|
|
||||||
supports_segment_timestamp: ClassVar[bool] = False
|
|
||||||
|
|
||||||
# Supported languages (Chinese as primary target)
|
|
||||||
supported_languages: ClassVar[Mapping[str, str]] = {
|
|
||||||
"zh": "Chinese",
|
|
||||||
"en": "English",
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||||
"""Return the placeholder string format for a given modality.
|
"""Return the placeholder string format for a given modality.
|
||||||
@@ -897,113 +856,11 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return "<|AUDIO|>"
|
return "<|AUDIO|>"
|
||||||
raise ValueError(f"Unsupported modality: {modality}")
|
raise ValueError(f"Unsupported modality: {modality}")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_generation_prompt(
|
|
||||||
cls,
|
|
||||||
audio: np.ndarray,
|
|
||||||
stt_config: SpeechToTextConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
language: str | None,
|
|
||||||
task_type: Literal["transcribe", "translate"],
|
|
||||||
request_prompt: str,
|
|
||||||
to_language: str | None,
|
|
||||||
) -> PromptType:
|
|
||||||
"""Get the prompt for the ASR model.
|
|
||||||
|
|
||||||
Generates a chat-formatted prompt for speech-to-text transcription
|
|
||||||
with JSON output format.
|
|
||||||
"""
|
|
||||||
# If user provides custom prompt, use it
|
|
||||||
if request_prompt:
|
|
||||||
return request_prompt
|
|
||||||
|
|
||||||
# Calculate audio duration for the prompt
|
|
||||||
# Audio should be at 24kHz, so duration = len(audio) / 24000
|
|
||||||
duration = len(audio) / 24000 if audio is not None else 10.0
|
|
||||||
|
|
||||||
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
|
||||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
|
||||||
user_suffix = (
|
|
||||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
|
||||||
+ ", ".join(show_keys)
|
|
||||||
)
|
|
||||||
|
|
||||||
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
|
|
||||||
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
|
|
||||||
user_content = "<|AUDIO|>\n" + user_suffix
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
|
||||||
f"<|im_start|>user\n{user_content}<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_speech_to_text_config(
|
|
||||||
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
|
|
||||||
) -> SpeechToTextConfig:
|
|
||||||
"""Get the speech to text config for the ASR model."""
|
|
||||||
return SpeechToTextConfig(
|
|
||||||
language=None, # Auto-detect or use request language
|
|
||||||
task_type=task_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_num_audio_tokens(
|
|
||||||
cls,
|
|
||||||
audio_duration_s: float,
|
|
||||||
stt_config: SpeechToTextConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
) -> int | None:
|
|
||||||
"""Estimate number of audio tokens from duration.
|
|
||||||
|
|
||||||
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
|
|
||||||
Note: _get_prompt_updates actually generates:
|
|
||||||
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
|
|
||||||
So total prompt tokens = N + 3, but this returns N (the embedding count).
|
|
||||||
"""
|
|
||||||
sampling_rate = 24000
|
|
||||||
compress_ratio = 3200
|
|
||||||
samples = int(audio_duration_s * sampling_rate)
|
|
||||||
num_tokens = int(np.ceil(samples / compress_ratio))
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_other_languages(cls) -> Mapping[str, str]:
|
|
||||||
"""Get languages from Whisper map not natively supported."""
|
|
||||||
# Import LANGUAGES from vllm
|
|
||||||
try:
|
|
||||||
from vllm.transformers_utils.tokenizer import LANGUAGES
|
|
||||||
except ImportError:
|
|
||||||
# Fallback to empty dict if import fails
|
|
||||||
return {}
|
|
||||||
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_language(cls, language: str | None) -> str | None:
|
|
||||||
"""Validate the language code."""
|
|
||||||
if language is None or language in cls.supported_languages:
|
|
||||||
return language
|
|
||||||
elif language in cls.get_other_languages():
|
|
||||||
print(f"Warning: Language {language!r} is not natively supported")
|
|
||||||
return language
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported language: {language!r}. "
|
|
||||||
f"Supported: {list(cls.supported_languages.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Keep a copy of the resolved model path for any custom weight-loading
|
|
||||||
# logic (e.g., loading audio encoder weights in fp32 directly from
|
|
||||||
# safetensors shards).
|
|
||||||
self._model_path = vllm_config.model_config.model
|
|
||||||
|
|
||||||
self.audio_encoder = VibeVoiceAudioEncoder(config)
|
self.audio_encoder = VibeVoiceAudioEncoder(config)
|
||||||
|
|
||||||
# Pass decoder_config to the language model initialization
|
# Pass decoder_config to the language model initialization
|
||||||
@@ -1100,14 +957,14 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# Process each audio through the VibeVoice encoder
|
# Process each audio through the VibeVoice encoder
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
# Get model device and dtype for alignment
|
# Get model device for tensor placement.
|
||||||
|
# dtype is NOT set here — audio_encoder.forward() handles it internally:
|
||||||
|
# input: converted to fp32 (self._audio_encoder_dtype)
|
||||||
|
# output: converted to bfloat16 (self._lm_dtype)
|
||||||
try:
|
try:
|
||||||
device = next(self.audio_encoder.parameters()).device
|
device = next(self.audio_encoder.parameters()).device
|
||||||
dtype = next(self.audio_encoder.parameters()).dtype
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Fallback if no parameters (shouldn't happen)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# Handle both stacked tensor and list of tensors
|
# Handle both stacked tensor and list of tensors
|
||||||
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
||||||
@@ -1138,7 +995,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if not isinstance(audio_tensor, torch.Tensor):
|
if not isinstance(audio_tensor, torch.Tensor):
|
||||||
audio_tensor = torch.tensor(audio_tensor)
|
audio_tensor = torch.tensor(audio_tensor)
|
||||||
|
|
||||||
# Let vLLM handle dtype (bfloat16 by default)
|
# Only place on correct device; audio_encoder.forward() handles dtype
|
||||||
audio_tensor = audio_tensor.to(device=device)
|
audio_tensor = audio_tensor.to(device=device)
|
||||||
|
|
||||||
# Get actual length if available, otherwise use full length
|
# Get actual length if available, otherwise use full length
|
||||||
@@ -1294,7 +1151,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
intermediate_tensors: Intermediate tensors for pipeline parallelism.
|
intermediate_tensors: Intermediate tensors for pipeline parallelism.
|
||||||
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
|
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
|
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
|
||||||
# Only compute from input_ids if inputs_embeds is not available
|
# Only compute from input_ids if inputs_embeds is not available
|
||||||
if inputs_embeds is None and input_ids is not None:
|
if inputs_embeds is None and input_ids is not None:
|
||||||
@@ -1321,9 +1177,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for training checkpoint compatibility
|
# Alias for training checkpoint compatibility
|
||||||
VibeVoiceForASRTraining = VibeVoiceForCausalLM
|
VibeVoiceForASRTraining = VibeVoiceForCausalLM
|
||||||
|
|||||||
Binary file not shown.
+101
-85
@@ -1,14 +1,23 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Test VibeVoice vLLM API with Streaming (Real-time output).
|
Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.
|
||||||
|
|
||||||
|
This script tests ASR transcription via the vLLM OpenAI-compatible API.
|
||||||
|
By default, it runs standard transcription without hotwords.
|
||||||
|
|
||||||
|
Optionally, you can provide hotwords (context_info) to improve recognition
|
||||||
|
of domain-specific content like proper nouns, technical terms, and speaker names.
|
||||||
|
Hotwords are embedded in the prompt as "with extra info: {hotwords}".
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python test_api.py [audio_path] [--url URL]
|
python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
python test_api.py # Use default audio
|
# Standard transcription (no hotwords)
|
||||||
python test_api.py /path/to/audio.wav # Specify audio file
|
python3 test_api.py audio.wav
|
||||||
python test_api.py /path/to/audio.mp3 --url http://localhost:8000 # Custom URL
|
|
||||||
|
# With hotwords for better recognition of specific terms
|
||||||
|
python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
|
||||||
"""
|
"""
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
@@ -21,38 +30,38 @@ import argparse
|
|||||||
|
|
||||||
|
|
||||||
def _guess_mime_type(path: str) -> str:
|
def _guess_mime_type(path: str) -> str:
|
||||||
|
"""Guess MIME type from file extension."""
|
||||||
ext = os.path.splitext(path)[1].lower()
|
ext = os.path.splitext(path)[1].lower()
|
||||||
if ext == ".wav":
|
mime_map = {
|
||||||
return "audio/wav"
|
".wav": "audio/wav",
|
||||||
if ext in (".mp3",):
|
".mp3": "audio/mpeg",
|
||||||
return "audio/mpeg"
|
".m4a": "audio/mp4",
|
||||||
if ext in (".m4a",):
|
".mp4": "video/mp4",
|
||||||
return "audio/mp4"
|
".flac": "audio/flac",
|
||||||
if ext in (".mp4", ".m4v", ".mov", ".webm"):
|
".ogg": "audio/ogg",
|
||||||
return "video/mp4"
|
".opus": "audio/ogg",
|
||||||
if ext in (".flac",):
|
}
|
||||||
return "audio/flac"
|
return mime_map.get(ext, "application/octet-stream")
|
||||||
if ext in (".ogg", ".opus"):
|
|
||||||
return "audio/ogg"
|
|
||||||
return "application/octet-stream"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_duration_seconds_ffprobe(path: str) -> float:
|
def _get_duration_seconds_ffprobe(path: str) -> float:
|
||||||
"""Get audio duration using ffprobe."""
|
"""Get audio duration using ffprobe."""
|
||||||
cmd = [
|
cmd = [
|
||||||
"ffprobe",
|
"ffprobe", "-v", "error",
|
||||||
"-v",
|
"-show_entries", "format=duration",
|
||||||
"error",
|
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||||
"-show_entries",
|
|
||||||
"format=duration",
|
|
||||||
"-of",
|
|
||||||
"default=noprint_wrappers=1:nokey=1",
|
|
||||||
path,
|
path,
|
||||||
]
|
]
|
||||||
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
|
out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
|
||||||
return float(out)
|
return float(out)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_video_file(path: str) -> bool:
|
||||||
|
"""Check if the file is a video file that needs audio extraction."""
|
||||||
|
ext = os.path.splitext(path)[1].lower()
|
||||||
|
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
||||||
|
|
||||||
|
|
||||||
def _extract_audio_from_video(video_path: str) -> str:
|
def _extract_audio_from_video(video_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
|
Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
|
||||||
@@ -74,26 +83,40 @@ def _extract_audio_from_video(video_path: str) -> str:
|
|||||||
return audio_path
|
return audio_path
|
||||||
|
|
||||||
|
|
||||||
def _is_video_file(path: str) -> bool:
|
def test_transcription_with_hotwords(
|
||||||
"""Check if the file is a video file that needs audio extraction."""
|
audio_path: str,
|
||||||
ext = os.path.splitext(path)[1].lower()
|
context_info: str = None,
|
||||||
return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")
|
base_url: str = "http://localhost:8000",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test ASR transcription with customized hotwords.
|
||||||
|
|
||||||
|
Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
|
||||||
|
This helps the model recognize domain-specific terms more accurately.
|
||||||
|
|
||||||
def test_transcription(audio_path: str, base_url: str = "http://localhost:8000"):
|
Args:
|
||||||
"""Test ASR transcription with streaming output."""
|
audio_path: Path to the audio file
|
||||||
|
context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
|
||||||
|
base_url: vLLM server URL
|
||||||
|
"""
|
||||||
|
|
||||||
print(f"Loading audio from: {audio_path}")
|
print(f"=" * 70)
|
||||||
|
print(f"Testing Customized Hotwords Support")
|
||||||
|
print(f"=" * 70)
|
||||||
|
print(f"Input file: {audio_path}")
|
||||||
|
print(f"Hotwords: {context_info or '(none)'}")
|
||||||
|
print()
|
||||||
|
|
||||||
# Handle video files: extract audio first
|
# Handle video files: extract audio first
|
||||||
temp_audio_path = None
|
temp_audio_path = None
|
||||||
actual_audio_path = audio_path
|
actual_audio_path = audio_path
|
||||||
if _is_video_file(audio_path):
|
if _is_video_file(audio_path):
|
||||||
print(f"Detected video file, extracting audio...")
|
print(f"🎬 Detected video file, extracting audio...")
|
||||||
temp_audio_path = _extract_audio_from_video(audio_path)
|
temp_audio_path = _extract_audio_from_video(audio_path)
|
||||||
actual_audio_path = temp_audio_path
|
actual_audio_path = temp_audio_path
|
||||||
print(f"Audio extracted to: {temp_audio_path}")
|
print(f"✅ Audio extracted to: {temp_audio_path}")
|
||||||
|
|
||||||
|
# Load audio
|
||||||
try:
|
try:
|
||||||
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
||||||
print(f"Audio duration: {duration:.2f} seconds")
|
print(f"Audio duration: {duration:.2f} seconds")
|
||||||
@@ -106,16 +129,30 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error preparing audio: {e}")
|
print(f"Error preparing audio: {e}")
|
||||||
|
# Cleanup temp file if created
|
||||||
|
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||||
|
os.remove(temp_audio_path)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Build the request
|
# Build the request
|
||||||
url = f"{base_url}/v1/chat/completions"
|
url = f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
||||||
|
|
||||||
|
# Build prompt with optional hotwords
|
||||||
|
# Hotwords are embedded as "with extra info: {hotwords}" in the prompt
|
||||||
|
if context_info and context_info.strip():
|
||||||
|
prompt_text = (
|
||||||
|
f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
|
||||||
|
f"Please transcribe it with these keys: " + ", ".join(show_keys)
|
||||||
|
)
|
||||||
|
print(f"\n📝 Hotwords embedded in prompt: '{context_info}'")
|
||||||
|
else:
|
||||||
prompt_text = (
|
prompt_text = (
|
||||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||||
+ ", ".join(show_keys)
|
+ ", ".join(show_keys)
|
||||||
)
|
)
|
||||||
|
print(f"\n📝 No hotwords provided")
|
||||||
|
|
||||||
mime = _guess_mime_type(actual_audio_path)
|
mime = _guess_mime_type(actual_audio_path)
|
||||||
data_url = f"data:{mime};base64,{audio_b64}"
|
data_url = f"data:{mime};base64,{audio_b64}"
|
||||||
@@ -139,20 +176,19 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
|
|||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
"repetition_penalty": 1.0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"\nSending request to {url} (Streaming Mode)...")
|
print(f"\n{'=' * 70}")
|
||||||
print(f"Prompt: {prompt_text}")
|
print(f"Sending request to {url}")
|
||||||
print("-" * 60)
|
print(f"{'=' * 70}")
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = requests.post(url, json=payload, stream=True, timeout=12000)
|
response = requests.post(url, json=payload, stream=True, timeout=12000)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print("Response received. Streaming content:\n")
|
print("\n✅ Response received. Streaming content:\n")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
printed = ""
|
printed = ""
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
@@ -162,92 +198,72 @@ def test_transcription(audio_path: str, base_url: str = "http://localhost:8000")
|
|||||||
if decoded_line.startswith("data: "):
|
if decoded_line.startswith("data: "):
|
||||||
json_str = decoded_line[6:]
|
json_str = decoded_line[6:]
|
||||||
if json_str.strip() == "[DONE]":
|
if json_str.strip() == "[DONE]":
|
||||||
print("\n\n[Finished]")
|
print("\n" + "-" * 50)
|
||||||
|
print("✅ [Finished]")
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_str)
|
data = json.loads(json_str)
|
||||||
|
|
||||||
delta = data['choices'][0]['delta']
|
delta = data['choices'][0]['delta']
|
||||||
content = delta.get('content', '')
|
content = delta.get('content', '')
|
||||||
if content:
|
if content:
|
||||||
|
|
||||||
# vLLM/OpenAI-compatible streams may emit either
|
|
||||||
# incremental deltas OR the full accumulated text.
|
|
||||||
# Only print the newly-added part to avoid repeats.
|
|
||||||
if content.startswith(printed):
|
if content.startswith(printed):
|
||||||
to_print = content[len(printed):]
|
to_print = content[len(printed):]
|
||||||
else:
|
else:
|
||||||
to_print = content
|
to_print = content
|
||||||
|
|
||||||
if to_print:
|
if to_print:
|
||||||
print(to_print, end='', flush=True)
|
print(to_print, end='', flush=True)
|
||||||
printed += to_print
|
printed += to_print
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
print(f"Error: {response.status_code}")
|
print(f"❌ Error: {response.status_code}")
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
print("\nRequest timed out!")
|
print("❌ Request timed out!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError: {e}")
|
print(f"❌ Error: {e}")
|
||||||
|
|
||||||
print(f"\n{'-'*60}")
|
elapsed = time.time() - t0
|
||||||
print(f"Total time elapsed: {time.time() - t0:.2f}s")
|
print(f"\n{'=' * 70}")
|
||||||
|
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||||
|
print(f"📊 RTF (Real-Time Factor): {elapsed / duration:.2f}x")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
|
||||||
# Cleanup temp audio file if created
|
# Cleanup temp audio file if created
|
||||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||||
os.remove(temp_audio_path)
|
os.remove(temp_audio_path)
|
||||||
print(f"Cleaned up temp file: {temp_audio_path}")
|
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Test VibeVoice vLLM API with streaming output"
|
description="Test VibeVoice vLLM API with Customized Hotwords"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"audio_path",
|
"audio_path",
|
||||||
nargs="?",
|
|
||||||
default=None,
|
|
||||||
help="Path to audio file (wav, mp3, flac, etc.) or video file"
|
help="Path to audio file (wav, mp3, flac, etc.) or video file"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--url",
|
"--url",
|
||||||
default="http://localhost:8000",
|
default="http://localhost:8000",
|
||||||
help="vLLM server base URL (default: http://localhost:8000)"
|
help="vLLM server URL (default: http://localhost:8000)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Find default audio if not specified
|
# Run test
|
||||||
audio_path = args.audio_path
|
test_transcription_with_hotwords(
|
||||||
if audio_path is None:
|
audio_path=args.audio_path,
|
||||||
# Try to find a sample audio in common locations
|
context_info=args.hotwords,
|
||||||
possible_paths = [
|
base_url=args.url,
|
||||||
# In VibeVoice demo folder
|
)
|
||||||
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "en-Carter_man.wav"),
|
|
||||||
os.path.join(os.path.dirname(__file__), "..", "..", "demo", "voices", "zh-Anchen_man_bgm.wav"),
|
|
||||||
# Relative to current directory
|
|
||||||
"demo/voices/en-Carter_man.wav",
|
|
||||||
"demo/voices/zh-Anchen_man_bgm.wav",
|
|
||||||
]
|
|
||||||
|
|
||||||
for path in possible_paths:
|
|
||||||
if os.path.exists(path):
|
|
||||||
audio_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if audio_path is None:
|
|
||||||
print("Error: No audio file specified and no default audio found.")
|
|
||||||
print("Usage: python test_api.py <audio_path>")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not os.path.exists(audio_path):
|
|
||||||
print(f"Error: Audio file not found: {audio_path}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
test_transcription(audio_path, args.url)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,18 +1,36 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
VibeVoice vLLM API with Auto-Recovery from Repetition Loops.
|
Test VibeVoice vLLM API with Streaming, Hotwords, and Auto-Recovery.
|
||||||
|
|
||||||
Strategy:
|
This script tests ASR transcription with automatic recovery from repetition loops.
|
||||||
1. Start with greedy decoding (temperature=0, top_p=1.0)
|
Supports optional hotwords to improve recognition of domain-specific terms.
|
||||||
2. Stream and detect repetition patterns in real-time
|
|
||||||
3. Only output content up to (current_length - window_size) at segment boundaries
|
|
||||||
4. When loop detected:
|
|
||||||
- Truncate to last complete segment boundary (},)
|
|
||||||
- Recovery with temperature=0.2/0.3/0.4 for retry 1/2/3, top_p=0.95
|
|
||||||
5. Max 3 retries, if all fail output error message
|
|
||||||
|
|
||||||
User sees: clean streaming transcription output (only complete segments)
|
Features:
|
||||||
Internal: automatic recovery from repetition loops (silent)
|
- Streaming output with real-time repetition detection
|
||||||
|
- Auto-recovery when model enters repetition loops
|
||||||
|
- Optional hotwords support (embedded in prompt as "with extra info: {hotwords}")
|
||||||
|
- Video file support (auto-extracts audio)
|
||||||
|
|
||||||
|
Recovery Strategy:
|
||||||
|
1. First attempt: greedy decoding (temperature=0, top_p=1.0)
|
||||||
|
2. If loop detected: retry with temperature=0.2/0.3/0.4, top_p=0.95
|
||||||
|
3. Max 3 retries, truncate to last complete segment boundary
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_api_auto_recover.py <audio_path> [output_path] [--url URL] [--hotwords "word1,word2"] [--debug]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Basic usage
|
||||||
|
python3 test_api_auto_recover.py audio.wav
|
||||||
|
|
||||||
|
# With hotwords
|
||||||
|
python3 test_api_auto_recover.py audio.wav --hotwords "Microsoft,VibeVoice"
|
||||||
|
|
||||||
|
# Save result to file
|
||||||
|
python3 test_api_auto_recover.py audio.wav result.txt
|
||||||
|
|
||||||
|
# Debug mode (show recovery info)
|
||||||
|
python3 test_api_auto_recover.py audio.wav --debug
|
||||||
"""
|
"""
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
@@ -22,6 +40,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import re
|
import re
|
||||||
|
import argparse
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
@@ -441,30 +460,41 @@ def stream_with_recovery(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def test_transcription_with_recovery():
|
def test_transcription_with_recovery(
|
||||||
"""Main test function with auto-recovery."""
|
audio_path: str,
|
||||||
|
output_path: str = None,
|
||||||
|
base_url: str = "http://localhost:8000",
|
||||||
|
hotwords: str = None,
|
||||||
|
debug: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test ASR transcription with auto-recovery from repetition loops.
|
||||||
|
|
||||||
# Parse arguments
|
Args:
|
||||||
debug = "--debug" in sys.argv or "-debug" in sys.argv
|
audio_path: Path to the audio file
|
||||||
args = [a for a in sys.argv[1:] if not a.startswith("-")]
|
output_path: Optional path to save transcription result
|
||||||
|
base_url: vLLM server URL
|
||||||
|
hotwords: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
|
||||||
|
debug: Show recovery debug info
|
||||||
|
"""
|
||||||
|
|
||||||
audio_path = (
|
print(f"=" * 70)
|
||||||
args[0]
|
print(f"Testing with Auto-Recovery")
|
||||||
)
|
print(f"=" * 70)
|
||||||
|
print(f"Input file: {audio_path}")
|
||||||
output_path = args[1] if len(args) > 1 else None
|
print(f"Hotwords: {hotwords or '(none)'}")
|
||||||
|
print()
|
||||||
print(f"Loading audio from: {audio_path}")
|
|
||||||
|
|
||||||
# Handle video files: extract audio first
|
# Handle video files: extract audio first
|
||||||
temp_audio_path = None
|
temp_audio_path = None
|
||||||
actual_audio_path = audio_path
|
actual_audio_path = audio_path
|
||||||
if _is_video_file(audio_path):
|
if _is_video_file(audio_path):
|
||||||
print(f"Detected video file, extracting audio...")
|
print(f"🎬 Detected video file, extracting audio...")
|
||||||
temp_audio_path = _extract_audio_from_video(audio_path)
|
temp_audio_path = _extract_audio_from_video(audio_path)
|
||||||
actual_audio_path = temp_audio_path
|
actual_audio_path = temp_audio_path
|
||||||
print(f"Audio extracted to: {temp_audio_path}")
|
print(f"✅ Audio extracted to: {temp_audio_path}")
|
||||||
|
|
||||||
|
# Load audio
|
||||||
try:
|
try:
|
||||||
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
duration = _get_duration_seconds_ffprobe(actual_audio_path)
|
||||||
print(f"Audio duration: {duration:.2f} seconds")
|
print(f"Audio duration: {duration:.2f} seconds")
|
||||||
@@ -476,16 +506,29 @@ def test_transcription_with_recovery():
|
|||||||
print(f"Audio size: {len(audio_bytes)} bytes")
|
print(f"Audio size: {len(audio_bytes)} bytes")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error preparing audio: {e}")
|
print(f"❌ Error preparing audio: {e}")
|
||||||
|
# Cleanup temp file if created
|
||||||
|
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||||
|
os.remove(temp_audio_path)
|
||||||
return
|
return
|
||||||
|
|
||||||
url = "http://localhost:8000/v1/chat/completions"
|
url = f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
||||||
|
|
||||||
|
# Build prompt with optional hotwords
|
||||||
|
if hotwords and hotwords.strip():
|
||||||
|
prompt_text = (
|
||||||
|
f"This is a {duration:.2f} seconds audio, with extra info: {hotwords.strip()}\n\n"
|
||||||
|
f"Please transcribe it with these keys: " + ", ".join(show_keys)
|
||||||
|
)
|
||||||
|
print(f"\n📝 Hotwords embedded in prompt: '{hotwords}'")
|
||||||
|
else:
|
||||||
prompt_text = (
|
prompt_text = (
|
||||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
||||||
+ ", ".join(show_keys)
|
+ ", ".join(show_keys)
|
||||||
)
|
)
|
||||||
|
print(f"\n📝 No hotwords provided")
|
||||||
|
|
||||||
mime = _guess_mime_type(actual_audio_path)
|
mime = _guess_mime_type(actual_audio_path)
|
||||||
data_url = f"data:{mime};base64,{audio_b64}"
|
data_url = f"data:{mime};base64,{audio_b64}"
|
||||||
@@ -505,12 +548,13 @@ def test_transcription_with_recovery():
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
print(f"\nSending request to {url} (Streaming Mode)...")
|
print(f"\n{'=' * 70}")
|
||||||
print(f"Prompt: {prompt_text}")
|
print(f"Sending request to {url}")
|
||||||
print("-" * 60)
|
print(f"{'=' * 70}")
|
||||||
print("Response received. Streaming content:\n")
|
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
print("\n✅ Response received. Streaming content:\n")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
result = stream_with_recovery(
|
result = stream_with_recovery(
|
||||||
url=url,
|
url=url,
|
||||||
@@ -522,27 +566,73 @@ def test_transcription_with_recovery():
|
|||||||
debug=debug,
|
debug=debug,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n[Finished]")
|
elapsed = time.time() - t0
|
||||||
print("-" * 60)
|
print("-" * 50)
|
||||||
print(f"Total time elapsed: {time.time() - t0:.2f}s")
|
print("✅ [Finished]")
|
||||||
|
print(f"\n{'=' * 70}")
|
||||||
|
print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
print("Transcription failed")
|
print("❌ Transcription failed")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Final output length: {len(result)} chars")
|
print(f"📄 Final output length: {len(result)} chars")
|
||||||
|
|
||||||
# Optionally save result
|
# Optionally save result
|
||||||
if output_path:
|
if output_path:
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
f.write(result)
|
f.write(result)
|
||||||
print(f"Result saved to: {output_path}")
|
print(f"💾 Result saved to: {output_path}")
|
||||||
|
|
||||||
# Cleanup temp audio file if created
|
# Cleanup temp audio file if created
|
||||||
if temp_audio_path and os.path.exists(temp_audio_path):
|
if temp_audio_path and os.path.exists(temp_audio_path):
|
||||||
os.remove(temp_audio_path)
|
os.remove(temp_audio_path)
|
||||||
print(f"Cleaned up temp file: {temp_audio_path}")
|
print(f"🗑️ Cleaned up temp file: {temp_audio_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Test VibeVoice vLLM API with auto-recovery from repetition loops"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"audio_path",
|
||||||
|
help="Path to audio file (wav, mp3, flac, etc.) or video file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"output_path",
|
||||||
|
nargs="?",
|
||||||
|
default=None,
|
||||||
|
help="Optional path to save transcription result"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--url",
|
||||||
|
default="http://localhost:8000",
|
||||||
|
help="vLLM server URL (default: http://localhost:8000)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Show recovery debug info"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run test
|
||||||
|
test_transcription_with_recovery(
|
||||||
|
audio_path=args.audio_path,
|
||||||
|
output_path=args.output_path,
|
||||||
|
base_url=args.url,
|
||||||
|
hotwords=args.hotwords,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_transcription_with_recovery()
|
main()
|
||||||
|
|||||||
Binary file not shown.
Reference in New Issue
Block a user