From 5cd81bb497dfb2b3173f6f851f533e22825c15b7 Mon Sep 17 00:00:00 2001 From: Jianwei Yu Date: Fri, 27 Mar 2026 18:48:06 +0000 Subject: [PATCH] fix: restore sequential encoder (batch encoder causes OOM) Batch encoder across multiple requests caused GPU OOM when vLLM scheduler sends many audio items at once. The encoder intermediates (~700MB per 69s audio) compete with KV cache for GPU memory. Sequential encoding is stable and proven correct. The encoder (267ms per request) is not the primary throughput bottleneck when encoder cache is enabled (default). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- vllm_plugin/model.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/vllm_plugin/model.py b/vllm_plugin/model.py index 9b40f71..bd9060f 100644 --- a/vllm_plugin/model.py +++ b/vllm_plugin/model.py @@ -1049,9 +1049,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): embeddings = [] # Get model device for tensor placement. - # dtype is NOT set here — audio_encoder.forward() handles it internally: - # input: converted to fp32 (self._audio_encoder_dtype) - # output: converted to bfloat16 (self._lm_dtype) try: device = next(self.audio_encoder.parameters()).device except StopIteration: @@ -1061,63 +1058,44 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] if isinstance(raw_audio, torch.Tensor): if raw_audio.dim() == 3: - # Shape: [batch_size, 1, seq_len] - squeeze the middle dimension num_audios = raw_audio.shape[0] audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)] elif raw_audio.dim() == 2: - # Shape: [batch_size, seq_len] num_audios = raw_audio.shape[0] audio_list = [raw_audio[i] for i in range(num_audios)] else: - # Single 1D tensor audio_list = [raw_audio] elif isinstance(raw_audio, (list, tuple)): audio_list = list(raw_audio) else: - # Single tensor audio_list = [raw_audio] for i, audio_tensor in enumerate(audio_list): try: if isinstance(audio_tensor, list): audio_tensor = torch.stack(audio_tensor) - - # Ensure tensor if not isinstance(audio_tensor, torch.Tensor): audio_tensor = torch.tensor(audio_tensor) - - # Only place on correct device; audio_encoder.forward() handles dtype audio_tensor = audio_tensor.to(device=device) - - # Get actual length if available, otherwise use full length if raw_audio_lengths and i < len(raw_audio_lengths): actual_len = int(raw_audio_lengths[i]) if actual_len > 0 and actual_len <= audio_tensor.shape[-1]: - # Truncate from the last dimension (sequence length) audio_tensor = audio_tensor[..., :actual_len] - - # Skip if audio is too short (< 1 frame) - if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz + if audio_tensor.numel() < 160: continue - # Encode audio through VibeVoice encoder audio_embeds = self.audio_encoder( audio_tensor, use_streaming=use_streaming_flag, segment_duration_s=streaming_segment_duration, ) - - # audio_embeds shape: [1, seq_len, hidden_size] - # We need to return it as a single embedding tensor per audio final_embed = audio_embeds.squeeze(0) embeddings.append(final_embed) except Exception as e: - # Log error but don't crash - this helps debug profiling issues print(f"[VibeVoice] Error encoding audio {i}: {e}") import traceback traceback.print_exc() - # Return empty embedding to avoid crash continue return tuple(embeddings)