From cef628e1b548a0a6f6f3472f8ffe6763f6010f1a Mon Sep 17 00:00:00 2001 From: pengzhiliang <1083127130@qq.com> Date: Thu, 22 Jan 2026 05:20:25 -0800 Subject: [PATCH] update ft code --- finetuning/README.md | 30 +++++++++++++++++++++++------- finetuning/inference_lora.py | 6 ------ finetuning/lora_finetune.py | 29 ++++++++++++----------------- finetuning/toy_dataset/0.json | 4 ++-- finetuning/toy_dataset/1.json | 6 +++--- 5 files changed, 40 insertions(+), 35 deletions(-) diff --git a/finetuning/README.md b/finetuning/README.md index 7c2b8ff..b10bf67 100644 --- a/finetuning/README.md +++ b/finetuning/README.md @@ -5,6 +5,9 @@ This directory contains scripts for LoRA (Low-Rank Adaptation) fine-tuning of th ## Requirements ```bash +# you need to install vibevoice first +# pip install -e .[asr] + pip install peft accelerate ``` @@ -52,23 +55,36 @@ Each JSON file should have the following structure: "end": 77.88 } ], - "hotwords": ["Tea Brew", "Aiden Host"] // optional + "customized_context": ["Tea Brew", "Aiden Host", "The property is near Meter Street."] // optional, domain-specific terms or context sentences } ``` ## Training -### Basic Usage +### Basic ```bash -python lora_finetune.py \ +# 1 GPU +torchrun --nproc_per_node=1 lora_finetune.py \ --model_path microsoft/VibeVoice-ASR \ --data_dir ./toy_dataset \ --output_dir ./output \ --num_train_epochs 3 \ --per_device_train_batch_size 1 \ --learning_rate 1e-4 \ - --bf16 + --bf16 \ + --report_to none + +# Specific GPUs (e.g., GPU 0,1,2,3) +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 lora_finetune.py \ + --model_path microsoft/VibeVoice-ASR \ + --data_dir ./toy_dataset \ + --output_dir ./output \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --bf16 \ + --report_to none ``` ### Full Options @@ -76,7 +92,7 @@ python lora_finetune.py \ The script uses HuggingFace's `TrainingArguments`, so all standard options are available: ```bash -python lora_finetune.py \ +torchrun --nproc_per_node=4 lora_finetune.py \ --model_path microsoft/VibeVoice-ASR \ --data_dir ./toy_dataset \ --output_dir ./output \ @@ -108,7 +124,7 @@ python lora_finetune.py \ | `--gradient_accumulation_steps` | 1 | Effective batch size = batch_size × grad_accum | | `--learning_rate` | 5e-5 | Learning rate (1e-4 to 2e-4 typical for LoRA) | | `--gradient_checkpointing` | False | Enable to reduce memory usage | -| `--use_hotwords` | True | Include hotwords from JSON as context | +| `--use_customized_context` | True | Include customized_context from JSON as additional context | | `--max_audio_length` | None | Skip audio longer than this (seconds) | ## Inference with Fine-tuned Model @@ -118,7 +134,7 @@ python inference_lora.py \ --base_model microsoft/VibeVoice-ASR \ --lora_path ./output \ --audio_file ./toy_dataset/0.mp3 \ - --context_info "Hotwords: Tea Brew, Aiden Host" + --context_info "Tea Brew, Aiden Host" ``` ## Merging LoRA Weights (Optional) diff --git a/finetuning/inference_lora.py b/finetuning/inference_lora.py index 4c9a316..d6d8163 100644 --- a/finetuning/inference_lora.py +++ b/finetuning/inference_lora.py @@ -11,17 +11,11 @@ Usage: --audio_file ./toy_dataset/0.mp3 """ -import os -import sys import argparse import torch -from pathlib import Path from peft import PeftModel -# Add parent directory to path for vibevoice imports -sys.path.insert(0, str(Path(__file__).parent.parent)) - from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor diff --git a/finetuning/lora_finetune.py b/finetuning/lora_finetune.py index e5b64b0..0c4f9c9 100644 --- a/finetuning/lora_finetune.py +++ b/finetuning/lora_finetune.py @@ -6,8 +6,6 @@ This script implements LoRA (Low-Rank Adaptation) fine-tuning for the VibeVoice It uses PEFT (Parameter-Efficient Fine-Tuning) library for efficient training. """ -import os -import sys import json import logging from pathlib import Path @@ -31,9 +29,6 @@ from peft import ( TaskType, ) -# Add parent directory to path for vibevoice imports -sys.path.insert(0, str(Path(__file__).parent.parent)) - from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor @@ -66,9 +61,9 @@ class DataArguments: default=None, metadata={"help": "Maximum audio length in seconds (default: no limit)"} ) - use_hotwords: bool = field( + use_customized_context: bool = field( default=True, - metadata={"help": "Whether to use hotwords from JSON"} + metadata={"help": "Whether to use customized_context from JSON as additional context"} ) @@ -174,7 +169,7 @@ class VibeVoiceASRDataset(Dataset): }, ... ], - "hotwords": ["Tea Brew", "Aiden Host", ...] # optional + "customized_context": ["Tea Brew", "The property is near Meter Street."] # optional } """ @@ -183,7 +178,7 @@ class VibeVoiceASRDataset(Dataset): data_dir: str, processor: VibeVoiceASRProcessor, max_audio_length: Optional[float] = None, # in seconds - use_hotwords: bool = True, + use_customized_context: bool = True, ): """ Initialize the dataset. @@ -192,12 +187,12 @@ class VibeVoiceASRDataset(Dataset): data_dir: Directory containing audio files and JSON labels processor: VibeVoice ASR processor max_audio_length: Maximum audio length in seconds (None = no limit) - use_hotwords: Whether to include hotwords in context + use_customized_context: Whether to include customized_context in prompt """ self.data_dir = Path(data_dir) self.processor = processor self.max_audio_length = max_audio_length - self.use_hotwords = use_hotwords + self.use_customized_context = use_customized_context # Find all JSON files self.samples = self._load_samples() @@ -284,12 +279,12 @@ class VibeVoiceASRDataset(Dataset): data = sample["data"] audio_path = sample["audio_path"] - # Prepare context info (hotwords) + # Prepare context info (customized_context) context_info = None - if self.use_hotwords and "hotwords" in data: - hotwords = data["hotwords"] - if hotwords: - context_info = "\n".join(hotwords) + if self.use_customized_context and "customized_context" in data: + customized_context = data["customized_context"] + if customized_context: + context_info = "\n".join(customized_context) # Process audio using the processor's internal method encoding = self.processor._process_single_audio( @@ -488,7 +483,7 @@ def train( data_dir=data_args.data_dir, processor=processor, max_audio_length=data_args.max_audio_length, - use_hotwords=data_args.use_hotwords, + use_customized_context=data_args.use_customized_context, ) if len(train_dataset) == 0: diff --git a/finetuning/toy_dataset/0.json b/finetuning/toy_dataset/0.json index 19c72b3..86a3498 100755 --- a/finetuning/toy_dataset/0.json +++ b/finetuning/toy_dataset/0.json @@ -63,11 +63,11 @@ "end": 351.73 } ], - "hotwords": [ + "customized_context": [ "Tea Brew", "Aiden Host", "Saeed Guest", "Rent Byte", - "Meter Street" + "The property is located near Meter Street." ] } \ No newline at end of file diff --git a/finetuning/toy_dataset/1.json b/finetuning/toy_dataset/1.json index 26844ac..1cdf118 100755 --- a/finetuning/toy_dataset/1.json +++ b/finetuning/toy_dataset/1.json @@ -69,11 +69,11 @@ "end": 328.27 } ], - "hotwords": [ + "customized_context": [ "Thandie", "Leila", "Zara", - "Coyle High", - "Crown Rites" + "The story takes place at Coyle High school.", + "Crown Rites is a campaign for hair rights." ] } \ No newline at end of file