Literature

There are a set of popular graph network architectures implemented already in kgcnn . They can be found in kgcnn.literature . Most models are set up in the functional keras API. Information on hyperparameters, training and benchmarking can be found below.

Training Scripts

Currently there are training scripts train_graph.py, train_node.py, train_force.py.

NOTE: They are quite integrated with kgcnn models and datasets which is why a custom training script can be favorable for models not in kgcnn.literature.

Training scripts can be started with:

python3 train_node.py --hyper hyper/hyper_cora.py --category GCN
python3 train_graph.py --hyper hyper/hyper_esol.py --category GIN

Where hyper_esol.py stores hyperparameter and must be in the same folder or a path to a .py.

In principle, training can be fully configured with a serialized hyper parameter file as ‘hyper.json’ or ‘hyper.yaml’. If a pyhton file ‘hyper.py’ is used then hyper = {...} must be set in the python script, in which case the items do not necessarily need to be in serailized form.

[1]:
hyper = {
    "info":{
        # General information for training run
        "kgcnn_version": "4.0.0", # Version
        "postfix": "" # Postfix for output folder.
    },
    "model": {
        # Model specific parameter, see kgcnn.literature.
    },
    "data": {
        # Data specific parameters.
    },
    "dataset": {
        # Dataset specific parameters.
    },
    "training": {
        "fit": {
            # serialized keras fit arguments.
        },
        "compile": {
            # serialized keras compile arguments.
        },
        "cross_validation": {
            # serialized parameters for cross-validation.
        },
        "scaler": {
            # serialized parameters for scaler.
            # Only add when training for regression.
        }
    }
}

Data hyperparameter

The kwargs for the dataset are not fully identical and vary a little depending on the datset. However, the most common are listed below.

[2]:
hyper.update({
    "data":{
        # Other optinal entries (depends on the training script)
        "data_unit": "mol/L",
    },
    "dataset": {
        "class_name": "QM9Dataset", # Name of the dataset
        "module_name": "kgcnn.data.datasets.QM9Dataset",

        # Config like filepath etc., leave empty for pre-defined datasets
        "config": {},

        # Methods to run on dataset, i.e. the list of graphs
        "methods": [
            {"prepare_data": {}}, # Used for cache and pre-compute data, leave out for pre-defined datasets
            {"read_in_memory": {}}, # Used for reading into memory, leave out for pre-defined datasets

            # Example method to run over each graph in the list using `map_list` method.
            # The string 'set_range' refers to a preprocessor. Legacy short access to graph preprocessors.
            {"map_list": {"method": "set_range", "max_distance": 4, "max_neighbours": 30}},
            {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_edges",
                          "count_edges": "edge_indices", "count_nodes": "node_attributes", "total_nodes": "total_nodes"}},
        ]
    }
})

Model hyperparameter

The model parameters can be reviewed from the default values in kgcnn.literature. Mostly model input and output has to be matched depending on the data representation. That is type of input and its shape. An input-type checker can be used from kgcnn.data.base.MemoryGraphDataset, which has assert_valid_model_input. In inputs a list of kwargs must be given, which are each unpacked in the corresponding tf.keras.layers.Input. The order matters and is model dependent.

Moreover, naming of the model input is used to link the tensor properties of the dataset with the model input. The output dimension of either node or graph embedding can be set for most models with the “output_mlp” argument.

[3]:
hyper.update({
    "model":{
        "module_name": "kgcnn.literature.GCN",
        "class_name": "make_model",
        "config":{
            "inputs": [
                {"shape": [None, 100], "name": "node_attributes", "dtype": "float32"},
                {"shape": [None, 2], "name": "edge_indices", "dtype": "int64"},
                {"shape": (), "name": "total_nodes", "dtype": "int64"},
                {"shape": (), "name": "total_edges", "dtype": "int64"}
            ],
            # More model specific kwargs, like:
            "depth": 5,
            # Output part defining model output
            "output_embedding": "graph",
            "output_mlp": {"use_bias": [True, True, False], "units": [140, 70, 70],
                           "activation": ["relu", "relu", "softmax"]}
        }
    }
})

Training hyperparameter

