This commit is contained in:
YingboHAO
2026-02-06 14:38:16 +00:00
parent 7761242bf3
commit 0508c3e86f
2 changed files with 37 additions and 185 deletions
-1
View File
@@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
from .model import VibeVoiceForCausalLM
from .inputs import vibevoice_audio_input_mapper
def register_vibevoice():
+37 -184
View File
@@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
integration for speech-to-text inference.
"""
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
import json
import math
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
from io import BytesIO
import tempfile
import base64
@@ -29,32 +23,12 @@ import base64
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
def _suffix_from_media_type(media_type: str | None) -> str:
if not media_type:
return ".bin"
mt = media_type.lower().strip()
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
return ".wav"
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
return ".mp3"
if mt in ("audio/flac",):
return ".flac"
if mt in ("audio/ogg", "audio/opus"):
return ".ogg"
if mt in ("audio/mp4", "audio/m4a"):
return ".m4a"
if mt in ("video/mp4",):
return ".mp4"
return ".bin"
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg.
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
Returns:
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
"""
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
normalizer = AudioNormalizer()
audio = normalizer(audio)
@@ -79,10 +53,10 @@ class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(data, media_type=None)
return _ffmpeg_load_bytes(data)
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
return _ffmpeg_load_bytes(base64.b64decode(data))
def load_file(self, filepath) -> tuple[np.ndarray, int]:
return _ffmpeg_load_file(filepath)
@@ -96,17 +70,13 @@ _vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
# ============================================================================
from transformers import Qwen2Config, BatchFeature
from transformers import BatchFeature
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.config import VllmConfig, ModelConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import MultiModalDataParser
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
from vllm.inputs import PromptType
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
maybe_prefix,
@@ -873,17 +843,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with a causal language model for text generation.
"""
# SupportsTranscription interface
supports_transcription: ClassVar[Literal[True]] = True
supports_transcription_only: ClassVar[bool] = False
supports_segment_timestamp: ClassVar[bool] = False
# Supported languages (Chinese as primary target)
supported_languages: ClassVar[Mapping[str, str]] = {
"zh": "Chinese",
"en": "English",
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""Return the placeholder string format for a given modality.
@@ -897,112 +856,10 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return "<|AUDIO|>"
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the prompt for the ASR model.
Generates a chat-formatted prompt for speech-to-text transcription
with JSON output format.
"""
# If user provides custom prompt, use it
if request_prompt:
return request_prompt
# Calculate audio duration for the prompt
# Audio should be at 24kHz, so duration = len(audio) / 24000
duration = len(audio) / 24000 if audio is not None else 10.0
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
user_suffix = (
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
+ ", ".join(show_keys)
)
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
user_content = "<|AUDIO|>\n" + user_suffix
prompt = (
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
f"<|im_start|>user\n{user_content}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return prompt
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
) -> SpeechToTextConfig:
"""Get the speech to text config for the ASR model."""
return SpeechToTextConfig(
language=None, # Auto-detect or use request language
task_type=task_type,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
"""Estimate number of audio tokens from duration.
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
Note: _get_prompt_updates actually generates:
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
So total prompt tokens = N + 3, but this returns N (the embedding count).
"""
sampling_rate = 24000
compress_ratio = 3200
samples = int(audio_duration_s * sampling_rate)
num_tokens = int(np.ceil(samples / compress_ratio))
return num_tokens
@classmethod
def get_other_languages(cls) -> Mapping[str, str]:
"""Get languages from Whisper map not natively supported."""
# Import LANGUAGES from vllm
try:
from vllm.transformers_utils.tokenizer import LANGUAGES
except ImportError:
# Fallback to empty dict if import fails
return {}
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
@classmethod
def validate_language(cls, language: str | None) -> str | None:
"""Validate the language code."""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
print(f"Warning: Language {language!r} is not natively supported")
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. "
f"Supported: {list(cls.supported_languages.keys())}"
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
# Keep a copy of the resolved model path for any custom weight-loading
# logic (e.g., loading audio encoder weights in fp32 directly from
# safetensors shards).
self._model_path = vllm_config.model_config.model
self.audio_encoder = VibeVoiceAudioEncoder(config)
@@ -1100,14 +957,14 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Process each audio through the VibeVoice encoder
embeddings = []
# Get model device and dtype for alignment
# Get model device for tensor placement.
# dtype is NOT set here — audio_encoder.forward() handles it internally:
# input: converted to fp32 (self._audio_encoder_dtype)
# output: converted to bfloat16 (self._lm_dtype)
try:
device = next(self.audio_encoder.parameters()).device
dtype = next(self.audio_encoder.parameters()).dtype
except StopIteration:
# Fallback if no parameters (shouldn't happen)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Handle both stacked tensor and list of tensors
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
@@ -1138,7 +995,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor)
# Let vLLM handle dtype (bfloat16 by default)
# Only place on correct device; audio_encoder.forward() handles dtype
audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length
@@ -1294,35 +1151,31 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Intermediate tensors for pipeline parallelism.
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
"""
try:
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None:
# Compute embeddings from input_ids
inputs_embeds = self.get_input_embeddings()(input_ids)
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
if intermediate_tensors is not None:
inputs_embeds = None
# Get the inner model - handle both wrapped and direct language models
language_model = self.language_model
if hasattr(language_model, "language_model"):
language_model = language_model.language_model
# Call the language model's model (Qwen2Model)
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
hidden_states = language_model.model(
input_ids=None, # Always None when we have inputs_embeds
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds
)
return hidden_states
except Exception as e:
raise
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
# Only compute from input_ids if inputs_embeds is not available
if inputs_embeds is None and input_ids is not None:
# Compute embeddings from input_ids
inputs_embeds = self.get_input_embeddings()(input_ids)
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
if intermediate_tensors is not None:
inputs_embeds = None
# Get the inner model - handle both wrapped and direct language models
language_model = self.language_model
if hasattr(language_model, "language_model"):
language_model = language_model.language_model
# Call the language model's model (Qwen2Model)
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
hidden_states = language_model.model(
input_ids=None, # Always None when we have inputs_embeds
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds
)
return hidden_states
# Alias for training checkpoint compatibility