diff --git a/vllm_plugin/model.py b/vllm_plugin/model.py index 9b40f71..bd9060f 100644 --- a/vllm_plugin/model.py +++ b/vllm_plugin/model.py @@ -1049,9 +1049,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): embeddings = [] # Get model device for tensor placement. - # dtype is NOT set here — audio_encoder.forward() handles it internally: - # input: converted to fp32 (self._audio_encoder_dtype) - # output: converted to bfloat16 (self._lm_dtype) try: device = next(self.audio_encoder.parameters()).device except StopIteration: @@ -1061,63 +1058,44 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] if isinstance(raw_audio, torch.Tensor): if raw_audio.dim() == 3: - # Shape: [batch_size, 1, seq_len] - squeeze the middle dimension num_audios = raw_audio.shape[0] audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)] elif raw_audio.dim() == 2: - # Shape: [batch_size, seq_len] num_audios = raw_audio.shape[0] audio_list = [raw_audio[i] for i in range(num_audios)] else: - # Single 1D tensor audio_list = [raw_audio] elif isinstance(raw_audio, (list, tuple)): audio_list = list(raw_audio) else: - # Single tensor audio_list = [raw_audio] for i, audio_tensor in enumerate(audio_list): try: if isinstance(audio_tensor, list): audio_tensor = torch.stack(audio_tensor) - - # Ensure tensor if not isinstance(audio_tensor, torch.Tensor): audio_tensor = torch.tensor(audio_tensor) - - # Only place on correct device; audio_encoder.forward() handles dtype audio_tensor = audio_tensor.to(device=device) - - # Get actual length if available, otherwise use full length if raw_audio_lengths and i < len(raw_audio_lengths): actual_len = int(raw_audio_lengths[i]) if actual_len > 0 and actual_len <= audio_tensor.shape[-1]: - # Truncate from the last dimension (sequence length) audio_tensor = audio_tensor[..., :actual_len] - - # Skip if audio is too short (< 1 frame) - if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz + if audio_tensor.numel() < 160: continue - # Encode audio through VibeVoice encoder audio_embeds = self.audio_encoder( audio_tensor, use_streaming=use_streaming_flag, segment_duration_s=streaming_segment_duration, ) - - # audio_embeds shape: [1, seq_len, hidden_size] - # We need to return it as a single embedding tensor per audio final_embed = audio_embeds.squeeze(0) embeddings.append(final_embed) except Exception as e: - # Log error but don't crash - this helps debug profiling issues print(f"[VibeVoice] Error encoding audio {i}: {e}") import traceback traceback.print_exc() - # Return empty embedding to avoid crash continue return tuple(embeddings)