fix: restore sequential encoder (batch encoder causes OOM)

Batch encoder across multiple requests caused GPU OOM when vLLM
scheduler sends many audio items at once. The encoder intermediates
(~700MB per 69s audio) compete with KV cache for GPU memory.

Sequential encoding is stable and proven correct. The encoder
(267ms per request) is not the primary throughput bottleneck when
encoder cache is enabled (default).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Jianwei Yu
2026-03-27 18:48:06 +00:00
parent cd945395d4
commit 5cd81bb497
+1 -23
View File
@@ -1049,9 +1049,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
embeddings = []
# Get model device for tensor placement.
# dtype is NOT set here — audio_encoder.forward() handles it internally:
# input: converted to fp32 (self._audio_encoder_dtype)
# output: converted to bfloat16 (self._lm_dtype)
try:
device = next(self.audio_encoder.parameters()).device
except StopIteration:
@@ -1061,63 +1058,44 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
if isinstance(raw_audio, torch.Tensor):
if raw_audio.dim() == 3:
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
elif raw_audio.dim() == 2:
# Shape: [batch_size, seq_len]
num_audios = raw_audio.shape[0]
audio_list = [raw_audio[i] for i in range(num_audios)]
else:
# Single 1D tensor
audio_list = [raw_audio]
elif isinstance(raw_audio, (list, tuple)):
audio_list = list(raw_audio)
else:
# Single tensor
audio_list = [raw_audio]
for i, audio_tensor in enumerate(audio_list):
try:
if isinstance(audio_tensor, list):
audio_tensor = torch.stack(audio_tensor)
# Ensure tensor
if not isinstance(audio_tensor, torch.Tensor):
audio_tensor = torch.tensor(audio_tensor)
# Only place on correct device; audio_encoder.forward() handles dtype
audio_tensor = audio_tensor.to(device=device)
# Get actual length if available, otherwise use full length
if raw_audio_lengths and i < len(raw_audio_lengths):
actual_len = int(raw_audio_lengths[i])
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
# Truncate from the last dimension (sequence length)
audio_tensor = audio_tensor[..., :actual_len]
# Skip if audio is too short (< 1 frame)
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
if audio_tensor.numel() < 160:
continue
# Encode audio through VibeVoice encoder
audio_embeds = self.audio_encoder(
audio_tensor,
use_streaming=use_streaming_flag,
segment_duration_s=streaming_segment_duration,
)
# audio_embeds shape: [1, seq_len, hidden_size]
# We need to return it as a single embedding tensor per audio
final_embed = audio_embeds.squeeze(0)
embeddings.append(final_embed)
except Exception as e:
# Log error but don't crash - this helps debug profiling issues
print(f"[VibeVoice] Error encoding audio {i}: {e}")
import traceback
traceback.print_exc()
# Return empty embedding to avoid crash
continue
return tuple(embeddings)