fix: restore sequential encoder (batch encoder causes OOM)
Batch encoder across multiple requests caused GPU OOM when vLLM scheduler sends many audio items at once. The encoder intermediates (~700MB per 69s audio) compete with KV cache for GPU memory. Sequential encoding is stable and proven correct. The encoder (267ms per request) is not the primary throughput bottleneck when encoder cache is enabled (default). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
+1
-23
@@ -1049,9 +1049,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
embeddings = []
|
||||
|
||||
# Get model device for tensor placement.
|
||||
# dtype is NOT set here — audio_encoder.forward() handles it internally:
|
||||
# input: converted to fp32 (self._audio_encoder_dtype)
|
||||
# output: converted to bfloat16 (self._lm_dtype)
|
||||
try:
|
||||
device = next(self.audio_encoder.parameters()).device
|
||||
except StopIteration:
|
||||
@@ -1061,63 +1058,44 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
||||
if isinstance(raw_audio, torch.Tensor):
|
||||
if raw_audio.dim() == 3:
|
||||
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
|
||||
num_audios = raw_audio.shape[0]
|
||||
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
|
||||
elif raw_audio.dim() == 2:
|
||||
# Shape: [batch_size, seq_len]
|
||||
num_audios = raw_audio.shape[0]
|
||||
audio_list = [raw_audio[i] for i in range(num_audios)]
|
||||
else:
|
||||
# Single 1D tensor
|
||||
audio_list = [raw_audio]
|
||||
elif isinstance(raw_audio, (list, tuple)):
|
||||
audio_list = list(raw_audio)
|
||||
else:
|
||||
# Single tensor
|
||||
audio_list = [raw_audio]
|
||||
|
||||
for i, audio_tensor in enumerate(audio_list):
|
||||
try:
|
||||
if isinstance(audio_tensor, list):
|
||||
audio_tensor = torch.stack(audio_tensor)
|
||||
|
||||
# Ensure tensor
|
||||
if not isinstance(audio_tensor, torch.Tensor):
|
||||
audio_tensor = torch.tensor(audio_tensor)
|
||||
|
||||
# Only place on correct device; audio_encoder.forward() handles dtype
|
||||
audio_tensor = audio_tensor.to(device=device)
|
||||
|
||||
# Get actual length if available, otherwise use full length
|
||||
if raw_audio_lengths and i < len(raw_audio_lengths):
|
||||
actual_len = int(raw_audio_lengths[i])
|
||||
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
|
||||
# Truncate from the last dimension (sequence length)
|
||||
audio_tensor = audio_tensor[..., :actual_len]
|
||||
|
||||
# Skip if audio is too short (< 1 frame)
|
||||
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
|
||||
if audio_tensor.numel() < 160:
|
||||
continue
|
||||
|
||||
# Encode audio through VibeVoice encoder
|
||||
audio_embeds = self.audio_encoder(
|
||||
audio_tensor,
|
||||
use_streaming=use_streaming_flag,
|
||||
segment_duration_s=streaming_segment_duration,
|
||||
)
|
||||
|
||||
# audio_embeds shape: [1, seq_len, hidden_size]
|
||||
# We need to return it as a single embedding tensor per audio
|
||||
final_embed = audio_embeds.squeeze(0)
|
||||
embeddings.append(final_embed)
|
||||
|
||||
except Exception as e:
|
||||
# Log error but don't crash - this helps debug profiling issues
|
||||
print(f"[VibeVoice] Error encoding audio {i}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# Return empty embedding to avoid crash
|
||||
continue
|
||||
|
||||
return tuple(embeddings)
|
||||
|
||||
Reference in New Issue
Block a user