Layers¶
The most general layers in kgcnn
take normal and ragged tensor as input. The graph oriented operations are im
The most general layers that kept maintained beyond different models with proper documentation are located in
kgcnn.layers
. These are:activ
Activation layers with learnable parameters.aggr
Aggregatoion layers for e.g. aggregating edge messages.attention
Layers for graph attention.casting
Layers for casting tensor formats.conv
Basic convolution layers.gather
Layers around tf.gather.geom
Geometry operations.message
Message passing base layer.mlp
Multi-layer perceptron for graphs.modules
Keras layers and modules to support ragged tensor input.norm
Normalization layers for graph tensors.polynom
Layers for Polynomials.pooling
General layers for standard aggregation and pooling.relational
Relational message processing.scale
Scaling layer to (constantly) rescale e.g. graph output.set2set
Set2Set type architectures for e.g. pooling nodes.update
Some node/edge update layers.
NOTE: Please check https://kgcnn.readthedocs.io/en/latest/kgcnn.layers.html for documentation of each layer.
Implementaion details¶
The following steps that are most representative for GNNs have layers in kgcnn.layers
.
Casting¶
Cast batched node and edge indices to a (single) disjoint graph representation of Pytorch Geometric (PyG). For PyG a batch of graphs is represented by single graph which contains disjoint sub-graphs, and the batch information is passed as batch ID tensor: graph_id_node
and graph_id_edge
. For keras padded tensors can be used to input into keras models
[1]:
from keras import ops
nodes = ops.convert_to_tensor([[[0.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]]])
edges = ops.convert_to_tensor([[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], [-1.0, 1.0, 1.0]]])
edge_indices = ops.convert_to_tensor([[[0, 0], [0, 1], [1, 0], [1, 1]], [[0, 0], [0, 1], [1, 0], [1, 1]]], dtype="int64")
node_mask = ops.convert_to_tensor([[True, False], [True, True]])
edge_mask = ops.convert_to_tensor([[True, False, False, False], [True, True, True, False]])
[2]:
from kgcnn.layers.casting import CastBatchedIndicesToDisjoint
disjoint_tensors = CastBatchedIndicesToDisjoint(uses_mask=True)([nodes, edge_indices, node_mask, edge_mask])
node_attr, disjoint_index, graph_id_node, graph_id_edge, node_id, edge_id, node_count, edge_count = disjoint_tensors
print("Disjoint index:\n", disjoint_index)
print("Node attributes:\n", node_attr)
print("Batch ID nodes:\n", graph_id_node)
Disjoint index:
tensor([[0, 1, 1, 2],
[0, 1, 2, 1]], device='cuda:0')
Node attributes:
tensor([[0., 1.],
[1., 0.],
[1., 1.]], device='cuda:0')
Batch ID nodes:
tensor([0, 1, 1], device='cuda:0')
Note that also ragged tensors can be used to input keras models which is much more effective and less costly in casting, but are only supported for tensorflow for now. If the tensor shape must not be changed for JAX also padded disjoint output can be generated with:
[3]:
disjoint_tensors = CastBatchedIndicesToDisjoint(uses_mask=True, padded_disjoint=True)([nodes, edge_indices, node_mask, edge_mask])
node_attr, disjoint_index, graph_id_node, graph_id_edge, node_id, edge_id, node_count, edge_count = disjoint_tensors
print("Disjoint index:\n", disjoint_index)
print("Node attributes:\n", node_attr)
print("Batch ID nodes:\n", graph_id_node)
Disjoint index:
tensor([[0, 1, 0, 0, 0, 3, 3, 4, 0],
[0, 1, 0, 0, 0, 3, 4, 3, 0]], device='cuda:0')
Node attributes:
tensor([[0., 0.],
[0., 1.],
[0., 0.],
[1., 0.],
[1., 1.]], device='cuda:0')
Batch ID nodes:
tensor([0, 1, 0, 2, 2], device='cuda:0')
Here nodes and edges with ID 0 are dummy nodes and can be later removed. They do message passing without interfering with the oder subgraphs. However, using a padded batch is much more effective but requires a dataloader, i.e. kgcnn.io
.
Gather¶
Selecting nodes via edge indices is simply realised by using take
and carried out by the keras layer with some options:
[4]:
from kgcnn.layers.gather import GatherNodes
nodes_per_edge = GatherNodes(split_indices=(0, 1), concat_axis=1)([node_attr, disjoint_index])
nodes_in, nodes_out = GatherNodes(split_indices=(0, 1), concat_axis=None)([node_attr, disjoint_index])
print(nodes_per_edge.shape)
print(nodes_in.shape, nodes_out.shape)
torch.Size([9, 4])
torch.Size([9, 2]) torch.Size([9, 2])
Convolution¶
Convolution per node can now be done with for example a standard keras Dense
layer.
[5]:
from keras.layers import Dense
edges_transformed = Dense(units=16, use_bias=True, activation="swish")(nodes_per_edge)
print(edges_transformed.shape)
torch.Size([9, 16])
Aggregation¶
Aggregation of edges per node can be done with scatter or segment operations. For backward compatibility and without any additional transformation AggregateLocalEdges
offers a direct approach. Additionally the node tensor has to be provided for the target shape (batch dimension) but can also be directly used to aggregate edges into.
[6]:
from kgcnn.layers.aggr import AggregateLocalEdges
edges_aggregated = AggregateLocalEdges(pooling_method="scatter_sum", pooling_index=0)([node_attr, edges_transformed, disjoint_index])
print(edges_aggregated.shape)
torch.Size([5, 16])
The basic aggregation layer design is at the moment:
[7]:
from kgcnn.layers.aggr import Aggregate
edges_aggregated = Aggregate(pooling_method="scatter_sum")([edges_transformed, disjoint_index[0], node_attr])
print(edges_aggregated.shape)
torch.Size([5, 16])
Pooling¶
For graph level embedding nodes or edges are pooled per graph. Therefore the graph batch ID tensor is required and which can be done with Aggregate
in the same way. This is used in kgcnn.layers.pooling
. For reference we can use the node_count
tensor.
[8]:
from kgcnn.layers.pooling import PoolingNodes
graph_output = PoolingNodes()([node_count, edges_aggregated, graph_id_node])
print(graph_output)
tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[-0.2240, -0.1971, 0.4667, -0.2024, -0.1236, 0.4036, -0.0160, 0.2721,
-0.0063, -0.1154, 0.6441, 0.4041, -0.2673, 0.4717, -0.2080, 0.1283],
[-0.7468, -0.1389, 0.2592, -0.3825, 0.2881, 0.7621, 0.9968, 0.7264,
0.4894, 0.4421, -0.4755, 0.6927, -0.3123, -0.3772, 0.4574, 1.0335]],
device='cuda:0', grad_fn=<ScatterReduceBackward0>)
since we used a padded disjoint representation the 0 graph was a dummy graph to deal with empty nodes. It must be removed to get the final graph embeddings for the two samples in the batch:
[9]:
out = graph_output[1:]
out
[9]:
tensor([[-0.2240, -0.1971, 0.4667, -0.2024, -0.1236, 0.4036, -0.0160, 0.2721,
-0.0063, -0.1154, 0.6441, 0.4041, -0.2673, 0.4717, -0.2080, 0.1283],
[-0.7468, -0.1389, 0.2592, -0.3825, 0.2881, 0.7621, 0.9968, 0.7264,
0.4894, 0.4421, -0.4755, 0.6927, -0.3123, -0.3772, 0.4574, 1.0335]],
device='cuda:0', grad_fn=<SliceBackward0>)
NOTE: You can find this page as jupyter notebook in https://github.com/aimat-lab/gcnn_keras/tree/master/docs/source