Source code for kgcnn.literature.MAT._make

import keras as ks
from keras.backend import backend as backend_to_use
from kgcnn.layers.modules import Embedding
from kgcnn.layers.mlp import MLP
from kgcnn.models.utils import update_model_kwargs
from ._layers import MATAttentionHead, MATDistanceMatrix, MATReduceMask, MATGlobalPool, MATExpandMask

# Keep track of model version from commit date in literature.
# To be updated if model is changed in a significant way.
__model_version__ = "2023-12-08"

# Supported backends
__kgcnn_model_backend_supported__ = ["tensorflow", "torch", "jax"]
if backend_to_use() not in __kgcnn_model_backend_supported__:
    raise NotImplementedError("Backend '%s' for model 'MAT' is not supported." % backend_to_use())

# Implementation of MAT in `tf.keras` from paper:
# Molecule Attention Transformer
# Łukasz Maziarka, Tomasz Danel, Sławomir Mucha, Krzysztof Rataj, Jacek Tabor, Stanisław Jastrzębski
# https://arxiv.org/abs/2002.08264
# https://github.com/ardigen/MAT
# https://github.com/lucidrains/molecule-attention-transformer


model_default = {
    "name": "MAT",
    "inputs": [
        {"shape": (None,), "name": "node_number", "dtype": "int64"},
        {"shape": (None, 3), "name": "node_coordinates", "dtype": "float32"},
        {"shape": (None, None, 1), "name": "adjacency_matrix", "dtype": "float32"},
        {"shape": (None,), "name": "node_mask", "dtype": "bool"},
        {"shape": (None, None), "name": "adjacency_mask", "dtype": "bool"},
    ],
    "input_tensor_type": "padded",
    "input_embedding": None,
    "input_node_embedding": {"input_dim": 95, "output_dim": 64},
    "input_edge_embedding": {"input_dim": 95, "output_dim": 64},
    "max_atoms": None,
    "distance_matrix_kwargs": {"trafo": "exp"},
    "attention_kwargs": {"units": 8, "lambda_attention": 0.3, "lambda_distance": 0.3, "lambda_adjacency": None,
                         "dropout": 0.1, "add_identity": False},
    "feed_forward_kwargs": {"units": [32, 32, 32], "activation": ["relu", "relu", "linear"]},
    "embedding_units": 32,
    "depth": 5,
    "heads": 8,
    "merge_heads": "concat",
    "verbose": 10,
    "pooling_kwargs": {"pooling_method": "sum"},
    "output_embedding": "graph",
    "output_to_tensor": None,
    "output_mlp": {"use_bias": [True, True, True], "units": [32, 16, 1],
                   "activation": ["relu", "relu", "linear"]},
    "output_tensor_type": "padded"
}


