Source code for kgcnn.layers.attention

# import keras_core as ks
from kgcnn.layers.gather import GatherNodesIngoing, GatherNodesOutgoing
from keras.layers import Dense, Concatenate, Activation, Average, Layer
from kgcnn.layers.aggr import AggregateLocalEdgesAttention
from keras import ops


[docs]class AttentionHeadGAT(Layer): # noqa r"""Computes the attention head according to `GAT <https://arxiv.org/abs/1710.10903>`__ . The attention coefficients are computed by :math:`a_{ij} = \sigma(a^T W n_i || W n_j)`, optionally by :math:`a_{ij} = \sigma( W n_i || W n_j || e_{ij})` with edges :math:`e_{ij}`. The attention is obtained by :math:`\alpha_{ij} = \text{softmax}_j (a_{ij})`. And the messages are pooled by :math:`m_i = \sum_j \alpha_{ij} W n_j`. If the graph has no self-loops, they must be added beforehand or use external skip connections. And optionally passed through an activation :math:`h_i = \sigma(\sum_j \alpha_{ij} W n_j)`. An edge is defined by index tuple :math:`(i, j)` with the direction of the connection from :math:`j` to :math:`i`. """
[docs] def __init__(self, units, use_edge_features=False, use_final_activation=True, has_self_loops=True, activation="kgcnn>leaky_relu", use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', normalize_softmax: bool = False, **kwargs): """Initialize layer. Args: units (int): Units for the linear trafo of node features before attention. use_edge_features (bool): Append edge features to attention computation. Default is False. use_final_activation (bool): Whether to apply the final activation for the output. has_self_loops (bool): If the graph has self-loops. Not used here. Default is True. activation (str): Activation. Default is "kgcnn>leaky_relu", use_bias (bool): Use bias. Default is True. kernel_regularizer: Kernel regularization. Default is None. bias_regularizer: Bias regularization. Default is None. activity_regularizer: Activity regularization. Default is None. kernel_constraint: Kernel constrains. Default is None. bias_constraint: Bias constrains. Default is None. kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. bias_initializer: Initializer for bias. Default is 'zeros'. """ super(AttentionHeadGAT, self).__init__(**kwargs) self.use_edge_features = use_edge_features self.use_final_activation = use_final_activation self.has_self_loops = has_self_loops self.normalize_softmax = normalize_softmax self.units = int(units) self.use_bias = use_bias kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args) self.lay_alpha = Dense(1, activation=activation, use_bias=False, **kernel_args) self.lay_gather_in = GatherNodesIngoing() self.lay_gather_out = GatherNodesOutgoing() self.lay_concat = Concatenate(axis=-1) self.lay_pool_attention = AggregateLocalEdgesAttention(normalize_softmax=normalize_softmax) if self.use_final_activation: self.lay_final_activ = Activation(activation=activation)
[docs] def build(self, input_shape): """Build layer.""" super(AttentionHeadGAT, self).build(input_shape)
[docs] def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): of [node, edges, edge_indices] - nodes (Tensor): Node embeddings of shape ([N], F) - edges (Tensor): Edge or message embeddings of shape ([M], F) - edge_indices (Tensor): Edge indices referring to nodes of shape (2, [M]) Returns: Tensor: Embedding tensor of pooled edge attentions for each node. """ node, edge, edge_index = inputs w_n = self.lay_linear_trafo(node, **kwargs) wn_in = self.lay_gather_in([w_n, edge_index], **kwargs) wn_out = self.lay_gather_out([w_n, edge_index], **kwargs) if self.use_edge_features: e_ij = self.lay_concat([wn_in, wn_out, edge], **kwargs) else: e_ij = self.lay_concat([wn_in, wn_out], **kwargs) a_ij = self.lay_alpha(e_ij, **kwargs) # Should be dimension (batch, None,1) h_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs) if self.use_final_activation: h_i = self.lay_final_activ(h_i, **kwargs) return h_i
[docs] def get_config(self): """Update layer config.""" config = super(AttentionHeadGAT, self).get_config() config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias, "units": self.units, "has_self_loops": self.has_self_loops, "normalize_softmax": self.normalize_softmax, "use_final_activation": self.use_final_activation}) conf_sub = self.lay_alpha.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: if x in conf_sub: config.update({x: conf_sub[x]}) return config
[docs]class AttentionHeadGATV2(Layer): # noqa r"""Computes the modified attention head according to `GATv2 <https://arxiv.org/pdf/2105.14491.pdf>`__ . The attention coefficients are computed by :math:`a_{ij} = a^T \sigma( W [n_i || n_j] )`, optionally by :math:`a_{ij} = a^T \sigma( W [n_i || n_j || e_{ij}] )` with edges :math:`e_{ij}`. The attention is obtained by :math:`\alpha_{ij} = \text{softmax}_j (a_{ij})`. And the messages are pooled by :math:`m_i = \sum_j \alpha_{ij} e_{ij}`. If the graph has no self-loops, they must be added beforehand or use external skip connections. And optionally passed through an activation :math:`h_i = \sigma(\sum_j \alpha_{ij} e_{ij})`. An edge is defined by index tuple :math:`(i, j)` with the direction of the connection from :math:`j` to :math:`i`. """
[docs] def __init__(self, units, use_edge_features=False, use_final_activation=True, has_self_loops=True, activation="kgcnn>leaky_relu", use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', normalize_softmax: bool = False, **kwargs): """Initialize layer. Args: units (int): Units for the linear trafo of node features before attention. use_edge_features (bool): Append edge features to attention computation. Default is False. use_final_activation (bool): Whether to apply the final activation for the output. has_self_loops (bool): If the graph has self-loops. Not used here. Default is True. activation (str): Activation. Default is "kgcnn>leaky_relu", use_bias (bool): Use bias. Default is True. kernel_regularizer: Kernel regularization. Default is None. bias_regularizer: Bias regularization. Default is None. activity_regularizer: Activity regularization. Default is None. kernel_constraint: Kernel constrains. Default is None. bias_constraint: Bias constrains. Default is None. kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. bias_initializer: Initializer for bias. Default is 'zeros'. """ super(AttentionHeadGATV2, self).__init__(**kwargs) self.use_edge_features = use_edge_features self.use_final_activation = use_final_activation self.has_self_loops = has_self_loops self.units = int(units) self.normalize_softmax = normalize_softmax self.use_bias = use_bias kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args) self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args) self.lay_gather_in = GatherNodesIngoing() self.lay_gather_out = GatherNodesOutgoing() self.lay_concat = Concatenate(axis=-1) self.lay_pool_attention = AggregateLocalEdgesAttention(normalize_softmax=normalize_softmax) if self.use_final_activation: self.lay_final_activ = Activation(activation=activation)
[docs] def build(self, input_shape): """Build layer.""" super(AttentionHeadGATV2, self).build(input_shape)
[docs] def call(self, inputs, **kwargs): """Forward pass. Args: inputs (list): of [node, edges, edge_indices] - nodes (Tensor): Node embeddings of shape ([N], F) - edges (Tensor): Edge or message embeddings of shape ([M], F) - edge_indices (Tensor): Edge indices referring to nodes of shape (2, [M]) Returns: Tensor: Embedding tensor of pooled edge attentions for each node. """ node, edge, edge_index = inputs w_n = self.lay_linear_trafo(node, **kwargs) n_in = self.lay_gather_in([node, edge_index], **kwargs) n_out = self.lay_gather_out([node, edge_index], **kwargs) wn_out = self.lay_gather_out([w_n, edge_index], **kwargs) if self.use_edge_features: e_ij = self.lay_concat([n_in, n_out, edge], **kwargs) else: e_ij = self.lay_concat([n_in, n_out], **kwargs) a_ij = self.lay_alpha_activation(e_ij, **kwargs) a_ij = self.lay_alpha(a_ij, **kwargs) h_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs) if self.use_final_activation: h_i = self.lay_final_activ(h_i, **kwargs) return h_i
[docs] def get_config(self): """Update layer config.""" config = super(AttentionHeadGATV2, self).get_config() config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias, "units": self.units, "has_self_loops": self.has_self_loops, "normalize_softmax": self.normalize_softmax, "use_final_activation": self.use_final_activation}) conf_sub = self.lay_alpha_activation.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: if x in conf_sub: config.update({x: conf_sub[x]}) return config
[docs]class MultiHeadGATV2Layer(AttentionHeadGATV2): # noqa def __init__(self, units: int, num_heads: int, activation: str = 'kgcnn>leaky_relu', use_bias: bool = True, concat_heads: bool = True, **kwargs): super(MultiHeadGATV2Layer, self).__init__( units=units, activation=activation, use_bias=use_bias, **kwargs ) self.num_heads = num_heads self.concat_heads = concat_heads self.head_layers = [] for _ in range(num_heads): lay_linear = Dense(units, activation=activation, use_bias=use_bias) lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias) lay_alpha = Dense(1, activation='linear', use_bias=False) self.head_layers.append((lay_linear, lay_alpha_activation, lay_alpha)) self.lay_concat_alphas = Concatenate(axis=-2) self.lay_concat_embeddings = Concatenate(axis=-2) self.lay_pool_attention = AggregateLocalEdgesAttention() # self.lay_pool = AggregateLocalEdges() if self.concat_heads: self.lay_combine_heads = Concatenate(axis=-1) else: self.lay_combine_heads = Average() def __call__(self, inputs, **kwargs): node, edge, edge_index = inputs # "a_ij" is a single-channel edge attention logits tensor. "a_ijs" is consequently the list which # stores these tensors for each attention head. # "h_i" is a single-channel node embedding tensor. "h_is" is consequently the list which stores # these tensors for each attention head. a_ijs = [] h_is = [] for k, (lay_linear, lay_alpha_activation, lay_alpha) in enumerate(self.head_layers): # Copied from the original class w_n = lay_linear(node, **kwargs) n_in = self.lay_gather_in([node, edge_index], **kwargs) n_out = self.lay_gather_out([node, edge_index], **kwargs) wn_out = self.lay_gather_out([w_n, edge_index], **kwargs) if self.use_edge_features: e_ij = self.lay_concat([n_in, n_out, edge], **kwargs) else: e_ij = self.lay_concat([n_in, n_out], **kwargs) # a_ij: ([batch], [M], 1) a_ij = lay_alpha_activation(e_ij, **kwargs) a_ij = lay_alpha(a_ij, **kwargs) # h_i: ([batch], [N], F) h_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs) if self.use_final_activation: h_i = self.lay_final_activ(h_i, **kwargs) # a_ij after expand: ([batch], [M], 1, 1) a_ij = ops.expand_dims(a_ij, axis=-2) a_ijs.append(a_ij) # h_i = tf.expand_dims(h_i, axis=-2) h_is.append(h_i) a_ijs = self.lay_concat_alphas(a_ijs) h_is = self.lay_combine_heads(h_is) # An important modification we need here is that this layer also returns the attention coefficients # because in MEGAN we need those to calculate the edge attention values with! # h_is: ([batch], [N], K * Vu) or ([batch], [N], Vu) # a_ijs: ([batch], [M], K, 1) return h_is, a_ijs
[docs] def get_config(self): """Update layer config.""" config = super(MultiHeadGATV2Layer, self).get_config() config.update({ 'num_heads': self.num_heads, 'concat_heads': self.concat_heads }) return config
[docs]class AttentiveHeadFP(Layer): r"""Computes the attention head for `Attentive FP <https://doi.org/10.1021/acs.jmedchem.9b00959>`__ model. The attention coefficients are computed by :math:`a_{ij} = \sigma_1( W_1 [h_i || h_j] )`. The initial representation :math:`h_i` and :math:`h_j` must be calculated beforehand. The attention is obtained by :math:`\alpha_{ij} = \text{softmax}_j (a_{ij})`. And finally pooled for context :math:`C_i = \sigma_2(\sum_j \alpha_{ij} W_2 h_j)`. An edge is defined by index tuple :math:`(i, j)` with the direction of the connection from :math:`j` to :math:`i`. """
[docs] def __init__(self, units, use_edge_features=False, activation='kgcnn>leaky_relu', activation_context="elu", use_bias=True, kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', **kwargs): """Initialize layer. Args: units (int): Units for the linear trafo of node features before attention. use_edge_features (bool): Append edge features to attention computation. Default is False. activation (str): Activation. Default is {"class_name": "kgcnn>leaky_relu", "config": {"alpha": 0.2}}. activation_context (str): Activation function for context. Default is "elu". use_bias (bool): Use bias. Default is True. kernel_regularizer: Kernel regularization. Default is None. bias_regularizer: Bias regularization. Default is None. activity_regularizer: Activity regularization. Default is None. kernel_constraint: Kernel constrains. Default is None. bias_constraint: Bias constrains. Default is None. kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. bias_initializer: Initializer for bias. Default is 'zeros'. """ super(AttentiveHeadFP, self).__init__(**kwargs) self.use_edge_features = use_edge_features self.units = int(units) self.use_bias = use_bias kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, "bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, "bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, "bias_initializer": bias_initializer} self.lay_linear_trafo = Dense(units, activation="linear", use_bias=use_bias, **kernel_args) self.lay_alpha_activation = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) self.lay_alpha = Dense(1, activation="linear", use_bias=False, **kernel_args) self.lay_gather_in = GatherNodesIngoing() self.lay_gather_out = GatherNodesOutgoing() self.lay_concat = Concatenate(axis=-1) self.lay_pool_attention = AggregateLocalEdgesAttention() self.lay_final_activ = Activation(activation=activation_context) if use_edge_features: self.lay_fc1 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) self.lay_fc2 = Dense(units, activation=activation, use_bias=use_bias, **kernel_args) self.lay_concat_edge = Concatenate(axis=-1)
[docs] def build(self, input_shape): """Build layer.""" super(AttentiveHeadFP, self).build(input_shape)
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs (list): [node, edges, edge_indices] - nodes (Tensor): Node embeddings of shape ([N], F) - edges (Tensor): Edge or message embeddings of shape ([M], F) - edge_indices (Tensor): Edge indices referring to nodes of shape ([M], 2) Returns: Tensor: Hidden tensor of pooled edge attentions for each node. """ node, edge, edge_index = inputs if self.use_edge_features: n_in = self.lay_gather_in([node, edge_index], **kwargs) n_out = self.lay_gather_out([node, edge_index], **kwargs) n_in = self.lay_fc1(n_in, **kwargs) n_out = self.lay_concat_edge([n_out, edge], **kwargs) n_out = self.lay_fc2(n_out, **kwargs) else: n_in = self.lay_gather_in([node, edge_index], **kwargs) n_out = self.lay_gather_out([node, edge_index], **kwargs) wn_out = self.lay_linear_trafo(n_out, **kwargs) e_ij = self.lay_concat([n_in, n_out], **kwargs) e_ij = self.lay_alpha_activation(e_ij, **kwargs) # Maybe uses GAT original definition. # a_ij = e_ij a_ij = self.lay_alpha(e_ij, **kwargs) # Should be dimension (None, 1) not fully clear in original paper. n_i = self.lay_pool_attention([node, wn_out, a_ij, edge_index], **kwargs) out = self.lay_final_activ(n_i, **kwargs) return out
[docs] def get_config(self): """Update layer config.""" config = super(AttentiveHeadFP, self).get_config() config.update({"use_edge_features": self.use_edge_features, "use_bias": self.use_bias, "units": self.units}) conf_sub = self.lay_alpha_activation.get_config() for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", "bias_constraint", "kernel_initializer", "bias_initializer", "activation"]: if x in conf_sub.keys(): config.update({x: conf_sub[x]}) conf_context = self.lay_final_activ.get_config() config.update({"activation_context": conf_context["activation"]}) return config