Add XPU sdpa Support
This commit is contained in:
@@ -37,7 +37,7 @@ class VibeVoiceASRBatchInference:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to the pretrained model
|
model_path: Path to the pretrained model
|
||||||
device: Device to run inference on (cuda, mps, cpu, auto)
|
device: Device to run inference on (cuda, mps, xpu, cpu, auto)
|
||||||
dtype: Data type for model weights
|
dtype: Data type for model weights
|
||||||
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
|
attn_implementation: Attention implementation to use ('flash_attention_2', 'sdpa', 'eager')
|
||||||
"""
|
"""
|
||||||
@@ -442,8 +442,8 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
type=str,
|
type=str,
|
||||||
default="cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"),
|
default="cuda" if torch.cuda.is_available() else ("xpu" if torch.backends.xpu.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ),
|
||||||
choices=["cuda", "cpu", "mps", "auto"],
|
choices=["cuda", "cpu", "mps","xpu", "auto"],
|
||||||
help="Device to run inference on"
|
help="Device to run inference on"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -475,7 +475,7 @@ def main():
|
|||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
choices=["flash_attention_2", "sdpa", "eager", "auto"],
|
choices=["flash_attention_2", "sdpa", "eager", "auto"],
|
||||||
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU)"
|
help="Attention implementation to use. 'auto' will select the best available for your device (flash_attention_2 for CUDA, sdpa for MPS/CPU/XPU)"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -490,7 +490,7 @@ def main():
|
|||||||
print("flash_attn not installed, falling back to sdpa")
|
print("flash_attn not installed, falling back to sdpa")
|
||||||
args.attn_implementation = "sdpa"
|
args.attn_implementation = "sdpa"
|
||||||
else:
|
else:
|
||||||
# MPS and CPU don't support flash_attention_2
|
# MPS/XPU/CPU don't support flash_attention_2
|
||||||
args.attn_implementation = "sdpa"
|
args.attn_implementation = "sdpa"
|
||||||
print(f"Auto-detected attention implementation: {args.attn_implementation}")
|
print(f"Auto-detected attention implementation: {args.attn_implementation}")
|
||||||
|
|
||||||
@@ -532,6 +532,8 @@ def main():
|
|||||||
# Handle MPS device and dtype
|
# Handle MPS device and dtype
|
||||||
if args.device == "mps":
|
if args.device == "mps":
|
||||||
model_dtype = torch.float32 # MPS works better with float32
|
model_dtype = torch.float32 # MPS works better with float32
|
||||||
|
elif args.device == "xpu":
|
||||||
|
model_dtype = torch.float32
|
||||||
elif args.device == "cpu":
|
elif args.device == "cpu":
|
||||||
model_dtype = torch.float32
|
model_dtype = torch.float32
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user