fix: handle torch.dtype serialization in config classes

Fixes #199 - Object of type dtype is not JSON serializable

When loading models with torch_dtype as a torch.dtype object (e.g.,
torch.bfloat16), transformers would fail to serialize the config to
JSON for logging purposes, raising TypeError.

This fix:
- Adds _convert_dtype_to_string() helper function to convert torch.dtype
  objects to their string representation (e.g., 'bfloat16')
- Overrides to_dict() method in VibeVoiceConfig, VibeVoiceASRConfig,
  and VibeVoiceStreamingConfig to apply this conversion

The fix is backward compatible - string dtype values and None values
continue to work as expected.
This commit is contained in:
ThanhNguyxn
2026-01-24 12:12:48 +07:00
committed by YaoyaoChang
parent e67b15f47d
commit 5cf026569e
2 changed files with 46 additions and 1 deletions
@@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Tuple
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
@@ -10,6 +11,23 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
logger = logging.get_logger(__name__)
def _convert_dtype_to_string(config_dict: dict) -> dict:
"""
Convert torch.dtype objects to their string representation for JSON serialization.
This fixes the "Object of type dtype is not JSON serializable" error that occurs
when transformers tries to log/serialize the config with torch_dtype as a torch.dtype object.
See: https://github.com/microsoft/VibeVoice/issues/199
"""
if "torch_dtype" in config_dict and config_dict["torch_dtype"] is not None:
dtype = config_dict["torch_dtype"]
if isinstance(dtype, torch.dtype):
# Convert torch.dtype to string (e.g., torch.bfloat16 -> "bfloat16")
config_dict["torch_dtype"] = str(dtype).replace("torch.", "")
return config_dict
class VibeVoiceAcousticTokenizerConfig(PretrainedConfig):
model_type = "vibevoice_acoustic_tokenizer"
@@ -259,6 +277,14 @@ class VibeVoiceConfig(PretrainedConfig):
"""
return self.decoder_config
def to_dict(self):
"""
Override to_dict to handle torch.dtype serialization.
Fixes: https://github.com/microsoft/VibeVoice/issues/199
"""
output = super().to_dict()
return _convert_dtype_to_string(output)
class VibeVoiceASRConfig(PretrainedConfig):
model_type = "vibevoice"
@@ -328,6 +354,15 @@ class VibeVoiceASRConfig(PretrainedConfig):
super().__init__(**kwargs)
def to_dict(self):
"""
Override to_dict to handle torch.dtype serialization.
Fixes: https://github.com/microsoft/VibeVoice/issues/199
"""
output = super().to_dict()
return _convert_dtype_to_string(output)
def get_text_config(self, decoder: bool = False):
"""Return the text (decoder) config for generation."""
return self.decoder_config
@@ -1,11 +1,12 @@
""" VibeVoice Streaming model configuration"""
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig
from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig, _convert_dtype_to_string
logger = logging.get_logger(__name__)
@@ -80,6 +81,15 @@ class VibeVoiceStreamingConfig(PretrainedConfig):
super().__init__(**kwargs)
def to_dict(self):
"""
Override to_dict to handle torch.dtype serialization.
Fixes: https://github.com/microsoft/VibeVoice/issues/199
"""
output = super().to_dict()
return _convert_dtype_to_string(output)
__all__ = [
"VibeVoiceStreamingConfig"
]