Add XPU sdpa Support

This commit is contained in:
DDXDB
2026-01-26 14:00:31 +08:00
committed by GitHub
parent 523713e806
commit 1c5dbc4190
+7 -5
View File
@@ -37,7 +37,7 @@ class VibeVoiceASRBatchInference:
Args:
model_path: Path to the pretrained model
device: Device to run inference on (cuda, mps, cpu, auto)
device: Device to run inference on (cuda, mps, xpu, cpu, auto)
dtype: Data type for model weights
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
"""
@@ -442,8 +442,8 @@ def main():
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"),
choices=["cuda", "cpu", "mps", "auto"],
default="cuda" if torch.cuda.is_available() else ("xpu" if torch.backends.xpu.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ),
choices=["cuda", "cpu", "mps","xpu", "auto"],
help="Device to run inference on"
)
parser.add_argument(
@@ -475,7 +475,7 @@ def main():
type=str,
default="auto",
choices=["flash_attention_2", "sdpa", "eager", "auto"],
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)"
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU/XPU)"
)
args = parser.parse_args()
@@ -490,7 +490,7 @@ def main():
print("flash_attn not installed, falling back to sdpa")
args.attn_implementation = "sdpa"
else:
# MPS and CPU don't support flash_attention_2
# MPS/XPU/CPU don't support flash_attention_2
args.attn_implementation = "sdpa"
print(f"Auto-detected attention implementation: {args.attn_implementation}")
@@ -532,6 +532,8 @@ def main():
# Handle MPS device and dtype
if args.device == "mps":
model_dtype = torch.float32 # MPS works better with float32
elif args.device == "xpu":
model_dtype = torch.float32
elif args.device == "cpu":
model_dtype = torch.float32
else: