Add XPU sdpa Support
This commit is contained in:
@@ -37,7 +37,7 @@ class VibeVoiceASRBatchInference:
|
||||
|
||||
Args:
|
||||
model_path: Path to the pretrained model
|
||||
device: Device to run inference on (cuda, mps, cpu, auto)
|
||||
device: Device to run inference on (cuda, mps, xpu, cpu, auto)
|
||||
dtype: Data type for model weights
|
||||
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
|
||||
"""
|
||||
@@ -442,8 +442,8 @@ def main():
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"),
|
||||
choices=["cuda", "cpu", "mps", "auto"],
|
||||
default="cuda" if torch.cuda.is_available() else ("xpu" if torch.backends.xpu.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ),
|
||||
choices=["cuda", "cpu", "mps","xpu", "auto"],
|
||||
help="Device to run inference on"
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -475,7 +475,7 @@ def main():
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["flash_attention_2", "sdpa", "eager", "auto"],
|
||||
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)"
|
||||
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU/XPU)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -490,7 +490,7 @@ def main():
|
||||
print("flash_attn not installed, falling back to sdpa")
|
||||
args.attn_implementation = "sdpa"
|
||||
else:
|
||||
# MPS and CPU don't support flash_attention_2
|
||||
# MPS/XPU/CPU don't support flash_attention_2
|
||||
args.attn_implementation = "sdpa"
|
||||
print(f"Auto-detected attention implementation: {args.attn_implementation}")
|
||||
|
||||
@@ -532,6 +532,8 @@ def main():
|
||||
# Handle MPS device and dtype
|
||||
if args.device == "mps":
|
||||
model_dtype = torch.float32 # MPS works better with float32
|
||||
elif args.device == "xpu":
|
||||
model_dtype = torch.float32
|
||||
elif args.device == "cpu":
|
||||
model_dtype = torch.float32
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user