Source code for kgcnn.literature.GNNExplain._model

import keras as ks
from keras import ops
from kgcnn.ops.core import norm


[docs]class GNNInterface: """An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable. This class is just an interface, which is used by the `GNNExplainer` and should be implemented in a subclass. The implementation of this class could be a wrapper around an existing Tensorflow/Keras GNN. The output of the methods `predict` and `masked_predict` should be of same dimension and the output to be explained. """
[docs] def predict(self, gnn_input, **kwargs): """Returns the prediction for the `gnn_input`. Args: gnn_input: The input graph to which a prediction should be made by the GNN. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs] def masked_predict(self, gnn_input, edge_mask, feature_mask, node_mask, **kwargs): """Returns the prediction for the `gnn_input` when it is masked by the three given masks. Args: gnn_input: The input graph to which should be masked before a prediction should be made by the GNN. edge_mask: A `Tensor` of shape `[get_number_of_edges(self, gnn_input), 1]`, which should mask the edges of the input graph. feature_mask: A `Tensor` of shape `[get_number_of_node_features(self, gnn_input), 1]`, which should mask the node features in the input graph. node_mask: A `Tensor` of shape `[get_number_of_nodes(self, gnn_input), 1]`, which should mask the node features in the input graph. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs] def get_number_of_nodes(self, gnn_input): """Returns the number of nodes in the `gnn_input` graph. Args: gnn_input: The input graph to which this function returns the number of nodes in. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs] def get_number_of_edges(self, gnn_input): """Returns the number of edges in the `gnn_input` graph. Args: gnn_input: The input graph to which this function returns the number of edges in. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs] def get_number_of_node_features(self, gnn_input): """Returns the number of node features to the corresponding `gnn_input`. Args: gnn_input: The input graph to which this function returns the number of node features in. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs] def get_explanation(self, gnn_input, edge_mask, feature_mask, node_mask, **kwargs): """Takes the graph input and the masks learned by the GNNExplainer and combines them to some sort of explanation The form of explanation could e.g. consist of a networkx graph, which has mask values as labels to nodes/edge and a dict for the feature explanation values. Args: gnn_input: The input graph to which should the masks were found by the GNNExplainer. edge_mask: A `Tensor` of shape `[get_number_of_edges(self, gnn_input), 1]`, which was found by the GNNExplainer. feature_mask: A `Tensor` of shape `[get_number_of_node_features(self, gnn_input), 1]`, which was found by the GNNExplainer. node_mask: A `Tensor` of shape `[get_number_of_nodes(self, gnn_input), 1]`, which was found by the GNNExplainer. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs] def present_explanation(self, explanation, **kwargs): """Takes an explanation, which was generated by `get_explanation` and presents it to the user in a suitable way. The presentation of an explanation largely depends on the data domain and targeted user group. Examples for presentations: * A visualization of the most relevant subgraph(s) to the decision * A visualization of the whole graph with highlighted parts * Bar diagrams for feature explanations * ... Args: explanation: An explanation for the GNN decision, which is of the form the `get_explanation` method returns an explanation. Raises: NotImplementedError: This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass. """ raise NotImplementedError( "Implement this method in a specific subclass.")
[docs]class GNNExplainer: """`GNNExplainer` explains the decisions of a GNN, which implements `GNNInterface`. See Ying et al. (https://arxiv.org/abs/1903.03894) for details on how such an explanation is found. Note that this implementation is inspired by the paper by Ying et al., but differs in some aspects. """
[docs] def __init__(self, gnn, gnnexplaineroptimizer_options=None, compile_options=None, fit_options=None, **kwargs): """Constructs a GNNExplainer instance for the given `gnn`. Args: gnn: An instance of a class which implements the `GNNInterface`. gnnexplaineroptimizer_options (dict, optional): Parameters in this dict are forwarded to the constructor of the `GNNExplainerOptimizer` (see docstring of `GNNExplainerOptimizer.__init__`). Defaults to {}. compile_options (dict, optional): Parameters in ths dict are forwarded to the `keras.Model.compile` method of the `GNNExplainerOpimizer`. Can be used to customize the optimization process of the `GNNExplainerOptimizer`. Defaults to {}. fit_options (dict, optional): Parameters in ths dict are forwarded to the `keras.Model.fit` method of the `GNNExplainerOpimizer`. Defaults to {}. """ if gnnexplaineroptimizer_options is None: gnnexplaineroptimizer_options = {} if compile_options is None: compile_options = {} if fit_options is None: fit_options = {} self.gnn = gnn self.gnnx_optimizer = None self.graph_instance = None self.gnnexplaineroptimizer_options = gnnexplaineroptimizer_options # We need to save options as serialized version to recreate the optimizer on multiple explain calls. if "optimizer" in compile_options: if isinstance(compile_options["optimizer"], ks.optimizers.Optimizer): compile_options["optimizer"] = ks.saving.serialize_keras_object(compile_options["optimizer"]) self.compile_options = compile_options self.fit_options = fit_options
[docs] def explain(self, graph_instance, output_to_explain=None, inspection=False, **kwargs): """Finds the masks to the decision of the `self.gnn` on the given `graph_instance`. This method does not have a return value, but only has side effects. To get the explanation which was found, call `get_explanation` after calling this method. This method just instantiates a `GNNExplainerOptimizer`, which then finds the masks for the explanation via gradient descent. Args: graph_instance: The graph input to the GNN to which an explanation should be found. output_to_explain (optional): Set this parameter to the output which should be explained. By default the GNNExplainer explains the output the `self.gnn` on the given `graph_instance`. Defaults to None. inspection (optional): If `inspection` is set to True this function will return information about the optimization process in a dictionary form. Be aware that inspections results in longer runtimes. Defaults to False. """ self.graph_instance = graph_instance # Add inspection callback to fit options, if inspection is True fit_options = self.fit_options.copy() if inspection: inspection_callback = self.InspectionCallback(self.graph_instance) if 'callbacks' in self.fit_options.keys(): fit_options['callbacks'].append(inspection_callback) else: fit_options['callbacks'] = [inspection_callback] # Set up GNNExplainerOptimizer and optimize with respect to masks gnnx_optimizer = GNNExplainerOptimizer( self.gnn, graph_instance, **self.gnnexplaineroptimizer_options) self.gnnx_optimizer = gnnx_optimizer if output_to_explain is not None: if gnnx_optimizer._output_to_explain_as_variable: gnnx_optimizer.output_to_explain.assign(output_to_explain) else: gnnx_optimizer.output_to_explain = output_to_explain gnnx_optimizer.compile(**self.compile_options) gnnx_optimizer.fit(x=graph_instance, y=gnnx_optimizer.output_to_explain, **fit_options) # Read out information from inspection_callback if inspection: dict_fields = ['predictions', 'total_loss', 'edge_mask_loss', 'feature_mask_loss', 'node_mask_loss'] inspection_information = {} for field in dict_fields: if hasattr(inspection_callback, field) and len(getattr(inspection_callback, field)) > 0: inspection_information[field] = getattr(inspection_callback, field) return inspection_information
[docs] def get_explanation(self, **kwargs): """Returns the explanation (derived from the learned masks) to a decision on the graph, which was passed to the `explain` method before. Important: The `explain` method should always be called before calling this method. Internally this method just calls the `GNNInterface.get_explanation` method implemented by the `self.gnn` with the masks found by the `GNNExplainerOptimizer` as parameters. Raises: Exception: If the `explain` method is not called before, this method raises an Exception. Returns: The explanation which is returned by `GNNInterface.get_explanation` implemented by the `self.gnn`, parametrized by the learned masks. """ if self.graph_instance is None or self.gnnx_optimizer is None: raise Exception( "You must first call explain on the GNNExplainer instance.") edge_mask = self.gnnx_optimizer.get_mask("edge") feature_mask = self.gnnx_optimizer.get_mask("feature") node_mask = self.gnnx_optimizer.get_mask("node") return self.gnn.get_explanation(self.graph_instance, edge_mask, feature_mask, node_mask, **kwargs)
[docs] def present_explanation(self, explanation, **kwargs): """Takes an explanation, which was generated by `get_explanation` and presents it. Internally this method just calls the `GNNInterface.present_explanation` method implemented by the `self.gnn`. Args: explanation: The explanation (obtained by `get_explanation`) which should be presented. Returns: A presentation of the given explanation. """ return self.gnn.present_explanation(explanation, **kwargs)
[docs] class InspectionCallback(ks.callbacks.Callback): """Callback class to get the inspection information, if 'inspection' is set to true for the 'GNNExplainer.explain' method. """ def __init__(self, graph_instance): super(GNNExplainer.InspectionCallback, self).__init__() self.graph_instance = graph_instance self.predictions = [] self.total_loss = [] self.edge_mask_loss = [] self.feature_mask_loss = [] self.node_mask_loss = []
[docs] def on_epoch_begin(self, epoch, logs=None): masked = ops.convert_to_numpy(self.model.call(self.graph_instance))[0] self.predictions.append(masked)
[docs] def on_epoch_end(self, epoch, logs=None): """After epoch.""" if self.model.edge_mask_loss_weight > 0: self.edge_mask_loss.append(ops.convert_to_numpy(self.model._metric_edge_tracker.result())) self.model._metric_edge_tracker.reset_state() if self.model.feature_mask_loss_weight > 0: self.feature_mask_loss.append(ops.convert_to_numpy(self.model._metric_feature_tracker.result())) self.model._metric_feature_tracker.reset_state() if self.model.node_mask_loss_weight > 0: self.node_mask_loss.append(ops.convert_to_numpy(self.model._metric_node_tracker.result())) self.model._metric_node_tracker.reset_state() self.total_loss.append(logs['loss'])
[docs]class GNNExplainerOptimizer(ks.Model): """The `GNNExplainerOptimizer` solves the optimization problem which is used to find masks, which then can be used to explain decisions by GNNs. """ _output_to_explain_as_variable = False
[docs] def __init__(self, gnn_model, graph_instance, edge_mask_loss_weight=1e-4, edge_mask_norm_ord=1, feature_mask_loss_weight=1e-4, feature_mask_norm_ord=1, node_mask_loss_weight=0.0, node_mask_norm_ord=1, **kwargs): """Constructs a `GNNExplainerOptimizer` instance with the given parameters. Args: gnn_model (GNNInterface): An instance of a class which implements the methods of the `GNNInterface`. graph_instance: The graph to which the masks should be found. edge_mask_loss_weight (float, optional): The weight of the edge mask loss term in the optimization problem. Defaults to 1e-4. edge_mask_norm_ord (float, optional): The norm p value for the p-norm, which is used on the edge mask. Smaller values encourage sparser masks. Defaults to 1. feature_mask_loss_weight (float, optional): The weight of the feature mask loss term in the optimization problem. Defaults to 1e-4. feature_mask_norm_ord (float, optional): The norm p value for the p-norm, which is used on the feature mask. Smaller values encourage sparser masks. Defaults to 1. node_mask_loss_weight (float, optional): The weight of the node mask loss term in the optimization problem. Defaults to 0.0. node_mask_norm_ord (float, optional): The norm p value for the p-norm, which is used on the feature mask. Smaller values encourage sparser masks. Defaults to 1. """ super(GNNExplainerOptimizer, self).__init__(**kwargs) self.gnn_model = gnn_model self._metric_node_tracker = ks.metrics.Mean(name="mask_loss") self._metric_edge_tracker = ks.metrics.Mean(name="mask_loss") self._metric_feature_tracker = ks.metrics.Mean(name="mask_loss") self._edge_mask_dim = self.gnn_model.get_number_of_edges( graph_instance) self._feature_mask_dim = self.gnn_model.get_number_of_node_features( graph_instance) self._node_mask_dim = self.gnn_model.get_number_of_nodes( graph_instance) self.edge_mask = self.add_weight( name='edge_mask', shape=(self._edge_mask_dim, 1), initializer=ks.initializers.Constant( value=5.), dtype=self.dtype, trainable=True ) self.feature_mask = self.add_weight( name='feature_mask', shape=(self._feature_mask_dim, 1), initializer=ks.initializers.Constant( value=5.), dtype=self.dtype, trainable=True ) self.node_mask = self.add_weight( name='node_mask', shape=(self._node_mask_dim, 1), initializer=ks.initializers.Constant( value=5.), dtype=self.dtype, trainable=True ) output_to_explain = gnn_model.predict(graph_instance) if self._output_to_explain_as_variable: self.output_to_explain = self.add_weight( name='output_to_explain', shape=output_to_explain.shape, initializer=ks.initializers.Constant(0.), dtype=output_to_explain.dtype, trainable=False ) self.output_to_explain.assign(output_to_explain) else: self.output_to_explain = output_to_explain # Configuration Parameters self.edge_mask_loss_weight = edge_mask_loss_weight self.edge_mask_norm_ord = edge_mask_norm_ord self.feature_mask_loss_weight = feature_mask_loss_weight self.feature_mask_norm_ord = feature_mask_norm_ord self.node_mask_loss_weight = node_mask_loss_weight self.node_mask_norm_ord = node_mask_norm_ord
[docs] def call(self, graph_input, training: bool = False, **kwargs): """Call GNN model. Args: graph_input: Graph input. training (bool): If training mode. Default is False. Returns: Tensor: Masked prediction of GNN model. """ edge_mask = self.get_mask("edge") feature_mask = self.get_mask("feature") node_mask = self.get_mask("node") y_pred = self.gnn_model.masked_predict(graph_input, edge_mask, feature_mask, node_mask, training=training) # edge_mask loss if self.edge_mask_loss_weight > 0: loss = norm(ops.sigmoid(self.edge_mask), ord=self.edge_mask_norm_ord) * self.edge_mask_loss_weight self.add_loss(loss) self._metric_edge_tracker.update_state([loss]) # feature_mask loss if self.feature_mask_loss_weight > 0: loss = norm(ops.sigmoid(self.feature_mask), ord=self.feature_mask_norm_ord) * self.feature_mask_loss_weight self.add_loss(loss) self._metric_feature_tracker.update_state([loss]) # node_mask loss if self.node_mask_loss_weight > 0: loss = norm(ops.sigmoid(self.node_mask), ord=self.node_mask_norm_ord) * self.node_mask_loss_weight self.add_loss(loss) self._metric_node_tracker.update_state([loss]) return y_pred
[docs] def get_mask(self, mask_identifier): if mask_identifier == "edge": return self._get_mask(self.edge_mask, self.edge_mask_loss_weight) elif mask_identifier == "feature": return self._get_mask(self.feature_mask, self.feature_mask_loss_weight) elif mask_identifier == "node": return self._get_mask(self.node_mask, self.node_mask_loss_weight) raise Exception("mask_identifier must be 'edge', 'feature' or 'node'")
def _get_mask(self, mask, weight): if weight > 0: return ops.sigmoid(mask) return ops.ones_like(mask)