Literature¶
There are a set of popular graph network architectures implemented already in kgcnn
. They can be found in kgcnn.literature
. Most models are set up in the functional keras
API. Information on hyperparameters, training and benchmarking can be found below.
AttentiveFP: Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism by Xiong et al. (2019)
CGCNN: Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties by Xie et al. (2018)
CMPNN: Communicative Representation Learning on Attributed Molecular Graphs by Song et al. (2020)
DGIN: Improved Lipophilicity and Aqueous Solubility Prediction with Composite Graph Neural Networks by Wieder et al. (2021)
DimeNetPP: Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules by Klicpera et al. (2020)
DMPNN: Analyzing Learned Molecular Representations for Property Prediction by Yang et al. (2019)
EGNN: E(n) Equivariant Graph Neural Networks by Satorras et al. (2021)
GAT: Graph Attention Networks by Veličković et al. (2018)
GATv2: How Attentive are Graph Attention Networks? by Brody et al. (2021)
GCN: Semi-Supervised Classification with Graph Convolutional Networks by Kipf et al. (2016)
GIN: How Powerful are Graph Neural Networks? by Xu et al. (2019)
GNNExplainer: GNNExplainer: Generating Explanations for Graph Neural Networks by Ying et al. (2019)
GNNFilm: GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation by Marc Brockschmidt (2020)
GraphSAGE: Inductive Representation Learning on Large Graphs by Hamilton et al. (2017)
HamNet: HamNet: Conformation-Guided Molecular Representation with Hamiltonian Neural Networks by Li et al. (2021)
HDNNP2nd: Atom-centered symmetry functions for constructing high-dimensional neural network potentials by Jörg Behler (2011)
INorp: Interaction Networks for Learning about Objects,Relations and Physics by Battaglia et al. (2016)
MAT: Molecule Attention Transformer by Maziarka et al. (2020)
MEGAN: MEGAN: Multi-explanation Graph Attention Network by Teufel et al. (2023)
Megnet: Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals by Chen et al. (2019)
MoGAT: Multi-order graph attention network for water solubility prediction and interpretation by Lee et al. (2023)
MXMNet: Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures by Zhang et al. (2020)
NMPN: Neural Message Passing for Quantum Chemistry by Gilmer et al. (2017)
PAiNN: Equivariant message passing for the prediction of tensorial properties and molecular spectra by Schütt et al. (2020)
RGCN: Modeling Relational Data with Graph Convolutional Networks by Schlichtkrull et al. (2017)
rGIN Random Features Strengthen Graph Neural Networks by Sato et al. (2020)
Schnet: SchNet – A deep learning architecture for molecules and materials by Schütt et al. (2017)
Training Scripts¶
Currently there are training scripts train_graph.py, train_node.py, train_force.py.
NOTE: They are quite integrated with
kgcnn
models and datasets which is why a custom training script can be favorable for models not inkgcnn.literature
.
Training scripts can be started with:
python3 train_node.py --hyper hyper/hyper_cora.py --category GCN
python3 train_graph.py --hyper hyper/hyper_esol.py --category GIN
Where hyper_esol.py
stores hyperparameter and must be in the same folder or a path to a .py
.
In principle, training can be fully configured with a serialized hyper parameter file as ‘hyper.json’ or ‘hyper.yaml’. If a pyhton file ‘hyper.py’ is used then hyper = {...}
must be set in the python script, in which case the items do not necessarily need to be in serailized form.
[1]:
hyper = {
"info":{
# General information for training run
"kgcnn_version": "4.0.0", # Version
"postfix": "" # Postfix for output folder.
},
"model": {
# Model specific parameter, see kgcnn.literature.
},
"data": {
# Data specific parameters.
},
"dataset": {
# Dataset specific parameters.
},
"training": {
"fit": {
# serialized keras fit arguments.
},
"compile": {
# serialized keras compile arguments.
},
"cross_validation": {
# serialized parameters for cross-validation.
},
"scaler": {
# serialized parameters for scaler.
# Only add when training for regression.
}
}
}
Data hyperparameter¶
The kwargs for the dataset are not fully identical and vary a little depending on the datset. However, the most common are listed below.
[2]:
hyper.update({
"data":{
# Other optinal entries (depends on the training script)
"data_unit": "mol/L",
},
"dataset": {
"class_name": "QM9Dataset", # Name of the dataset
"module_name": "kgcnn.data.datasets.QM9Dataset",
# Config like filepath etc., leave empty for pre-defined datasets
"config": {},
# Methods to run on dataset, i.e. the list of graphs
"methods": [
{"prepare_data": {}}, # Used for cache and pre-compute data, leave out for pre-defined datasets
{"read_in_memory": {}}, # Used for reading into memory, leave out for pre-defined datasets
# Example method to run over each graph in the list using `map_list` method.
# The string 'set_range' refers to a preprocessor. Legacy short access to graph preprocessors.
{"map_list": {"method": "set_range", "max_distance": 4, "max_neighbours": 30}},
{"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_edges",
"count_edges": "edge_indices", "count_nodes": "node_attributes", "total_nodes": "total_nodes"}},
]
}
})
Model hyperparameter¶
The model parameters can be reviewed from the default values in kgcnn.literature
. Mostly model input and output has to be matched depending on the data representation. That is type of input and its shape. An input-type checker can be used from kgcnn.data.base.MemoryGraphDataset
, which has assert_valid_model_input
. In inputs
a list of kwargs must be given, which are each unpacked in the corresponding tf.keras.layers.Input
. The order matters and is model dependent.
Moreover, naming of the model input is used to link the tensor properties of the dataset with the model input. The output dimension of either node or graph embedding can be set for most models with the “output_mlp” argument.
[3]:
hyper.update({
"model":{
"module_name": "kgcnn.literature.GCN",
"class_name": "make_model",
"config":{
"inputs": [
{"shape": [None, 100], "name": "node_attributes", "dtype": "float32"},
{"shape": [None, 2], "name": "edge_indices", "dtype": "int64"},
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_edges", "dtype": "int64"}
],
# More model specific kwargs, like:
"depth": 5,
# Output part defining model output
"output_embedding": "graph",
"output_mlp": {"use_bias": [True, True, False], "units": [140, 70, 70],
"activation": ["relu", "relu", "softmax"]}
}
}
})
Training hyperparameter¶
The kwargs for training simply sets arguments for model.compile(**kwargs)
and model.fit(**kwargs)
that matches keras arguments as well as for the k-fold split from scikit-learn. The kwargs are expected to be fully serialized, if the hyper parameters are supposed to be saved to json.
[4]:
import keras as ks
hyper.update({
"training":{
# Cross-validation of the data
"cross_validation": {
"class_name": "KFold",
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}
},
# Standard scaler for regression targets
"scaler": {
"class_name": "StandardScaler",
"module_name": "kgcnn.data.transform.scaler.standard",
"config": {"with_std": True, "with_mean": True, "copy": True}
},
# Keras model compile and fit
"compile": {
"loss": "categorical_crossentropy",
"optimizer": ks.saving.serialize_keras_object(
ks.optimizers.Adam(learning_rate=0.001))
},
"fit": {
"batch_size": 32, "epochs": 800, "verbose": 2,
"callbacks": []
}
}
})
Info¶
Some general information on the training, such as the used kgcnn version or a postfix for the output files.
[5]:
hyper.update({
"info":{ # Generla information
"postfix": "_v1", # Appends _v1 to output folder
"postfix_file": "_run2", # Appends _run2 to info files
"kgcnn_version": "4.0.0"
}
})
Benchmarks¶
[6]:
from IPython.display import Markdown, display
Markdown(open('../../training/results/README.md', encoding='utf-8').read())
[6]:
Summary of Benchmark Training¶
Note that these are the results for models within kgcnn
implementation, and that training is not always done with optimal hyperparameter or splits, when comparing with literature. This table is generated automatically from keras history logs. Model weights and training statistics plots are not uploaded on github due to their file size.
Max. or Min. denotes the best test error observed for any epoch during training. To show overall best test error run python3 summary.py --min_max True
. If not noted otherwise, we use a (fixed) random k-fold split for validation errors.
ClinToxDataset¶
ClinTox (MoleculeNet) consists of 1478 compounds as smiles and data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons. We use random 5-fold cross-validation. The first label ‘approved’ is chosen as target.
model |
kgcnn |
epochs |
Accuracy |
AUC(ROC) |
---|---|---|---|---|
DMPNN |
4.0.0 |
50 |
0.9480 ± 0.0138 |
0.8297 ± 0.0568 |
GAT |
4.0.0 |
50 |
0.9480 ± 0.0070 |
0.8512 ± 0.0468 |
GATv2 |
4.0.0 |
50 |
0.9372 ± 0.0155 |
0.8587 ± 0.0754 |
GCN |
4.0.0 |
50 |
0.9432 ± 0.0155 |
0.8555 ± 0.0593 |
GIN |
4.0.0 |
50 |
0.9412 ± 0.0034 |
0.8066 ± 0.0636 |
GraphSAGE |
4.0.0 |
100 |
0.9412 ± 0.0073 |
0.8013 ± 0.0422 |
Schnet |
4.0.0 |
50 |
0.9277 ± 0.0102 |
0.6562 ± 0.0760 |
CoraDataset¶
Cora Dataset of 19793 publications and 8710 sparse node attributes and 70 node classes. Here we use random 5-fold cross-validation on nodes.
model |
kgcnn |
epochs |
Categorical accuracy |
---|---|---|---|
DMPNN |
4.0.0 |
300 |
0.2476 ± 0.1706 |
GAT |
4.0.0 |
250 |
0.6157 ± 0.0071 |
GATv2 |
4.0.0 |
1000 |
0.6211 ± 0.0048 |
GCN |
4.0.0 |
300 |
0.6232 ± 0.0054 |
GIN |
4.0.0 |
800 |
0.6263 ± 0.0080 |
GraphSAGE |
4.0.0 |
600 |
0.6151 ± 0.0053 |
CoraLuDataset¶
Cora Dataset after Lu et al. (2003) of 2708 publications and 1433 sparse attributes and 7 node classes. Here we use random 5-fold cross-validation on nodes.
model |
kgcnn |
epochs |
Categorical accuracy |
---|---|---|---|
DMPNN |
4.0.0 |
300 |
0.8357 ± 0.0156 |
GAT |
4.0.0 |
250 |
0.8397 ± 0.0122 |
GATv2 |
4.0.0 |
250 |
0.8331 ± 0.0104 |
GCN |
4.0.0 |
300 |
0.8072 ± 0.0109 |
GIN |
4.0.0 |
500 |
0.8279 ± 0.0170 |
GraphSAGE |
4.0.0 |
500 |
0.8497 ± 0.0100 |
ESOLDataset¶
ESOL consists of 1128 compounds as smiles and their corresponding water solubility in log10(mol/L). We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [log mol/L] |
RMSE [log mol/L] |
---|---|---|---|---|
AttentiveFP |
4.0.0 |
200 |
0.4351 ± 0.0110 |
0.6080 ± 0.0207 |
DGIN |
4.0.0 |
300 |
0.4434 ± 0.0252 |
0.6225 ± 0.0420 |
DMPNN |
4.0.0 |
300 |
0.4401 ± 0.0165 |
0.6203 ± 0.0292 |
EGNN |
4.0.0 |
800 |
0.4507 ± 0.0152 |
0.6563 ± 0.0370 |
GAT |
4.0.0 |
500 |
0.4818 ± 0.0240 |
0.6919 ± 0.0694 |
GATv2 |
4.0.0 |
500 |
0.4598 ± 0.0234 |
0.6650 ± 0.0409 |
GCN |
4.0.0 |
800 |
0.4613 ± 0.0205 |
0.6534 ± 0.0513 |
GIN |
4.0.0 |
300 |
0.5369 ± 0.0334 |
0.7954 ± 0.0861 |
GNNFilm |
4.0.0 |
800 |
0.4854 ± 0.0368 |
0.6724 ± 0.0436 |
GraphSAGE |
4.0.0 |
500 |
0.4874 ± 0.0228 |
0.6982 ± 0.0608 |
HDNNP2nd |
4.0.0 |
500 |
0.7857 ± 0.0986 |
1.0467 ± 0.1367 |
INorp |
4.0.0 |
500 |
0.5055 ± 0.0436 |
0.7297 ± 0.0786 |
MAT |
4.0.0 |
400 |
0.5064 ± 0.0299 |
0.7194 ± 0.0630 |
MEGAN |
4.0.0 |
400 |
0.4281 ± 0.0201 |
0.6062 ± 0.0252 |
Megnet |
4.0.0 |
800 |
0.5679 ± 0.0310 |
0.8196 ± 0.0480 |
MXMNet |
4.0.0 |
900 |
0.6486 ± 0.0633 |
1.0123 ± 0.2059 |
NMPN |
4.0.0 |
800 |
0.5046 ± 0.0266 |
0.7193 ± 0.0607 |
PAiNN |
4.0.0 |
250 |
0.4857 ± 0.0598 |
0.6650 ± 0.0674 |
RGCN |
4.0.0 |
800 |
0.4703 ± 0.0251 |
0.6529 ± 0.0318 |
rGIN |
4.0.0 |
300 |
0.5196 ± 0.0351 |
0.7142 ± 0.0263 |
Schnet |
4.0.0 |
800 |
0.4777 ± 0.0294 |
0.6977 ± 0.0538 |
FreeSolvDataset¶
FreeSolv (MoleculeNet) consists of 642 compounds as smiles and their corresponding hydration free energy for small neutral molecules in water. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [log mol/L] |
RMSE [log mol/L] |
---|---|---|---|---|
CMPNN |
4.0.0 |
600 |
0.5202 ± 0.0504 |
0.9339 ± 0.1286 |
DGIN |
4.0.0 |
300 |
0.5489 ± 0.0374 |
0.9448 ± 0.0787 |
DimeNetPP |
4.0.0 |
872 |
0.6167 ± 0.0719 |
1.0302 ± 0.1717 |
DMPNN |
4.0.0 |
300 |
0.5487 ± 0.0754 |
0.9206 ± 0.1889 |
EGNN |
4.0.0 |
800 |
0.5386 ± 0.0548 |
1.0363 ± 0.1237 |
GAT |
4.0.0 |
500 |
0.6051 ± 0.0861 |
1.0326 ± 0.1819 |
GATv2 |
4.0.0 |
500 |
0.6151 ± 0.0247 |
1.0535 ± 0.0817 |
GCN |
4.0.0 |
800 |
0.6400 ± 0.0834 |
1.0876 ± 0.1393 |
GIN |
4.0.0 |
300 |
0.8100 ± 0.1016 |
1.2695 ± 0.1192 |
GNNFilm |
4.0.0 |
800 |
0.6562 ± 0.0552 |
1.1597 ± 0.1245 |
GraphSAGE |
4.0.0 |
500 |
0.5894 ± 0.0675 |
1.0009 ± 0.1491 |
HamNet |
4.0.0 |
400 |
0.6619 ± 0.0428 |
1.1410 ± 0.1120 |
HDNNP2nd |
4.0.0 |
500 |
1.0201 ± 0.1559 |
1.6351 ± 0.3419 |
INorp |
4.0.0 |
500 |
0.6612 ± 0.0188 |
1.1155 ± 0.1061 |
MAT |
4.0.0 |
400 |
0.8115 ± 0.0649 |
1.3099 ± 0.1235 |
MEGAN |
4.0.0 |
400 |
0.6303 ± 0.0550 |
1.0429 ± 0.1031 |
Megnet |
4.0.0 |
800 |
0.8878 ± 0.0528 |
1.4134 ± 0.1200 |
MoGAT |
4.0.0 |
200 |
0.7097 ± 0.0374 |
1.0911 ± 0.1334 |
MXMNet |
4.0.0 |
900 |
1.1386 ± 0.1979 |
3.0487 ± 2.1757 |
RGCN |
4.0.0 |
800 |
0.5128 ± 0.0810 |
0.9228 ± 0.1887 |
rGIN |
4.0.0 |
300 |
0.8503 ± 0.0613 |
1.3285 ± 0.0976 |
Schnet |
4.0.0 |
800 |
0.6070 ± 0.0285 |
1.0603 ± 0.0549 |
ISO17Dataset¶
The database consist of 129 molecules each containing 5,000 conformational geometries, energies and forces with a resolution of 1 femtosecond in the molecular dynamics trajectories. The molecules were randomly drawn from the largest set of isomers in the QM9 dataset.
model |
kgcnn |
epochs |
Energy (test_within) |
Force (test_within) |
---|---|---|---|---|
Schnet.EnergyForceModel |
4.0.0 |
1000 |
0.0061 ± nan |
0.0134 ± nan |
LipopDataset¶
Lipophilicity (MoleculeNet) consists of 4200 compounds as smiles. Graph labels for regression are octanol/water distribution coefficient (logD at pH 7.4). We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [log mol/L] |
RMSE [log mol/L] |
---|---|---|---|---|
DMPNN |
4.0.0 |
300 |
0.3814 ± 0.0064 |
0.5462 ± 0.0095 |
GAT |
4.0.0 |
500 |
0.5168 ± 0.0088 |
0.7220 ± 0.0098 |
GATv2 |
4.0.0 |
500 |
0.4342 ± 0.0104 |
0.6056 ± 0.0114 |
GCN |
4.0.0 |
800 |
0.4960 ± 0.0107 |
0.6833 ± 0.0155 |
GIN |
4.0.0 |
300 |
0.4745 ± 0.0101 |
0.6658 ± 0.0159 |
GraphSAGE |
4.0.0 |
500 |
0.4333 ± 0.0217 |
0.6218 ± 0.0318 |
Schnet |
4.0.0 |
800 |
0.5657 ± 0.0202 |
0.7485 ± 0.0245 |
MD17Dataset¶
Energies and forces for molecular dynamics trajectories of eight organic molecules. All geometries in A, energy labels in kcal/mol and force labels in kcal/mol/A. We use preset train-test split. Training on 1000 geometries, test on 500/1000 geometries. Errors are MAE for forces. Results are for the CCSD and CCSD(T) data in MD17.
model |
kgcnn |
epochs |
Aspirin |
Toluene |
Malonaldehyde |
Benzene |
Ethanol |
---|---|---|---|---|---|---|---|
PAiNN.EnergyForceModel |
4.0.0 |
1000 |
nan ± nan |
nan ± nan |
nan ± nan |
nan ± nan |
0.5805 ± nan |
Schnet.EnergyForceModel |
4.0.0 |
1000 |
1.2173 ± nan |
0.7395 ± nan |
0.8444 ± nan |
0.3353 ± nan |
0.4832 ± nan |
MD17RevisedDataset¶
Energies and forces for molecular dynamics trajectories. All geometries in A, energy labels in kcal/mol and force labels in kcal/mol/A. We use preset train-test split. Training on 1000 geometries, test on 500/1000 geometries. Errors are MAE for forces.
model |
kgcnn |
epochs |
Aspirin |
Toluene |
Malonaldehyde |
Benzene |
Ethanol |
---|---|---|---|---|---|---|---|
Schnet.EnergyForceModel |
4.0.0 |
1000 |
1.0389 ± 0.0071 |
0.5482 ± 0.0105 |
0.6727 ± 0.0132 |
0.2525 ± 0.0091 |
0.4471 ± 0.0199 |
MatProjectDielectricDataset¶
Materials Project dataset from Matbench with 4764 crystal structures and their corresponding Refractive index (unitless). We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [no unit] |
RMSE [no unit] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
0.3180 ± 0.0359 |
1.8509 ± 0.5854 |
MatProjectEFormDataset¶
Materials Project dataset from Matbench with 132752 crystal structures and their corresponding formation energy in [eV/atom]. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [eV/atom] |
RMSE [eV/atom] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
0.0211 ± 0.0003 |
0.0510 ± 0.0024 |
MatProjectGapDataset¶
Materials Project dataset from Matbench with 106113 crystal structures and their band gap as calculated by PBE DFT from the Materials Project, in eV. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [eV] |
RMSE [eV] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
1.2226 ± 1.0573 |
58.3713 ± 114.2957 |
MatProjectIsMetalDataset¶
Materials Project dataset from Matbench with 106113 crystal structures and their corresponding Metallicity determined with pymatgen. 1 if the compound is a metal, 0 if the compound is not a metal. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
Accuracy |
AUC |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
80 |
0.8953 ± 0.0058 |
0.9506 ± 0.0053 |
MatProjectJdft2dDataset¶
Materials Project dataset from Matbench with 636 crystal structures and their corresponding Exfoliation energy (meV/atom). We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [meV/atom] |
RMSE [meV/atom] |
---|---|---|---|---|
CGCNN.make_crystal_model |
4.0.0 |
1000 |
57.6974 ± 18.0803 |
140.6167 ± 44.8418 |
DimeNetPP.make_crystal_model |
4.0.0 |
780 |
50.2880 ± 11.4199 |
126.0600 ± 38.3769 |
PAiNN.make_crystal_model |
4.0.0 |
800 |
49.3889 ± 11.5376 |
121.7087 ± 30.0472 |
Schnet.make_crystal_model |
4.0.0 |
800 |
45.2412 ± 11.6395 |
115.6890 ± 39.0929 |
MatProjectLogGVRHDataset¶
Materials Project dataset from Matbench with 10987 crystal structures and their corresponding Base 10 logarithm of the DFT Voigt-Reuss-Hill average shear moduli in GPa. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [log(GPa)] |
RMSE [log(GPa)] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
0.0836 ± 0.0021 |
0.1296 ± 0.0044 |
MatProjectLogKVRHDataset¶
Materials Project dataset from Matbench with 10987 crystal structures and their corresponding Base 10 logarithm of the DFT Voigt-Reuss-Hill average bulk moduli in GPa. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [log(GPa)] |
RMSE [log(GPa)] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
0.0635 ± 0.0016 |
0.1186 ± 0.0044 |
MatProjectPerovskitesDataset¶
Materials Project dataset from Matbench with 18928 crystal structures and their corresponding Heat of formation of the entire 5-atom perovskite cell in eV. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [eV] |
RMSE [eV] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
0.0381 ± 0.0005 |
0.0645 ± 0.0024 |
MatProjectPhononsDataset¶
Materials Project dataset from Matbench with 1,265 crystal structures and their corresponding vibration properties in [1/cm]. We use a random 5-fold cross-validation.
model |
kgcnn |
epochs |
MAE [eV/atom] |
RMSE [eV/atom] |
---|---|---|---|---|
Schnet.make_crystal_model |
4.0.0 |
800 |
43.0692 ± 3.6227 |
88.5151 ± 20.0244 |
MUTAGDataset¶
MUTAG dataset from TUDataset for classification with 188 graphs. We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
Accuracy |
AUC(ROC) |
---|---|---|---|---|
DMPNN |
4.0.0 |
300 |
0.8407 ± 0.0463 |
0.8567 ± 0.0511 |
GAT |
4.0.0 |
500 |
0.8141 ± 0.1077 |
0.8671 ± 0.0923 |
GATv2 |
4.0.0 |
500 |
0.8193 ± 0.0945 |
0.8379 ± 0.1074 |
GCN |
4.0.0 |
800 |
0.7716 ± 0.0531 |
0.7956 ± 0.0909 |
GIN |
4.0.0 |
300 |
0.8091 ± 0.0781 |
0.8693 ± 0.0855 |
GraphSAGE |
4.0.0 |
500 |
0.8357 ± 0.0798 |
0.8533 ± 0.0824 |
MutagenicityDataset¶
Mutagenicity dataset from TUDataset for classification with 4337 graphs. The dataset was cleaned for unconnected atoms. We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
Accuracy |
AUC(ROC) |
---|---|---|---|---|
DMPNN |
4.0.0 |
300 |
0.8266 ± 0.0059 |
0.8708 ± 0.0076 |
GAT |
4.0.0 |
500 |
0.7989 ± 0.0114 |
0.8290 ± 0.0112 |
GATv2 |
4.0.0 |
200 |
0.7674 ± 0.0048 |
0.8423 ± 0.0064 |
GCN |
4.0.0 |
800 |
0.7955 ± 0.0154 |
0.8191 ± 0.0137 |
GIN |
4.0.0 |
300 |
0.8118 ± 0.0091 |
0.8492 ± 0.0077 |
GraphSAGE |
4.0.0 |
500 |
0.8195 ± 0.0126 |
0.8515 ± 0.0083 |
PROTEINSDataset¶
TUDataset of proteins that are classified as enzymes or non-enzymes. Nodes represent the amino acids of the protein. We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
Accuracy |
AUC(ROC) |
---|---|---|---|---|
DMPNN |
4.0.0 |
300 |
0.7287 ± 0.0253 |
0.7970 ± 0.0343 |
GAT |
4.0.0 |
500 |
0.7314 ± 0.0357 |
0.7899 ± 0.0468 |
GATv2 |
4.0.0 |
500 |
0.6720 ± 0.0595 |
0.6850 ± 0.0938 |
GCN |
4.0.0 |
800 |
0.7017 ± 0.0303 |
0.7211 ± 0.0254 |
GIN |
4.0.0 |
150 |
0.7224 ± 0.0343 |
0.7905 ± 0.0528 |
GraphSAGE |
4.0.0 |
500 |
0.7009 ± 0.0398 |
0.7263 ± 0.0453 |
QM7Dataset¶
QM7 dataset is a subset of GDB-13. Molecules of up to 23 atoms (including 7 heavy atoms C, N, O, and S), totalling 7165 molecules. We use dataset-specific 5-fold cross-validation. The atomization energies are given in kcal/mol and are ranging from -800 to -2000 kcal/mol).
model |
kgcnn |
epochs |
MAE [kcal/mol] |
RMSE [kcal/mol] |
---|---|---|---|---|
Schnet |
4.0.0 |
800 |
3.4313 ± 0.4757 |
10.8978 ± 7.3863 |
QM9Dataset¶
QM9 dataset of 134k stable small organic molecules made up of C,H,O,N,F. Labels include geometric, energetic, electronic, and thermodynamic properties. We use a random 5-fold cross-validation, but not all splits are evaluated for cheaper evaluation. Test errors are MAE and for energies are given in [eV].
model |
kgcnn |
epochs |
HOMO [eV] |
LUMO [eV] |
U0 [eV] |
H [eV] |
G [eV] |
---|---|---|---|---|---|---|---|
Schnet |
4.0.0 |
800 |
0.0402 ± 0.0007 |
0.0340 ± 0.0001 |
0.0142 ± 0.0002 |
0.0146 ± 0.0002 |
0.0143 ± 0.0002 |
SIDERDataset¶
SIDER (MoleculeNet) consists of 1427 compounds as smiles and data for adverse drug reactions (ADR), grouped into 27 system organ classes. We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
Accuracy |
AUC(ROC) |
---|---|---|---|---|
DMPNN |
4.0.0 |
50 |
0.7519 ± 0.0055 |
0.6280 ± 0.0173 |
GAT |
4.0.0 |
50 |
0.7595 ± 0.0034 |
0.6224 ± 0.0106 |
GATv2 |
4.0.0 |
50 |
0.7548 ± 0.0052 |
0.6152 ± 0.0154 |
GIN |
4.0.0 |
50 |
0.7472 ± 0.0055 |
0.5995 ± 0.0058 |
GraphSAGE |
4.0.0 |
30 |
0.7547 ± 0.0043 |
0.6038 ± 0.0108 |
Schnet |
4.0.0 |
50 |
0.7583 ± 0.0076 |
0.6119 ± 0.0159 |
Tox21MolNetDataset¶
Tox21 (MoleculeNet) consists of 7831 compounds as smiles and 12 different targets relevant to drug toxicity. We use random 5-fold cross-validation.
model |
kgcnn |
epochs |
Accuracy |
AUC(ROC) |
BACC |
---|---|---|---|---|---|
DMPNN |
4.0.0 |
50 |
0.9272 ± 0.0024 |
0.8321 ± 0.0103 |
0.6995 ± 0.0130 |
GAT |
4.0.0 |
50 |
0.9243 ± 0.0022 |
0.8279 ± 0.0092 |
0.6504 ± 0.0074 |
GATv2 |
4.0.0 |
50 |
0.9222 ± 0.0019 |
0.8251 ± 0.0069 |
0.6760 ± 0.0140 |
GIN |
4.0.0 |
50 |
0.9220 ± 0.0024 |
0.7986 ± 0.0180 |
0.6741 ± 0.0143 |
GraphSAGE |
4.0.0 |
100 |
0.9180 ± 0.0027 |
0.7976 ± 0.0087 |
0.6755 ± 0.0047 |
Schnet |
4.0.0 |
50 |
0.9128 ± 0.0030 |
0.7719 ± 0.0139 |
0.6639 ± 0.0162 |
NOTE: You can find this page as jupyter notebook in https://github.com/aimat-lab/gcnn_keras/tree/master/docs/source