update ft code
This commit is contained in:
+23
-7
@@ -5,6 +5,9 @@ This directory contains scripts for LoRA (Low-Rank Adaptation) fine-tuning of th
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
# you need to install vibevoice first
|
||||
# pip install -e .[asr]
|
||||
|
||||
pip install peft accelerate
|
||||
```
|
||||
|
||||
@@ -52,23 +55,36 @@ Each JSON file should have the following structure:
|
||||
"end": 77.88
|
||||
}
|
||||
],
|
||||
"hotwords": ["Tea Brew", "Aiden Host"] // optional
|
||||
"customized_context": ["Tea Brew", "Aiden Host", "The property is near Meter Street."] // optional, domain-specific terms or context sentences
|
||||
}
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Basic Usage
|
||||
### Basic
|
||||
|
||||
```bash
|
||||
python lora_finetune.py \
|
||||
# 1 GPU
|
||||
torchrun --nproc_per_node=1 lora_finetune.py \
|
||||
--model_path microsoft/VibeVoice-ASR \
|
||||
--data_dir ./toy_dataset \
|
||||
--output_dir ./output \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--learning_rate 1e-4 \
|
||||
--bf16
|
||||
--bf16 \
|
||||
--report_to none
|
||||
|
||||
# Specific GPUs (e.g., GPU 0,1,2,3)
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 lora_finetune.py \
|
||||
--model_path microsoft/VibeVoice-ASR \
|
||||
--data_dir ./toy_dataset \
|
||||
--output_dir ./output \
|
||||
--num_train_epochs 3 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--learning_rate 1e-4 \
|
||||
--bf16 \
|
||||
--report_to none
|
||||
```
|
||||
|
||||
### Full Options
|
||||
@@ -76,7 +92,7 @@ python lora_finetune.py \
|
||||
The script uses HuggingFace's `TrainingArguments`, so all standard options are available:
|
||||
|
||||
```bash
|
||||
python lora_finetune.py \
|
||||
torchrun --nproc_per_node=4 lora_finetune.py \
|
||||
--model_path microsoft/VibeVoice-ASR \
|
||||
--data_dir ./toy_dataset \
|
||||
--output_dir ./output \
|
||||
@@ -108,7 +124,7 @@ python lora_finetune.py \
|
||||
| `--gradient_accumulation_steps` | 1 | Effective batch size = batch_size × grad_accum |
|
||||
| `--learning_rate` | 5e-5 | Learning rate (1e-4 to 2e-4 typical for LoRA) |
|
||||
| `--gradient_checkpointing` | False | Enable to reduce memory usage |
|
||||
| `--use_hotwords` | True | Include hotwords from JSON as context |
|
||||
| `--use_customized_context` | True | Include customized_context from JSON as additional context |
|
||||
| `--max_audio_length` | None | Skip audio longer than this (seconds) |
|
||||
|
||||
## Inference with Fine-tuned Model
|
||||
@@ -118,7 +134,7 @@ python inference_lora.py \
|
||||
--base_model microsoft/VibeVoice-ASR \
|
||||
--lora_path ./output \
|
||||
--audio_file ./toy_dataset/0.mp3 \
|
||||
--context_info "Hotwords: Tea Brew, Aiden Host"
|
||||
--context_info "Tea Brew, Aiden Host"
|
||||
```
|
||||
|
||||
## Merging LoRA Weights (Optional)
|
||||
|
||||
@@ -11,17 +11,11 @@ Usage:
|
||||
--audio_file ./toy_dataset/0.mp3
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
from peft import PeftModel
|
||||
|
||||
# Add parent directory to path for vibevoice imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
|
||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||
|
||||
|
||||
+12
-17
@@ -6,8 +6,6 @@ This script implements LoRA (Low-Rank Adaptation) fine-tuning for the VibeVoice
|
||||
It uses PEFT (Parameter-Efficient Fine-Tuning) library for efficient training.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
@@ -31,9 +29,6 @@ from peft import (
|
||||
TaskType,
|
||||
)
|
||||
|
||||
# Add parent directory to path for vibevoice imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
|
||||
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor
|
||||
|
||||
@@ -66,9 +61,9 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "Maximum audio length in seconds (default: no limit)"}
|
||||
)
|
||||
use_hotwords: bool = field(
|
||||
use_customized_context: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use hotwords from JSON"}
|
||||
metadata={"help": "Whether to use customized_context from JSON as additional context"}
|
||||
)
|
||||
|
||||
|
||||
@@ -174,7 +169,7 @@ class VibeVoiceASRDataset(Dataset):
|
||||
},
|
||||
...
|
||||
],
|
||||
"hotwords": ["Tea Brew", "Aiden Host", ...] # optional
|
||||
"customized_context": ["Tea Brew", "The property is near Meter Street."] # optional
|
||||
}
|
||||
"""
|
||||
|
||||
@@ -183,7 +178,7 @@ class VibeVoiceASRDataset(Dataset):
|
||||
data_dir: str,
|
||||
processor: VibeVoiceASRProcessor,
|
||||
max_audio_length: Optional[float] = None, # in seconds
|
||||
use_hotwords: bool = True,
|
||||
use_customized_context: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
@@ -192,12 +187,12 @@ class VibeVoiceASRDataset(Dataset):
|
||||
data_dir: Directory containing audio files and JSON labels
|
||||
processor: VibeVoice ASR processor
|
||||
max_audio_length: Maximum audio length in seconds (None = no limit)
|
||||
use_hotwords: Whether to include hotwords in context
|
||||
use_customized_context: Whether to include customized_context in prompt
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self.processor = processor
|
||||
self.max_audio_length = max_audio_length
|
||||
self.use_hotwords = use_hotwords
|
||||
self.use_customized_context = use_customized_context
|
||||
|
||||
# Find all JSON files
|
||||
self.samples = self._load_samples()
|
||||
@@ -284,12 +279,12 @@ class VibeVoiceASRDataset(Dataset):
|
||||
data = sample["data"]
|
||||
audio_path = sample["audio_path"]
|
||||
|
||||
# Prepare context info (hotwords)
|
||||
# Prepare context info (customized_context)
|
||||
context_info = None
|
||||
if self.use_hotwords and "hotwords" in data:
|
||||
hotwords = data["hotwords"]
|
||||
if hotwords:
|
||||
context_info = "\n".join(hotwords)
|
||||
if self.use_customized_context and "customized_context" in data:
|
||||
customized_context = data["customized_context"]
|
||||
if customized_context:
|
||||
context_info = "\n".join(customized_context)
|
||||
|
||||
# Process audio using the processor's internal method
|
||||
encoding = self.processor._process_single_audio(
|
||||
@@ -488,7 +483,7 @@ def train(
|
||||
data_dir=data_args.data_dir,
|
||||
processor=processor,
|
||||
max_audio_length=data_args.max_audio_length,
|
||||
use_hotwords=data_args.use_hotwords,
|
||||
use_customized_context=data_args.use_customized_context,
|
||||
)
|
||||
|
||||
if len(train_dataset) == 0:
|
||||
|
||||
@@ -63,11 +63,11 @@
|
||||
"end": 351.73
|
||||
}
|
||||
],
|
||||
"hotwords": [
|
||||
"customized_context": [
|
||||
"Tea Brew",
|
||||
"Aiden Host",
|
||||
"Saeed Guest",
|
||||
"Rent Byte",
|
||||
"Meter Street"
|
||||
"The property is located near Meter Street."
|
||||
]
|
||||
}
|
||||
@@ -69,11 +69,11 @@
|
||||
"end": 328.27
|
||||
}
|
||||
],
|
||||
"hotwords": [
|
||||
"customized_context": [
|
||||
"Thandie",
|
||||
"Leila",
|
||||
"Zara",
|
||||
"Coyle High",
|
||||
"Crown Rites"
|
||||
"The story takes place at Coyle High school.",
|
||||
"Crown Rites is a campaign for hair rights."
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user