Source code for kgcnn.layers.message

import keras as ks
from keras import ops
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.aggr import AggregateLocalEdges

[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='MessagePassingBase') class MessagePassingBase(ks.layers.Layer): r"""Base layer for Message passing type networks. This is a general frame to implement custom message and update functions. The idea is to create a subclass of :obj:`MessagePassingBase` and then just implement the methods :obj:`message_function` and :obj:`update_nodes`. The pooling or aggregating is handled by built-in :obj:`AggregateLocalEdges`. Alternatively also :obj:`aggregate_message` could be overwritten. The original message passing scheme was proposed by `NMPNN <>`__ . """
[docs] def __init__(self, pooling_method: str = "scatter_sum", use_id_tensors: int = None, **kwargs): r"""Initialize :obj:`MessagePassingBase` layer. Args: pooling_method (str): Aggregation method of edges. Default is "sum". use_id_tensors (int): Whether :obj:`call` receives graph ID information, which is passed onto message and aggregation function. Specifies the number of additional tensors. """ super(MessagePassingBase, self).__init__(**kwargs) self.pooling_method = pooling_method self.lay_gather = GatherNodes(concat_axis=None) self.use_id_tensors = use_id_tensors self.lay_pool_default = AggregateLocalEdges(pooling_method=self.pooling_method)
[docs] def build(self, input_shape): super(MessagePassingBase, self).build(input_shape)
[docs] def message_function(self, inputs, **kwargs): r"""Defines the message function, i.e. a method the generates a message from node and edge embeddings at a certain depth (not considered here). Args: inputs: [nodes_in, nodes_out, edge_index] - nodes_in (Tensor): Receiving node embeddings of shape ([M], F) - nodes_out (Tensor): Sending node embeddings of shape ([M], F) - edges (Tensor, optional): Edge or message embeddings of shape ([M], F) Returns: Tensor: Messages for each edge of shape ([M], F) """ raise NotImplementedError( "A method to generate messages must be implemented in sub-class of `MessagePassingBase`.")
[docs] def aggregate_message(self, inputs, **kwargs): r"""Pre-defined message aggregation that uses :obj:`AggregateLocalEdges`. Args: inputs: [nodes, edges, edge_index] - nodes (Tensor): Node embeddings of shape ([N], F) - edges (Tensor): Edge or message embeddings of shape ([M], F) - edge_index (Tensor): Edge indices referring to nodes of shape (2, [M]) Returns: Tensor: Aggregated edge embeddings per node of shape ([N], F) """ return self.lay_pool_default(inputs, **kwargs)
[docs] def update_nodes(self, inputs, **kwargs): r"""Defines the update function, i.e. a method that updates the node embeddings from aggregated messages. Args: inputs: [nodes, node_updates] - nodes (Tensor): Node embeddings (from previous step) of shape ([N], F) - node_updates (Tensor): Updates for nodes of shape ([N], F) Returns: Tensor: Updated node embeddings (for next step) of shape ([N], F) """ raise NotImplementedError( "A method to update nodes must be implemented in sub-class of `MessagePassingBase`.")
[docs] def call(self, inputs, **kwargs): r"""Pre-implemented standard message passing scheme using :obj:`update_nodes`, :obj:`aggregate_message` and :obj:`message_function`. Args: inputs: [nodes, edges, edge_index] - nodes (Tensor): Node embeddings of shape ([N], F) - edges (Tensor, optional): Edge or message embeddings of shape ([M], F) - edge_index (Tensor): Edge indices referring to nodes of shape (2, [M]) Returns: Tensor: Updated node embeddings of shape ([N], F) """ if self.use_id_tensors is not None: ids = inputs[-int(self.use_id_tensors):] inputs = inputs[:-int(self.use_id_tensors)] else: ids = [] if len(inputs) == 2: nodes, edge_index = inputs[:2] edges = None else: nodes, edges, edge_index = inputs[:3] n_in, n_out = self.lay_gather([nodes, edge_index], **kwargs) if edges is None: msg = self.message_function([n_in, n_out] + ids, **kwargs) else: msg = self.message_function([n_in, n_out, edges] + ids, **kwargs) pool_n = self.aggregate_message([nodes, msg, edge_index], **kwargs) n_new = self.update_nodes([nodes, pool_n] + ids, **kwargs) return n_new
[docs] def get_config(self): """Update config.""" config = super(MessagePassingBase, self).get_config() config.update({"pooling_method": self.pooling_method, "use_id_tensors": self.use_id_tensors}) return config
[docs]class MatMulMessages(ks.layers.Layer): r"""Linear transformation of edges or messages, i.e. matrix multiplication. The message dimension must be suitable for matrix multiplication. The actual matrix is not a trainable weight of this layer but passed as input. This was proposed by `NMPNN <>`__ . For each node or edge :math:`i` the output is given by: .. math:: x_i' = \mathbf{A_i} \; x_i """
[docs] def __init__(self, **kwargs): """Initialize layer.""" super(MatMulMessages, self).__init__(**kwargs)
[docs] def build(self, input_shape): """Build layer.""" super(MatMulMessages, self).build(input_shape)
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs (list): [mat, edges] - mat (Tensor): Transformation matrix for each message of shape ([M], F', F). - edges (Tensor): Edge embeddings or messages ([M], F) Returns: Tensor: Transformation of messages by matrix multiplication of shape (batch, [M], F') """ mat, e = inputs e_e = ops.expand_dims(e, axis=1) out = ops.sum(mat*e_e, axis=2) return out
[docs] def get_config(self): """Update layer config.""" config = super(MatMulMessages, self).get_config() return config