kgcnn.literature.GNNExplain package¶
Module contents¶
-
class
kgcnn.literature.GNNExplain.
GNNExplainer
(gnn, gnnexplaineroptimizer_options=None, compile_options=None, fit_options=None, **kwargs)[source]¶ Bases:
object
GNNExplainer explains the decisions of a GNN, which implements GNNInterface.
See Ying et al. (https://arxiv.org/abs/1903.03894) for details on how such an explanation is found. Note that this implementation is inspired by the paper by Ying et al., but differs in some aspects.
-
class
InspectionCallback
(graph_instance)[source]¶ Bases:
keras.src.callbacks.callback.Callback
Callback class to get the inspection information, if ‘inspection’ is set to true for the ‘GNNExplainer.explain’ method.
-
on_epoch_begin
(epoch, logs=None)[source]¶ Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters
epoch – Integer, index of epoch.
logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.
-
-
__init__
(gnn, gnnexplaineroptimizer_options=None, compile_options=None, fit_options=None, **kwargs)[source]¶ Constructs a GNNExplainer instance for the given gnn.
- Parameters
gnn – An instance of a class which implements the GNNInterface.
gnnexplaineroptimizer_options (dict, optional) – Parameters in this dict are forwarded to the constructor of the GNNExplainerOptimizer (see docstring of GNNExplainerOptimizer.__init__). Defaults to {}.
compile_options (dict, optional) – Parameters in ths dict are forwarded to the keras.Model.compile method of the GNNExplainerOpimizer. Can be used to customize the optimization process of the GNNExplainerOptimizer. Defaults to {}.
fit_options (dict, optional) – Parameters in ths dict are forwarded to the keras.Model.fit method of the GNNExplainerOpimizer. Defaults to {}.
-
explain
(graph_instance, output_to_explain=None, inspection=False, **kwargs)[source]¶ Finds the masks to the decision of the self.gnn on the given graph_instance. This method does not have a return value, but only has side effects. To get the explanation which was found, call get_explanation after calling this method. This method just instantiates a GNNExplainerOptimizer, which then finds the masks for the explanation via gradient descent.
- Parameters
graph_instance – The graph input to the GNN to which an explanation should be found.
output_to_explain (optional) – Set this parameter to the output which should be explained. By default the GNNExplainer explains the output the self.gnn on the given graph_instance. Defaults to None.
inspection (optional) – If inspection is set to True this function will return information about the optimization process in a dictionary form. Be aware that inspections results in longer runtimes. Defaults to False.
-
get_explanation
(**kwargs)[source]¶ Returns the explanation (derived from the learned masks) to a decision on the graph, which was passed to the explain method before. Important: The explain method should always be called before calling this method. Internally this method just calls the GNNInterface.get_explanation method implemented by the self.gnn with the masks found by the GNNExplainerOptimizer as parameters.
- Raises
Exception – If the explain method is not called before, this method raises an Exception.
- Returns
The explanation which is returned by GNNInterface.get_explanation implemented by the self.gnn, parametrized by the learned masks.
-
present_explanation
(explanation, **kwargs)[source]¶ Takes an explanation, which was generated by get_explanation and presents it. Internally this method just calls the GNNInterface.present_explanation method implemented by the self.gnn.
- Parameters
explanation – The explanation (obtained by get_explanation) which should be presented.
- Returns
A presentation of the given explanation.
-
class
-
class
kgcnn.literature.GNNExplain.
GNNExplainerOptimizer
(*args, **kwargs)[source]¶ Bases:
keras.src.models.model.Model
The GNNExplainerOptimizer solves the optimization problem which is used to find masks, which then can be used to explain decisions by GNNs.
-
__init__
(gnn_model, graph_instance, edge_mask_loss_weight=0.0001, edge_mask_norm_ord=1, feature_mask_loss_weight=0.0001, feature_mask_norm_ord=1, node_mask_loss_weight=0.0, node_mask_norm_ord=1, **kwargs)[source]¶ Constructs a GNNExplainerOptimizer instance with the given parameters.
- Parameters
gnn_model (GNNInterface) – An instance of a class which implements the methods of the GNNInterface.
graph_instance – The graph to which the masks should be found.
edge_mask_loss_weight (float, optional) – The weight of the edge mask loss term in the optimization problem. Defaults to 1e-4.
edge_mask_norm_ord (float, optional) – The norm p value for the p-norm, which is used on the edge mask. Smaller values encourage sparser masks. Defaults to 1.
feature_mask_loss_weight (float, optional) – The weight of the feature mask loss term in the optimization problem. Defaults to 1e-4.
feature_mask_norm_ord (float, optional) – The norm p value for the p-norm, which is used on the feature mask. Smaller values encourage sparser masks. Defaults to 1.
node_mask_loss_weight (float, optional) – The weight of the node mask loss term in the optimization problem. Defaults to 0.0.
node_mask_norm_ord (float, optional) – The norm p value for the p-norm, which is used on the feature mask. Smaller values encourage sparser masks. Defaults to 1.
-
-
class
kgcnn.literature.GNNExplain.
GNNInterface
[source]¶ Bases:
object
An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable. This class is just an interface, which is used by the GNNExplainer and should be implemented in a subclass. The implementation of this class could be a wrapper around an existing Tensorflow/Keras GNN. The output of the methods predict and masked_predict should be of same dimension and the output to be explained.
-
get_explanation
(gnn_input, edge_mask, feature_mask, node_mask, **kwargs)[source]¶ Takes the graph input and the masks learned by the GNNExplainer and combines them to some sort of explanation The form of explanation could e.g. consist of a networkx graph, which has mask values as labels to nodes/edge and a dict for the feature explanation values.
- Parameters
gnn_input – The input graph to which should the masks were found by the GNNExplainer.
edge_mask – A Tensor of shape [get_number_of_edges(self, gnn_input), 1], which was found by the GNNExplainer.
feature_mask – A Tensor of shape [get_number_of_node_features(self, gnn_input), 1], which was found by the GNNExplainer.
node_mask – A Tensor of shape [get_number_of_nodes(self, gnn_input), 1], which was found by the GNNExplainer.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-
get_number_of_edges
(gnn_input)[source]¶ Returns the number of edges in the gnn_input graph.
- Parameters
gnn_input – The input graph to which this function returns the number of edges in.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-
get_number_of_node_features
(gnn_input)[source]¶ Returns the number of node features to the corresponding gnn_input.
- Parameters
gnn_input – The input graph to which this function returns the number of node features in.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-
get_number_of_nodes
(gnn_input)[source]¶ Returns the number of nodes in the gnn_input graph.
- Parameters
gnn_input – The input graph to which this function returns the number of nodes in.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-
masked_predict
(gnn_input, edge_mask, feature_mask, node_mask, **kwargs)[source]¶ Returns the prediction for the gnn_input when it is masked by the three given masks.
- Parameters
gnn_input – The input graph to which should be masked before a prediction should be made by the GNN.
edge_mask – A Tensor of shape [get_number_of_edges(self, gnn_input), 1], which should mask the edges of the input graph.
feature_mask – A Tensor of shape [get_number_of_node_features(self, gnn_input), 1], which should mask the node features in the input graph.
node_mask – A Tensor of shape [get_number_of_nodes(self, gnn_input), 1], which should mask the node features in the input graph.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-
predict
(gnn_input, **kwargs)[source]¶ Returns the prediction for the gnn_input.
- Parameters
gnn_input – The input graph to which a prediction should be made by the GNN.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-
present_explanation
(explanation, **kwargs)[source]¶ Takes an explanation, which was generated by get_explanation and presents it to the user in a suitable way. The presentation of an explanation largely depends on the data domain and targeted user group. Examples for presentations:
A visualization of the most relevant subgraph(s) to the decision
A visualization of the whole graph with highlighted parts
Bar diagrams for feature explanations
…
- Parameters
explanation – An explanation for the GNN decision, which is of the form the get_explanation method returns an explanation.
- Raises
NotImplementedError – This is just an interface class, to indicate which methods should be implemented. Implement this method in a subclass.
-