fix(demo): add MPS and CPU support for ASR inference demo
- Add MPS device choice and auto-detect MPS availability - Change default attention implementation to 'auto' with smart fallback - Auto-detect flash_attention_2 availability on CUDA, fallback to sdpa - Use sdpa for MPS and CPU devices (flash_attention_2 not supported) - Use float32 dtype for MPS/CPU devices for better compatibility Fixes #206
This commit is contained in:
@@ -30,14 +30,14 @@ class VibeVoiceASRBatchInference:
|
||||
model_path: str,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
attn_implementation: str = "flash_attention_2"
|
||||
attn_implementation: str = "sdpa"
|
||||
):
|
||||
"""
|
||||
Initialize the ASR batch inference pipeline.
|
||||
|
||||
Args:
|
||||
model_path: Path to the pretrained model
|
||||
device: Device to run inference on
|
||||
device: Device to run inference on (cuda, mps, cpu, auto)
|
||||
dtype: Data type for model weights
|
||||
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
|
||||
"""
|
||||
@@ -442,8 +442,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cuda", "cpu", "auto"],
|
||||
default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"),
|
||||
choices=["cuda", "cpu", "mps", "auto"],
|
||||
help="Device to run inference on"
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -473,12 +473,27 @@ def main():
|
||||
parser.add_argument(
|
||||
"--attn_implementation",
|
||||
type=str,
|
||||
default="flash_attention_2",
|
||||
help="Attention implementation to use (default: flash_attention_2)"
|
||||
default="auto",
|
||||
choices=["flash_attention_2", "sdpa", "eager", "auto"],
|
||||
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Auto-detect best attention implementation based on device
|
||||
if args.attn_implementation == "auto":
|
||||
if args.device == "cuda" and torch.cuda.is_available():
|
||||
try:
|
||||
import flash_attn
|
||||
args.attn_implementation = "flash_attention_2"
|
||||
except ImportError:
|
||||
print("flash_attn not installed, falling back to sdpa")
|
||||
args.attn_implementation = "sdpa"
|
||||
else:
|
||||
# MPS and CPU don't support flash_attention_2
|
||||
args.attn_implementation = "sdpa"
|
||||
print(f"Auto-detected attention implementation: {args.attn_implementation}")
|
||||
|
||||
# Collect audio files
|
||||
audio_files = []
|
||||
concatenated_audio = None # For storing concatenated dataset audio
|
||||
@@ -514,10 +529,18 @@ def main():
|
||||
print(f"\nConcatenated dataset audios: {len(concatenated_audio)} audio(s)")
|
||||
|
||||
# Initialize model
|
||||
# Handle MPS device and dtype
|
||||
if args.device == "mps":
|
||||
model_dtype = torch.float32 # MPS works better with float32
|
||||
elif args.device == "cpu":
|
||||
model_dtype = torch.float32
|
||||
else:
|
||||
model_dtype = torch.bfloat16
|
||||
|
||||
asr = VibeVoiceASRBatchInference(
|
||||
model_path=args.model_path,
|
||||
device=args.device,
|
||||
dtype=torch.bfloat16 if args.device != "cpu" else torch.float32,
|
||||
dtype=model_dtype,
|
||||
attn_implementation=args.attn_implementation
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user