fix
This commit is contained in:
@@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
|||||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast
|
||||||
|
|
||||||
from .model import VibeVoiceForCausalLM
|
from .model import VibeVoiceForCausalLM
|
||||||
from .inputs import vibevoice_audio_input_mapper
|
|
||||||
|
|
||||||
|
|
||||||
def register_vibevoice():
|
def register_vibevoice():
|
||||||
|
|||||||
+37
-184
@@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr
|
|||||||
integration for speech-to-text inference.
|
integration for speech-to-text inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal
|
from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from io import BytesIO
|
|
||||||
import tempfile
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
@@ -29,32 +23,12 @@ import base64
|
|||||||
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
|
from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer
|
||||||
|
|
||||||
|
|
||||||
def _suffix_from_media_type(media_type: str | None) -> str:
|
def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]:
|
||||||
if not media_type:
|
"""Load audio bytes using FFmpeg via stdin-pipe decoding.
|
||||||
return ".bin"
|
|
||||||
mt = media_type.lower().strip()
|
|
||||||
if mt in ("audio/wav", "audio/x-wav", "audio/wave"):
|
|
||||||
return ".wav"
|
|
||||||
if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"):
|
|
||||||
return ".mp3"
|
|
||||||
if mt in ("audio/flac",):
|
|
||||||
return ".flac"
|
|
||||||
if mt in ("audio/ogg", "audio/opus"):
|
|
||||||
return ".ogg"
|
|
||||||
if mt in ("audio/mp4", "audio/m4a"):
|
|
||||||
return ".m4a"
|
|
||||||
if mt in ("video/mp4",):
|
|
||||||
return ".mp4"
|
|
||||||
return ".bin"
|
|
||||||
|
|
||||||
|
|
||||||
def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]:
|
|
||||||
"""Load audio bytes using FFmpeg.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
|
Tuple of (audio_waveform, sample_rate). Sample rate is always 24000.
|
||||||
"""
|
"""
|
||||||
# Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency.
|
|
||||||
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
|
audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000)
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
audio = normalizer(audio)
|
audio = normalizer(audio)
|
||||||
@@ -79,10 +53,10 @@ class _PatchedAudioMediaIO(_OriginalAudioMediaIO):
|
|||||||
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
|
"""AudioMediaIO implementation using FFmpeg for audio decoding."""
|
||||||
|
|
||||||
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]:
|
||||||
return _ffmpeg_load_bytes(data, media_type=None)
|
return _ffmpeg_load_bytes(data)
|
||||||
|
|
||||||
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]:
|
||||||
return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type)
|
return _ffmpeg_load_bytes(base64.b64decode(data))
|
||||||
|
|
||||||
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
def load_file(self, filepath) -> tuple[np.ndarray, int]:
|
||||||
return _ffmpeg_load_file(filepath)
|
return _ffmpeg_load_file(filepath)
|
||||||
@@ -96,17 +70,13 @@ _vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO
|
|||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
from transformers import Qwen2Config, BatchFeature
|
from transformers import BatchFeature
|
||||||
from transformers.models.whisper import WhisperFeatureExtractor
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
||||||
from vllm.config import VllmConfig, ModelConfig
|
|
||||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.parse import MultiModalDataParser
|
from vllm.multimodal.parse import MultiModalDataParser
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings
|
||||||
from vllm.inputs import PromptType
|
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
@@ -873,17 +843,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
with a causal language model for text generation.
|
with a causal language model for text generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# SupportsTranscription interface
|
|
||||||
supports_transcription: ClassVar[Literal[True]] = True
|
|
||||||
supports_transcription_only: ClassVar[bool] = False
|
|
||||||
supports_segment_timestamp: ClassVar[bool] = False
|
|
||||||
|
|
||||||
# Supported languages (Chinese as primary target)
|
|
||||||
supported_languages: ClassVar[Mapping[str, str]] = {
|
|
||||||
"zh": "Chinese",
|
|
||||||
"en": "English",
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||||
"""Return the placeholder string format for a given modality.
|
"""Return the placeholder string format for a given modality.
|
||||||
@@ -897,112 +856,10 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return "<|AUDIO|>"
|
return "<|AUDIO|>"
|
||||||
raise ValueError(f"Unsupported modality: {modality}")
|
raise ValueError(f"Unsupported modality: {modality}")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_generation_prompt(
|
|
||||||
cls,
|
|
||||||
audio: np.ndarray,
|
|
||||||
stt_config: SpeechToTextConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
language: str | None,
|
|
||||||
task_type: Literal["transcribe", "translate"],
|
|
||||||
request_prompt: str,
|
|
||||||
to_language: str | None,
|
|
||||||
) -> PromptType:
|
|
||||||
"""Get the prompt for the ASR model.
|
|
||||||
|
|
||||||
Generates a chat-formatted prompt for speech-to-text transcription
|
|
||||||
with JSON output format.
|
|
||||||
"""
|
|
||||||
# If user provides custom prompt, use it
|
|
||||||
if request_prompt:
|
|
||||||
return request_prompt
|
|
||||||
|
|
||||||
# Calculate audio duration for the prompt
|
|
||||||
# Audio should be at 24kHz, so duration = len(audio) / 24000
|
|
||||||
duration = len(audio) / 24000 if audio is not None else 10.0
|
|
||||||
|
|
||||||
system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format."
|
|
||||||
show_keys = ["Start time", "End time", "Speaker ID", "Content"]
|
|
||||||
user_suffix = (
|
|
||||||
f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
|
|
||||||
+ ", ".join(show_keys)
|
|
||||||
)
|
|
||||||
|
|
||||||
# IMPORTANT: keep <|AUDIO|> as the only placeholder token here.
|
|
||||||
# `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders.
|
|
||||||
user_content = "<|AUDIO|>\n" + user_suffix
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
|
||||||
f"<|im_start|>user\n{user_content}<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
|
||||||
)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_speech_to_text_config(
|
|
||||||
cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
|
|
||||||
) -> SpeechToTextConfig:
|
|
||||||
"""Get the speech to text config for the ASR model."""
|
|
||||||
return SpeechToTextConfig(
|
|
||||||
language=None, # Auto-detect or use request language
|
|
||||||
task_type=task_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_num_audio_tokens(
|
|
||||||
cls,
|
|
||||||
audio_duration_s: float,
|
|
||||||
stt_config: SpeechToTextConfig,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
) -> int | None:
|
|
||||||
"""Estimate number of audio tokens from duration.
|
|
||||||
|
|
||||||
Returns the number of audio EMBEDDING positions (speech_pad_id tokens).
|
|
||||||
Note: _get_prompt_updates actually generates:
|
|
||||||
[speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id]
|
|
||||||
So total prompt tokens = N + 3, but this returns N (the embedding count).
|
|
||||||
"""
|
|
||||||
sampling_rate = 24000
|
|
||||||
compress_ratio = 3200
|
|
||||||
samples = int(audio_duration_s * sampling_rate)
|
|
||||||
num_tokens = int(np.ceil(samples / compress_ratio))
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_other_languages(cls) -> Mapping[str, str]:
|
|
||||||
"""Get languages from Whisper map not natively supported."""
|
|
||||||
# Import LANGUAGES from vllm
|
|
||||||
try:
|
|
||||||
from vllm.transformers_utils.tokenizer import LANGUAGES
|
|
||||||
except ImportError:
|
|
||||||
# Fallback to empty dict if import fails
|
|
||||||
return {}
|
|
||||||
return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_language(cls, language: str | None) -> str | None:
|
|
||||||
"""Validate the language code."""
|
|
||||||
if language is None or language in cls.supported_languages:
|
|
||||||
return language
|
|
||||||
elif language in cls.get_other_languages():
|
|
||||||
print(f"Warning: Language {language!r} is not natively supported")
|
|
||||||
return language
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported language: {language!r}. "
|
|
||||||
f"Supported: {list(cls.supported_languages.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Keep a copy of the resolved model path for any custom weight-loading
|
|
||||||
# logic (e.g., loading audio encoder weights in fp32 directly from
|
|
||||||
# safetensors shards).
|
|
||||||
self._model_path = vllm_config.model_config.model
|
|
||||||
|
|
||||||
self.audio_encoder = VibeVoiceAudioEncoder(config)
|
self.audio_encoder = VibeVoiceAudioEncoder(config)
|
||||||
|
|
||||||
@@ -1100,14 +957,14 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# Process each audio through the VibeVoice encoder
|
# Process each audio through the VibeVoice encoder
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
# Get model device and dtype for alignment
|
# Get model device for tensor placement.
|
||||||
|
# dtype is NOT set here — audio_encoder.forward() handles it internally:
|
||||||
|
# input: converted to fp32 (self._audio_encoder_dtype)
|
||||||
|
# output: converted to bfloat16 (self._lm_dtype)
|
||||||
try:
|
try:
|
||||||
device = next(self.audio_encoder.parameters()).device
|
device = next(self.audio_encoder.parameters()).device
|
||||||
dtype = next(self.audio_encoder.parameters()).dtype
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Fallback if no parameters (shouldn't happen)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# Handle both stacked tensor and list of tensors
|
# Handle both stacked tensor and list of tensors
|
||||||
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
||||||
@@ -1138,7 +995,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if not isinstance(audio_tensor, torch.Tensor):
|
if not isinstance(audio_tensor, torch.Tensor):
|
||||||
audio_tensor = torch.tensor(audio_tensor)
|
audio_tensor = torch.tensor(audio_tensor)
|
||||||
|
|
||||||
# Let vLLM handle dtype (bfloat16 by default)
|
# Only place on correct device; audio_encoder.forward() handles dtype
|
||||||
audio_tensor = audio_tensor.to(device=device)
|
audio_tensor = audio_tensor.to(device=device)
|
||||||
|
|
||||||
# Get actual length if available, otherwise use full length
|
# Get actual length if available, otherwise use full length
|
||||||
@@ -1294,35 +1151,31 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
intermediate_tensors: Intermediate tensors for pipeline parallelism.
|
intermediate_tensors: Intermediate tensors for pipeline parallelism.
|
||||||
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
|
inputs_embeds: Pre-computed embeddings (from multimodal merge or decode).
|
||||||
"""
|
"""
|
||||||
try:
|
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
|
||||||
# PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode)
|
# Only compute from input_ids if inputs_embeds is not available
|
||||||
# Only compute from input_ids if inputs_embeds is not available
|
if inputs_embeds is None and input_ids is not None:
|
||||||
if inputs_embeds is None and input_ids is not None:
|
# Compute embeddings from input_ids
|
||||||
# Compute embeddings from input_ids
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
|
||||||
# If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds
|
if intermediate_tensors is not None:
|
||||||
if intermediate_tensors is not None:
|
inputs_embeds = None
|
||||||
inputs_embeds = None
|
|
||||||
|
# Get the inner model - handle both wrapped and direct language models
|
||||||
# Get the inner model - handle both wrapped and direct language models
|
language_model = self.language_model
|
||||||
language_model = self.language_model
|
if hasattr(language_model, "language_model"):
|
||||||
if hasattr(language_model, "language_model"):
|
language_model = language_model.language_model
|
||||||
language_model = language_model.language_model
|
|
||||||
|
# Call the language model's model (Qwen2Model)
|
||||||
# Call the language model's model (Qwen2Model)
|
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
|
||||||
# vLLM V1 passes kv_caches and attn_metadata via context, not arguments
|
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
|
||||||
# IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding
|
hidden_states = language_model.model(
|
||||||
hidden_states = language_model.model(
|
input_ids=None, # Always None when we have inputs_embeds
|
||||||
input_ids=None, # Always None when we have inputs_embeds
|
positions=positions,
|
||||||
positions=positions,
|
intermediate_tensors=intermediate_tensors,
|
||||||
intermediate_tensors=intermediate_tensors,
|
inputs_embeds=inputs_embeds
|
||||||
inputs_embeds=inputs_embeds
|
)
|
||||||
)
|
return hidden_states
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# Alias for training checkpoint compatibility
|
# Alias for training checkpoint compatibility
|
||||||
|
|||||||
Reference in New Issue
Block a user