diff --git a/finetuning/README.md b/finetuning/README.md new file mode 100644 index 0000000..7c2b8ff --- /dev/null +++ b/finetuning/README.md @@ -0,0 +1,138 @@ +# VibeVoice ASR LoRA Fine-tuning + +This directory contains scripts for LoRA (Low-Rank Adaptation) fine-tuning of the VibeVoice ASR model. + +## Requirements + +```bash +pip install peft accelerate +``` + +## Toy Dataset + +> **Note**: The `toy_dataset/` included in this directory contains **synthetic audio generated by VibeVoice TTS** for demonstration purposes only. It is NOT a production-quality dataset. +> +> When using your own data, you should: +> - Prepare real audio recordings with accurate transcriptions +> - Adjust hyperparameters (learning rate, epochs, LoRA rank) based on your dataset size and domain +> - Consider the audio quality and speaker diversity in your data + +## Data Format + +Training data should be organized as pairs of audio files and JSON labels in the same directory: + +``` +toy_dataset/ +├── 0.mp3 +├── 0.json +├── 1.mp3 +├── 1.json +└── ... +``` + +### JSON Label Format + +Each JSON file should have the following structure: + +```json +{ + "audio_duration": 351.73, + "audio_path": "0.mp3", + "segments": [ + { + "speaker": 0, + "text": "Hey everyone, welcome back...", + "start": 0.0, + "end": 38.68 + }, + { + "speaker": 1, + "text": "Thanks for having me...", + "start": 38.75, + "end": 77.88 + } + ], + "hotwords": ["Tea Brew", "Aiden Host"] // optional +} +``` + +## Training + +### Basic Usage + +```bash +python lora_finetune.py \ + --model_path microsoft/VibeVoice-ASR \ + --data_dir ./toy_dataset \ + --output_dir ./output \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --bf16 +``` + +### Full Options + +The script uses HuggingFace's `TrainingArguments`, so all standard options are available: + +```bash +python lora_finetune.py \ + --model_path microsoft/VibeVoice-ASR \ + --data_dir ./toy_dataset \ + --output_dir ./output \ + --lora_r 16 \ + --lora_alpha 32 \ + --lora_dropout 0.05 \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate 1e-4 \ + --warmup_ratio 0.1 \ + --weight_decay 0.01 \ + --max_grad_norm 1.0 \ + --logging_steps 10 \ + --save_steps 100 \ + --gradient_checkpointing \ + --bf16 \ + --report_to none +``` + +### Key Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `--lora_r` | 16 | LoRA rank (lower = fewer params, higher = more expressive) | +| `--lora_alpha` | 32 | LoRA scaling factor (typically 2x rank) | +| `--lora_dropout` | 0.05 | Dropout for LoRA layers | +| `--per_device_train_batch_size` | 8 | Batch size per device | +| `--gradient_accumulation_steps` | 1 | Effective batch size = batch_size × grad_accum | +| `--learning_rate` | 5e-5 | Learning rate (1e-4 to 2e-4 typical for LoRA) | +| `--gradient_checkpointing` | False | Enable to reduce memory usage | +| `--use_hotwords` | True | Include hotwords from JSON as context | +| `--max_audio_length` | None | Skip audio longer than this (seconds) | + +## Inference with Fine-tuned Model + +```bash +python inference_lora.py \ + --base_model microsoft/VibeVoice-ASR \ + --lora_path ./output \ + --audio_file ./toy_dataset/0.mp3 \ + --context_info "Hotwords: Tea Brew, Aiden Host" +``` + +## Merging LoRA Weights (Optional) + +To merge LoRA weights into the base model for faster inference: + +```python +from peft import PeftModel + +# Load base model + LoRA +model = VibeVoiceASRForConditionalGeneration.from_pretrained("microsoft/VibeVoice-ASR", ...) +model = PeftModel.from_pretrained(model, "./output") + +# Merge and save +model = model.merge_and_unload() +model.save_pretrained("./merged_model") +``` diff --git a/finetuning/inference_lora.py b/finetuning/inference_lora.py new file mode 100644 index 0000000..4c9a316 --- /dev/null +++ b/finetuning/inference_lora.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +""" +Inference with LoRA Fine-tuned VibeVoice ASR Model + +This script loads a LoRA fine-tuned model and runs inference. + +Usage: + python inference_lora.py \ + --base_model microsoft/VibeVoice-ASR \ + --lora_path ./output \ + --audio_file ./toy_dataset/0.mp3 +""" + +import os +import sys +import argparse +import torch +from pathlib import Path + +from peft import PeftModel + +# Add parent directory to path for vibevoice imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration +from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor + + +def load_lora_model( + base_model_path: str, + lora_path: str, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, +): + """ + Load base model and merge with LoRA weights. + + Args: + base_model_path: Path to base pretrained model + lora_path: Path to LoRA adapter weights + device: Device to load model on + dtype: Data type for model + + Returns: + Tuple of (model, processor) + """ + print(f"Loading base model from {base_model_path}") + + # Load processor + processor = VibeVoiceASRProcessor.from_pretrained( + base_model_path, + language_model_pretrained_name="Qwen/Qwen2.5-7B" + ) + + # Load base model + model = VibeVoiceASRForConditionalGeneration.from_pretrained( + base_model_path, + dtype=dtype, + device_map=device if device == "auto" else None, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + + if device != "auto": + model = model.to(device) + + # Load LoRA adapter + print(f"Loading LoRA adapter from {lora_path}") + model = PeftModel.from_pretrained(model, lora_path) + + # Optionally merge LoRA weights into base model for faster inference + # model = model.merge_and_unload() + + model.eval() + print("Model loaded successfully") + + return model, processor + + +def transcribe( + model, + processor, + audio_path: str, + max_new_tokens: int = 4096, + temperature: float = 0.0, + context_info: str = None, + device: str = "cuda", +): + """ + Transcribe an audio file using the LoRA fine-tuned model. + + Args: + model: The LoRA fine-tuned model + processor: The processor + audio_path: Path to audio file + max_new_tokens: Maximum tokens to generate + temperature: Sampling temperature (0 = greedy) + context_info: Optional context info (e.g., hotwords) + device: Device + + Returns: + Transcription result + """ + print(f"\nTranscribing: {audio_path}") + + # Process audio + inputs = processor( + audio=audio_path, + sampling_rate=None, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + context_info=context_info, + ) + + # Move to device + inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in inputs.items()} + + # Generation config + gen_config = { + "max_new_tokens": max_new_tokens, + "pad_token_id": processor.pad_id, + "eos_token_id": processor.tokenizer.eos_token_id, + "do_sample": temperature > 0, + } + if temperature > 0: + gen_config["temperature"] = temperature + gen_config["top_p"] = 0.9 + + # Generate + with torch.no_grad(): + output_ids = model.generate(**inputs, **gen_config) + + # Decode + input_length = inputs['input_ids'].shape[1] + generated_ids = output_ids[0, input_length:] + generated_text = processor.decode(generated_ids, skip_special_tokens=True) + + # Parse structured output + try: + segments = processor.post_process_transcription(generated_text) + except Exception as e: + print(f"Warning: Failed to parse structured output: {e}") + segments = [] + + return { + "raw_text": generated_text, + "segments": segments, + } + + +def main(): + parser = argparse.ArgumentParser(description="Inference with LoRA Fine-tuned VibeVoice ASR") + + parser.add_argument( + "--base_model", + type=str, + default="microsoft/VibeVoice-ASR", + help="Path to base pretrained model" + ) + parser.add_argument( + "--lora_path", + type=str, + required=True, + help="Path to LoRA adapter weights" + ) + parser.add_argument( + "--audio_file", + type=str, + required=True, + help="Path to audio file to transcribe" + ) + parser.add_argument( + "--context_info", + type=str, + default=None, + help="Optional context info (e.g., 'Hotwords: Tea Brew, Aiden Host')" + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=4096, + help="Maximum tokens to generate" + ) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature (0 = greedy)" + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to use" + ) + + args = parser.parse_args() + + # Load model + dtype = torch.bfloat16 if args.device != "cpu" else torch.float32 + model, processor = load_lora_model( + base_model_path=args.base_model, + lora_path=args.lora_path, + device=args.device, + dtype=dtype, + ) + + # Transcribe + result = transcribe( + model=model, + processor=processor, + audio_path=args.audio_file, + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + context_info=args.context_info, + device=args.device, + ) + + # Print results + print("\n" + "="*60) + print("Transcription Result") + print("="*60) + + print("\n--- Raw Output ---") + raw_text = result['raw_text'] + print(raw_text[:2000] + "..." if len(raw_text) > 2000 else raw_text) + + if result['segments']: + print(f"\n--- Structured Output ({len(result['segments'])} segments) ---") + for seg in result['segments'][:20]: + print(f"[{seg.get('start_time', 'N/A')} - {seg.get('end_time', 'N/A')}] " + f"Speaker {seg.get('speaker_id', 'N/A')}: {seg.get('text', '')[:80]}...") + if len(result['segments']) > 20: + print(f" ... and {len(result['segments']) - 20} more segments") + + +if __name__ == "__main__": + main() diff --git a/finetuning/lora_finetune.py b/finetuning/lora_finetune.py new file mode 100644 index 0000000..e5b64b0 --- /dev/null +++ b/finetuning/lora_finetune.py @@ -0,0 +1,561 @@ +#!/usr/bin/env python +""" +VibeVoice ASR LoRA Fine-tuning Script + +This script implements LoRA (Low-Rank Adaptation) fine-tuning for the VibeVoice ASR model. +It uses PEFT (Parameter-Efficient Fine-Tuning) library for efficient training. +""" + +import os +import sys +import json +import logging +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +import numpy as np + +from transformers import ( + TrainingArguments, + Trainer, + HfArgumentParser, +) +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_kbit_training, + TaskType, +) + +# Add parent directory to path for vibevoice imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration +from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor + +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """Arguments for model configuration.""" + model_path: str = field( + default="microsoft/VibeVoice-ASR", + metadata={"help": "Path to pretrained model (HuggingFace model ID or local path)"} + ) + + +@dataclass +class DataArguments: + """Arguments for data configuration.""" + data_dir: str = field( + default="./toy_dataset", + metadata={"help": "Directory containing training data"} + ) + max_audio_length: Optional[float] = field( + default=None, + metadata={"help": "Maximum audio length in seconds (default: no limit)"} + ) + use_hotwords: bool = field( + default=True, + metadata={"help": "Whether to use hotwords from JSON"} + ) + + +@dataclass +class LoraArguments: + """Arguments for LoRA configuration.""" + lora_r: int = field( + default=16, + metadata={"help": "LoRA rank"} + ) + lora_alpha: int = field( + default=32, + metadata={"help": "LoRA alpha (scaling factor)"} + ) + lora_dropout: float = field( + default=0.05, + metadata={"help": "LoRA dropout"} + ) + + +@dataclass +class VibeVoiceASRDataCollator: + """ + Data collator for VibeVoice ASR fine-tuning. + Handles batching of variable-length audio and text sequences. + """ + processor: VibeVoiceASRProcessor + pad_token_id: int + label_pad_token_id: int = -100 + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + """ + Collate a batch of features into model inputs. + """ + # Separate inputs and labels + input_ids_list = [f["input_ids"] for f in features] + labels_list = [f["labels"] for f in features] + acoustic_mask_list = [f["acoustic_input_mask"] for f in features] + speech_list = [f["speech"] for f in features] + vae_tok_lens = [f["vae_tok_len"] for f in features] + + # Determine max lengths + max_seq_len = max(len(ids) for ids in input_ids_list) + max_speech_len = max(len(s) for s in speech_list) + max_vae_len = max(vae_tok_lens) + + batch_size = len(features) + + # Initialize padded tensors + input_ids = torch.full((batch_size, max_seq_len), self.pad_token_id, dtype=torch.long) + attention_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.long) + labels = torch.full((batch_size, max_seq_len), self.label_pad_token_id, dtype=torch.long) + acoustic_input_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.bool) + speech_tensors = torch.zeros((batch_size, max_speech_len), dtype=torch.float32) + speech_masks = torch.zeros((batch_size, max_vae_len), dtype=torch.bool) + + # Fill in the tensors (right padding for training) + # Note: processor uses left padding for inference/generation, but training uses right padding + for i, (ids, lbls, amask, speech, vae_len) in enumerate( + zip(input_ids_list, labels_list, acoustic_mask_list, speech_list, vae_tok_lens) + ): + seq_len = len(ids) + + # Right padding for input_ids and labels + input_ids[i, :seq_len] = torch.tensor(ids, dtype=torch.long) + attention_mask[i, :seq_len] = 1 + labels[i, :seq_len] = torch.tensor(lbls, dtype=torch.long) + acoustic_input_mask[i, :seq_len] = torch.tensor(amask, dtype=torch.bool) + + # Speech tensors (right padding, zeros work as padding) + speech_len = len(speech) + speech_tensors[i, :speech_len] = torch.tensor(speech, dtype=torch.float32) + speech_masks[i, :vae_len] = True + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "acoustic_input_mask": acoustic_input_mask, + "speech_tensors": speech_tensors, + "speech_masks": speech_masks, + } + + +class VibeVoiceASRDataset(Dataset): + """ + Dataset for VibeVoice ASR fine-tuning. + + Expected data format: + - Audio files: .mp3, .wav, .flac, etc. + - Label files: .json with matching name + + JSON format: + { + "audio_path": "0.mp3", + "audio_duration": 351.73, + "segments": [ + { + "speaker": 0, + "text": "Hey everyone, welcome back...", + "start": 0.0, + "end": 38.68 + }, + ... + ], + "hotwords": ["Tea Brew", "Aiden Host", ...] # optional + } + """ + + def __init__( + self, + data_dir: str, + processor: VibeVoiceASRProcessor, + max_audio_length: Optional[float] = None, # in seconds + use_hotwords: bool = True, + ): + """ + Initialize the dataset. + + Args: + data_dir: Directory containing audio files and JSON labels + processor: VibeVoice ASR processor + max_audio_length: Maximum audio length in seconds (None = no limit) + use_hotwords: Whether to include hotwords in context + """ + self.data_dir = Path(data_dir) + self.processor = processor + self.max_audio_length = max_audio_length + self.use_hotwords = use_hotwords + + # Find all JSON files + self.samples = self._load_samples() + logger.info(f"Loaded {len(self.samples)} samples from {data_dir}") + + def _load_samples(self) -> List[Dict[str, Any]]: + """Load and validate all samples from data directory.""" + samples = [] + + for json_path in sorted(self.data_dir.glob("*.json")): + try: + with open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + # Get audio path from JSON + audio_filename = data.get("audio_path") + if not audio_filename: + logger.warning(f"No audio_path specified in {json_path}") + continue + + audio_path = self.data_dir / audio_filename + if not audio_path.exists(): + logger.warning(f"Audio file not found: {audio_path}") + continue + + # Optional: filter by duration + if self.max_audio_length is not None: + duration = data.get("audio_duration", float("inf")) + if duration > self.max_audio_length: + logger.info(f"Skipping {json_path.stem}: duration {duration:.1f}s > max {self.max_audio_length}s") + continue + + samples.append({ + "audio_path": str(audio_path), + "json_path": str(json_path), + "data": data, + }) + + except Exception as e: + logger.warning(f"Error loading {json_path}: {e}") + continue + + return samples + + def _format_transcription(self, segments: List[Dict], audio_duration: float) -> str: + """ + Format transcription segments into JSON output format. + + This matches the expected model output format used in training. + """ + formatted_segments = [] + + for seg in segments: + formatted_seg = {} + # Add timestamp + formatted_seg["Start"] = round(seg['start'], 2) + formatted_seg["End"] = round(seg['end'], 2) + # Add speaker if available + if "speaker" in seg: + formatted_seg["Speaker"] = seg["speaker"] + # Add content + formatted_seg["Content"] = seg.get("text", "") + formatted_segments.append(formatted_seg) + + # Return as compact JSON string (no spaces after separators) + return json.dumps(formatted_segments, ensure_ascii=False, separators=(',', ':')) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """ + Get a single sample for training. + + Returns: + Dict with: + - input_ids: Token IDs for input (system + user + assistant prompt) + - labels: Token IDs for labels (-100 for non-predicted tokens) + - acoustic_input_mask: Mask for speech token positions + - speech: Raw audio array + - vae_tok_len: Number of speech tokens + """ + sample = self.samples[idx] + data = sample["data"] + audio_path = sample["audio_path"] + + # Prepare context info (hotwords) + context_info = None + if self.use_hotwords and "hotwords" in data: + hotwords = data["hotwords"] + if hotwords: + context_info = "\n".join(hotwords) + + # Process audio using the processor's internal method + encoding = self.processor._process_single_audio( + audio_path, + sampling_rate=None, + add_generation_prompt=True, + use_streaming=True, + context_info=context_info, + ) + + # Get the input tokens (system + user + generation prompt) + input_ids = encoding["input_ids"] + acoustic_input_mask = encoding["acoustic_input_mask"] + speech = encoding["speech"] + vae_tok_len = encoding["vae_tok_len"] + + # Format the target transcription + target_text = self._format_transcription( + data["segments"], + data.get("audio_duration", len(speech) / 24000) + ) + + # Encode target using apply_chat_template to match training format + # This adds the assistant role tokens (e.g., <|im_start|>assistant\n...<|im_end|>) + target_tokens = self.processor.tokenizer.apply_chat_template( + [{"role": "assistant", "content": target_text}], + tokenize=True, + add_generation_prompt=False, + ) + + # Combine input and target + full_input_ids = input_ids + target_tokens + full_acoustic_mask = acoustic_input_mask + [0] * len(target_tokens) + + # Create labels: -100 for input tokens, actual tokens for target + # We mask the input portion so loss is only computed on the response + labels = [-100] * len(input_ids) + target_tokens + + return { + "input_ids": full_input_ids, + "labels": labels, + "acoustic_input_mask": full_acoustic_mask, + "speech": speech, + "vae_tok_len": vae_tok_len, + } + + +def get_lora_config( + r: int = 16, + lora_alpha: int = 32, + lora_dropout: float = 0.05, + target_modules: Optional[List[str]] = None, +) -> LoraConfig: + """ + Create LoRA configuration for VibeVoice ASR model. + + We apply LoRA to the language model's attention layers and MLP, + following common practices for LLM fine-tuning. + + Args: + r: LoRA rank + lora_alpha: LoRA scaling factor + lora_dropout: Dropout for LoRA layers + target_modules: List of module names to apply LoRA to + + Returns: + LoraConfig object + """ + if target_modules is None: + # Target Qwen2 attention and MLP layers + # These are the common targets for language model fine-tuning + target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + + return LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type=TaskType.CAUSAL_LM, + ) + + +def setup_model_for_training( + model_path: str, + lora_config: LoraConfig, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + gradient_checkpointing: bool = True, +) -> Tuple[nn.Module, VibeVoiceASRProcessor]: + """ + Load and prepare model for LoRA training. + + Args: + model_path: Path to pretrained model + lora_config: LoRA configuration + device: Device to use + dtype: Data type for model + gradient_checkpointing: Whether to use gradient checkpointing + + Returns: + Tuple of (model, processor) + """ + logger.info(f"Loading model from {model_path}") + + # Load processor + processor = VibeVoiceASRProcessor.from_pretrained( + model_path, + language_model_pretrained_name="Qwen/Qwen2.5-7B" + ) + + # Load model + model = VibeVoiceASRForConditionalGeneration.from_pretrained( + model_path, + dtype=dtype, + device_map=device if device == "auto" else None, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + + if device != "auto": + model = model.to(device) + + # Freeze speech tokenizers (we only want to fine-tune the language model) + for name, param in model.named_parameters(): + if "acoustic_tokenizer" in name or "semantic_tokenizer" in name: + param.requires_grad = False + logger.debug(f"Frozen: {name}") + + # Apply LoRA + logger.info(f"Applying LoRA with config: r={lora_config.r}, alpha={lora_config.lora_alpha}") + model = get_peft_model(model, lora_config) + + # Print trainable parameters + model.print_trainable_parameters() + + # Enable gradient checkpointing if requested + if gradient_checkpointing: + model.enable_input_require_grads() + model.gradient_checkpointing_enable() + + return model, processor + + +def train( + model_args: ModelArguments, + data_args: DataArguments, + lora_args: LoraArguments, + training_args: TrainingArguments, + gradient_checkpointing: bool = True, +): + """ + Main training function for LoRA fine-tuning. + + Args: + model_args: Model configuration arguments + data_args: Data configuration arguments + lora_args: LoRA configuration arguments + training_args: HuggingFace TrainingArguments + gradient_checkpointing: Whether to use gradient checkpointing + """ + # Set seed + torch.manual_seed(training_args.seed) + np.random.seed(training_args.seed) + + # Setup LoRA config + lora_config = get_lora_config( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + lora_dropout=lora_args.lora_dropout, + ) + + # Determine device + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Load model and processor + dtype = torch.bfloat16 if device != "cpu" else torch.float32 + model, processor = setup_model_for_training( + model_path=model_args.model_path, + lora_config=lora_config, + device=device, + dtype=dtype, + gradient_checkpointing=gradient_checkpointing, + ) + + # Create dataset + train_dataset = VibeVoiceASRDataset( + data_dir=data_args.data_dir, + processor=processor, + max_audio_length=data_args.max_audio_length, + use_hotwords=data_args.use_hotwords, + ) + + if len(train_dataset) == 0: + logger.error("No training samples found!") + return + + # Create data collator + data_collator = VibeVoiceASRDataCollator( + processor=processor, + pad_token_id=processor.pad_id, + ) + + # Set some sensible defaults for audio training + training_args.dataloader_num_workers = 0 # Audio loading can be tricky with multiprocessing + training_args.remove_unused_columns = False # Keep all columns + + # Create trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=data_collator, + ) + + # Train + logger.info("Starting training...") + logger.info(f" Num samples = {len(train_dataset)}") + logger.info(f" Num epochs = {training_args.num_train_epochs}") + logger.info(f" Batch size = {training_args.per_device_train_batch_size}") + logger.info(f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}") + total_steps = len(train_dataset) * int(training_args.num_train_epochs) // ( + training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps + ) + logger.info(f" Total optimization steps = {total_steps}") + + train_result = trainer.train() + + # Save final model + logger.info(f"Saving model to {training_args.output_dir}") + trainer.save_model(training_args.output_dir) + + # Save training metrics + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + + # Save processor config + processor.save_pretrained(training_args.output_dir) + + logger.info("Training complete!") + + return model, processor + + +def main(): + # Use HfArgumentParser to parse all argument dataclasses + parser = HfArgumentParser((ModelArguments, DataArguments, LoraArguments, TrainingArguments)) + model_args, data_args, lora_args, training_args = parser.parse_args_into_dataclasses() + + # Run training + train( + model_args=model_args, + data_args=data_args, + lora_args=lora_args, + training_args=training_args, + ) + + +if __name__ == "__main__": + main() diff --git a/finetuning/toy_dataset/0.json b/finetuning/toy_dataset/0.json new file mode 100755 index 0000000..19c72b3 --- /dev/null +++ b/finetuning/toy_dataset/0.json @@ -0,0 +1,73 @@ +{ + "audio_duration": 351.73333333333335, + "audio_path": "0.mp3", + "segments": [ + { + "speaker": 0, + "text": "Hey everyone, welcome back to Tea Brew. I’m Aidan Host, and, uh, sorry we’re a tad late today. Travel was a bit wild on the roads and our guest came in from near Meeter Street, so, you know, life happens. But we’re here now and buzzing to talk property tech. We’ve had loads of messages from landlords about spiraling energy costs in HMOs and, at the same time, a bunch of questions on streamlining tenant recruitment. So today we’re kinda merging both worlds: how a smart heating setup can cut bills, and how a software assistant like Rent Byte can take the grind out of advertising, onboarding, and maintenance reporting. And sitting with me is Sayid Guest, who’s, uh, a landlord turned gadget builder—haha—ready to demystify the thermostatic stuff that tenants can actually live with.", + "start": 0.0, + "end": 38.68 + }, + { + "speaker": 1, + "text": "Thanks, Aidan Host, and, yeah, hi everyone. I’m Sayid Guest, and this topic is personal for me. I started out managing a handful of HMOs, and—um—the classic scene was radiators roaring away with nobody home, windows cracked wide open. You pay for the heat and watch it pour right out. Not fun. There’s also that line we can’t cross: you can’t lock tenants out of controls. So I began tinkering, looking for a way to respect tenant comfort while, uh, controlling the waste. The device I built evolved from a scrappy prototype to a solid system, and the surprising bit was how immediate the savings were. Like, right from the first month, bills were trending down hard, and I thought, wow, landlords need something like this across the board.", + "start": 38.75, + "end": 77.88 + }, + { + "speaker": 0, + "text": "Exactly. And, actually, on the software side, we’ve seen a similar DIY-to-pro evolution. Rent Byte started because landlords were fed up with juggling spreadsheets, emails, and random listings. It’s, um, designed by landlords with tenants in mind—trust and transparency baked in. With Rent Byte, you can push your property ads out fast, track leads, run checks, and glide folks through the onboarding without the headache. Then, when people move in, it doesn’t stop; you’ve got in-app maintenance reporting, job tracking, and clear timelines, so tenants don’t feel, you know, ignored. Tea Brew listeners keep telling us the pain isn’t just finding residents, it’s keeping the whole machine humming. And if you pair that workflow with smart heating controls, you’re hitting cost, comfort, and communication all in one go.", + "start": 77.92, + "end": 115.18 + }, + { + "speaker": 2, + "text": "Hold on, can I jump in? So, um, I manage two HMOs over by Meeter Street, and the energy bills last winter were brutal. Everyone online was shouting tips—weather compensation, schedule tweaks, the whole shebang—but tenants still cranked it up when they felt chilly. We can’t, like, lock the thermostat, right? So how does your approach, Sayid Guest, keep the system fair? Tenants get access, but landlords don’t get burned—haha, terrible pun. Does it do occupancy sensing, or is it just a timer? And how do you stop the classic boost button from becoming a permanent ‘on’?", + "start": 115.19, + "end": 147.56 + }, + { + "speaker": 1, + "text": "Great questions. So it’s a combo. Think of it as a comfort-first schedule with protections. You set reasonable heating windows—morning and evening, say—and tenants can press a boost for extra time. But the boost is capped and resets, so it won’t, uh, run the boiler all day. For empty rooms and the, you know, window-wide-open scenario, if you add sensors, the system detects rapid drops or no movement and tapers the heat until conditions make sense. It’s not about denying warmth; it’s about stopping waste that literally no one benefits from. With HMOs you need communal logic—landings and kitchens matter—so multi-zone control helps keep spaces balanced. And landlords, like me, get analytics: you can see where energy is leaking and fine-tune settings. When I first tested it, I saw 30–50% reductions. That’s not a promise for every building—each setup is its own puzzle—but the pattern has been, um, consistently strong.", + "start": 147.56, + "end": 193.57 + }, + { + "speaker": 0, + "text": "Yeah, that resonates. I walked into a shared house once—near Meeter Street—and the thermostat was at 29, like sauna levels, with the window propped open. The tenants weren’t being malicious, they were just coping: drafty room, quick fix, crank the dial. So, as Tea Brew keeps saying, we need systems that, uh, balance tenant agency with sensible guardrails. And the nice thing with Rent Byte is that it keeps the conversation flowing. Tenants can raise an issue through maintenance reporting, and as soon as the ticket is created, the timeline starts. If it’s a cold spot or a radiator fault, you’ve got the history at your fingertips. That way, you don’t blame behavior when it’s actually a hardware problem.", + "start": 193.57, + "end": 225.73 + }, + { + "speaker": 2, + "text": "Okay, so picture this: I advertise a new room, get flooded with inquiries, then I’m buried in emails. The software side—like Rent Byte—would pull those leads into a pipeline, right? And if someone mentions the room feels chilly during a viewing, I could, uh, flag that straight away. Here’s the kicker: can your device talk to the platform? Like, if there’s a temp anomaly or a sensor alert, could it auto-create a maintenance ticket so I don’t miss it? And, sorry, I’m thinking out loud here, but could Rent Byte show those energy graphs inside the tenant portal without freaking people out—more like, you know, helpful nudges than lecturing?", + "start": 225.74, + "end": 258.74 + }, + { + "speaker": 1, + "text": "Haha, I love the way you think. Yes, integration is the future. We’ve built an API so platforms like Rent Byte can pull summary data—no one needs to drown in charts, just the, um, useful stuff. For example, you can surface a gentle insight: “Heating is already scheduled; boost is available for 30 minutes.” Tenants see options, not rules. If a sensor flags a stuck valve or a window open for ages, Rent Byte can spin up a maintenance ticket, assign it, and track the fix. And because Tea Brew listeners keep asking about transparency, we’ve found tenants appreciate seeing that there’s a fair schedule in place. By the way, thanks again for the invite, Aidan Host; getting this across without jargon is half the battle.", + "start": 258.74, + "end": 292.52 + }, + { + "speaker": 0, + "text": "Totally. And, uh, we’ve noticed another benefit: advertising feels cleaner when you can tell prospects, right up front, that the home uses sensible heating with tenant-controlled boosts. It sounds small, but it sets expectations and, you know, avoids future friction. On the admin side, Rent Byte keeps everything documented—from viewing notes to audit trails—so if someone on Meeter Street says their radiator’s been weird for weeks, you can point to the timeline and fix history quickly. Um, we’ve had a bunch of landlords message Tea Brew after implementing this kind of setup, saying the combo of software plus smart heating saved money and, honestly, reduced arguments. That’s the vibe we want.", + "start": 292.53, + "end": 323.63 + }, + { + "speaker": 2, + "text": "Same here. To wrap, uh, I’m thinking: start with clear, humane policies, pair them with tech that respects tenant comfort, and keep the communication channel open. Aidan Host, if folks want to try Rent Byte, what’s the first step? And, Sayid Guest, for the device, is there like a starter kit guide so, um, non-tech landlords don’t panic? Maybe put links under the episode—sorry—under Tea Brew show notes. I’ve got two more rooms to fill near Meeter Street, and it’d be great to kick this off before the cold snaps hit again.", + "start": 323.64, + "end": 351.73 + } + ], + "hotwords": [ + "Tea Brew", + "Aiden Host", + "Saeed Guest", + "Rent Byte", + "Meter Street" + ] +} \ No newline at end of file diff --git a/finetuning/toy_dataset/0.mp3 b/finetuning/toy_dataset/0.mp3 new file mode 100755 index 0000000..a994394 Binary files /dev/null and b/finetuning/toy_dataset/0.mp3 differ diff --git a/finetuning/toy_dataset/1.json b/finetuning/toy_dataset/1.json new file mode 100755 index 0000000..26844ac --- /dev/null +++ b/finetuning/toy_dataset/1.json @@ -0,0 +1,79 @@ +{ + "audio_duration": 328.26666666666665, + "audio_path": "1.mp3", + "segments": [ + { + "speaker": 0, + "text": "Welcome back to our Youth Month special. Um, before we dive in, Tandi, you ready? We’re honoring young folks who, like, you know, shook the ground. The day itself marks those anti–language policy protests in the 70s—students across the country, campuses, townships—standing up. And today, we’re asking you all on the WhatsApp line to share who inspired you: a teacher, a cousin, someone you admire. Also, we’ll talk about campaigns like Crown Wrights, because identity’s not cosmetic; it’s, uh, core.", + "start": 0.0, + "end": 33.15 + }, + { + "speaker": 1, + "text": "Yeah, yeah, I’m totally in, haha. Thanks, Leila. So, um, I’ve been thinking about how the current youth still carry heavy stuff—joblessness, violence against women, and, heartbreakingly, queer kids being targeted. It’s, like, gutting. Yet they keep going. I mean, this is Tandi speaking from the heart here. The person I want to celebrate is Zahra. She stood up at Coyl High when policies tried to tame her natural hair, and she, uh, didn’t flinch. She even wrote a children’s story as part of Crown Wrights, to help little ones see their coils as power instead of problem.", + "start": 33.24, + "end": 71.33 + }, + { + "speaker": 2, + "text": "Wow, okay, that hits hard. Zahra at Coyl High—I remember, like, seeing clips where she just, you know, stood there calmly while adults were telling her to “fix” herself. It gave me chills. Identity isn’t some minor detail; it’s a major, uh, principle—wait, I always mix that with principal, haha. Anyway, the bravery at Crown Wrights events has ripple effects. And Leila, you’ve talked about how hair, especially coily textures, can be policed as a way to, um, shrink someone’s confidence.", + "start": 71.52, + "end": 102.9 + }, + { + "speaker": 0, + "text": "Exactly! And thanks, Tandi, for, like, naming the hard stuff. The way that activist reframed “acceptable” appearance shows young people don’t need permission to be whole. When you own your look, you walk into boardrooms you never imagined. You, uh, sit at tables, you speak up. It’s the same energy that wins pageants and policies—like, a confidence that shifts rooms. And for everyone listening, send a voice note about your own champion and share stories about reclaiming a “crown.”", + "start": 102.9, + "end": 132.14 + }, + { + "speaker": 1, + "text": "Hold on—just to paint the scene. At Coyl High, Zahra didn’t wait for a senior to intervene; she, um, chose the moment. No hesitation. She knew the risk: future opportunities, social backlash, being labeled “difficult.” And she still stood up. For me, Tandi, that shows the lesson the old protests tried to teach—use your voice now, because silence, like, steals time. And you always say confidence is contagious; when one girl lifts her chin, dozens follow.", + "start": 132.17, + "end": 161.49 + }, + { + "speaker": 2, + "text": "Yeah, yes—exactly. And, mm, the thing that gets me is the dignity piece. People act like hair is trivial, but for Black girls, policing curls is a way to control presence. That campaign reframes it: from “problem” to “pride.” When that flip happens, you start applying for roles you thought were off-limits, you choose your course, you don’t break—uh, brake—just because someone says you don’t fit. And, you mentioned queer youth; the intersections matter. That stand speaks beyond curls, into, like, the right to be fully yourself.", + "start": 161.52, + "end": 196.53 + }, + { + "speaker": 0, + "text": "Speaking of which, Leila, could you share a listener shout-out? We got a bunch of WhatsApps about teachers, neighbors, and, uh, nurses who kept kids afloat. Also there’s a note praising the book in the Crown Wrights series—apparently the illustrations made a little one feel seen at school.", + "start": 196.91, + "end": 212.93 + }, + { + "speaker": 2, + "text": "This is Leila—oh yeah, totally. So, um, there’s a message from a parent who says their child used to hide her coils under a cap, and after reading that Crown Wrights story she walked into class with no fear. And another listener’s like, “my teacher changed my life,” which—haha—yes. The past protests weren’t just about language policy; they set a template: organize, insist, repeat. Honestly, Tandi, I feel like honoring Zahra honors every student who refuses to be edited at Coyl High-type schools, you know?", + "start": 213.59, + "end": 245.25 + }, + { + "speaker": 1, + "text": "Exactly! And, um, can I just say: elders sometimes tell us to wait for leaders to fix things. But the lesson, from then till now, is don’t outsource your voice. The moment you speak, the course of the moment changes—like, the crowd pivots from idle to action. And for anyone feeling alone, tag us and the campaign; you’ll find community fast.", + "start": 245.25, + "end": 266.64 + }, + { + "speaker": 0, + "text": "Right, and, wow, we’ve barely scratched the surface. I want to circle back because someone asked if honoring hair is, um, frivolous compared to, say, jobs or safety. My answer: it’s linked. Coyl High-style rules are part of a bigger system that decides who is welcome. Undo that, and you expand opportunity. And if you’re looking for names, Leila suggested a list: local mentors, teacher heroes, youth organizers. That point about resilience—phew—needed.", + "start": 266.64, + "end": 297.57 + }, + { + "speaker": 2, + "text": "Mhm, and one more thought, then we’ll wrap. Zahra didn’t become a symbol because she wanted fame; she became one because she refused to break. That stubborn joy? It’s, like, fuel. I’m Leila, and I’m grateful we still celebrate that day in June as a reminder to keep going. To every listener—send your piece, um, peace—haha—on the line. And thanks, Tandi, for bringing this story, and thanks to campaigns like Crown Wrights for keeping the flame on.", + "start": 297.78, + "end": 328.27 + } + ], + "hotwords": [ + "Thandie", + "Leila", + "Zara", + "Coyle High", + "Crown Rites" + ] +} \ No newline at end of file diff --git a/finetuning/toy_dataset/1.mp3 b/finetuning/toy_dataset/1.mp3 new file mode 100755 index 0000000..fa6a27e Binary files /dev/null and b/finetuning/toy_dataset/1.mp3 differ diff --git a/vibevoice/modular/modeling_vibevoice_asr.py b/vibevoice/modular/modeling_vibevoice_asr.py index 706bf00..4663d3f 100644 --- a/vibevoice/modular/modeling_vibevoice_asr.py +++ b/vibevoice/modular/modeling_vibevoice_asr.py @@ -364,7 +364,7 @@ class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, Generati output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict - use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else getattr(self.config, 'use_cache', False) # Process inputs if inputs_embeds is None and input_ids is not None: @@ -377,6 +377,8 @@ class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, Generati speech_masks=speech_masks, speech_semantic_tensors=speech_semantic_tensors, ) + # Clone to avoid in-place operation on leaf variable during training + inputs_embeds = inputs_embeds.clone() inputs_embeds[acoustic_input_mask] = speech_features # Forward through the model @@ -402,7 +404,7 @@ class VibeVoiceASRForConditionalGeneration(VibeVoiceASRPreTrainedModel, Generati shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss(ignore_index=-100) shift_logits = shift_logits.view(-1, self.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism