From 5cf026569e9ddd72c26a3e6849f8da04adf38d89 Mon Sep 17 00:00:00 2001 From: ThanhNguyxn Date: Sat, 24 Jan 2026 12:12:48 +0700 Subject: [PATCH] fix: handle torch.dtype serialization in config classes Fixes #199 - Object of type dtype is not JSON serializable When loading models with torch_dtype as a torch.dtype object (e.g., torch.bfloat16), transformers would fail to serialize the config to JSON for logging purposes, raising TypeError. This fix: - Adds _convert_dtype_to_string() helper function to convert torch.dtype objects to their string representation (e.g., 'bfloat16') - Overrides to_dict() method in VibeVoiceConfig, VibeVoiceASRConfig, and VibeVoiceStreamingConfig to apply this conversion The fix is backward compatible - string dtype values and None values continue to work as expected. --- vibevoice/modular/configuration_vibevoice.py | 35 +++++++++++++++++++ .../configuration_vibevoice_streaming.py | 12 ++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/vibevoice/modular/configuration_vibevoice.py b/vibevoice/modular/configuration_vibevoice.py index d5e3149..1845113 100644 --- a/vibevoice/modular/configuration_vibevoice.py +++ b/vibevoice/modular/configuration_vibevoice.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple +import torch from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -10,6 +11,23 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config logger = logging.get_logger(__name__) +def _convert_dtype_to_string(config_dict: dict) -> dict: + """ + Convert torch.dtype objects to their string representation for JSON serialization. + + This fixes the "Object of type dtype is not JSON serializable" error that occurs + when transformers tries to log/serialize the config with torch_dtype as a torch.dtype object. + + See: https://github.com/microsoft/VibeVoice/issues/199 + """ + if "torch_dtype" in config_dict and config_dict["torch_dtype"] is not None: + dtype = config_dict["torch_dtype"] + if isinstance(dtype, torch.dtype): + # Convert torch.dtype to string (e.g., torch.bfloat16 -> "bfloat16") + config_dict["torch_dtype"] = str(dtype).replace("torch.", "") + return config_dict + + class VibeVoiceAcousticTokenizerConfig(PretrainedConfig): model_type = "vibevoice_acoustic_tokenizer" @@ -259,6 +277,14 @@ class VibeVoiceConfig(PretrainedConfig): """ return self.decoder_config + def to_dict(self): + """ + Override to_dict to handle torch.dtype serialization. + + Fixes: https://github.com/microsoft/VibeVoice/issues/199 + """ + output = super().to_dict() + return _convert_dtype_to_string(output) class VibeVoiceASRConfig(PretrainedConfig): model_type = "vibevoice" @@ -328,6 +354,15 @@ class VibeVoiceASRConfig(PretrainedConfig): super().__init__(**kwargs) + def to_dict(self): + """ + Override to_dict to handle torch.dtype serialization. + + Fixes: https://github.com/microsoft/VibeVoice/issues/199 + """ + output = super().to_dict() + return _convert_dtype_to_string(output) + def get_text_config(self, decoder: bool = False): """Return the text (decoder) config for generation.""" return self.decoder_config diff --git a/vibevoice/modular/configuration_vibevoice_streaming.py b/vibevoice/modular/configuration_vibevoice_streaming.py index 426adff..f27d983 100644 --- a/vibevoice/modular/configuration_vibevoice_streaming.py +++ b/vibevoice/modular/configuration_vibevoice_streaming.py @@ -1,11 +1,12 @@ """ VibeVoice Streaming model configuration""" +import torch from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.models.qwen2.configuration_qwen2 import Qwen2Config -from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig +from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceDiffusionHeadConfig, _convert_dtype_to_string logger = logging.get_logger(__name__) @@ -80,6 +81,15 @@ class VibeVoiceStreamingConfig(PretrainedConfig): super().__init__(**kwargs) + def to_dict(self): + """ + Override to_dict to handle torch.dtype serialization. + + Fixes: https://github.com/microsoft/VibeVoice/issues/199 + """ + output = super().to_dict() + return _convert_dtype_to_string(output) + __all__ = [ "VibeVoiceStreamingConfig" ] \ No newline at end of file