The kwargs for training simply sets arguments for model.compile(**kwargs) and model.fit(**kwargs) that matches keras arguments as well as for the k-fold split from scikit-learn. The kwargs are expected to be fully serialized, if the hyper parameters are supposed to be saved to json.

[4]:
import keras as ks
hyper.update({
    "training":{
        # Cross-validation of the data
        "cross_validation": {
            "class_name": "KFold",
            "config": {"n_splits": 5, "random_state": 42, "shuffle": True}
        },
        # Standard scaler for regression targets
        "scaler": {
            "class_name": "StandardScaler",
            "module_name": "kgcnn.data.transform.scaler.standard",
            "config": {"with_std": True, "with_mean": True, "copy": True}
        },
        # Keras model compile and fit
        "compile": {
            "loss": "categorical_crossentropy",
            "optimizer": ks.saving.serialize_keras_object(
                ks.optimizers.Adam(learning_rate=0.001))
        },
        "fit": {
            "batch_size": 32, "epochs": 800, "verbose": 2,
            "callbacks": []
        }
    }
})

Info

Some general information on the training, such as the used kgcnn version or a postfix for the output files.

[5]:
hyper.update({
    "info":{ # Generla information
        "postfix": "_v1", # Appends _v1 to output folder
        "postfix_file": "_run2", # Appends _run2 to info files
        "kgcnn_version": "4.0.0"
    }
})

Benchmarks

[6]:
from IPython.display import Markdown, display
Markdown(open('../../training/results/README.md', encoding='utf-8').read())
[6]:

Summary of Benchmark Training

Note that these are the results for models within kgcnn implementation, and that training is not always done with optimal hyperparameter or splits, when comparing with literature. This table is generated automatically from keras history logs. Model weights and training statistics plots are not uploaded on github due to their file size.

Max. or Min. denotes the best test error observed for any epoch during training. To show overall best test error run python3 summary.py --min_max True. If not noted otherwise, we use a (fixed) random k-fold split for validation errors.

ClinToxDataset

ClinTox (MoleculeNet) consists of 1478 compounds as smiles and data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons. We use random 5-fold cross-validation. The first label ‘approved’ is chosen as target.

model

kgcnn

epochs

Accuracy

AUC(ROC)

DMPNN

4.0.0

50

0.9480 ± 0.0138

0.8297 ± 0.0568

GAT

4.0.0

50

0.9480 ± 0.0070

0.8512 ± 0.0468

GATv2

4.0.0

50

0.9372 ± 0.0155

0.8587 ± 0.0754

GCN

4.0.0

50

0.9432 ± 0.0155

0.8555 ± 0.0593

GIN

4.0.0

50

0.9412 ± 0.0034

0.8066 ± 0.0636

GraphSAGE

4.0.0

100

0.9412 ± 0.0073

0.8013 ± 0.0422

Schnet

4.0.0

50

0.9277 ± 0.0102

0.6562 ± 0.0760

CoraDataset

Cora Dataset of 19793 publications and 8710 sparse node attributes and 70 node classes. Here we use random 5-fold cross-validation on nodes.

model

kgcnn

epochs

Categorical accuracy

DMPNN

4.0.0

300

0.2476 ± 0.1706

GAT

4.0.0

250

0.6157 ± 0.0071

GATv2

4.0.0

1000

0.6211 ± 0.0048

GCN

4.0.0

300

0.6232 ± 0.0054

GIN

4.0.0

800

0.6263 ± 0.0080

GraphSAGE

4.0.0

600

0.6151 ± 0.0053

CoraLuDataset

Cora Dataset after Lu et al. (2003) of 2708 publications and 1433 sparse attributes and 7 node classes. Here we use random 5-fold cross-validation on nodes.

model

kgcnn

epochs

Categorical accuracy

DMPNN

4.0.0

300

0.8357 ± 0.0156

GAT

4.0.0

250

0.8397 ± 0.0122

GATv2

4.0.0

250

0.8331 ± 0.0104

GCN

4.0.0

300

0.8072 ± 0.0109

GIN

4.0.0

500

0.8279 ± 0.0170

GraphSAGE

4.0.0

500

0.8497 ± 0.0100

ESOLDataset

