tts support latest transformers(4.57.6)

This commit is contained in:
YaoyaoChang
2026-01-26 03:14:07 -08:00
parent c4ee4fe716
commit a00f431e14
2 changed files with 202 additions and 14 deletions
@@ -81,6 +81,15 @@ class VibeVoiceStreamingConfig(PretrainedConfig):
super().__init__(**kwargs)
def get_text_config(self, decoder=False):
"""Returns the decoder config (required for transformers >= 4.57 cache compatibility)."""
return self.decoder_config
@property
def num_hidden_layers(self):
"""Proxy to decoder_config.num_hidden_layers (required for transformers >= 4.57)."""
return self.decoder_config.num_hidden_layers
def to_dict(self):
"""
Override to_dict to handle torch.dtype serialization.
@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
from tqdm import tqdm
import inspect
import torch
import torch.nn as nn
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers import modeling_utils
@@ -30,29 +30,110 @@ TTS_TEXT_WINDOW_SIZE = 5
TTS_SPEECH_WINDOW_SIZE = 6
# ============================================================================
# Transformers >= 4.57 Compatibility Layer
# The cache system was refactored in transformers 4.57, requiring these helpers.
# ============================================================================
class MockCacheLayer:
"""
Mock cache layer for transformers >= 4.57 compatibility.
Provides the `layers` interface expected by DynamicCache in newer versions.
"""
def __init__(self, key_cache, value_cache, parent_cache=None, layer_idx=0):
self.key_cache = key_cache
self.value_cache = value_cache
self._parent_cache = parent_cache
self._layer_idx = layer_idx
def get_mask_sizes(self, cache_position):
"""Return KV length and offset for mask creation."""
kv_length = self.key_cache.shape[2] if self.key_cache is not None else 0
return kv_length, 0
def update(self, key_states, value_states, cache_kwargs=None):
"""Update the cache with new key/value states."""
if self._parent_cache is None:
return self.key_cache, self.value_cache
parent = self._parent_cache
idx = self._layer_idx
# Extend cache lists if needed
while len(parent.key_cache) <= idx:
parent.key_cache.append(None)
parent.value_cache.append(None)
# Concatenate or initialize cache
if parent.key_cache[idx] is not None:
parent.key_cache[idx] = torch.cat([parent.key_cache[idx], key_states], dim=2)
parent.value_cache[idx] = torch.cat([parent.value_cache[idx], value_states], dim=2)
else:
parent.key_cache[idx] = key_states
parent.value_cache[idx] = value_states
# Update local references
self.key_cache = parent.key_cache[idx]
self.value_cache = parent.value_cache[idx]
return self.key_cache, self.value_cache
def _ensure_cache_has_layers(cache):
"""
Ensure the cache has all required attributes for transformers >= 4.57.
Creates MockCacheLayer wrappers to provide the expected `layers` interface.
"""
if cache is None:
return cache
# Add required attributes (skip if read-only)
for attr, default in [('layer_class_to_replicate', None), ('offloading', False), ('is_compileable', False)]:
if not hasattr(cache, attr):
try:
setattr(cache, attr, default)
except AttributeError:
pass
# Build layers list from key_cache/value_cache
if hasattr(cache, 'key_cache') and hasattr(cache, 'value_cache'):
try:
cache.layers = [
MockCacheLayer(cache.key_cache[i], cache.value_cache[i], parent_cache=cache, layer_idx=i)
for i in range(len(cache.key_cache))
]
except AttributeError:
pass
elif not hasattr(cache, 'layers'):
try:
cache.layers = []
except AttributeError:
pass
return cache
def _update_model_kwargs_for_generation(
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
num_new_tokens: int = 1,
) -> Dict[str, Any]:
"""
Update model_kwargs after adding new tokens.
Update model_kwargs after adding new tokens (supports multi-token windows).
Mainly for the case num_new_tokens > 1 (e.g. a whole text window):
- past_key_values: take from current outputs
- attention_mask: append num_new_tokens ones
- cache_position: advance by creating a range for all new positions
Updates past_key_values, attention_mask, and cache_position for the next forward pass.
"""
# update past_key_values keeping its naming used in model code
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
model_kwargs["past_key_values"] = _ensure_cache_has_layers(outputs.past_key_values)
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
)
model_kwargs["cache_position"] = torch.arange(model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1).to(model_kwargs["cache_position"].device)
cache_pos = model_kwargs["cache_position"]
model_kwargs["cache_position"] = torch.arange(
cache_pos[-1] + 1, cache_pos[-1] + num_new_tokens + 1, device=cache_pos.device
)
return model_kwargs
@@ -157,6 +238,101 @@ class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreT
def set_ddpm_inference_steps(self, num_steps=None):
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
**kwargs,
):
"""Prepare model inputs for generation (transformers >= 4.57 compatible)."""
model_inputs = {"cache_position": cache_position}
# Slice inputs when using cache
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if inputs_embeds is not None and input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0]:]
elif inputs_embeds is not None or (cache_position is not None and cache_position[-1] >= input_ids.shape[1]):
input_ids = input_ids[:, -cache_position.shape[0]:]
elif cache_position is not None and input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
# Set input_ids or inputs_embeds
use_embeds = inputs_embeds is not None and (
past_key_values is None or (cache_position is not None and len(cache_position) == inputs_embeds.shape[1])
)
if use_embeds:
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) if input_ids is not None else None
model_inputs["inputs_embeds"] = None
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
# Create position_ids from attention_mask
if attention_mask is not None and kwargs.get("position_ids") is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids
# Slice position_ids when using cache
if kwargs.get("position_ids") is not None:
if past_key_values is not None:
seq_len = model_inputs["inputs_embeds"].shape[1] if model_inputs.get("inputs_embeds") is not None else model_inputs["input_ids"].shape[1]
model_inputs["position_ids"] = kwargs["position_ids"][:, -seq_len:].clone(memory_format=torch.contiguous_format)
else:
model_inputs["position_ids"] = kwargs.pop("position_ids").clone(memory_format=torch.contiguous_format)
# Forward remaining kwargs
for key, value in kwargs.items():
if key not in model_inputs:
model_inputs[key] = value
model_inputs.pop("labels", None)
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs,
is_encoder_decoder=False,
num_new_tokens=1,
):
"""Override to ensure cache compatibility with transformers >= 4.57."""
model_kwargs = super()._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens
)
if "past_key_values" in model_kwargs:
model_kwargs["past_key_values"] = _ensure_cache_has_layers(model_kwargs["past_key_values"])
return model_kwargs
def _init_cache_for_generation(self, generation_config, model_kwargs, batch_size, max_cache_length, device):
"""
Initialize cache for generation, handling different transformers versions.
For transformers >= 4.57, returns None to let the model create the cache dynamically.
"""
try:
from transformers.cache_utils import DynamicCache
sig = inspect.signature(DynamicCache.__init__)
if 'config' in sig.parameters:
# transformers >= 4.57: let model handle cache creation
return None
else:
# Older versions: use parent method
prep_sig = inspect.signature(self._prepare_cache_for_generation)
if 'device' in prep_sig.parameters:
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
else:
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length)
return model_kwargs.get("past_key_values")
except Exception:
return None
# @can_return_tuple
def forward_lm(
self,
@@ -367,7 +543,10 @@ class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreT
)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
# Handle cache initialization for different transformers versions
model_kwargs["past_key_values"] = self._init_cache_for_generation(
generation_config, model_kwargs, batch_size, max_cache_length, device
)
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
for k, v in model_kwargs.items():
if isinstance(v, torch.Tensor):