Source code for kgcnn.layers.update

import keras as ks
from keras.layers import Dense, Add, Layer


[docs]class GRUUpdate(Layer): r"""Gated recurrent unit for updating node or edge embeddings. As proposed by `NMPNN <http://arxiv.org/abs/1704.01212>`__ . """
[docs] def __init__(self, units, activation='tanh', recurrent_activation='sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, reset_after=True, **kwargs): r"""Initialize layer. Args: units (int): Units for GRU. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. Default: sigmoid (`sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, (default `True`), whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`. bias_initializer: Initializer for the bias vector. Default: `zeros`. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. Default: `None`. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_regularizer: Regularizer function applied to the bias vector. Default: `None`. kernel_constraint: Constraint function applied to the `kernel` weights matrix. Default: `None`. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_constraint: Constraint function applied to the bias vector. Default: `None`. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. reset_after: GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before", True = "after" (default and CuDNN compatible). """ super(GRUUpdate, self).__init__(**kwargs) self.units = units self.gru_cell = ks.layers.GRUCell( units=units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, reset_after=reset_after )
[docs] def build(self, input_shape): """Build layer.""" super(GRUUpdate, self).build(input_shape)
[docs] def call(self, inputs, mask=None, **kwargs): """Forward pass. Args: inputs (list): [nodes, updates] - nodes (Tensor): Node embeddings of shape ([N], F) - updates (Tensor): Matching node updates of shape ([N], F) mask: Mask for inputs. Default is None. Returns: Tensor: Updated nodes of shape ([N], F) """ n, eu = inputs out, _ = self.gru_cell(eu, n, **kwargs) return out
[docs] def get_config(self): """Update layer config.""" config = super(GRUUpdate, self).get_config() conf_cell = self.gru_cell.get_config() param_list = ["units", "activation", "recurrent_activation", "use_bias", "kernel_initializer", "recurrent_initializer", "bias_initializer", "kernel_regularizer", "recurrent_regularizer", "bias_regularizer", "kernel_constraint", "recurrent_constraint", "bias_constraint", "dropout", "recurrent_dropout", "reset_after"] for x in param_list: if x in conf_cell.keys(): config.update({x: conf_cell[x]}) return config
[docs]class ResidualLayer(Layer): r"""Residual Layer as defined by `DimNetPP <https://arxiv.org/abs/2011.14115>`__ ."""
[docs] def __init__(self, units, use_bias=True, activation='kgcnn>swish', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): """Initialize layer. Args: units: Dimension of the kernel. use_bias (bool, optional): Use bias. Defaults to True. activation (str): Activation function. Default is "kgcnn>swish". kernel_regularizer: Kernel regularization. Default is None. bias_regularizer: Bias regularization. Default is None. activity_regularizer: Activity regularization. Default is None. kernel_constraint: Kernel constrains. Default is None. bias_constraint: Bias constrains. Default is None. kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. bias_initializer: Initializer for bias. Default is 'zeros'. """ super(ResidualLayer, self).__init__(**kwargs) dense_args = { "units": units, "activation": activation, "use_bias": use_bias, "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer } self.dense_1 = Dense(**dense_args) self.dense_2 = Dense(**dense_args) self.add_end = Add()
[docs] def build(self, input_shape): """Build layer.""" super(ResidualLayer, self).build(input_shape)
[docs] def call(self, inputs, **kwargs): """Forward pass. Args: inputs (Tensor): Node or edge embedding of shape ([N], F) Returns: Tensor: Node or edge embedding of shape ([N], F) """ x = self.dense_1(inputs, **kwargs) x = self.dense_2(x, **kwargs) x = self.add_end([inputs, x], **kwargs) return x
[docs] def get_config(self): config = super(ResidualLayer, self).get_config() conf_dense = self.dense_1.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias", "units"]: if x in conf_dense.keys(): config.update({x: conf_dense[x]}) return config