ESOL consists of 1128 compounds as smiles and their corresponding water solubility in log10(mol/L). We use random 5-fold cross-validation.

model

kgcnn

epochs

MAE [log mol/L]

RMSE [log mol/L]

AttentiveFP

4.0.0

200

0.4351 ± 0.0110

0.6080 ± 0.0207

DGIN

4.0.0

300

0.4434 ± 0.0252

0.6225 ± 0.0420

DMPNN

4.0.0

300

0.4401 ± 0.0165

0.6203 ± 0.0292

EGNN

4.0.0

800

0.4507 ± 0.0152

0.6563 ± 0.0370

GAT

4.0.0

500

0.4818 ± 0.0240

0.6919 ± 0.0694

GATv2

4.0.0

500

0.4598 ± 0.0234

0.6650 ± 0.0409

GCN

4.0.0

800

0.4613 ± 0.0205

0.6534 ± 0.0513

GIN

4.0.0

300

0.5369 ± 0.0334

0.7954 ± 0.0861

GNNFilm

4.0.0

800

0.4854 ± 0.0368

0.6724 ± 0.0436

GraphSAGE

4.0.0

500

0.4874 ± 0.0228

0.6982 ± 0.0608

HDNNP2nd

4.0.0

500

0.7857 ± 0.0986

1.0467 ± 0.1367

INorp

4.0.0

500

0.5055 ± 0.0436

0.7297 ± 0.0786

MAT

4.0.0

400

0.5064 ± 0.0299

0.7194 ± 0.0630

MEGAN

4.0.0

400

0.4281 ± 0.0201

0.6062 ± 0.0252

Megnet

4.0.0

800

0.5679 ± 0.0310

0.8196 ± 0.0480

MXMNet

4.0.0

900

0.6486 ± 0.0633

1.0123 ± 0.2059

NMPN

4.0.0

800

0.5046 ± 0.0266

0.7193 ± 0.0607

PAiNN

4.0.0

250

0.4857 ± 0.0598

0.6650 ± 0.0674

RGCN

4.0.0

800

0.4703 ± 0.0251

0.6529 ± 0.0318

rGIN

4.0.0

300

0.5196 ± 0.0351

0.7142 ± 0.0263

Schnet

4.0.0

800

0.4777 ± 0.0294

0.6977 ± 0.0538

FreeSolvDataset

FreeSolv (MoleculeNet) consists of 642 compounds as smiles and their corresponding hydration free energy for small neutral molecules in water. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [log mol/L]

RMSE [log mol/L]

CMPNN

4.0.0

600

0.5202 ± 0.0504

0.9339 ± 0.1286

DGIN

4.0.0

300

0.5489 ± 0.0374

0.9448 ± 0.0787

DimeNetPP

4.0.0

872

0.6167 ± 0.0719

1.0302 ± 0.1717

DMPNN

4.0.0

300

0.5487 ± 0.0754

0.9206 ± 0.1889

EGNN

4.0.0

800

0.5386 ± 0.0548

1.0363 ± 0.1237

GAT

4.0.0

500

0.6051 ± 0.0861

1.0326 ± 0.1819

GATv2

4.0.0

500

0.6151 ± 0.0247

1.0535 ± 0.0817

GCN

4.0.0

800

0.6400 ± 0.0834

1.0876 ± 0.1393

GIN

4.0.0

300

0.8100 ± 0.1016

1.2695 ± 0.1192

GNNFilm

4.0.0

800

0.6562 ± 0.0552

1.1597 ± 0.1245

GraphSAGE

4.0.0

500

0.5894 ± 0.0675

1.0009 ± 0.1491

HamNet

4.0.0

400

0.6619 ± 0.0428

1.1410 ± 0.1120

HDNNP2nd

4.0.0

500

1.0201 ± 0.1559

1.6351 ± 0.3419

INorp

4.0.0

500

0.6612 ± 0.0188

1.1155 ± 0.1061

MAT

4.0.0

400

0.8115 ± 0.0649

1.3099 ± 0.1235

MEGAN

4.0.0

400

0.6303 ± 0.0550

1.0429 ± 0.1031

Megnet

4.0.0

800