[docs]@update_model_kwargs(model_default, update_recursive=0, deprecated=["input_embedding", "output_to_tensor"]) def make_model(name: str = None, inputs: list = None, input_embedding: dict = None, # noqa input_node_embedding: dict = None, input_tensor_type: str = None, input_edge_embedding: dict = None, distance_matrix_kwargs: dict = None, attention_kwargs: dict = None, max_atoms: int = None, feed_forward_kwargs:dict = None, embedding_units: int = None, depth: int = None, heads: int = None, merge_heads: str = None, verbose: int = None, # noqa pooling_kwargs: dict = None, output_embedding: str = None, output_to_tensor: bool = None, # noqa output_mlp: dict = None, output_tensor_type: str = None ): r"""Make `MAT <https://arxiv.org/pdf/2002.08264.pdf>`__ graph network via functional API. Default parameters can be found in :obj:`kgcnn.literature.MAT.model_default` . .. note:: We added a linear layer to keep correct node embedding dimension. Inputs: list: `[node_attributes, node_coordinates, adjacency_matrix, node_mask, adjacency_mask]` - node_attributes (Tensor): Node attributes of shape `(batch, N, F)` or `(batch, N)` using an embedding layer. - node_coordinates (Tensor): Node (atomic) coordinates of shape `(batch, N, 3)`. - adjacency_matrix (Tensor): Edge attributes of shape `(batch, N, N, F)` or `(batch, N, N)` using an embedding layer. - node_mask (Tensor): Node mask of shape `(batch, N)` - adjacency_mask (Tensor): Adjacency mask of shape `(batch, N, N)` Outputs: Tensor: Graph embeddings of shape `(batch, L)` if :obj:`output_embedding="graph"`. Args: name (str): Name of the model. Should be "MAT". inputs (list): List of dictionaries unpacked in :obj:`keras.layers.Input`. Order must match model definition. input_tensor_type (str): Input tensor type. Only "padded" is valid for this implementation. input_node_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. input_edge_embedding (dict): Dictionary of embedding arguments unpacked in :obj:`Embedding` layers. depth (int): Number of graph embedding units or depth of the network. verbose (int): Level for print information. distance_matrix_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MATDistanceMatrix`. attention_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MATDistanceMatrix`. feed_forward_kwargs (dict): Dictionary of layer arguments unpacked in feed forward :obj:`MLP`. embedding_units (int): Units for node embedding. heads (int): Number of attention heads merge_heads (str): How to merge head, using either 'sum' or 'concat'. pooling_kwargs (dict): Dictionary of layer arguments unpacked in :obj:`MATGlobalPool`. output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph". output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`. output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block. Defines number of model outputs and activation. output_tensor_type (str): Output tensor type. Only "padded" is valid for this implementation. Returns: :obj:`keras.models.Model` """ assert input_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation." assert output_tensor_type in ["padded", "mask", "masked"], "Only padded tensors are valid for this implementation." # Make input node_input = ks.layers.Input(**inputs[0]) xyz_input = ks.layers.Input(**inputs[1]) adjacency_matrix = ks.layers.Input(**inputs[2]) node_mask = ks.layers.Input(**inputs[3]) adjacency_mask = ks.layers.Input(**inputs[4]) use_edge_embedding = input_edge_embedding is not None use_node_embedding = input_node_embedding is not None # Embedding, if no feature dimension if use_node_embedding: n = Embedding(**input_node_embedding)(node_input) else: n = node_input if use_edge_embedding: adj = Embedding(**input_edge_embedding)(adjacency_matrix) else: adj = adjacency_matrix n_mask = node_mask adj_mask = adjacency_mask xyz = xyz_input n_mask = MATExpandMask(axis=-1)(n_mask) adj_mask = MATExpandMask(axis=-1)(adj_mask) # Cast to dense Tensor with padding for MAT. # Nodes must have feature dimension. dist, dist_mask = MATDistanceMatrix(**distance_matrix_kwargs)(xyz, mask=n_mask) # Check shapes # print(n.shape, dist.shape, adj.shape) # print(n_mask.shape, dist_mask.shape, adj_mask.shape) # Adjacency is derived from edge input. If edge input has no last dimension and no embedding is used, then adjacency # matrix will have shape (batch, max_atoms, max_atoms) and edge input should be ones or weights or bond degree. # Otherwise, adjacency bears feature expanded from edge attributes of shape (batch, max_atoms, max_atoms, features). has_edge_dim = len(inputs[2]["shape"]) >= 3 or len(inputs[2]["shape"]) < 3 and use_edge_embedding if has_edge_dim: # Assume that feature-wise attention is not desired for adjacency, reduce to single value. adj = ks.layers.Dense(1, use_bias=False)(adj) else: # Make sure that shape is (batch, max_atoms, max_atoms, 1). adj = MATExpandMask(axis=-1)(adj) # Repeat for depth. h_mask = n_mask h = ks.layers.Dense(units=embedding_units, use_bias=False)(n) # Assert correct feature dimension for skip. for _ in range(depth): # 1. Norm + Attention + Residual hn = ks.layers.LayerNormalization()(h) hs = [ MATAttentionHead(**attention_kwargs)( [hn, dist, adj], mask=[n_mask, dist_mask, adj_mask] ) for _ in range(heads) ] if merge_heads in ["add", "sum", "reduce_sum"]: hu = ks.layers.Add()(hs) hu = ks.layers.Dense(units=embedding_units, use_bias=False)(hu) else: hu = ks.layers.Concatenate(axis=-1)(hs) hu = ks.layers.Dense(units=embedding_units, use_bias=False)(hu) h = ks.layers.Add()([h, hu]) # 2. Norm + MLP + Residual hn = ks.layers.LayerNormalization()(h) hu = MLP(**feed_forward_kwargs)(hn) hu = ks.layers.Dense(units=embedding_units, use_bias=False)(hu) hu = ks.layers.Multiply()([hu, h_mask]) h = ks.layers.Add()([h, hu]) # pooling output out = h out_mask = h_mask out = ks.layers.LayerNormalization()(out) if output_embedding == 'graph': out = ks.layers.Multiply()([out, out_mask]) out = MATGlobalPool(**pooling_kwargs)(out, mask=out_mask) # final prediction MLP for the output! out = MLP(**output_mlp)(out) elif output_embedding == 'node': out = MLP(**output_mlp)(out) out = ks.layers.Multiply()([out, out_mask]) else: raise ValueError("Unsupported graph embedding for mode `MAT` .") model = ks.models.Model( inputs=[node_input, xyz_input, adjacency_matrix, node_mask, adjacency_mask], outputs=out, name=name ) model.__kgcnn_model_version__ = __model_version__ return model