Source code for kgcnn.layers.pooling

import keras as ks
from keras.layers import Layer, Dense, Concatenate, GRUCell, Activation
from kgcnn.layers.gather import GatherState
from keras import ops
from kgcnn.ops.scatter import scatter_reduce_softmax
from kgcnn.layers.aggr import Aggregate


[docs]class PoolingNodes(Layer): r"""Main layer to pool node or edge attributes. Uses :obj:`Aggregate` layer."""
[docs] def __init__(self, pooling_method="scatter_sum", **kwargs): """Initialize layer. Args: pooling_method (str): Pooling method to use i.e. segment_function. Default is 'scatter_sum'. """ super(PoolingNodes, self).__init__(**kwargs) self.pooling_method = pooling_method self._to_aggregate = Aggregate(pooling_method=pooling_method)
[docs] def build(self, input_shape): """Build Layer.""" self._to_aggregate.build([input_shape[1], input_shape[2], input_shape[0]]) self.built = True
[docs] def compute_output_shape(self, input_shape): """Compute output shape.""" return self._to_aggregate.compute_output_shape([input_shape[1], input_shape[2], input_shape[0]])
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs: [reference, attr, weights, batch_index] - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . - attr (Tensor): Node or edge embeddings of shape `([N], F)` . - batch_index (Tensor): Batch assignment of shape `([N], )` . Returns: Tensor: Embedding tensor of pooled node of shape `(batch, F)` . """ reference, x, idx = inputs return self._to_aggregate([x, idx, reference])
[docs] def get_config(self): """Update layer config.""" config = super(PoolingNodes, self).get_config() config.update({"pooling_method": self.pooling_method}) return config
[docs]class PoolingWeightedNodes(Layer): r"""Weighted polling all embeddings of edges or nodes per batch to obtain a graph level embedding. .. note:: In addition to pooling embeddings a weight tensor must be supplied that scales each embedding before pooling. Must broadcast. """
[docs] def __init__(self, pooling_method="scatter_sum", **kwargs): """Initialize layer. Args: pooling_method (str): Pooling method to use i.e. segment_function. Default is 'scatter_sum'. """ super(PoolingWeightedNodes, self).__init__(**kwargs) self.pooling_method = pooling_method self._to_aggregate = Aggregate(pooling_method=pooling_method)
[docs] def build(self, input_shape): """Build layer.""" assert len(input_shape) == 4 ref_shape, attr_shape, weights_shape, index_shape = [list(x) for x in input_shape] self._to_aggregate.build([tuple(x) for x in [attr_shape, index_shape, ref_shape]]) self.built = True
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs: [reference, attr, weights, batch_index] - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . - attr (Tensor): Node or edge embeddings of shape `([N], F)` . - weights (Tensor): Node or message weights. Most broadcast to nodes. Shape ([N], 1). - batch_index (Tensor): Batch assignment of shape `([N], )` . Returns: Tensor: Embedding tensor of pooled node of shape `(batch, F)` . """ # Need ragged input but can be generalized in the future. reference, x, w, idx = inputs xw = ops.broadcast_to(ops.cast(w, dtype=x.dtype), ops.shape(x)) * x return self._to_aggregate([xw, idx, reference])
[docs] def get_config(self): """Update layer config.""" config = super(PoolingWeightedNodes, self).get_config() config.update({"pooling_method": self.pooling_method}) return config
[docs]class PoolingEmbeddingAttention(Layer): r"""Polling all embeddings of edges or nodes per batch to obtain a graph level embedding in form of a :obj:`Tensor` . Uses attention for pooling. i.e. :math:`s = \sum_j \alpha_{i} n_i` . The attention is computed via: :math:`\alpha_i = \text{softmax}_i(a_i)` from the attention coefficients :math:`a_i` . The attention coefficients must be computed beforehand by node or edge features or by :math:`\sigma( W [s || n_i])` and are passed to this layer as input. Thereby this layer has no weights and only does pooling. In summary, :math:`s = \sum_i \text{softmax}_j(a_i) n_i` is computed by the layer. """
[docs] def __init__(self, softmax_method="scatter_softmax", pooling_method="scatter_sum", normalize_softmax: bool = False, **kwargs): """Initialize layer. Args: normalize_softmax (bool): Whether to use normalize in softmax. Default is False. """ super(PoolingEmbeddingAttention, self).__init__(**kwargs) self.normalize_softmax = normalize_softmax self.pooling_method = pooling_method self.softmax_method = softmax_method self.to_aggregate = Aggregate(pooling_method=pooling_method)
[docs] def build(self, input_shape): """Build layer.""" assert len(input_shape) == 4 ref_shape, attr_shape, attention_shape, index_shape = [list(x) for x in input_shape] self.to_aggregate.build([tuple(x) for x in [attr_shape, index_shape, ref_shape]]) self.built = True
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs: [reference, attr, attention, batch_index] - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . - attr (Tensor): Node or edge embeddings of shape `([N], F)` . - attention (Tensor): Attention coefficients of shape `([N], 1)` . - batch_index (Tensor): Batch assignment of shape `([N], )` . Returns: Tensor: Embedding tensor of pooled node of shape `(batch, F)` . """ reference, attr, attention, batch_index = inputs shape_attention = ops.shape(reference)[:1] + ops.shape(attention)[1:] a = scatter_reduce_softmax(batch_index, attention, shape=shape_attention, normalize=self.normalize_softmax) x = attr * ops.broadcast_to(a, ops.shape(attr)) return self.to_aggregate([x, batch_index, reference])
[docs] def get_config(self): """Update layer config.""" config = super(PoolingEmbeddingAttention, self).get_config() config.update({ "normalize_softmax": self.normalize_softmax, "pooling_method": self.pooling_method, "softmax_method": self.softmax_method }) return config
PoolingNodesAttention = PoolingEmbeddingAttention
[docs]class PoolingNodesAttentive(Layer): r"""Computes the attentive pooling for node embeddings for `Attentive FP <https://doi.org/10.1021/acs.jmedchem.9b00959>`__ model. """
[docs] def __init__(self, units, depth=3, pooling_method="sum", activation='kgcnn>leaky_relu', activation_context="elu", use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', recurrent_activation='sigmoid', recurrent_initializer='orthogonal', recurrent_regularizer=None, recurrent_constraint=None, dropout=0.0, recurrent_dropout=0.0, reset_after=True, **kwargs): """Initialize layer. Args: units (int): Units for the linear trafo of node features before attention. pooling_method(str): Initial pooling before iteration. Default is "sum". depth (int): Number of iterations for graph embedding. Default is 3. activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}. activation_context (str): Activation function for context. Default is "elu". use_bias (bool): Use bias. Default is True. kernel_regularizer: Kernel regularization. Default is None. bias_regularizer: Bias regularization. Default is None. activity_regularizer: Activity regularization. Default is None. kernel_constraint: Kernel constrains. Default is None. bias_constraint: Bias constrains. Default is None. kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. bias_initializer: Initializer for bias. Default is 'zeros'. """ super(PoolingNodesAttentive, self).__init__(**kwargs) self.pooling_method = pooling_method self.depth = depth self.units = int(units) kernel_args = {"use_bias": use_bias, "kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} gru_args = {"recurrent_activation": recurrent_activation, "use_bias": use_bias, "kernel_initializer": kernel_initializer, "recurrent_initializer": recurrent_initializer, "bias_initializer": bias_initializer, "kernel_regularizer": kernel_regularizer, "recurrent_regularizer": recurrent_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "recurrent_constraint": recurrent_constraint, "bias_constraint": bias_constraint, "dropout": dropout, "recurrent_dropout": recurrent_dropout, "reset_after": reset_after} self.lay_linear_trafo = Dense(units, activation="linear", **kernel_args) self.lay_alpha = Dense(1, activation=activation, **kernel_args) self.lay_gather_s = GatherState() self.lay_concat = Concatenate(axis=-1) self.lay_pool_start = PoolingNodes(pooling_method=self.pooling_method) self.lay_pool_attention = PoolingNodesAttention() self.lay_final_activ = Activation(activation=activation_context) self.lay_gru = GRUCell(units=units, activation="tanh", **gru_args)
[docs] def build(self, input_shape): """Build layer.""" super(PoolingNodesAttentive, self).build(input_shape)
[docs] def call(self, inputs, **kwargs): """Forward pass. Args: inputs: [reference, nodes, batch_index] - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . - nodes (Tensor): Node embeddings of shape `([N], F)` . - batch_index (Tensor): Batch assignment of shape `([N], )` . Returns: Tensor: Hidden tensor of pooled node attentions of shape (batch, F). """ ref, node, batch_index = inputs h = self.lay_pool_start([ref, node, batch_index], **kwargs) wn = self.lay_linear_trafo(node, **kwargs) for _ in range(self.depth): hv = self.lay_gather_s([h, batch_index], **kwargs) ev = self.lay_concat([hv, node], **kwargs) av = self.lay_alpha(ev, **kwargs) cont = self.lay_pool_attention([ref, wn, av, batch_index], **kwargs) cont = self.lay_final_activ(cont, **kwargs) h, _ = self.lay_gru(cont, h, **kwargs) out = h return out
[docs] def get_config(self): """Update layer config.""" config = super(PoolingNodesAttentive, self).get_config() config.update({"units": self.units, "depth": self.depth, "pooling_method": self.pooling_method}) conf_sub = self.lay_alpha.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation", "use_bias"]: if x in conf_sub.keys(): config.update({x: conf_sub[x]}) conf_context = self.lay_final_activ.get_config() config.update({"activation_context": conf_context["activation"]}) conf_gru = self.lay_gru.get_config() for x in ["recurrent_activation", "recurrent_initializer", "recurrent_regularizer", "recurrent_constraint", "dropout", "recurrent_dropout", "reset_after"]: if x in conf_gru.keys(): config.update({x: conf_gru[x]}) return config