0.8878 ± 0.0528

1.4134 ± 0.1200

MoGAT

4.0.0

200

0.7097 ± 0.0374

1.0911 ± 0.1334

MXMNet

4.0.0

900

1.1386 ± 0.1979

3.0487 ± 2.1757

RGCN

4.0.0

800

0.5128 ± 0.0810

0.9228 ± 0.1887

rGIN

4.0.0

300

0.8503 ± 0.0613

1.3285 ± 0.0976

Schnet

4.0.0

800

0.6070 ± 0.0285

1.0603 ± 0.0549

ISO17Dataset

The database consist of 129 molecules each containing 5,000 conformational geometries, energies and forces with a resolution of 1 femtosecond in the molecular dynamics trajectories. The molecules were randomly drawn from the largest set of isomers in the QM9 dataset.

model

kgcnn

epochs

Energy (test_within)

Force (test_within)

Schnet.EnergyForceModel

4.0.0

1000

0.0061 ± nan

0.0134 ± nan

LipopDataset

Lipophilicity (MoleculeNet) consists of 4200 compounds as smiles. Graph labels for regression are octanol/water distribution coefficient (logD at pH 7.4). We use random 5-fold cross-validation.

model

kgcnn

epochs

MAE [log mol/L]

RMSE [log mol/L]

DMPNN

4.0.0

300

0.3814 ± 0.0064

0.5462 ± 0.0095

GAT

4.0.0

500

0.5168 ± 0.0088

0.7220 ± 0.0098

GATv2

4.0.0

500

0.4342 ± 0.0104

0.6056 ± 0.0114

GCN

4.0.0

800

0.4960 ± 0.0107

0.6833 ± 0.0155

GIN

4.0.0

300

0.4745 ± 0.0101

0.6658 ± 0.0159

GraphSAGE

4.0.0

500

0.4333 ± 0.0217

0.6218 ± 0.0318

Schnet

4.0.0

800

0.5657 ± 0.0202

0.7485 ± 0.0245

MD17Dataset

Energies and forces for molecular dynamics trajectories of eight organic molecules. All geometries in A, energy labels in kcal/mol and force labels in kcal/mol/A. We use preset train-test split. Training on 1000 geometries, test on 500/1000 geometries. Errors are MAE for forces. Results are for the CCSD and CCSD(T) data in MD17.

model

kgcnn

epochs

Aspirin

Toluene

Malonaldehyde

Benzene

Ethanol

PAiNN.EnergyForceModel

4.0.0

1000

nan ± nan

nan ± nan

nan ± nan

nan ± nan

0.5805 ± nan

Schnet.EnergyForceModel

4.0.0

1000

1.2173 ± nan

0.7395 ± nan

0.8444 ± nan

0.3353 ± nan

0.4832 ± nan

MD17RevisedDataset

Energies and forces for molecular dynamics trajectories. All geometries in A, energy labels in kcal/mol and force labels in kcal/mol/A. We use preset train-test split. Training on 1000 geometries, test on 500/1000 geometries. Errors are MAE for forces.

model

kgcnn

epochs

Aspirin

Toluene

Malonaldehyde

Benzene

Ethanol

Schnet.EnergyForceModel

4.0.0

1000

1.0389 ± 0.0071

0.5482 ± 0.0105

0.6727 ± 0.0132

0.2525 ± 0.0091

0.4471 ± 0.0199

MatProjectDielectricDataset

Materials Project dataset from Matbench with 4764 crystal structures and their corresponding Refractive index (unitless). We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [no unit]

RMSE [no unit]

Schnet.make_crystal_model

4.0.0

800

0.3180 ± 0.0359

1.8509 ± 0.5854

MatProjectEFormDataset

Materials Project dataset from Matbench with 132752 crystal structures and their corresponding formation energy in [eV/atom]. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [eV/atom]

RMSE [eV/atom]

Schnet.make_crystal_model

4.0.0

800

0.0211 ± 0.0003

0.0510 ± 0.0024

MatProjectGapDataset

Materials Project dataset from Matbench with 106113 crystal structures and their band gap as calculated by PBE DFT from the Materials Project, in eV. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [eV]

RMSE [eV]

