Models

Some example data to show model inputs.

[1]:
%%capture
import keras as ks

Functional API

Like most models in kgcnn.literature the models can be set up with the keras functional API. Here an example for a simple message passing GNN. The layers are taken from kgcnn.layers . See documentation of layers for further details.

[2]:
from kgcnn.layers.casting import CastRaggedIndicesToDisjoint
from kgcnn.layers.gather import GatherNodes
from kgcnn.layers.pooling import PoolingNodes
from kgcnn.layers.aggr import AggregateLocalEdges
from kgcnn.layers.modules import Input

ns = Input(shape=(None, 1), dtype="float32", ragged=True)
e_idx = Input(shape=(None, 2), dtype="int64", ragged=True)

# Model is build with ragged input.
n, idx, batch_id, _, _, _, total_n, total_e = CastRaggedIndicesToDisjoint()([ns, e_idx])
n_in_out = GatherNodes()([n, idx])
node_messages = ks.layers.Dense(64, activation='relu')(n_in_out)
node_updates = AggregateLocalEdges()([n, node_messages, idx])
n_node_updates = ks.layers.Concatenate()([n, node_updates])
n_embedding = ks.layers.Dense(1)(n_node_updates)
g_embedding = PoolingNodes()([total_n, n_embedding, batch_id])

message_passing = ks.models.Model(inputs=[ns, e_idx], outputs=g_embedding)

Subclassing Model

A model can be constructed by subclassing from keras.models.Model where the call method must be implemented.

[3]:
from kgcnn.layers.casting import CastBatchedIndicesToDisjoint

class MessagePassingModel(ks.models.Model):

    def __init__(self):
        super().__init__()
        self._layer_casting = CastBatchedIndicesToDisjoint(uses_mask=False)
        self._layer_gather_nodes = GatherNodes()
        self._layer_dense = ks.layers.Dense(64, activation='relu')
        self._layer_aggregate_edges = AggregateLocalEdges()
        self._layer_cat = ks.layers.Concatenate(axis=-1)
        self._layer_dense_last = ks.layers.Dense(1)
        self._layer_pool_nodes = PoolingNodes()

    def build(self, input_shape):
        super().build(input_shape)

    def call(self, inputs, **kwargs):
        # Here we use padded input
        # nodes, indices, total_nodes, total_edges = inputs
        n, idx, batch_id, _, _, _, total_n, total_e = self._layer_casting(inputs)
        n_in_out = self._layer_gather_nodes([n, idx])
        node_messages = self._layer_dense(n_in_out)
        node_updates = self._layer_aggregate_edges([n, node_messages, idx])
        n_node_updates = self._layer_cat([n, node_updates])
        n_embedding = self._layer_dense_last(n_node_updates)
        g_embedding = self._layer_pool_nodes([total_n, n_embedding, batch_id])
        return g_embedding

message_passing_2 = MessagePassingModel()

Templates

Also layers can be further subclassed to create a GNN, for example of the message passing base layer. Where only message_function and update_nodes must be implemented.

[4]:
from kgcnn.layers.message import MessagePassingBase

class MyMessageNN(MessagePassingBase):

    def __init__(self, units, **kwargs):
        super(MyMessageNN, self).__init__(**kwargs)
        self.dense = ks.layers.Dense(units, activation='relu')
        self.cat = ks.layers.Concatenate(axis=-1)

    def message_function(self, inputs, **kwargs):
        n_in, n_out = inputs
        n_in_out = self.cat([n_in, n_out])
        return self.dense(n_in_out, **kwargs)

    def update_nodes(self, inputs, **kwargs):
        nodes, nodes_update = inputs
        return self.cat([nodes, nodes_update], **kwargs)

# Here we use direct disjoint input
n = ks.layers.Input(shape=(1, ), dtype="float32")
idx = ks.layers.Input(shape=(None, ), dtype="int64")
batch_id = ks.layers.Input(shape=(), dtype="int64")
total_n = ks.layers.Input(shape=(), dtype="int64")

n_node_updates = MyMessageNN(units=64)([n, idx])
n_embedding = ks.layers.Dense(1)(n_node_updates)
g_embedding = PoolingNodes()([total_n, n_embedding, batch_id])

message_passing_3 = ks.models.Model(inputs=[n, idx, batch_id, total_n], outputs=g_embedding)

Loading options

There are many options to load data to a keras model, which depend on the size and location of the data to pass to the model. There may differences in speed and utility depending on the loading method. For more examples, please find https://github.com/aimat-lab/gcnn_keras/blob/master/notebooks/tutorial_model_loading_options.ipynb .

In general padded tensor is most convenient and natural to keras but comes with a significant performance drop. Ragged tensor is restricted to tensorflow as of now but will likely be extended to pytorch as well. Direct disjoint input is most efficient but requires a dataloader to use.

1. Padded Tensor

The most simple way to pass tensors to the model is to simply pad to same size tensor. For the model input further information is required on the padding. Either a length tensor or a mask.

[5]:
from keras import ops
example_nodes = ops.convert_to_tensor([[[1.], [2.]], [[1.0], [0.0]], [[2.0], [0.0]], [[4.0], [0.0]]])
example_indices = ops.convert_to_tensor([[[0, 1], [1, 0], [1,1]], [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]], [[0, 0], [0, 0], [0, 0]]], dtype="int64")
example_total_nodes = ops.convert_to_tensor([2, 1, 1, 1], dtype="int64")
example_total_edges = ops.convert_to_tensor([3, 1, 1, 1], dtype="int64")
example_graph_labels = ops.convert_to_tensor([[1.0], [0.1], [0.3], [0.6]], dtype="float32")
[6]:
message_passing_2.predict([example_nodes, example_indices, example_total_nodes, example_total_edges])
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 818ms/step
[6]:
array([[1.340108  ],
       [0.26912418],
       [0.53824836],
       [1.0764967 ]], dtype=float32)
