update ft code

This commit is contained in:
pengzhiliang
2026-01-22 05:20:25 -08:00
parent db2f1d9ff3
commit cef628e1b5
5 changed files with 40 additions and 35 deletions
+23 -7
View File
@@ -5,6 +5,9 @@ This directory contains scripts for LoRA (Low-Rank Adaptation) fine-tuning of th
## Requirements
```bash
# you need to install vibevoice first
# pip install -e .[asr]
pip install peft accelerate
```
@@ -52,23 +55,36 @@ Each JSON file should have the following structure:
"end": 77.88
}
],
"hotwords": ["Tea Brew", "Aiden Host"] // optional
"customized_context": ["Tea Brew", "Aiden Host", "The property is near Meter Street."] // optional, domain-specific terms or context sentences
}
```
## Training
### Basic Usage
### Basic
```bash
python lora_finetune.py \
# 1 GPU
torchrun --nproc_per_node=1 lora_finetune.py \
--model_path microsoft/VibeVoice-ASR \
--data_dir ./toy_dataset \
--output_dir ./output \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--learning_rate 1e-4 \
--bf16
--bf16 \
--report_to none
# Specific GPUs (e.g., GPU 0,1,2,3)
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 lora_finetune.py \
--model_path microsoft/VibeVoice-ASR \
--data_dir ./toy_dataset \
--output_dir ./output \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--learning_rate 1e-4 \
--bf16 \
--report_to none
```
### Full Options
@@ -76,7 +92,7 @@ python lora_finetune.py \
The script uses HuggingFace's `TrainingArguments`, so all standard options are available:
```bash
python lora_finetune.py \
torchrun --nproc_per_node=4 lora_finetune.py \
--model_path microsoft/VibeVoice-ASR \
--data_dir ./toy_dataset \
--output_dir ./output \
@@ -108,7 +124,7 @@ python lora_finetune.py \
| `--gradient_accumulation_steps` | 1 | Effective batch size = batch_size × grad_accum |
| `--learning_rate` | 5e-5 | Learning rate (1e-4 to 2e-4 typical for LoRA) |
| `--gradient_checkpointing` | False | Enable to reduce memory usage |
| `--use_hotwords` | True | Include hotwords from JSON as context |
| `--use_customized_context` | True | Include customized_context from JSON as additional context |
| `--max_audio_length` | None | Skip audio longer than this (seconds) |
## Inference with Fine-tuned Model
@@ -118,7 +134,7 @@ python inference_lora.py \
--base_model microsoft/VibeVoice-ASR \
--lora_path ./output \
--audio_file ./toy_dataset/0.mp3 \
--context_info "Hotwords: Tea Brew, Aiden Host"
--context_info "Tea Brew, Aiden Host"
```
## Merging LoRA Weights (Optional)
-6
View File
@@ -11,17 +11,11 @@ Usage:
--audio_file ./toy_dataset/0.mp3
"""
import os
import sys
import argparse
import torch
from pathlib import Path
from peft import PeftModel
# Add parent directory to path for vibevoice imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
+12 -17
View File
@@ -6,8 +6,6 @@ This script implements LoRA (Low-Rank Adaptation) fine-tuning for the VibeVoice
It uses PEFT (Parameter-Efficient Fine-Tuning) library for efficient training.
"""
import os
import sys
import json
import logging
from pathlib import Path
@@ -31,9 +29,6 @@ from peft import (
TaskType,
)
# Add parent directory to path for vibevoice imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
@@ -66,9 +61,9 @@ class DataArguments:
default=None,
metadata={"help": "Maximum audio length in seconds (default: no limit)"}
)
use_hotwords: bool = field(
use_customized_context: bool = field(
default=True,
metadata={"help": "Whether to use hotwords from JSON"}
metadata={"help": "Whether to use customized_context from JSON as additional context"}
)
@@ -174,7 +169,7 @@ class VibeVoiceASRDataset(Dataset):
},
...
],
"hotwords": ["Tea Brew", "Aiden Host", ...] # optional
"customized_context": ["Tea Brew", "The property is near Meter Street."] # optional
}
"""
@@ -183,7 +178,7 @@ class VibeVoiceASRDataset(Dataset):
data_dir: str,
processor: VibeVoiceASRProcessor,
max_audio_length: Optional[float] = None, # in seconds
use_hotwords: bool = True,
use_customized_context: bool = True,
):
"""
Initialize the dataset.
@@ -192,12 +187,12 @@ class VibeVoiceASRDataset(Dataset):
data_dir: Directory containing audio files and JSON labels
processor: VibeVoice ASR processor
max_audio_length: Maximum audio length in seconds (None = no limit)
use_hotwords: Whether to include hotwords in context
use_customized_context: Whether to include customized_context in prompt
"""
self.data_dir = Path(data_dir)
self.processor = processor
self.max_audio_length = max_audio_length
self.use_hotwords = use_hotwords
self.use_customized_context = use_customized_context
# Find all JSON files
self.samples = self._load_samples()
@@ -284,12 +279,12 @@ class VibeVoiceASRDataset(Dataset):
data = sample["data"]
audio_path = sample["audio_path"]
# Prepare context info (hotwords)
# Prepare context info (customized_context)
context_info = None
if self.use_hotwords and "hotwords" in data:
hotwords = data["hotwords"]
if hotwords:
context_info = "\n".join(hotwords)
if self.use_customized_context and "customized_context" in data:
customized_context = data["customized_context"]
if customized_context:
context_info = "\n".join(customized_context)
# Process audio using the processor's internal method
encoding = self.processor._process_single_audio(
@@ -488,7 +483,7 @@ def train(
data_dir=data_args.data_dir,
processor=processor,
max_audio_length=data_args.max_audio_length,
use_hotwords=data_args.use_hotwords,
use_customized_context=data_args.use_customized_context,
)
if len(train_dataset) == 0:
+2 -2
View File
@@ -63,11 +63,11 @@
"end": 351.73
}
],
"hotwords": [
"customized_context": [
"Tea Brew",
"Aiden Host",
"Saeed Guest",
"Rent Byte",
"Meter Street"
"The property is located near Meter Street."
]
}
+3 -3
View File
@@ -69,11 +69,11 @@
"end": 328.27
}
],
"hotwords": [
"customized_context": [
"Thandie",
"Leila",
"Zara",
"Coyle High",
"Crown Rites"
"The story takes place at Coyle High school.",
"Crown Rites is a campaign for hair rights."
]
}