Source code for kgcnn.training.hyper

import os
import logging
import keras as ks
from typing import Union
from copy import deepcopy
from kgcnn.metrics.utils import merge_metrics
from kgcnn.data.utils import load_hyper_file, save_json_file
from keras.saving import deserialize_keras_object
import keras.metrics
import keras.losses

logging.basicConfig()  # Module logger
module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)


[docs]class HyperParameter: r"""A class to store hyperparameter for a specific dataset and model, exposing them for model training scripts. This includes training parameters and a set of general information like a path for output of the training stats or the expected version of :obj:`kgcnn`. The class methods will extract and possibly serialize or deserialize the necessary kwargs from the hyperparameter dictionary. .. code-block:: python hyper = HyperParameter(hyper_info={"model": {"config":{}}, "training": {}, "data":{"dataset":{}}}) """
[docs] def __init__(self, hyper_info: Union[str, dict], hyper_category: str = None, model_name: str = None, model_module: str = None, model_class: str = "make_model", dataset_name: str = None, dataset_class: str = None, dataset_module: str = None): r"""Make a hyperparameter class instance. Required is the hyperparameter dictionary or path to a config file containing the hyperparameter. Furthermore, name of the dataset and model can be provided for checking the information in :obj:`hyper_info`. Args: hyper_info (str, dict): Hyperparameter dictionary or path to file. hyper_category (str): Category within hyperparameter. model_name (str, optional): Name of the model. model_module (str, optional): Name of the module of the model. Defaults to None. model_class (str, optional): Class name or make function for model. Defaults to 'make_model'. dataset_name (str, optional): Name of the dataset. dataset_class (str, optional): Class name of the dataset. dataset_module (str, optional): Module name of the dataset. """ self._hyper = None self._hyper_category = hyper_category self._dataset_name = dataset_name self._model_name = model_name self._model_module = model_module self._model_class = model_class self._dataset_name = dataset_name self._dataset_class = dataset_class self._dataset_module = dataset_module if isinstance(hyper_info, str): self._hyper_all = load_hyper_file(hyper_info) elif isinstance(hyper_info, dict): self._hyper_all = hyper_info else: raise TypeError("`HyperParameter` requires valid hyper dictionary or path to file.") # If model and training section in hyper-dictionary, then this is a valid hyper setting. # If hyper is itself a dictionary with many models, pick the right model if model name is given. model_name_class = "%s.%s" % (model_name, model_class) if "model" in self._hyper_all and "training" in self._hyper_all: self._hyper = self._hyper_all elif hyper_category is not None: if hyper_category in self._hyper_all: self._hyper = self._hyper_all[self._hyper_category] else: raise ValueError("Category '%s' not in hyperparameter information." % hyper_category) elif model_name is not None and model_class == "make_model" and model_name in self._hyper_all: self._hyper = self._hyper_all[model_name] elif model_name is not None and model_class != "make_model" and model_name_class in self._hyper_all: self._hyper = self._hyper_all[model_name_class] else: raise ValueError("".join( ["Not a valid hyper dictionary. If there are multiple hyperparameter settings in `hyper_info`, ", "please specify `hyper_category` or the category is inferred from optional class/name ", "information but which is deprecated."]))
[docs] def verify(self, raise_error: bool = True, raise_warning: bool = False): """Logic to verify and optionally update hyperparameter dictionary.""" def error_msg(error_to_report): if raise_error: raise ValueError(error_to_report) else: module_logger.error(error_to_report) def warning_msg(warning_to_report): if raise_warning: raise ValueError(warning_to_report) else: module_logger.warning(warning_to_report) # Update some optional parameters in hyper. Issue a warning. if "config" not in self._hyper["model"] and "inputs" in self._hyper["model"]: # Older model config setting. warning_msg("Hyperparameter {'model': ...} changed to {'model': {'config': {...}}} .") self._hyper["model"] = {"config": deepcopy(self._hyper["model"])} if "class_name" not in self._hyper["model"] and self._model_class is not None: warning_msg("Adding model class from self to 'model': {'class_name': %s} ." % self._model_class) self._hyper["model"].update({"class_name": self._model_class}) if "module_name" not in self._hyper["model"] and self._model_module is not None: warning_msg("Adding model module from self to 'model': {'module_name': %s} ." % self._model_module) self._hyper["model"].update({"module_name": self._model_module}) if "info" not in self._hyper: self._hyper.update({"info": {}}) warning_msg("Adding 'info' category to hyperparameter.") if "postfix_file" not in self._hyper["info"]: warning_msg("Adding 'postfix_file' to 'info' category in hyperparameter.") self._hyper["info"].update({"postfix_file": ""}) if "dataset" in self._hyper["data"] and "dataset" not in self._hyper: warning_msg("Hyperparameter should have separate dataset category from kgcnn>=3.1.0 .") self._hyper["dataset"] = deepcopy(self._hyper["data"]["dataset"]) self._hyper["data"].pop("dataset") # Errors if missmatch is found between class definition and information in hyper-dictionary. # In principle all information regarding model and dataset can be stored in hyper-dictionary. if "class_name" in self._hyper["dataset"] and self._dataset_class is not None: if self._dataset_class != self._hyper["dataset"]["class_name"]: error_msg("Dataset '%s' does not agree with hyperparameter '%s'." % ( self._dataset_class, self._hyper["dataset"]["class_name"])) if "class_name" in self._hyper["model"] and self._model_class is not None: if self._hyper["model"]["class_name"] != self._model_class: error_msg("Model generation '%s' does not agree with hyperparameter '%s'." % ( self._model_class, self._hyper["model"]["class_name"])) if "module_name" in self._hyper["model"] and self._model_module is not None: if self._hyper["model"]["module_name"] != self._model_module: error_msg("Model module '%s' does not agree with hyperparameter '%s'." % ( self._model_module, self._hyper["model"]["module_name"])) if "name" in self._hyper["model"]["config"] and self._model_name is not None: if self._hyper["model"]["config"]["name"] != self._model_name: error_msg("Model name '%s' does not agree with hyperparameter '%s'." % ( self._model_name, self._hyper["model"]["config"]["name"]))
def __getitem__(self, item): return deepcopy(self._hyper[item])
[docs] def compile(self, loss=None, optimizer='rmsprop', metrics: list = None, weighted_metrics: list = None): r"""Generate kwargs for :obj:`tf.keras.Model.compile` from hyperparameter and default parameter. This function should handle deserialization of hyperparameter and, if not specified, fill them from default. Loss, optimizer are overwritten from hyperparameter, if available. Metrics in hyperparameter are added from function arguments. Note that otherwise metrics can not be deserialized, since `metrics` can include nested lists and a dictionary of model output names. .. warning:: When using deserialization with this function, you must not name your model output "class_name", "module_name" and "config". Args: loss: Default loss for fit. Default is None. optimizer: Default optimizer. Default is "rmsprop". metrics (list): Default list of metrics. Default is None. weighted_metrics (list): Default list of weighted_metrics. Default is None. Returns: dict: Deserialized compile kwargs from hyperparameter. """ hyper_compile = deepcopy(self._hyper["training"]["compile"]) if "compile" in self._hyper["training"] else {} if len(hyper_compile) == 0: module_logger.warning("Found no information for `compile` in hyperparameter.") reserved_compile_arguments = ["loss", "optimizer", "weighted_metrics", "metrics"] hyper_compile_additional = { key: value for key, value in hyper_compile.items() if key not in reserved_compile_arguments} def nested_deserialize(m, get): """Deserialize nested list or dict objects for keras model output like loss or metrics.""" if isinstance(m, (list, tuple)): return [nested_deserialize(x, get) for x in m] if isinstance(m, dict): if "class_name" in m and "config" in m: # Here we have a serialization dict. try: return get(m) except ValueError: return deserialize_keras_object(m) else: # Model outputs as dict. return {key: nested_deserialize(value, get) for key, value in m.items()} elif isinstance(m, str): return get(m) return m # Optimizer optimizer = nested_deserialize( (hyper_compile["optimizer"] if "optimizer" in hyper_compile else optimizer), ks.optimizers.get) # Loss loss = nested_deserialize((hyper_compile["loss"] if "loss" in hyper_compile else loss), ks.losses.get) # Metrics metrics = nested_deserialize(metrics, ks.metrics.get) weighted_metrics = nested_deserialize(weighted_metrics, ks.metrics.get) hyper_metrics = nested_deserialize( hyper_compile["metrics"], ks.metrics.get) if "metrics" in hyper_compile else None hyper_weighted_metrics = nested_deserialize( hyper_compile["weighted_metrics"], ks.metrics.get) if "weighted_metrics" in hyper_compile else None metrics = merge_metrics(hyper_metrics, metrics) weighted_metrics = merge_metrics(hyper_weighted_metrics, weighted_metrics) # Output deserialized compile kwargs. output = {"loss": loss, "optimizer": optimizer, "metrics": metrics, "weighted_metrics": weighted_metrics, **hyper_compile_additional} module_logger.info("Deserialized compile kwargs '%s'." % output) return output
[docs] def fit(self, epochs: int = 1, validation_freq: int = 1, batch_size: int = None, callbacks: list = None): """Select fit hyperparameter. Additional default values for the training scripts are given as functional kwargs. Functional kwargs are overwritten by hyperparameter. Args: epochs (int): Default number of epochs. Default is 1. validation_freq (int): Default validation frequency. Default is 1. batch_size (int): Default batch size. Default is None. callbacks (list): Default Callbacks. Default is None. Returns: dict: de-serialized fit kwargs from hyperparameter. """ hyper_fit = deepcopy(self._hyper["training"]["fit"]) reserved_fit_arguments = ["callbacks", "batch_size", "validation_freq", "epochs"] hyper_fit_additional = {key: value for key, value in hyper_fit.items() if key not in reserved_fit_arguments} if "epochs" in hyper_fit: epochs = hyper_fit["epochs"] if "validation_freq" in hyper_fit: validation_freq = hyper_fit["validation_freq"] if "batch_size" in hyper_fit: batch_size = hyper_fit["batch_size"] if "callbacks" in hyper_fit: if callbacks is None: callbacks = [] for cb in hyper_fit["callbacks"]: if isinstance(cb, (str, dict)): callbacks += [deserialize_keras_object(cb)] else: callbacks += [cb] out = {"batch_size": batch_size, "epochs": epochs, "validation_freq": validation_freq, "callbacks": callbacks} out.update(hyper_fit_additional) return out
[docs] def results_file_path(self): r"""Make output folder for results based on hyperparameter and return path to that folder. The folder is set up as `'results'/dataset/model_name + post_fix`. Where model and dataset name must be set by this class. Postfix can be in hyperparameter setting. Returns: str: File-path or path object to result folder. """ hyper_info = deepcopy(self._hyper["info"]) post_fix = str(hyper_info["postfix"]) if "postfix" in hyper_info else "" os.makedirs("results", exist_ok=True) os.makedirs(os.path.join("results", self.dataset_class), exist_ok=True) category = self.hyper_category.replace(".", "_") filepath = os.path.join("results", self.dataset_class, category + post_fix) os.makedirs(filepath, exist_ok=True) return filepath
[docs] def save(self, file_path: str): """Save the hyperparameter to path. Args: file_path (str): Full file path to save hyperparameter to. """ # Must make more refined saving and serialization here. save_json_file(self._hyper, file_path)
@property def dataset_name(self): if "dataset" in self._hyper: if "config" in self._hyper["dataset"]: if "dataset_name" in self._hyper["dataset"]["config"]: return self._hyper["dataset"]["config"]["dataset"] return self._dataset_name @property def dataset_class(self): if "dataset" in self._hyper: if "class_name" in self._hyper["dataset"]: return self._hyper["dataset"]["class_name"] return self._dataset_name @property def dataset_module(self): if "dataset" in self._hyper: if "module_name" in self._hyper["dataset"]: return self._hyper["dataset"]["module_name"] return self._dataset_name @property def model_name(self): if "model" in self._hyper: if "config" in self._hyper["model"]: if "name" in self._hyper["model"]["config"]: # Name is just called name not model_name to be more compatible with keras. return self._hyper["model"]["config"]["name"] return self._model_name @property def model_class(self): if "model" in self._hyper: if "class_name" in self._hyper["model"]: return self._hyper["model"]["class_name"] return self._model_class @property def model_module(self): if "model" in self._hyper: if "module_name" in self._hyper["model"]: return self._hyper["model"]["module_name"] return self._model_module @property def hyper_category(self): if self._hyper_category is not None: return self._hyper_category return "%s.%s" % (self.model_name, self.model_class)