[7]:
message_passing_2.compile(loss="mean_absolute_error")
message_passing_2.fit(x=[example_nodes, example_indices, example_total_nodes, example_total_edges], y=example_graph_labels, batch_size=2, epochs=7)
Epoch 1/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2488
Epoch 2/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1036
Epoch 3/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0680
Epoch 4/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.1017
Epoch 5/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0918
Epoch 6/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0945
Epoch 7/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0903
[7]:
<keras.src.callbacks.history.History at 0x2601c652dd0>

2. Ragged input

More data efficient is ragged or jagged tensor input.

[8]:
import tensorflow as tf
from keras.backend import backend
if backend() == "tensorflow":
    example_nodes = tf.ragged.constant([[[1.], [2.]], [[1.0]], [[2.0]], [[4.0]]], ragged_rank=1)
    example_indices =  tf.ragged.constant([[[0, 1], [1, 0], [1,1]], [[0, 0]], [[0, 0]], [[0, 0]]], dtype="int64", ragged_rank=1)
    print(example_nodes.shape, example_indices.shape)
elif backend() == "torch":
    # from torchrec.sparse.jagged_tensor import JaggedTensor
    raise NotImplementedError()
else:
    raise NotImplementedError()
(4, None, 1) (4, None, 2)
[9]:
message_passing.predict([example_nodes, example_indices])
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 352ms/step
[9]:
array([[-2.2719753],
       [-0.5563201],
       [-1.1126401],
       [-2.2252803]], dtype=float32)
[10]:
message_passing.compile(loss="mean_absolute_error")
message_passing.fit(x=[example_nodes, example_indices], y=example_graph_labels, batch_size=2, epochs=7)
Epoch 1/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 1.6347
Epoch 2/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 1.6999
Epoch 3/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 1.6529
Epoch 4/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 1.4860
Epoch 5/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 1.1121
Epoch 6/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 1.2790
Epoch 7/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 1.1712
[10]:
<keras.src.callbacks.history.History at 0x2601f90ece0>

3. Direct disjoint input via data loader.

We need to construct a data pipeline. Fully working datapipelines will be provided in kgcnn.io . They can be either based on tf.data or torch.Dataloader .

[11]:
example_nodes = [[[1.], [2.]], [[1.0]], [[2.0]], [[4.0]]]
example_indices =  [[[0, 1], [1, 0], [1,1]], [[0, 0]], [[0, 0]],  [[0, 0]]]
example_graph_labels = [[1.0], [0.1], [0.3], [0.6]]
print(len(example_nodes), len(example_indices), len(example_graph_labels))
4 4 4
[12]:
batch_size = 2
data_length = 3

# Minimal example to generate disjoint input from baisc operations.
def gen():
    for i in range(0, data_length, batch_size):
        gen_nodes = tf.concat(example_nodes[i:i+batch_size], axis=0)
        gen_total_nodes = tf.constant([len(x) for x in example_nodes[i:i+batch_size]], dtype="int64")
        gen_total_edges = tf.constant([len(x) for x in example_indices[i:i+batch_size]], dtype="int64")
        gen_batch_id = tf.repeat(tf.range(len(gen_total_nodes), dtype="int64"), gen_total_nodes)
        gen_indices = tf.cast(tf.concat(example_indices[i:i+batch_size], axis=0), dtype="int64")
        gen_node_splits = tf.pad(tf.cumsum(gen_total_nodes), [[1, 0]])
        gen_indices_offset = tf.expand_dims(tf.repeat(gen_node_splits[:-1], gen_total_edges), axis=-1)
        gen_indices = gen_indices + gen_indices_offset
        gen_indices = tf.transpose(gen_indices)
        yield (gen_nodes, gen_indices, gen_batch_id, gen_total_nodes)

ds_x_batch = tf.data.Dataset.from_generator(
    gen,
    output_signature=(
        tf.TensorSpec(shape=(None, 1), dtype="float32"),
        tf.TensorSpec(shape=(2, None), dtype="int64"),
        tf.TensorSpec(shape=(None, ), dtype="int64"),
        tf.TensorSpec(shape=(None, ), dtype="int64"),
    )
)
ds_y_batch = tf.data.Dataset.from_tensor_slices(tf.constant(example_graph_labels)).batch(batch_size)
ds_batch = tf.data.Dataset.zip((ds_x_batch, ds_y_batch))
[13]:
message_passing_3.compile(loss="mean_absolute_error")
message_passing_3.fit(ds_batch, epochs=7, batch_size=None, steps_per_epoch=None, verbose=1)
Epoch 1/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 2s 905ms/step - loss: 0.3999
Epoch 2/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 0.2608
Epoch 3/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.2581
Epoch 4/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 0.2573
Epoch 5/7
1/2 ━━━━━━━━━━━━━━━━━━━━ 0s 30ms/step - loss: 0.4163
C:\Users\patri\anaconda3\envs\gcnn_keras_test\lib\contextlib.py:153: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
  self.gen.throw(typ, value, traceback)
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 0.2573
Epoch 6/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 0.2686
Epoch 7/7
2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.2421
[13]:
<keras.src.callbacks.history.History at 0x2601fb2fe50>

NOTE: You can find this page as jupyter notebook in https://github.com/aimat-lab/gcnn_keras/tree/master/docs/source