fix: restore sequential encoder (batch encoder causes OOM)
Batch encoder across multiple requests caused GPU OOM when vLLM scheduler sends many audio items at once. The encoder intermediates (~700MB per 69s audio) compete with KV cache for GPU memory. Sequential encoding is stable and proven correct. The encoder (267ms per request) is not the primary throughput bottleneck when encoder cache is enabled (default). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
+1
-23
@@ -1049,9 +1049,6 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
# Get model device for tensor placement.
|
# Get model device for tensor placement.
|
||||||
# dtype is NOT set here — audio_encoder.forward() handles it internally:
|
|
||||||
# input: converted to fp32 (self._audio_encoder_dtype)
|
|
||||||
# output: converted to bfloat16 (self._lm_dtype)
|
|
||||||
try:
|
try:
|
||||||
device = next(self.audio_encoder.parameters()).device
|
device = next(self.audio_encoder.parameters()).device
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@@ -1061,63 +1058,44 @@ class VibeVoiceForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
# vLLM batches as: [batch_size, 1, seq_len] or [batch_size, seq_len]
|
||||||
if isinstance(raw_audio, torch.Tensor):
|
if isinstance(raw_audio, torch.Tensor):
|
||||||
if raw_audio.dim() == 3:
|
if raw_audio.dim() == 3:
|
||||||
# Shape: [batch_size, 1, seq_len] - squeeze the middle dimension
|
|
||||||
num_audios = raw_audio.shape[0]
|
num_audios = raw_audio.shape[0]
|
||||||
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
|
audio_list = [raw_audio[i].squeeze(0) for i in range(num_audios)]
|
||||||
elif raw_audio.dim() == 2:
|
elif raw_audio.dim() == 2:
|
||||||
# Shape: [batch_size, seq_len]
|
|
||||||
num_audios = raw_audio.shape[0]
|
num_audios = raw_audio.shape[0]
|
||||||
audio_list = [raw_audio[i] for i in range(num_audios)]
|
audio_list = [raw_audio[i] for i in range(num_audios)]
|
||||||
else:
|
else:
|
||||||
# Single 1D tensor
|
|
||||||
audio_list = [raw_audio]
|
audio_list = [raw_audio]
|
||||||
elif isinstance(raw_audio, (list, tuple)):
|
elif isinstance(raw_audio, (list, tuple)):
|
||||||
audio_list = list(raw_audio)
|
audio_list = list(raw_audio)
|
||||||
else:
|
else:
|
||||||
# Single tensor
|
|
||||||
audio_list = [raw_audio]
|
audio_list = [raw_audio]
|
||||||
|
|
||||||
for i, audio_tensor in enumerate(audio_list):
|
for i, audio_tensor in enumerate(audio_list):
|
||||||
try:
|
try:
|
||||||
if isinstance(audio_tensor, list):
|
if isinstance(audio_tensor, list):
|
||||||
audio_tensor = torch.stack(audio_tensor)
|
audio_tensor = torch.stack(audio_tensor)
|
||||||
|
|
||||||
# Ensure tensor
|
|
||||||
if not isinstance(audio_tensor, torch.Tensor):
|
if not isinstance(audio_tensor, torch.Tensor):
|
||||||
audio_tensor = torch.tensor(audio_tensor)
|
audio_tensor = torch.tensor(audio_tensor)
|
||||||
|
|
||||||
# Only place on correct device; audio_encoder.forward() handles dtype
|
|
||||||
audio_tensor = audio_tensor.to(device=device)
|
audio_tensor = audio_tensor.to(device=device)
|
||||||
|
|
||||||
# Get actual length if available, otherwise use full length
|
|
||||||
if raw_audio_lengths and i < len(raw_audio_lengths):
|
if raw_audio_lengths and i < len(raw_audio_lengths):
|
||||||
actual_len = int(raw_audio_lengths[i])
|
actual_len = int(raw_audio_lengths[i])
|
||||||
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
|
if actual_len > 0 and actual_len <= audio_tensor.shape[-1]:
|
||||||
# Truncate from the last dimension (sequence length)
|
|
||||||
audio_tensor = audio_tensor[..., :actual_len]
|
audio_tensor = audio_tensor[..., :actual_len]
|
||||||
|
if audio_tensor.numel() < 160:
|
||||||
# Skip if audio is too short (< 1 frame)
|
|
||||||
if audio_tensor.numel() < 160: # Minimum ~1ms at 24kHz
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Encode audio through VibeVoice encoder
|
|
||||||
audio_embeds = self.audio_encoder(
|
audio_embeds = self.audio_encoder(
|
||||||
audio_tensor,
|
audio_tensor,
|
||||||
use_streaming=use_streaming_flag,
|
use_streaming=use_streaming_flag,
|
||||||
segment_duration_s=streaming_segment_duration,
|
segment_duration_s=streaming_segment_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
# audio_embeds shape: [1, seq_len, hidden_size]
|
|
||||||
# We need to return it as a single embedding tensor per audio
|
|
||||||
final_embed = audio_embeds.squeeze(0)
|
final_embed = audio_embeds.squeeze(0)
|
||||||
embeddings.append(final_embed)
|
embeddings.append(final_embed)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error but don't crash - this helps debug profiling issues
|
|
||||||
print(f"[VibeVoice] Error encoding audio {i}: {e}")
|
print(f"[VibeVoice] Error encoding audio {i}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# Return empty embedding to avoid crash
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return tuple(embeddings)
|
return tuple(embeddings)
|
||||||
|
|||||||
Reference in New Issue
Block a user