tts support latest transformers(4.57.6)
This commit is contained in:
@@ -81,6 +81,15 @@ class VibeVoiceStreamingConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
"""Returns the decoder config (required for transformers >= 4.57 cache compatibility)."""
|
||||
return self.decoder_config
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
"""Proxy to decoder_config.num_hidden_layers (required for transformers >= 4.57)."""
|
||||
return self.decoder_config.num_hidden_layers
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Override to_dict to handle torch.dtype serialization.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
|
||||
from tqdm import tqdm
|
||||
import inspect
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from transformers import modeling_utils
|
||||
@@ -30,29 +30,110 @@ TTS_TEXT_WINDOW_SIZE = 5
|
||||
TTS_SPEECH_WINDOW_SIZE = 6
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transformers >= 4.57 Compatibility Layer
|
||||
# The cache system was refactored in transformers 4.57, requiring these helpers.
|
||||
# ============================================================================
|
||||
|
||||
class MockCacheLayer:
|
||||
"""
|
||||
Mock cache layer for transformers >= 4.57 compatibility.
|
||||
Provides the `layers` interface expected by DynamicCache in newer versions.
|
||||
"""
|
||||
|
||||
def __init__(self, key_cache, value_cache, parent_cache=None, layer_idx=0):
|
||||
self.key_cache = key_cache
|
||||
self.value_cache = value_cache
|
||||
self._parent_cache = parent_cache
|
||||
self._layer_idx = layer_idx
|
||||
|
||||
def get_mask_sizes(self, cache_position):
|
||||
"""Return KV length and offset for mask creation."""
|
||||
kv_length = self.key_cache.shape[2] if self.key_cache is not None else 0
|
||||
return kv_length, 0
|
||||
|
||||
def update(self, key_states, value_states, cache_kwargs=None):
|
||||
"""Update the cache with new key/value states."""
|
||||
if self._parent_cache is None:
|
||||
return self.key_cache, self.value_cache
|
||||
|
||||
parent = self._parent_cache
|
||||
idx = self._layer_idx
|
||||
|
||||
# Extend cache lists if needed
|
||||
while len(parent.key_cache) <= idx:
|
||||
parent.key_cache.append(None)
|
||||
parent.value_cache.append(None)
|
||||
|
||||
# Concatenate or initialize cache
|
||||
if parent.key_cache[idx] is not None:
|
||||
parent.key_cache[idx] = torch.cat([parent.key_cache[idx], key_states], dim=2)
|
||||
parent.value_cache[idx] = torch.cat([parent.value_cache[idx], value_states], dim=2)
|
||||
else:
|
||||
parent.key_cache[idx] = key_states
|
||||
parent.value_cache[idx] = value_states
|
||||
|
||||
# Update local references
|
||||
self.key_cache = parent.key_cache[idx]
|
||||
self.value_cache = parent.value_cache[idx]
|
||||
return self.key_cache, self.value_cache
|
||||
|
||||
|
||||
def _ensure_cache_has_layers(cache):
|
||||
"""
|
||||
Ensure the cache has all required attributes for transformers >= 4.57.
|
||||
Creates MockCacheLayer wrappers to provide the expected `layers` interface.
|
||||
"""
|
||||
if cache is None:
|
||||
return cache
|
||||
|
||||
# Add required attributes (skip if read-only)
|
||||
for attr, default in [('layer_class_to_replicate', None), ('offloading', False), ('is_compileable', False)]:
|
||||
if not hasattr(cache, attr):
|
||||
try:
|
||||
setattr(cache, attr, default)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Build layers list from key_cache/value_cache
|
||||
if hasattr(cache, 'key_cache') and hasattr(cache, 'value_cache'):
|
||||
try:
|
||||
cache.layers = [
|
||||
MockCacheLayer(cache.key_cache[i], cache.value_cache[i], parent_cache=cache, layer_idx=i)
|
||||
for i in range(len(cache.key_cache))
|
||||
]
|
||||
except AttributeError:
|
||||
pass
|
||||
elif not hasattr(cache, 'layers'):
|
||||
try:
|
||||
cache.layers = []
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return cache
|
||||
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
num_new_tokens: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update model_kwargs after adding new tokens.
|
||||
Update model_kwargs after adding new tokens (supports multi-token windows).
|
||||
|
||||
Mainly for the case num_new_tokens > 1 (e.g. a whole text window):
|
||||
- past_key_values: take from current outputs
|
||||
- attention_mask: append num_new_tokens ones
|
||||
- cache_position: advance by creating a range for all new positions
|
||||
Updates past_key_values, attention_mask, and cache_position for the next forward pass.
|
||||
"""
|
||||
|
||||
# update past_key_values keeping its naming used in model code
|
||||
model_kwargs["past_key_values"] = getattr(outputs, "past_key_values")
|
||||
model_kwargs["past_key_values"] = _ensure_cache_has_layers(outputs.past_key_values)
|
||||
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
|
||||
)
|
||||
|
||||
model_kwargs["cache_position"] = torch.arange(model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1).to(model_kwargs["cache_position"].device)
|
||||
cache_pos = model_kwargs["cache_position"]
|
||||
model_kwargs["cache_position"] = torch.arange(
|
||||
cache_pos[-1] + 1, cache_pos[-1] + num_new_tokens + 1, device=cache_pos.device
|
||||
)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@@ -157,6 +238,101 @@ class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreT
|
||||
def set_ddpm_inference_steps(self, num_steps=None):
|
||||
self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Prepare model inputs for generation (transformers >= 4.57 compatible)."""
|
||||
model_inputs = {"cache_position": cache_position}
|
||||
|
||||
# Slice inputs when using cache
|
||||
if past_key_values is not None:
|
||||
model_inputs["past_key_values"] = past_key_values
|
||||
if inputs_embeds is not None and input_ids.shape[1] == 0:
|
||||
inputs_embeds = inputs_embeds[:, -cache_position.shape[0]:]
|
||||
elif inputs_embeds is not None or (cache_position is not None and cache_position[-1] >= input_ids.shape[1]):
|
||||
input_ids = input_ids[:, -cache_position.shape[0]:]
|
||||
elif cache_position is not None and input_ids.shape[1] != cache_position.shape[0]:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# Set input_ids or inputs_embeds
|
||||
use_embeds = inputs_embeds is not None and (
|
||||
past_key_values is None or (cache_position is not None and len(cache_position) == inputs_embeds.shape[1])
|
||||
)
|
||||
if use_embeds:
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
else:
|
||||
model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) if input_ids is not None else None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
|
||||
if attention_mask is not None:
|
||||
model_inputs["attention_mask"] = attention_mask
|
||||
|
||||
# Create position_ids from attention_mask
|
||||
if attention_mask is not None and kwargs.get("position_ids") is None:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
# Slice position_ids when using cache
|
||||
if kwargs.get("position_ids") is not None:
|
||||
if past_key_values is not None:
|
||||
seq_len = model_inputs["inputs_embeds"].shape[1] if model_inputs.get("inputs_embeds") is not None else model_inputs["input_ids"].shape[1]
|
||||
model_inputs["position_ids"] = kwargs["position_ids"][:, -seq_len:].clone(memory_format=torch.contiguous_format)
|
||||
else:
|
||||
model_inputs["position_ids"] = kwargs.pop("position_ids").clone(memory_format=torch.contiguous_format)
|
||||
|
||||
# Forward remaining kwargs
|
||||
for key, value in kwargs.items():
|
||||
if key not in model_inputs:
|
||||
model_inputs[key] = value
|
||||
|
||||
model_inputs.pop("labels", None)
|
||||
return model_inputs
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
self,
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=False,
|
||||
num_new_tokens=1,
|
||||
):
|
||||
"""Override to ensure cache compatibility with transformers >= 4.57."""
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens
|
||||
)
|
||||
if "past_key_values" in model_kwargs:
|
||||
model_kwargs["past_key_values"] = _ensure_cache_has_layers(model_kwargs["past_key_values"])
|
||||
return model_kwargs
|
||||
|
||||
def _init_cache_for_generation(self, generation_config, model_kwargs, batch_size, max_cache_length, device):
|
||||
"""
|
||||
Initialize cache for generation, handling different transformers versions.
|
||||
For transformers >= 4.57, returns None to let the model create the cache dynamically.
|
||||
"""
|
||||
try:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
sig = inspect.signature(DynamicCache.__init__)
|
||||
if 'config' in sig.parameters:
|
||||
# transformers >= 4.57: let model handle cache creation
|
||||
return None
|
||||
else:
|
||||
# Older versions: use parent method
|
||||
prep_sig = inspect.signature(self._prepare_cache_for_generation)
|
||||
if 'device' in prep_sig.parameters:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
else:
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length)
|
||||
return model_kwargs.get("past_key_values")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# @can_return_tuple
|
||||
def forward_lm(
|
||||
self,
|
||||
@@ -367,7 +543,10 @@ class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreT
|
||||
)
|
||||
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device)
|
||||
# Handle cache initialization for different transformers versions
|
||||
model_kwargs["past_key_values"] = self._init_cache_for_generation(
|
||||
generation_config, model_kwargs, batch_size, max_cache_length, device
|
||||
)
|
||||
model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long)
|
||||
for k, v in model_kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user