fix: handle torch.dtype serialization in config classes
Fixes #199 - Object of type dtype is not JSON serializable When loading models with torch_dtype as a torch.dtype object (e.g., torch.bfloat16), transformers would fail to serialize the config to JSON for logging purposes, raising TypeError. This fix: - Adds _convert_dtype_to_string() helper function to convert torch.dtype objects to their string representation (e.g., 'bfloat16') - Overrides to_dict() method in VibeVoiceConfig, VibeVoiceASRConfig, and VibeVoiceStreamingConfig to apply this conversion The fix is backward compatible - string dtype values and None values continue to work as expected.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
@@ -10,6 +11,23 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_dtype_to_string(config_dict: dict) -> dict:
|
||||
"""
|
||||
Convert torch.dtype objects to their string representation for JSON serialization.
|
||||
|
||||
This fixes the "Object of type dtype is not JSON serializable" error that occurs
|
||||
when transformers tries to log/serialize the config with torch_dtype as a torch.dtype object.
|
||||
|
||||
See: https://github.com/microsoft/VibeVoice/issues/199
|
||||
"""
|
||||
if "torch_dtype" in config_dict and config_dict["torch_dtype"] is not None:
|
||||
dtype = config_dict["torch_dtype"]
|
||||
if isinstance(dtype, torch.dtype):
|
||||
# Convert torch.dtype to string (e.g., torch.bfloat16 -> "bfloat16")
|
||||
config_dict["torch_dtype"] = str(dtype).replace("torch.", "")
|
||||
return config_dict
|
||||
|
||||
|
||||
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
|
||||
model_type = "vibevoice_acoustic_tokenizer"
|
||||
|
||||
@@ -259,6 +277,14 @@ class VibeVoiceConfig(PretrainedConfig):
|
||||
"""
|
||||
return self.decoder_config
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Override to_dict to handle torch.dtype serialization.
|
||||
|
||||
Fixes: https://github.com/microsoft/VibeVoice/issues/199
|
||||
"""
|
||||
output = super().to_dict()
|
||||
return _convert_dtype_to_string(output)
|
||||
|
||||
class VibeVoiceASRConfig(PretrainedConfig):
|
||||
model_type = "vibevoice"
|
||||
@@ -328,6 +354,15 @@ class VibeVoiceASRConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Override to_dict to handle torch.dtype serialization.
|
||||
|
||||
Fixes: https://github.com/microsoft/VibeVoice/issues/199
|
||||
"""
|
||||
output = super().to_dict()
|
||||
return _convert_dtype_to_string(output)
|
||||
|
||||
def get_text_config(self, decoder: bool = False):
|
||||
"""Return the text (decoder) config for generation."""
|
||||
return self.decoder_config
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
""" VibeVoice Streaming model configuration"""
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
||||
|
||||
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig
|
||||
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig, _convert_dtype_to_string
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -80,6 +81,15 @@ class VibeVoiceStreamingConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Override to_dict to handle torch.dtype serialization.
|
||||
|
||||
Fixes: https://github.com/microsoft/VibeVoice/issues/199
|
||||
"""
|
||||
output = super().to_dict()
|
||||
return _convert_dtype_to_string(output)
|
||||
|
||||
__all__ = [
|
||||
"VibeVoiceStreamingConfig"
|
||||
]
|
||||
Reference in New Issue
Block a user