Source code for kgcnn.layers.set2set

import keras as ks
from keras import ops
from kgcnn.ops.scatter import scatter_reduce_sum, scatter_reduce_max
from kgcnn.layers.aggr import Aggregate

# Order Matters: Sequence to sequence for sets
# by Vinyals et al. 2016
# https://arxiv.org/abs/1511.06391


[docs]class PoolingSet2SetEncoder(ks.layers.Layer): r"""Pooling Node or edge embeddings by the Set2Set encoder part from layer. This was first proposed by `NMPNN <http://arxiv.org/abs/1704.01212>`__ . The Reading to Memory has to be handled separately. Uses a keras LSTM layer for the updates. """
[docs] def __init__(self, # Args channels, T=3, # noqa pooling_method='mean', init_qstar='mean', # Args for LSTM activation="tanh", recurrent_activation="sigmoid", use_bias=True, kernel_initializer="glorot_uniform", recurrent_initializer="orthogonal", bias_initializer="zeros", unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, implementation=2, return_sequences=False, # Should not be changed here return_state=False, # Should not be changed here go_backwards=False, # Should not be changed here stateful=False, # time_major=False, unroll=False, **kwargs): """Initialize layer. Args: channels (int): Number of channels for the LSTM update. T (int): Numer of iterations. Default is T=3. pooling_method : Pooling method for PoolingSet2SetEncoder. Default is 'mean'. init_qstar: How to generate the first q_star vector. Default is 'mean'. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. Default: sigmoid (`sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean (default `True`), whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`. bias_initializer: Initializer for the bias vector. Default: `zeros`. unit_forget_bias: Boolean (default `True`). If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. Default: `None`. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_regularizer: Regularizer function applied to the bias vector. Default: `None`. activity_regularizer: Regularizer function applied to the output of the layer (its "activation"). Default: `None`. kernel_constraint: Constraint function applied to the `kernel` weights matrix. Default: `None`. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_constraint: Constraint function applied to the bias vector. Default: `None`. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. return_sequences: Boolean. Whether to return the last output. in the output sequence, or the full sequence. Default: `False`. return_state: Boolean. Whether to return the last state in addition to the output. Default: `False`. go_backwards: Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence. stateful: Boolean (default `False`). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. unroll: Boolean (default `False`). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. """ super(PoolingSet2SetEncoder, self).__init__(**kwargs) # Number of Channels to use in LSTM self.channels = channels self.T = T # Number of Iterations to work on memory self.pooling_method = pooling_method self.init_qstar = init_qstar # Reduction of messages for f_et self._reduce_keys = { "sum": ops.sum, "mean": ops.mean, "max": ops.max, "min": ops.min, "var": ops.var, } if self.pooling_method not in self._reduce_keys: raise ValueError("ERROR:kgcnn: Unknown reduction '%s', choose one of '%s'." % ( self.pooling_method, self._reduce_keys.keys())) self._reduce = self._reduce_keys[self.pooling_method] self._pool_init = None if self.init_qstar in ["0", "zeros", "zero"]: self.qstar0 = self.init_qstar_0 elif self.init_qstar in ["ref", "reference", "input"]: self.qstar0 = self.init_qstar_ref else: self._pool_init = Aggregate(pooling_method=self.init_qstar) self.qstar0 = self.init_qstar_pool # LSTM Layer to run on m self.lay_lstm = ks.layers.LSTM( channels, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, unit_forget_bias=unit_forget_bias, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, # time_major=time_major, unroll=unroll )
[docs] def build(self, input_shape): """Build layer.""" assert len(input_shape) == 3 ref_shape, attr_shape, index_shape = [list(x) for x in input_shape] if self._pool_init is not None: self._pool_init.build([attr_shape, index_shape, ref_shape]) self.lay_lstm.build(tuple(ref_shape[:1] + [1] + [2*self.channels])) self.built = True
[docs] def compute_output_shape(self, input_shape): assert len(input_shape) == 3 ref_shape, attr_shape, index_shape = [list(x) for x in input_shape] return tuple(ref_shape[:1] + [1] + [2*self.channels])
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs: [reference, nodes, batch_index] - reference (Tensor): Reference for aggregation of shape `(batch, ...)` . - nodes (Tensor): Node embeddings of shape `([N], F)` . - batch_index (Tensor): Batch assignment of shape `([N], )` . Returns: Tensor: Pooled tensor q_star of shape `(batch, 1, 2*channels)` """ ref, x, batch_index = inputs # Reading to memory removed here, is to be done by separately m = x # ([N], feat) # Initialize q0 and r0 q_star = self.qstar0(m, batch_index, ref) # start loop for i in range(0, self.T): q = self.lay_lstm(q_star) # (batch, feat) q_star = self.update_q(q, m, batch_index, ref) return q_star
[docs] def update_q(self, q, m, batch_index, ref): qt = ops.take(q, batch_index, axis=0) et = self.f_et(m, qt) # (batch*num,) # get at = exp(et)/sum(et) with sum(et) at = ops.exp(et - self._get_scale_per_sample(et, batch_index, ref)) # (batch*num,) norm = self._get_norm(at, batch_index, ref) # (batch*num,) at = norm * at # (batch*num,) x (batch*num,) # calculate rt # at = ops.expand_dims(at, axis=1) rt = m * at # (batch*num,feat) x (batch*num,1) rt = self._pool_sum([rt, batch_index, ref]) # (batch,feat) # qstar = [q,r] q_star = ops.concatenate([q, rt], axis=1) # (batch,2*feat) q_star = ops.expand_dims(q_star, axis=1) # (batch,1,2*feat) return q_star
[docs] def f_et(self, fm, fq): r"""Function to compute scalar from :math:`m` and :math:`q` . Uses :obj:`pooling_method` argument of the layer. Args: fm (Tensor): of shape `([N], F)` . fq (Tensor): of shape `([N], F)` . Returns: Tensor: et of shape `([N], )` . """ return self._reduce(fm * fq, axis=1, keepdims=True) # ([N], )
[docs] @staticmethod def _get_scale_per_batch(x): """Get re-scaling for the batch.""" return ops.max(x, axis=0, keepdims=True)
[docs] def _get_scale_per_sample(self, x, ind, ref): """Get re-scaling for the sample.""" out = self._pool_max([x, ind, ref]) # (batch,) out = ops.take(out, ind, axis=0) # (batch*num,) return out
[docs] def _get_norm(self, x, ind, ref): """Compute Norm.""" norm = self._pool_sum([x, ind, ref]) # (batch,) norm = ops.divide(ops.convert_to_tensor(1.0, dtype=norm.dtype), norm) # (batch,) norm = ops.where(ops.logical_or(ops.isnan(norm), ops.isinf(norm)), 0., norm) norm = ops.take(norm, ind, axis=0) # (batch*num,) return norm
[docs] def init_qstar_ref(self, m, batch_index, reference): return reference
[docs] def init_qstar_0(self, m, batch_index, reference): """Initialize the q0 with zeros.""" batch_shape = ops.shape(reference)[0] return ops.zeros((batch_shape, 1, 2 * self.channels), dtype=m.dtype)
[docs] def init_qstar_pool(self, m, batch_index, reference): """Initialize the q0 with mean.""" # batch_shape = ksb.shape(batch_num) q = self._pool_init([m, batch_index, reference]) # (batch,feat) qstar= self.update_q(q, m, batch_index, reference) return qstar
[docs] def get_config(self): """Make config for layer.""" config = super(PoolingSet2SetEncoder, self).get_config() config.update({"channels": self.channels, "T": self.T, "pooling_method": self.pooling_method, "init_qstar": self.init_qstar}) lstm_conf = self.lay_lstm.get_config() lstm_param = ["activation", "recurrent_activation", "use_bias", "kernel_initializer", "recurrent_initializer", "bias_initializer", "unit_forget_bias", "kernel_regularizer", "recurrent_regularizer", "bias_regularizer", "activity_regularizer", "kernel_constraint", "recurrent_constraint", "bias_constraint", "dropout", "recurrent_dropout", "implementation", "return_sequences", # Should not be changed here "return_state", # Should not be changed here "go_backwards", # Should not be changed here "stateful", # "time_major", "unroll"] for x in lstm_param: if x in lstm_conf.keys(): config.update({x: lstm_conf[x]}) return config
def _pool_sum(self, inputs): values, indices, ref = inputs shape_ = ops.shape(ref)[:1] + ops.shape(values)[1:] return scatter_reduce_sum(indices, values, shape=shape_) def _pool_max(self, inputs): values, indices, ref = inputs shape_ = ops.shape(ref)[:1] + ops.shape(values)[1:] return scatter_reduce_max(indices, values, shape=shape_)