Schnet.make_crystal_model

4.0.0

800

1.2226 ± 1.0573

58.3713 ± 114.2957

MatProjectIsMetalDataset

Materials Project dataset from Matbench with 106113 crystal structures and their corresponding Metallicity determined with pymatgen. 1 if the compound is a metal, 0 if the compound is not a metal. We use a random 5-fold cross-validation.

model

kgcnn

epochs

Accuracy

AUC

Schnet.make_crystal_model

4.0.0

80

0.8953 ± 0.0058

0.9506 ± 0.0053

MatProjectJdft2dDataset

Materials Project dataset from Matbench with 636 crystal structures and their corresponding Exfoliation energy (meV/atom). We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [meV/atom]

RMSE [meV/atom]

CGCNN.make_crystal_model

4.0.0

1000

57.6974 ± 18.0803

140.6167 ± 44.8418

DimeNetPP.make_crystal_model

4.0.0

780

50.2880 ± 11.4199

126.0600 ± 38.3769

PAiNN.make_crystal_model

4.0.0

800

49.3889 ± 11.5376

121.7087 ± 30.0472

Schnet.make_crystal_model

4.0.0

800

45.2412 ± 11.6395

115.6890 ± 39.0929

MatProjectLogGVRHDataset

Materials Project dataset from Matbench with 10987 crystal structures and their corresponding Base 10 logarithm of the DFT Voigt-Reuss-Hill average shear moduli in GPa. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [log(GPa)]

RMSE [log(GPa)]

Schnet.make_crystal_model

4.0.0

800

0.0836 ± 0.0021

0.1296 ± 0.0044

MatProjectLogKVRHDataset

Materials Project dataset from Matbench with 10987 crystal structures and their corresponding Base 10 logarithm of the DFT Voigt-Reuss-Hill average bulk moduli in GPa. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [log(GPa)]

RMSE [log(GPa)]

Schnet.make_crystal_model

4.0.0

800

0.0635 ± 0.0016

0.1186 ± 0.0044

MatProjectPerovskitesDataset

Materials Project dataset from Matbench with 18928 crystal structures and their corresponding Heat of formation of the entire 5-atom perovskite cell in eV. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [eV]

RMSE [eV]

Schnet.make_crystal_model

4.0.0

800

0.0381 ± 0.0005

0.0645 ± 0.0024

MatProjectPhononsDataset

Materials Project dataset from Matbench with 1,265 crystal structures and their corresponding vibration properties in [1/cm]. We use a random 5-fold cross-validation.

model

kgcnn

epochs

MAE [eV/atom]

RMSE [eV/atom]

Schnet.make_crystal_model

4.0.0

800

43.0692 ± 3.6227

88.5151 ± 20.0244

MUTAGDataset

MUTAG dataset from TUDataset for classification with 188 graphs. We use random 5-fold cross-validation.

model

kgcnn

epochs

Accuracy

AUC(ROC)

DMPNN

4.0.0

300

0.8407 ± 0.0463

0.8567 ± 0.0511

GAT

4.0.0

500

0.8141 ± 0.1077

0.8671 ± 0.0923

GATv2

4.0.0

500

0.8193 ± 0.0945

0.8379 ± 0.1074

GCN

4.0.0

800

0.7716 ± 0.0531

0.7956 ± 0.0909

GIN

4.0.0

300

0.8091 ± 0.0781

0.8693 ± 0.0855

GraphSAGE

4.0.0

500

0.8357 ± 0.0798

0.8533 ± 0.0824

MutagenicityDataset

Mutagenicity dataset from TUDataset for classification with 4337 graphs. The dataset was cleaned for unconnected atoms. We use random 5-fold cross-validation.

model

kgcnn

epochs

Accuracy

AUC(ROC)

DMPNN

4.0.0

300

0.8266 ± 0.0059

0.8708 ± 0.0076

GAT

4.0.0

500

0.7989 ± 0.0114

0.8290 ± 0.0112

GATv2

4.0.0

200

0.7674 ± 0.0048

0.8423 ± 0.0064

GCN

4.0.0

800

0.7955 ± 0.0154

0.8191 ± 0.0137

GIN

4.0.0

300

