- Add MPS device choice and auto-detect MPS availability
- Change default attention implementation to 'auto' with smart fallback
- Auto-detect flash_attention_2 availability on CUDA, fallback to sdpa
- Use sdpa for MPS and CPU devices (flash_attention_2 not supported)
- Use float32 dtype for MPS/CPU devices for better compatibility
Fixes#206
Fixes#199 - Object of type dtype is not JSON serializable
When loading models with torch_dtype as a torch.dtype object (e.g.,
torch.bfloat16), transformers would fail to serialize the config to
JSON for logging purposes, raising TypeError.
This fix:
- Adds _convert_dtype_to_string() helper function to convert torch.dtype
objects to their string representation (e.g., 'bfloat16')
- Overrides to_dict() method in VibeVoiceConfig, VibeVoiceASRConfig,
and VibeVoiceStreamingConfig to apply this conversion
The fix is backward compatible - string dtype values and None values
continue to work as expected.
Add __init__.py files to vibevoice/modular and vibevoice/processor
directories to properly export classes and enable package imports.
This allows users to import the package after installation:
- from vibevoice import VibeVoiceStreamingForConditionalGenerationInference
- from vibevoice.modular import VibeVoiceStreamingConfig
- from vibevoice.processor import VibeVoiceStreamingProcessor
Fixes import errors when using `pip install -e .`