From 0508c3e86fd373505eda021ec0ff7275912cac22 Mon Sep 17 00:00:00 2001 From: YingboHAO <3259482542@qq.com> Date: Fri, 6 Feb 2026 14:38:16 +0000 Subject: [PATCH] fix --- vllm_plugin/__init__.py | 1 - vllm_plugin/model.py | 221 +++++++--------------------------------- 2 files changed, 37 insertions(+), 185 deletions(-) diff --git a/vllm_plugin/__init__.py b/vllm_plugin/__init__.py index b696a45..989acb4 100644 --- a/vllm_plugin/__init__.py +++ b/vllm_plugin/__init__.py @@ -15,7 +15,6 @@ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceASRTextTokenizerFast from .model import VibeVoiceForCausalLM -from .inputs import vibevoice_audio_input_mapper def register_vibevoice(): diff --git a/vllm_plugin/model.py b/vllm_plugin/model.py index 178fb74..2ca3cd2 100644 --- a/vllm_plugin/model.py +++ b/vllm_plugin/model.py @@ -5,17 +5,11 @@ This module implements the VibeVoice ASR model with full vLLM multimodal registr integration for speech-to-text inference. """ -from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence, ClassVar, Literal -import json -import math +from typing import List, Optional, Tuple, Union, Dict, Any, Iterable, Mapping, Sequence import os -import sys -from pathlib import Path import torch import torch.nn as nn import numpy as np -from io import BytesIO -import tempfile import base64 @@ -29,32 +23,12 @@ import base64 from vibevoice.processor.audio_utils import load_audio_use_ffmpeg, load_audio_bytes_use_ffmpeg, AudioNormalizer -def _suffix_from_media_type(media_type: str | None) -> str: - if not media_type: - return ".bin" - mt = media_type.lower().strip() - if mt in ("audio/wav", "audio/x-wav", "audio/wave"): - return ".wav" - if mt in ("audio/mpeg", "audio/mp3", "audio/x-mp3"): - return ".mp3" - if mt in ("audio/flac",): - return ".flac" - if mt in ("audio/ogg", "audio/opus"): - return ".ogg" - if mt in ("audio/mp4", "audio/m4a"): - return ".m4a" - if mt in ("video/mp4",): - return ".mp4" - return ".bin" - - -def _ffmpeg_load_bytes(data: bytes, *, media_type: str | None = None) -> tuple[np.ndarray, int]: - """Load audio bytes using FFmpeg. +def _ffmpeg_load_bytes(data: bytes) -> tuple[np.ndarray, int]: + """Load audio bytes using FFmpeg via stdin-pipe decoding. Returns: Tuple of (audio_waveform, sample_rate). Sample rate is always 24000. """ - # Prefer stdin-pipe decoding to avoid temp-file IO under high concurrency. audio, sr = load_audio_bytes_use_ffmpeg(data, resample=True, target_sr=24000) normalizer = AudioNormalizer() audio = normalizer(audio) @@ -79,10 +53,10 @@ class _PatchedAudioMediaIO(_OriginalAudioMediaIO): """AudioMediaIO implementation using FFmpeg for audio decoding.""" def load_bytes(self, data: bytes) -> tuple[np.ndarray, int]: - return _ffmpeg_load_bytes(data, media_type=None) + return _ffmpeg_load_bytes(data) def load_base64(self, media_type: str, data: str) -> tuple[np.ndarray, int]: - return _ffmpeg_load_bytes(base64.b64decode(data), media_type=media_type) + return _ffmpeg_load_bytes(base64.b64decode(data)) def load_file(self, filepath) -> tuple[np.ndarray, int]: return _ffmpeg_load_file(filepath) @@ -96,17 +70,13 @@ _vllm_utils_module.AudioMediaIO = _PatchedAudioMediaIO # ============================================================================ -from transformers import Qwen2Config, BatchFeature +from transformers import BatchFeature from transformers.models.whisper import WhisperFeatureExtractor -from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.config import VllmConfig, ModelConfig -from vllm.config.speech_to_text import SpeechToTextConfig +from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.parse import MultiModalDataParser from vllm.sequence import IntermediateTensors from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP, MultiModalEmbeddings -from vllm.inputs import PromptType from vllm.model_executor.models.utils import ( init_vllm_registered_model, maybe_prefix, @@ -873,17 +843,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): with a causal language model for text generation. """ - # SupportsTranscription interface - supports_transcription: ClassVar[Literal[True]] = True - supports_transcription_only: ClassVar[bool] = False - supports_segment_timestamp: ClassVar[bool] = False - - # Supported languages (Chinese as primary target) - supported_languages: ClassVar[Mapping[str, str]] = { - "zh": "Chinese", - "en": "English", - } - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: """Return the placeholder string format for a given modality. @@ -897,112 +856,10 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): return "<|AUDIO|>" raise ValueError(f"Unsupported modality: {modality}") - @classmethod - def get_generation_prompt( - cls, - audio: np.ndarray, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - language: str | None, - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: str | None, - ) -> PromptType: - """Get the prompt for the ASR model. - - Generates a chat-formatted prompt for speech-to-text transcription - with JSON output format. - """ - # If user provides custom prompt, use it - if request_prompt: - return request_prompt - - # Calculate audio duration for the prompt - # Audio should be at 24kHz, so duration = len(audio) / 24000 - duration = len(audio) / 24000 if audio is not None else 10.0 - - system_prompt = "You are a helpful assistant that transcribes audio input into text output in JSON format." - show_keys = ["Start time", "End time", "Speaker ID", "Content"] - user_suffix = ( - f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " - + ", ".join(show_keys) - ) - - # IMPORTANT: keep <|AUDIO|> as the only placeholder token here. - # `_get_prompt_updates` expands it into repeated `<|AUDIO|>` placeholders. - user_content = "<|AUDIO|>\n" + user_suffix - - prompt = ( - f"<|im_start|>system\n{system_prompt}<|im_end|>\n" - f"<|im_start|>user\n{user_content}<|im_end|>\n" - "<|im_start|>assistant\n" - ) - return prompt - - @classmethod - def get_speech_to_text_config( - cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"] - ) -> SpeechToTextConfig: - """Get the speech to text config for the ASR model.""" - return SpeechToTextConfig( - language=None, # Auto-detect or use request language - task_type=task_type, - ) - - @classmethod - def get_num_audio_tokens( - cls, - audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - ) -> int | None: - """Estimate number of audio tokens from duration. - - Returns the number of audio EMBEDDING positions (speech_pad_id tokens). - Note: _get_prompt_updates actually generates: - [speech_start_id] + [speech_pad_id] * N + [speech_end_id] + [newline_id] - So total prompt tokens = N + 3, but this returns N (the embedding count). - """ - sampling_rate = 24000 - compress_ratio = 3200 - samples = int(audio_duration_s * sampling_rate) - num_tokens = int(np.ceil(samples / compress_ratio)) - return num_tokens - - @classmethod - def get_other_languages(cls) -> Mapping[str, str]: - """Get languages from Whisper map not natively supported.""" - # Import LANGUAGES from vllm - try: - from vllm.transformers_utils.tokenizer import LANGUAGES - except ImportError: - # Fallback to empty dict if import fails - return {} - return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages} - - @classmethod - def validate_language(cls, language: str | None) -> str | None: - """Validate the language code.""" - if language is None or language in cls.supported_languages: - return language - elif language in cls.get_other_languages(): - print(f"Warning: Language {language!r} is not natively supported") - return language - else: - raise ValueError( - f"Unsupported language: {language!r}. " - f"Supported: {list(cls.supported_languages.keys())}" - ) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config - - # Keep a copy of the resolved model path for any custom weight-loading - # logic (e.g., loading audio encoder weights in fp32 directly from - # safetensors shards). - self._model_path = vllm_config.model_config.model self.audio_encoder = VibeVoiceAudioEncoder(config) @@ -1100,14 +957,14 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # Process each audio through the VibeVoice encoder embeddings = [] - # Get model device and dtype for alignment + # Get model device for tensor placement. + # dtype is NOT set here — audio_encoder.forward() handles it internally: + # input: converted to fp32 (self._audio_encoder_dtype) + # output: converted to bfloat16 (self._lm_dtype) try: device = next(self.audio_encoder.parameters()).device - dtype = next(self.audio_encoder.parameters()).dtype except StopIteration: - # Fallback if no parameters (shouldn't happen) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dtype = torch.bfloat16 # Handle both stacked tensor and list of tensors # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] @@ -1138,7 +995,7 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if not isinstance(audio_tensor, torch.Tensor): audio_tensor = torch.tensor(audio_tensor) - # Let vLLM handle dtype (bfloat16 by default) + # Only place on correct device; audio_encoder.forward() handles dtype audio_tensor = audio_tensor.to(device=device) # Get actual length if available, otherwise use full length @@ -1294,35 +1151,31 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): intermediate_tensors: Intermediate tensors for pipeline parallelism. inputs_embeds: Pre-computed embeddings (from multimodal merge or decode). """ - try: - # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode) - # Only compute from input_ids if inputs_embeds is not available - if inputs_embeds is None and input_ids is not None: - # Compute embeddings from input_ids - inputs_embeds = self.get_input_embeddings()(input_ids) - - # If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds - if intermediate_tensors is not None: - inputs_embeds = None - - # Get the inner model - handle both wrapped and direct language models - language_model = self.language_model - if hasattr(language_model, "language_model"): - language_model = language_model.language_model - - # Call the language model's model (Qwen2Model) - # vLLM V1 passes kv_caches and attn_metadata via context, not arguments - # IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding - hidden_states = language_model.model( - input_ids=None, # Always None when we have inputs_embeds - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds - ) - return hidden_states - - except Exception as e: - raise + # PRIORITY: Use inputs_embeds if provided (from vLLM multimodal merge or decode) + # Only compute from input_ids if inputs_embeds is not available + if inputs_embeds is None and input_ids is not None: + # Compute embeddings from input_ids + inputs_embeds = self.get_input_embeddings()(input_ids) + + # If we have intermediate tensors (pipeline parallelism), don't use inputs_embeds + if intermediate_tensors is not None: + inputs_embeds = None + + # Get the inner model - handle both wrapped and direct language models + language_model = self.language_model + if hasattr(language_model, "language_model"): + language_model = language_model.language_model + + # Call the language model's model (Qwen2Model) + # vLLM V1 passes kv_caches and attn_metadata via context, not arguments + # IMPORTANT: Pass input_ids=None when using inputs_embeds to avoid double embedding + hidden_states = language_model.model( + input_ids=None, # Always None when we have inputs_embeds + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds + ) + return hidden_states # Alias for training checkpoint compatibility