fix: restore sequential encoder (batch encoder causes OOM)

Batch encoder across multiple requests caused GPU OOM when vLLM
scheduler sends many audio items at once. The encoder intermediates
(~700MB per 69s audio) compete with KV cache for GPU memory.

Sequential encoding is stable and proven correct. The encoder
(267ms per request) is not the primary throughput bottleneck when
encoder cache is enabled (default).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Jianwei Yu
2026-03-27 18:48:06 +00:00
parent cd945395d4
commit 5cd81bb497
+1 -23
View File
@@ -1049,9 +1049,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
embeddings = [] embeddings = []
# Get model device for tensor placement. # Get model device for tensor placement.
# dtype is NOT set here — audio_encoder.forward() handles it internally:
# input: converted to fp32 (self._audio_encoder_dtype)
# output: converted to bfloat16 (self._lm_dtype)
try: try:
device = next(self.audio_encoder.parameters()).device device = next(self.audio_encoder.parameters()).device
except StopIteration: except StopIteration:
@@ -1061,63 +1058,44 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len] # vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
if isinstance(raw_audio, torch.Tensor): if isinstance(raw_audio, torch.Tensor):
if raw_audio.dim() == 3: if raw_audio.dim() == 3:
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
num_audios = raw_audio.shape[0] num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)] audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
elif raw_audio.dim() == 2: elif raw_audio.dim() == 2:
# Shape: [batch_size, seq_len]
num_audios = raw_audio.shape[0] num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i] for i in range(num_audios)] audio_list = [raw_audio[i] for i in range(num_audios)]
else: else:
# Single 1D tensor
audio_list = [raw_audio] audio_list = [raw_audio]
elif isinstance(raw_audio, (list, tuple)): elif isinstance(raw_audio, (list, tuple)):
audio_list = list(raw_audio) audio_list = list(raw_audio)
else: else:
# Single tensor
audio_list = [raw_audio] audio_list = [raw_audio]
for i, audio_tensor in enumerate(audio_list): for i, audio_tensor in enumerate(audio_list):
try: try:
if isinstance(audio_tensor, list): if isinstance(audio_tensor, list):
audio_tensor = torch.stack(audio_tensor) audio_tensor = torch.stack(audio_tensor)
# Ensure tensor
if not isinstance(audio_tensor, torch.Tensor): if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor) audio_tensor = torch.tensor(audio_tensor)
# Only place on correct device; audio_encoder.forward() handles dtype
audio_tensor = audio_tensor.to(device=device) audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length
if raw_audio_lengths and i < len(raw_audio_lengths): if raw_audio_lengths and i < len(raw_audio_lengths):
actual_len = int(raw_audio_lengths[i]) actual_len = int(raw_audio_lengths[i])
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]: if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
# Truncate from the last dimension (sequence length)
audio_tensor = audio_tensor[..., :actual_len] audio_tensor = audio_tensor[..., :actual_len]
if audio_tensor.numel() < 160:
# Skip if audio is too short (< 1 frame)
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
continue continue
# Encode audio through VibeVoice encoder
audio_embeds = self.audio_encoder( audio_embeds = self.audio_encoder(
audio_tensor, audio_tensor,
use_streaming=use_streaming_flag, use_streaming=use_streaming_flag,
segment_duration_s=streaming_segment_duration, segment_duration_s=streaming_segment_duration,
) )
# audio_embeds shape: [1, seq_len, hidden_size]
# We need to return it as a single embedding tensor per audio
final_embed = audio_embeds.squeeze(0) final_embed = audio_embeds.squeeze(0)
embeddings.append(final_embed) embeddings.append(final_embed)
except Exception as e: except Exception as e:
# Log error but don't crash - this helps debug profiling issues
print(f"[VibeVoice] Error encoding audio {i}: {e}") print(f"[VibeVoice] Error encoding audio {i}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
# Return empty embedding to avoid crash
continue continue
return tuple(embeddings) return tuple(embeddings)