Source code for kgcnn.models.serial

import logging
import keras as ks
from kgcnn.models.utils import get_model_class


[docs]def deserialize(obj_dict: dict) -> ks.models.Model: if isinstance(obj_dict, ks.models.Model): return obj_dict if not isinstance(obj_dict, dict): raise ValueError("Can not deserialize model with '%s'." % obj_dict) if "config" not in obj_dict: obj_dict.update({"config": {}}) class_name = obj_dict.get("class_name") module_name = obj_dict.get("module_name") model_cls = get_model_class(module_name=module_name, class_name=class_name) obj = model_cls(**obj_dict["config"]) if hasattr(obj, "set_weights") and "weights" in obj_dict: obj.set_weights(obj_dict["weights"]) # Call class methods if methods are in obj_dict. # Order is important here. if "methods" in obj_dict: method_list = obj_dict["methods"] if hasattr(obj, "set_methods"): obj.set_methods(method_list) else: # Try setting them manually. for method_item in method_list: for method, kwargs in method_item.items(): if hasattr(obj, method): getattr(obj, method)(**kwargs) else: logging.error("Class for deserialization does not have method '%s'." % method) return obj