0.8118 ± 0.0091

0.8492 ± 0.0077

GraphSAGE

4.0.0

500

0.8195 ± 0.0126

0.8515 ± 0.0083

PROTEINSDataset

TUDataset of proteins that are classified as enzymes or non-enzymes. Nodes represent the amino acids of the protein. We use random 5-fold cross-validation.

model

kgcnn

epochs

Accuracy

AUC(ROC)

DMPNN

4.0.0

300

0.7287 ± 0.0253

0.7970 ± 0.0343

GAT

4.0.0

500

0.7314 ± 0.0357

0.7899 ± 0.0468

GATv2

4.0.0

500

0.6720 ± 0.0595

0.6850 ± 0.0938

GCN

4.0.0

800

0.7017 ± 0.0303

0.7211 ± 0.0254

GIN

4.0.0

150

0.7224 ± 0.0343

0.7905 ± 0.0528

GraphSAGE

4.0.0

500

0.7009 ± 0.0398

0.7263 ± 0.0453

QM7Dataset

QM7 dataset is a subset of GDB-13. Molecules of up to 23 atoms (including 7 heavy atoms C, N, O, and S), totalling 7165 molecules. We use dataset-specific 5-fold cross-validation. The atomization energies are given in kcal/mol and are ranging from -800 to -2000 kcal/mol).

model

kgcnn

epochs

MAE [kcal/mol]

RMSE [kcal/mol]

Schnet

4.0.0

800

3.4313 ± 0.4757

10.8978 ± 7.3863

QM9Dataset

QM9 dataset of 134k stable small organic molecules made up of C,H,O,N,F. Labels include geometric, energetic, electronic, and thermodynamic properties. We use a random 5-fold cross-validation, but not all splits are evaluated for cheaper evaluation. Test errors are MAE and for energies are given in [eV].

model

kgcnn

epochs

HOMO [eV]

LUMO [eV]

U0 [eV]

H [eV]

G [eV]

Schnet

4.0.0

800

0.0402 ± 0.0007

0.0340 ± 0.0001

0.0142 ± 0.0002

0.0146 ± 0.0002

0.0143 ± 0.0002

SIDERDataset

SIDER (MoleculeNet) consists of 1427 compounds as smiles and data for adverse drug reactions (ADR), grouped into 27 system organ classes. We use random 5-fold cross-validation.

model

kgcnn

epochs

Accuracy

AUC(ROC)

DMPNN

4.0.0

50

0.7519 ± 0.0055

0.6280 ± 0.0173

GAT

4.0.0

50

0.7595 ± 0.0034

0.6224 ± 0.0106

GATv2

4.0.0

50

0.7548 ± 0.0052

0.6152 ± 0.0154

GIN

4.0.0

50

0.7472 ± 0.0055

0.5995 ± 0.0058

GraphSAGE

4.0.0

30

0.7547 ± 0.0043

0.6038 ± 0.0108

Schnet

4.0.0

50

0.7583 ± 0.0076

0.6119 ± 0.0159

Tox21MolNetDataset

Tox21 (MoleculeNet) consists of 7831 compounds as smiles and 12 different targets relevant to drug toxicity. We use random 5-fold cross-validation.

model

kgcnn

epochs

Accuracy

AUC(ROC)

BACC

DMPNN

4.0.0

50

0.9272 ± 0.0024

0.8321 ± 0.0103

0.6995 ± 0.0130

GAT

4.0.0

50

0.9243 ± 0.0022

0.8279 ± 0.0092

0.6504 ± 0.0074

GATv2

4.0.0

50

0.9222 ± 0.0019

0.8251 ± 0.0069

0.6760 ± 0.0140

GIN

4.0.0

50

0.9220 ± 0.0024

0.7986 ± 0.0180

0.6741 ± 0.0143

GraphSAGE

4.0.0

100

0.9180 ± 0.0027

0.7976 ± 0.0087

0.6755 ± 0.0047

Schnet

4.0.0

50

0.9128 ± 0.0030

0.7719 ± 0.0139

0.6639 ± 0.0162

NOTE: You can find this page as jupyter notebook in https://github.com/aimat-lab/gcnn_keras/tree/master/docs/source