From 523713e80643b3398e3118ba4eb554bbfce270a8 Mon Sep 17 00:00:00 2001 From: ThanhNguyxn Date: Sat, 24 Jan 2026 19:03:04 +0700 Subject: [PATCH] fix(demo): add MPS and CPU support for ASR inference demo - Add MPS device choice and auto-detect MPS availability - Change default attention implementation to 'auto' with smart fallback - Auto-detect flash_attention_2 availability on CUDA, fallback to sdpa - Use sdpa for MPS and CPU devices (flash_attention_2 not supported) - Use float32 dtype for MPS/CPU devices for better compatibility Fixes #206 --- demo/vibevoice_asr_inference_from_file.py | 37 ++++++++++++++++++----- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/demo/vibevoice_asr_inference_from_file.py b/demo/vibevoice_asr_inference_from_file.py index e7d95d4..082b9b2 100644 --- a/demo/vibevoice_asr_inference_from_file.py +++ b/demo/vibevoice_asr_inference_from_file.py @@ -30,14 +30,14 @@ class VibeVoiceASRBatchInference: model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, - attn_implementation: str = "flash_attention_2" + attn_implementation: str = "sdpa" ): """ Initialize the ASR batch inference pipeline. Args: model_path: Path to the pretrained model - device: Device to run inference on + device: Device to run inference on (cuda, mps, cpu, auto) dtype: Data type for model weights attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') """ @@ -442,8 +442,8 @@ def main(): parser.add_argument( "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - choices=["cuda", "cpu", "auto"], + default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"), + choices=["cuda", "cpu", "mps", "auto"], help="Device to run inference on" ) parser.add_argument( @@ -473,12 +473,27 @@ def main(): parser.add_argument( "--attn_implementation", type=str, - default="flash_attention_2", - help="Attention implementation to use (default: flash_attention_2)" + default="auto", + choices=["flash_attention_2", "sdpa", "eager", "auto"], + help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)" ) args = parser.parse_args() + # Auto-detect best attention implementation based on device + if args.attn_implementation == "auto": + if args.device == "cuda" and torch.cuda.is_available(): + try: + import flash_attn + args.attn_implementation = "flash_attention_2" + except ImportError: + print("flash_attn not installed, falling back to sdpa") + args.attn_implementation = "sdpa" + else: + # MPS and CPU don't support flash_attention_2 + args.attn_implementation = "sdpa" + print(f"Auto-detected attention implementation: {args.attn_implementation}") + # Collect audio files audio_files = [] concatenated_audio = None # For storing concatenated dataset audio @@ -514,10 +529,18 @@ def main(): print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)") # Initialize model + # Handle MPS device and dtype + if args.device == "mps": + model_dtype = torch.float32 # MPS works better with float32 + elif args.device == "cpu": + model_dtype = torch.float32 + else: + model_dtype = torch.bfloat16 + asr = VibeVoiceASRBatchInference( model_path=args.model_path, device=args.device, - dtype=torch.bfloat16 if args.device != "cpu" else torch.float32, + dtype=model_dtype, attn_implementation=args.attn_implementation )