kgcnn.literature.GNNExplain package

Module contents

class kgcnn.literature.GNNExplain.GNNExplainer(gnn, gnnexplaineroptimizer_options=None, compile_options=None, fit_options=None, **kwargs)[source]

Bases: object

GNNExplainer explains the decisions of a GNN, which implements GNNInterface.

See Ying et al. (https://arxiv.org/abs/1903.03894) for details on how such an explanation is found. Note that this implementation is inspired by the paper by Ying et al., but differs in some aspects.

class InspectionCallback(graph_instance)[source]

Bases: keras.src.callbacks.callback.Callback

Callback class to get the inspection information, if ‘inspection’ is set to true for the ‘GNNExplainer.explain’ method.

on_epoch_begin(epoch, logs=None)[source]

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)[source]

After epoch.

__init__(gnn, gnnexplaineroptimizer_options=None, compile_options=None, fit_options=None, **kwargs)[source]

Constructs a GNNExplainer instance for the given gnn.

Parameters
  • gnn – An instance of a class which implements the GNNInterface.

  • gnnexplaineroptimizer_options (dict, optional) – Parameters in this dict are forwarded to the constructor of the GNNExplainerOptimizer (see docstring of GNNExplainerOptimizer.__init__). Defaults to {}.

  • compile_options (dict, optional) – Parameters in ths dict are forwarded to the keras.Model.compile method of the GNNExplainerOpimizer. Can be used to customize the optimization process of the GNNExplainerOptimizer. Defaults to {}.

  • fit_options (dict, optional) – Parameters in ths dict are forwarded to the keras.Model.fit method of the GNNExplainerOpimizer. Defaults to {}.

explain(graph_instance, output_to_explain=None, inspection=False, **kwargs)[source]

Finds the masks to the decision of the self.gnn on the given graph_instance. This method does not have a return value, but only has side effects. To get the explanation which was found, call get_explanation after calling this method. This method just instantiates a GNNExplainerOptimizer, which then finds the masks for the explanation via gradient descent.

Parameters
  • graph_instance – The graph input to the GNN to which an explanation should be found.

  • output_to_explain (optional) – Set this parameter to the output which should be explained. By default the GNNExplainer explains the output the self.gnn on the given graph_instance. Defaults to None.

  • inspection (optional) – If inspection is set to True this function will return information about the optimization process in a dictionary form. Be aware that inspections results in longer runtimes. Defaults to False.

get_explanation(**kwargs)[source]

Returns the explanation (derived from the learned masks) to a decision on the graph, which was passed to the explain method before. Important: The explain method should always be called before calling this method. Internally this method just calls the GNNInterface.get_explanation method implemented by the self.gnn with the masks found by the GNNExplainerOptimizer as parameters.

Raises

Exception – If the explain method is not called before, this method raises an Exception.

Returns

The explanation which is returned by GNNInterface.get_explanation implemented by the self.gnn, parametrized by the learned masks.

present_explanation(explanation, **kwargs)[source]

Takes an explanation, which was generated by get_explanation and presents it. Internally this method just calls the GNNInterface.present_explanation method implemented by the self.gnn.

Parameters

explanation – The explanation (obtained by get_explanation) which should be presented.

Returns

A presentation of the given explanation.

class kgcnn.literature.GNNExplain.GNNExplainerOptimizer(*args, **kwargs)[source]

Bases: keras.src.models.model.Model

The GNNExplainerOptimizer solves the optimization problem which is used to find masks, which then can be used to explain decisions by GNNs.

__init__(gnn_model, graph_instance, edge_mask_loss_weight=0.0001, edge_mask_norm_ord=1, feature_mask_loss_weight=0.0001, feature_mask_norm_ord=1, node_mask_loss_weight=0.0, node_mask_norm_ord=1, **kwargs)[source]

Constructs a GNNExplainerOptimizer instance with the given parameters.

