# import keras_core as ks
import functools
import logging
from math import inf
from typing import Union
from copy import deepcopy
import importlib
# Module logger
logging.basicConfig()
module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)
[docs]def get_model_class(module_name: str, class_name: str):
r"""Helper function to get model class by string identifier.
Args:
module_name (str): Name of the module of the model.
class_name (str): Name of the model class.
Returns:
:obj:`tf.keras.models.Model`
"""
if module_name[:6] != "kgcnn.":
# Assume that is simply the module name in kgcnn.literature.
module_name = "kgcnn.literature.%s" % module_name
if class_name is None or class_name == "":
# Assume that "make_model" function is used.
class_name = "make_model"
try:
make_model = getattr(importlib.import_module(module_name), class_name)
except ModuleNotFoundError:
raise NotImplementedError("Unknown model identifier '%s' for a model in kgcnn.literature." % class_name)
return make_model
[docs]def update_model_kwargs_logic(default_kwargs: dict = None, user_kwargs: dict = None,
update_recursive: Union[int, float] = inf):
r"""Make model kwargs dictionary with updated default values. This is essentially a nested version of update()
for dicts. This is supposed to be more convenient if the values of kwargs are again layer kwargs to be unpacked,
and do not need to be fully known to update them.
Args:
default_kwargs (dict): Dictionary of default values. Default is None.
user_kwargs (dict): Dictionary of args to update. Default is None.
update_recursive (int): Max depth to update mappings like dict. Default is `inf`.
Returns:
dict: New dict and update with first default and then user args.
"""
if default_kwargs is None:
default_kwargs = {}
if user_kwargs is None:
user_kwargs = {}
# Check valid kwargs
for iter_key in user_kwargs.keys():
if iter_key not in default_kwargs:
raise ValueError("Model kwarg {0} not in default arguments {1}".format(iter_key, default_kwargs.keys()))
# Start with default values.
out = deepcopy(default_kwargs)
# Nested update of kwargs:
def _nested_update(dict1, dict2, max_depth=inf, depth=0):
for key, values in dict2.items():
if key not in dict1:
module_logger.warning("Model kwargs: Unknown key {0} with value {1}".format(key, values))
dict1[key] = values
continue
if not isinstance(dict1[key], dict):
dict1[key] = values
continue
if not isinstance(values, dict):
module_logger.warning("Model kwargs: Overwriting dictionary of {0} with {1}".format(key, values))
dict1[key] = values
continue
# Nested update.
if depth < max_depth:
dict1[key] = _nested_update(dict1[key], values, max_depth=max_depth, depth=depth+1)
else:
dict1[key] = values
return dict1
return _nested_update(out, user_kwargs, update_recursive, 0)
[docs]def update_model_kwargs(model_default, update_recursive=inf, deprecated: list = None):
"""Decorating function for update_model_kwargs_logic() ."""
def model_update_decorator(func):
@functools.wraps(func)
def update_wrapper(*args, **kwargs):
updated_kwargs = update_model_kwargs_logic(model_default, kwargs, update_recursive)
# Logging of updated values.
if 'verbose' in updated_kwargs:
module_logger.setLevel(updated_kwargs["verbose"])
module_logger.info("Updated model kwargs: '%s'." % updated_kwargs)
if len(args) > 0:
module_logger.error("Can only update kwargs, not %s" % args)
return func(*args, **updated_kwargs)
return update_wrapper
return model_update_decorator
[docs]def change_attributes_in_all_layers(model, attributes_to_change=None, layer_type=None):
r"""Utility/helper function to change the attributes from a dictionary in all layers of a model of a certain type.
.. warning::
This function can change attributes but can cause problems for built models. Also take care which attributes
you are changing, especially if they include weights. Always check model behaviour after applying this function.
Args:
model (tf.keras.models.Model): Model to modify.
attributes_to_change (dict): Dictionary of attributes to change in all layers of a specific type.
layer_type: Class type of the layer to change. Default is None.
Returns:
tf.keras.models.Model: Model which has layers with changed attributes.
"""
if model.built:
module_logger.warning("Model '%s' has already been built. Will set `built=False` and continue." % model.name)
model.built = False
if attributes_to_change is None:
attributes_to_change = {}
all_layers = model._flatten_layers(include_self=False, recursive=True)
for x in all_layers:
if layer_type is not None:
if not isinstance(x, layer_type):
continue
changed_attributes = False
for key, value in attributes_to_change.items():
if hasattr(x, key):
setattr(x, key, value)
changed_attributes = True
if changed_attributes:
if x.built:
module_logger.warning(
"Layer '%s' in model has already been built. Will set `built=False` and continue." % x.name)
x.built = False
return model