From a00f431e14fcbd624ba3888857a6bc8a7120f348 Mon Sep 17 00:00:00 2001 From: YaoyaoChang Date: Mon, 26 Jan 2026 03:14:07 -0800 Subject: [PATCH] tts support latest transformers(4.57.6) --- .../configuration_vibevoice_streaming.py | 9 + .../modeling_vibevoice_streaming_inference.py | 207 ++++++++++++++++-- 2 files changed, 202 insertions(+), 14 deletions(-) diff --git a/vibevoice/modular/configuration_vibevoice_streaming.py b/vibevoice/modular/configuration_vibevoice_streaming.py index f27d983..2bd9d6e 100644 --- a/vibevoice/modular/configuration_vibevoice_streaming.py +++ b/vibevoice/modular/configuration_vibevoice_streaming.py @@ -81,6 +81,15 @@ class VibeVoiceStreamingConfig(PretrainedConfig): super().__init__(**kwargs) + def get_text_config(self, decoder=False): + """Returns the decoder config (required for transformers >= 4.57 cache compatibility).""" + return self.decoder_config + + @property + def num_hidden_layers(self): + """Proxy to decoder_config.num_hidden_layers (required for transformers >= 4.57).""" + return self.decoder_config.num_hidden_layers + def to_dict(self): """ Override to_dict to handle torch.dtype serialization. diff --git a/vibevoice/modular/modeling_vibevoice_streaming_inference.py b/vibevoice/modular/modeling_vibevoice_streaming_inference.py index 3baab5d..251b927 100644 --- a/vibevoice/modular/modeling_vibevoice_streaming_inference.py +++ b/vibevoice/modular/modeling_vibevoice_streaming_inference.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union, Callable from tqdm import tqdm +import inspect import torch import torch.nn as nn from transformers.models.auto import AutoModel, AutoModelForCausalLM - from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers import modeling_utils @@ -30,29 +30,110 @@ TTS_TEXT_WINDOW_SIZE = 5 TTS_SPEECH_WINDOW_SIZE = 6 +# ============================================================================ +# Transformers >= 4.57 Compatibility Layer +# The cache system was refactored in transformers 4.57, requiring these helpers. +# ============================================================================ + +class MockCacheLayer: + """ + Mock cache layer for transformers >= 4.57 compatibility. + Provides the `layers` interface expected by DynamicCache in newer versions. + """ + + def __init__(self, key_cache, value_cache, parent_cache=None, layer_idx=0): + self.key_cache = key_cache + self.value_cache = value_cache + self._parent_cache = parent_cache + self._layer_idx = layer_idx + + def get_mask_sizes(self, cache_position): + """Return KV length and offset for mask creation.""" + kv_length = self.key_cache.shape[2] if self.key_cache is not None else 0 + return kv_length, 0 + + def update(self, key_states, value_states, cache_kwargs=None): + """Update the cache with new key/value states.""" + if self._parent_cache is None: + return self.key_cache, self.value_cache + + parent = self._parent_cache + idx = self._layer_idx + + # Extend cache lists if needed + while len(parent.key_cache) <= idx: + parent.key_cache.append(None) + parent.value_cache.append(None) + + # Concatenate or initialize cache + if parent.key_cache[idx] is not None: + parent.key_cache[idx] = torch.cat([parent.key_cache[idx], key_states], dim=2) + parent.value_cache[idx] = torch.cat([parent.value_cache[idx], value_states], dim=2) + else: + parent.key_cache[idx] = key_states + parent.value_cache[idx] = value_states + + # Update local references + self.key_cache = parent.key_cache[idx] + self.value_cache = parent.value_cache[idx] + return self.key_cache, self.value_cache + + +def _ensure_cache_has_layers(cache): + """ + Ensure the cache has all required attributes for transformers >= 4.57. + Creates MockCacheLayer wrappers to provide the expected `layers` interface. + """ + if cache is None: + return cache + + # Add required attributes (skip if read-only) + for attr, default in [('layer_class_to_replicate', None), ('offloading', False), ('is_compileable', False)]: + if not hasattr(cache, attr): + try: + setattr(cache, attr, default) + except AttributeError: + pass + + # Build layers list from key_cache/value_cache + if hasattr(cache, 'key_cache') and hasattr(cache, 'value_cache'): + try: + cache.layers = [ + MockCacheLayer(cache.key_cache[i], cache.value_cache[i], parent_cache=cache, layer_idx=i) + for i in range(len(cache.key_cache)) + ] + except AttributeError: + pass + elif not hasattr(cache, 'layers'): + try: + cache.layers = [] + except AttributeError: + pass + + return cache + + def _update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, ) -> Dict[str, Any]: """ - Update model_kwargs after adding new tokens. - - Mainly for the case num_new_tokens > 1 (e.g. a whole text window): - - past_key_values: take from current outputs - - attention_mask: append num_new_tokens ones - - cache_position: advance by creating a range for all new positions + Update model_kwargs after adding new tokens (supports multi-token windows). + + Updates past_key_values, attention_mask, and cache_position for the next forward pass. """ - - # update past_key_values keeping its naming used in model code - model_kwargs["past_key_values"] = getattr(outputs, "past_key_values") - + model_kwargs["past_key_values"] = _ensure_cache_has_layers(outputs.past_key_values) + attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1 ) - - model_kwargs["cache_position"] = torch.arange(model_kwargs["cache_position"][-1] + 1, model_kwargs["cache_position"][-1] + num_new_tokens + 1).to(model_kwargs["cache_position"].device) + + cache_pos = model_kwargs["cache_position"] + model_kwargs["cache_position"] = torch.arange( + cache_pos[-1] + 1, cache_pos[-1] + num_new_tokens + 1, device=cache_pos.device + ) return model_kwargs @@ -157,6 +238,101 @@ class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreT def set_ddpm_inference_steps(self, num_steps=None): self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + **kwargs, + ): + """Prepare model inputs for generation (transformers >= 4.57 compatible).""" + model_inputs = {"cache_position": cache_position} + + # Slice inputs when using cache + if past_key_values is not None: + model_inputs["past_key_values"] = past_key_values + if inputs_embeds is not None and input_ids.shape[1] == 0: + inputs_embeds = inputs_embeds[:, -cache_position.shape[0]:] + elif inputs_embeds is not None or (cache_position is not None and cache_position[-1] >= input_ids.shape[1]): + input_ids = input_ids[:, -cache_position.shape[0]:] + elif cache_position is not None and input_ids.shape[1] != cache_position.shape[0]: + input_ids = input_ids[:, cache_position] + + # Set input_ids or inputs_embeds + use_embeds = inputs_embeds is not None and ( + past_key_values is None or (cache_position is not None and len(cache_position) == inputs_embeds.shape[1]) + ) + if use_embeds: + model_inputs["input_ids"] = None + model_inputs["inputs_embeds"] = inputs_embeds + else: + model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) if input_ids is not None else None + model_inputs["inputs_embeds"] = None + + if attention_mask is not None: + model_inputs["attention_mask"] = attention_mask + + # Create position_ids from attention_mask + if attention_mask is not None and kwargs.get("position_ids") is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs["position_ids"] = position_ids + + # Slice position_ids when using cache + if kwargs.get("position_ids") is not None: + if past_key_values is not None: + seq_len = model_inputs["inputs_embeds"].shape[1] if model_inputs.get("inputs_embeds") is not None else model_inputs["input_ids"].shape[1] + model_inputs["position_ids"] = kwargs["position_ids"][:, -seq_len:].clone(memory_format=torch.contiguous_format) + else: + model_inputs["position_ids"] = kwargs.pop("position_ids").clone(memory_format=torch.contiguous_format) + + # Forward remaining kwargs + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + model_inputs.pop("labels", None) + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_kwargs, + is_encoder_decoder=False, + num_new_tokens=1, + ): + """Override to ensure cache compatibility with transformers >= 4.57.""" + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens + ) + if "past_key_values" in model_kwargs: + model_kwargs["past_key_values"] = _ensure_cache_has_layers(model_kwargs["past_key_values"]) + return model_kwargs + + def _init_cache_for_generation(self, generation_config, model_kwargs, batch_size, max_cache_length, device): + """ + Initialize cache for generation, handling different transformers versions. + For transformers >= 4.57, returns None to let the model create the cache dynamically. + """ + try: + from transformers.cache_utils import DynamicCache + sig = inspect.signature(DynamicCache.__init__) + if 'config' in sig.parameters: + # transformers >= 4.57: let model handle cache creation + return None + else: + # Older versions: use parent method + prep_sig = inspect.signature(self._prepare_cache_for_generation) + if 'device' in prep_sig.parameters: + self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device) + else: + self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length) + return model_kwargs.get("past_key_values") + except Exception: + return None + # @can_return_tuple def forward_lm( self, @@ -367,7 +543,10 @@ class VibeVoiceStreamingForConditionalGenerationInference(VibeVoiceStreamingPreT ) max_cache_length = generation_config.max_length - 1 - self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device) + # Handle cache initialization for different transformers versions + model_kwargs["past_key_values"] = self._init_cache_for_generation( + generation_config, model_kwargs, batch_size, max_cache_length, device + ) model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long) for k, v in model_kwargs.items(): if isinstance(v, torch.Tensor):