Parameters
  • gnn_model (GNNInterface) – An instance of a class which implements the methods of the GNNInterface.

  • graph_instance – The graph to which the masks should be found.

  • edge_mask_loss_weight (float, optional) – The weight of the edge mask loss term in the optimization problem. Defaults to 1e-4.

  • edge_mask_norm_ord (float, optional) – The norm p value for the p-norm, which is used on the edge mask. Smaller values encourage sparser masks. Defaults to 1.

  • feature_mask_loss_weight (float, optional) – The weight of the feature mask loss term in the optimization problem. Defaults to 1e-4.

  • feature_mask_norm_ord (float, optional) – The norm p value for the p-norm, which is used on the feature mask. Smaller values encourage sparser masks. Defaults to 1.

  • node_mask_loss_weight (float, optional) – The weight of the node mask loss term in the optimization problem. Defaults to 0.0.

  • node_mask_norm_ord (float, optional) – The norm p value for the p-norm, which is used on the feature mask. Smaller values encourage sparser masks. Defaults to 1.

call(graph_input, training: bool = False, **kwargs)[source]

Call GNN model.

Parameters
  • graph_input – Graph input.

  • training (bool) – If training mode. Default is False.

Returns

Masked prediction of GNN model.

Return type

Tensor

get_mask(mask_identifier)[source]
class kgcnn.literature.GNNExplain.GNNInterface[source]

Bases: object

An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable. This class is just an interface, which is used by the GNNExplainer and should be implemented in a subclass. The implementation of this class could be a wrapper around an existing Tensorflow/Keras GNN. The output of the methods predict and masked_predict should be of same dimension and the output to be explained.

get_explanation(gnn_input, edge_mask, feature_mask, node_mask, **kwargs)[source]

Takes the graph input and the masks learned by the GNNExplainer and combines them to some sort of explanation The form of explanation could e.g. consist of a networkx graph, which has mask values as labels to nodes/edge and a dict for the feature explanation values.

Parameters
  • gnn_input – The input graph to which should the masks were found by the GNNExplainer.

  • edge_mask – A Tensor of shape [get_number_of_edges(self, gnn_input), 1], which was found by the GNNExplainer.

  • feature_mask – A Tensor of shape [get_number_of_node_features(self, gnn_input), 1], which was found by the GNNExplainer.

  • node_mask – A Tensor of shape [get_number_of_nodes(self, gnn_input), 1], which was found by the GNNExplainer.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.

get_number_of_edges(gnn_input)[source]

Returns the number of edges in the gnn_input graph.

Parameters

gnn_input – The input graph to which this function returns the number of edges in.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.

get_number_of_node_features(gnn_input)[source]

Returns the number of node features to the corresponding gnn_input.

Parameters

gnn_input – The input graph to which this function returns the number of node features in.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.

get_number_of_nodes(gnn_input)[source]

Returns the number of nodes in the gnn_input graph.

Parameters

gnn_input – The input graph to which this function returns the number of nodes in.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.

masked_predict(gnn_input, edge_mask, feature_mask, node_mask, **kwargs)[source]

Returns the prediction for the gnn_input when it is masked by the three given masks.

Parameters
  • gnn_input – The input graph to which should be masked before a prediction should be made by the GNN.

  • edge_mask – A Tensor of shape [get_number_of_edges(self, gnn_input), 1], which should mask the edges of the input graph.

  • feature_mask – A Tensor of shape [get_number_of_node_features(self, gnn_input), 1], which should mask the node features in the input graph.

  • node_mask – A Tensor of shape [get_number_of_nodes(self, gnn_input), 1], which should mask the node features in the input graph.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.

predict(gnn_input, **kwargs)[source]

Returns the prediction for the gnn_input.

Parameters

gnn_input – The input graph to which a prediction should be made by the GNN.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.

present_explanation(explanation, **kwargs)[source]

Takes an explanation, which was generated by get_explanation and presents it to the user in a suitable way. The presentation of an explanation largely depends on the data domain and targeted user group. Examples for presentations:

  • A visualization of the most relevant subgraph(s) to the decision

  • A visualization of the whole graph with highlighted parts

  • Bar diagrams for feature explanations

Parameters

explanation – An explanation for the GNN decision, which is of the form the get_explanation method returns an explanation.

Raises

NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.