diff --git a/demo/vibevoice_asr_inference_from_file.py b/demo/vibevoice_asr_inference_from_file.py index 082b9b2..368d701 100644 --- a/demo/vibevoice_asr_inference_from_file.py +++ b/demo/vibevoice_asr_inference_from_file.py @@ -37,7 +37,7 @@ class VibeVoiceASRBatchInference: Args: model_path: Path to the pretrained model - device: Device to run inference on (cuda, mps, cpu, auto) + device: Device to run inference on (cuda, mps, xpu, cpu, auto) dtype: Data type for model weights attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager') """ @@ -442,8 +442,8 @@ def main(): parser.add_argument( "--device", type=str, - default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"), - choices=["cuda", "cpu", "mps", "auto"], + default="cuda" if torch.cuda.is_available() else ("xpu" if torch.backends.xpu.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ), + choices=["cuda", "cpu", "mps","xpu", "auto"], help="Device to run inference on" ) parser.add_argument( @@ -475,7 +475,7 @@ def main(): type=str, default="auto", choices=["flash_attention_2", "sdpa", "eager", "auto"], - help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)" + help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU/XPU)" ) args = parser.parse_args() @@ -490,7 +490,7 @@ def main(): print("flash_attn not installed, falling back to sdpa") args.attn_implementation = "sdpa" else: - # MPS and CPU don't support flash_attention_2 + # MPS/XPU/CPU don't support flash_attention_2 args.attn_implementation = "sdpa" print(f"Auto-detected attention implementation: {args.attn_implementation}") @@ -532,6 +532,8 @@ def main(): # Handle MPS device and dtype if args.device == "mps": model_dtype = torch.float32 # MPS works better with float32 + elif args.device == "xpu": + model_dtype = torch.float32 elif args.device == "cpu": model_dtype = torch.float32 else: