Source code for kgcnn.literature.MEGAN._make

from ._model import MEGAN
from kgcnn.models.utils import update_model_kwargs
from kgcnn.layers.modules import Input
from keras.backend import backend as backend_to_use
from kgcnn.models.casting import (template_cast_output, template_cast_list_input,
                                  template_cast_list_input_docs, template_cast_output_docs)
from kgcnn.ops.activ import *

# Keep track of model version from commit date in literature.
# To be updated if model is changed in a significant way.
__model_version__ = "2023-12-08"

# Supported backends
__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"]
if backend_to_use() not in __kgcnn_model_backend_supported__:
    raise NotImplementedError("Backend '%s' for model 'MEGAN' is not supported." % backend_to_use())


# Implementation of INorp in `tf.keras` from paper:
# 'Interaction Networks for Learning about Objects, Relations and Physics'
# by Peter W. Battaglia, Razvan Pascanu, Matthew Lai, Danilo Rezende, Koray Kavukcuoglu
# http://papers.nips.cc/paper/6417-interaction-networks-for-learning-about-objects-relations-and-physics
# https://arxiv.org/abs/1612.00222
# https://github.com/higgsfield/interaction_network_pytorch


model_default = {
    "name": "MEGAN",
    "inputs": [
        {'shape': (None, 128), 'name': "node_attributes", 'dtype': 'float32'},
        {'shape': (None, 64), 'name': "edge_attributes", 'dtype': 'float32'},
        {'shape': (None, 2), 'name': "edge_indices", 'dtype': 'int64'},
        {'shape': (1,), 'name': "graph_labels", 'dtype': 'float32'},
        {"shape": (), "name": "total_nodes", "dtype": "int64"},
        {"shape": (), "name": "total_edges", "dtype": "int64"}
    ],
    "input_tensor_type": "padded",
    'input_embedding': None,
    "cast_disjoint_kwargs": {},
    "units": [128],
    "activation":  "kgcnn>leaky_relu",
    "use_bias":  True,
    "dropout_rate":  0.0,
    "use_edge_features":  True,
    "input_node_embedding":  None,
    # node/edge importance related arguments
    "importance_units":  [],
    "importance_channels":  2,
    "importance_activation":  "sigmoid",  # do not change
    "importance_dropout_rate":  0.0,  # do not change
    "importance_factor":  0.0,
    "importance_multiplier":  10.0,
    "sparsity_factor":  0.0,
    "concat_heads":  True,
    # mlp tail end related arguments
    "final_units":  [1],
    "final_dropout_rate":  0.0,
    "final_activation":  'linear',
    "final_pooling":  'sum',
    "regression_limits":  None,
    "regression_reference": None,
    "return_importances": True,
    "output_tensor_type": "padded",
    'output_embedding': 'graph',
}


[docs]@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) def make_model(inputs: list = None, name: str = None, input_tensor_type: str = None, cast_disjoint_kwargs: dict = None, units: list = None, activation: str = None, use_bias: bool = None, dropout_rate: float = None, use_edge_features: bool = None, input_embedding: dict = None, # deprecated input_node_embedding: dict = None, # node/edge importance related arguments importance_units: list = None, importance_channels: int = None, importance_activation: str = None, # do not change importance_dropout_rate: float = None, # do not change importance_factor: float = None, importance_multiplier: float = None, sparsity_factor: float = None, concat_heads: bool = None, # mlp tail end related arguments final_units: list = None, final_dropout_rate: float = None, final_activation: str = None, final_pooling: str = None, regression_limits: tuple = None, regression_reference: float = None, return_importances: bool = True, output_embedding: dict = None, output_tensor_type: str = None ): r"""Functional model definition of MEGAN. Please check documentation of :obj:`kgcnn.literature.MEGAN` . **Model inputs**: Model uses the list template of inputs and standard output template. The supported inputs are :obj:`[nodes, edges, edge_indices, graph_labels...]` with '...' indicating mask or ID tensors following the template below. Graph labels are used to generate explanations but not to influence model output. %s **Model outputs**: The standard output template: %s Args: name: Name of the model. inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. input_tensor_type (str): Input type of graph tensor. Default is "padded". cast_disjoint_kwargs (dict): Dictionary of arguments for casting layer. units: A list of ints where each element configures an additional attention layer. The numeric value determines the number of hidden units to be used in the attention heads of that layer activation: The activation function to be used within the attention layers of the network use_bias: Whether the layers of the network should use bias weights at all dropout_rate: The dropout rate to be applied after *each* of the attention layers of the network. input_node_embedding: Dictionary of embedding kwargs for input embedding layer. use_edge_features: Whether edge features should be used. Generally the network supports the usage of edge features, but if the input data does not contain edge features, this should be set to False. importance_units: A list of ints where each element configures another dense layer in the subnetwork that produces the node importance tensor from the main node embeddings. The numeric value determines the number of hidden units in that layer. importance_channels: The int number of explanation channels to be produced by the network. This is the value referred to as "K". Note that this will also determine the number of attention heads used within the attention subnetwork. importance_factor: The weight of the explanation-only train step. If this is set to exactly zero then the explanation train step will not be executed at all (less computationally expensive) importance_multiplier: An additional hyperparameter of the explanation-only train step. This is essentially the scaling factor that is applied to the values of the dataset such that the target values can reasonably be approximated by a sum of [0, 1] importance values. sparsity_factor: The coefficient for the sparsity regularization of the node importance tensor. concat_heads: Whether to concat the heads of the attention subnetwork. The default is True. In that case the output of each individual attention head is concatenated and the concatenated vector is then used as the input of the next attention layer's heads. If this is False, the vectors are average pooled instead. final_units: A list of ints where each element configures another dense layer in the MLP at the tail end of the network. The numeric value determines the number of the hidden units in that layer. Note that the final element in this list has to be the same as the dimension to be expected for the samples of the training dataset! final_dropout_rate: The dropout rate to be applied after *every* layer of the final MLP. final_activation: The activation to be applied at the very last layer of the MLP to produce the actual output of the network. final_pooling: The pooling method to be used during the global pooling phase in the network. regression_limits: A tuple where the first value is the lower limit for the expected value range of the regression task and teh second value the upper limit. regression_reference: A reference value which is inside the range of expected values (best if it was in the middle, but does not have to). Choosing different references will result in different explanations. return_importances: Whether the importance / explanation tensors should be returned as an output of the model. If this is True, the output of the model will be a 3-tuple: (output, node importances, edge importances), otherwise it is just the output itself output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded". output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". Returns: :obj:`keras.models.Model` """ model_inputs = [Input(**x) for x in inputs] di_inputs = template_cast_list_input( model_inputs, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, mask_assignment=[0, 1, 1, None], index_assignment=[None, None, 0, None] ) n, ed, disjoint_indices, gs, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = di_inputs # Wrapping disjoint model. out = MEGAN( units=units, activation=activation, use_bias=use_bias, dropout_rate=dropout_rate, use_edge_features=use_edge_features, input_node_embedding=input_node_embedding, importance_units=importance_units, importance_channels=importance_channels, importance_activation=importance_activation, # do not change importance_dropout_rate=importance_dropout_rate, # do not change importance_factor=importance_factor, importance_multiplier=importance_multiplier, sparsity_factor=sparsity_factor, concat_heads=concat_heads, final_units=final_units, final_dropout_rate=final_dropout_rate, final_activation=final_activation, final_pooling=final_pooling, regression_limits=regression_limits, regression_reference=regression_reference, return_importances=return_importances )([n, ed, disjoint_indices, gs, batch_id_node, count_nodes]) # Output embedding choice out = template_cast_output( [out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges], output_embedding=output_embedding, output_tensor_type=output_tensor_type, input_tensor_type=input_tensor_type, cast_disjoint_kwargs=cast_disjoint_kwargs, ) model = ks.models.Model(inputs=model_inputs, outputs=out, name=name) model.__kgcnn_model_version__ = __model_version__ return model
make_model.__doc__ = make_model.__doc__ % (template_cast_list_input_docs, template_cast_output_docs)