fix(demo): add MPS and CPU support for ASR inference demo

- Add MPS device choice and auto-detect MPS availability
- Change default attention implementation to 'auto' with smart fallback
- Auto-detect flash_attention_2 availability on CUDA, fallback to sdpa
- Use sdpa for MPS and CPU devices (flash_attention_2 not supported)
- Use float32 dtype for MPS/CPU devices for better compatibility

Fixes #206
This commit is contained in:
ThanhNguyxn
2026-01-24 19:03:04 +07:00
committed by YaoyaoChang
parent 5cf026569e
commit 523713e806
+30 -7
View File
@@ -30,14 +30,14 @@ class VibeVoiceASRBatchInference:
model_path: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
attn_implementation: str = "flash_attention_2"
attn_implementation: str = "sdpa"
):
"""
Initialize the ASR batch inference pipeline.
Args:
model_path: Path to the pretrained model
device: Device to run inference on
device: Device to run inference on (cuda, mps, cpu, auto)
dtype: Data type for model weights
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
"""
@@ -442,8 +442,8 @@ def main():
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cuda", "cpu", "auto"],
default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"),
choices=["cuda", "cpu", "mps", "auto"],
help="Device to run inference on"
)
parser.add_argument(
@@ -473,12 +473,27 @@ def main():
parser.add_argument(
"--attn_implementation",
type=str,
default="flash_attention_2",
help="Attention implementation to use (default: flash_attention_2)"
default="auto",
choices=["flash_attention_2", "sdpa", "eager", "auto"],
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)"
)
args = parser.parse_args()
# Auto-detect best attention implementation based on device
if args.attn_implementation == "auto":
if args.device == "cuda" and torch.cuda.is_available():
try:
import flash_attn
args.attn_implementation = "flash_attention_2"
except ImportError:
print("flash_attn not installed, falling back to sdpa")
args.attn_implementation = "sdpa"
else:
# MPS and CPU don't support flash_attention_2
args.attn_implementation = "sdpa"
print(f"Auto-detected attention implementation: {args.attn_implementation}")
# Collect audio files
audio_files = []
concatenated_audio = None # For storing concatenated dataset audio
@@ -514,10 +529,18 @@ def main():
print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
# Initialize model
# Handle MPS device and dtype
if args.device == "mps":
model_dtype = torch.float32 # MPS works better with float32
elif args.device == "cpu":
model_dtype = torch.float32
else:
model_dtype = torch.bfloat16
asr = VibeVoiceASRBatchInference(
model_path=args.model_path,
device=args.device,
dtype=torch.bfloat16 if args.device != "cpu" else torch.float32,
dtype=model_dtype,
attn_implementation=args.attn_implementation
)