Source code for kgcnn.layers.mlp

import keras as ks
from keras.layers import Dense, Layer, Activation, Dropout
from keras.layers import LayerNormalization, GroupNormalization, BatchNormalization, UnitNormalization
from kgcnn.layers.norm import (GraphNormalization, GraphInstanceNormalization,
                               GraphBatchNormalization, GraphLayerNormalization)
from kgcnn.layers.norm import global_normalization_args as global_normalization_args_graph
from kgcnn.layers.relational import RelationalDense


global_normalization_args = {
    "UnitNormalization": (
        "axis"
    ),
    "BatchNormalization": (
        "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
        "gamma_regularizer", "beta_constraint", "gamma_constraint", "momentum", "moving_mean_initializer",
        "moving_variance_initializer"
    ),
    "GroupNormalization": (
        "groups", "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
        "gamma_regularizer", "beta_constraint", "gamma_constraint"
    ),
    "LayerNormalization": (
        "axis", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer", "beta_regularizer",
        "gamma_regularizer", "beta_constraint", "gamma_constraint"
    )
}
global_normalization_args.update(global_normalization_args_graph)


class _MLPBase(Layer):  # noqa

    def __init__(self,
                 units,
                 use_bias=True,
                 activation=None,
                 activity_regularizer=None,
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 kernel_initializer="glorot_uniform",
                 bias_initializer="zeros",
                 kernel_constraint=None,
                 bias_constraint=None,
                 # Normalization
                 use_normalization=False,
                 normalization_technique="BatchNormalization",
                 axis=-1,
                 momentum=0.99,
                 epsilon=0.001,
                 mean_shift=True,
                 center=True,
                 scale=True,
                 alpha_initializer="ones",
                 beta_initializer="zeros",
                 gamma_initializer="ones",
                 moving_mean_initializer="zeros",
                 moving_variance_initializer="ones",
                 alpha_regularizer=None,
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 alpha_constraint=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 # Dropout
                 use_dropout=False,
                 rate=None,
                 noise_shape=None,
                 seed=None,
                 # Graph
                 padded_disjoint: bool = False,
                 **kwargs):
        r"""
            activation: Activation function to use.
                If you don't specify anything, no activation is applied
                (ie. "linear" activation: `a(x) = x`).
            use_bias: Boolean, whether the layer uses a bias vector.
            kernel_initializer: Initializer for the `kernel` weights matrix.
            bias_initializer: Initializer for the bias vector.
            kernel_regularizer: Regularizer function applied to
                the `kernel` weights matrix.
            bias_regularizer: Regularizer function applied to the bias vector.
            activity_regularizer: Regularizer function applied to
                the output of the layer (its "activation").
            kernel_constraint: Constraint function applied to
                the `kernel` weights matrix.
            bias_constraint: Constraint function applied to the bias vector.
            use_normalization: Whether to use a normalization layer in between.
            normalization_technique: Which keras normalization technique to apply.
                This can be either 'batch', 'layer', 'group' etc.
            axis: Integer, the axis that should be normalized (typically the features
                axis). For instance, after a `Conv2D` layer with
                `data_format="channels_first"`, set `axis=1` in `GraphBatchNormalization`.
            momentum: Momentum for the moving average.
            epsilon: Small float added to variance to avoid dividing by zero.
            mean_shift: Whether to apply alpha.
            center: If True, add offset of `beta` to normalized tensor. If False, `beta`
                is ignored.
            scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
                next layer is linear (also e.g. `nn.relu`), this can be disabled since the
                scaling will be done by the next layer.
            alpha_initializer: Initializer for the alpha weight. Defaults to 'ones'.
            beta_initializer: Initializer for the beta weight.
            gamma_initializer: Initializer for the gamma weight.
            moving_mean_initializer: Initializer for the moving mean.
            moving_variance_initializer: Initializer for the moving variance.
            alpha_regularizer: Optional regularizer for the alpha weight.
            beta_regularizer: Optional regularizer for the beta weight.
            gamma_regularizer: Optional regularizer for the gamma weight.
            beta_constraint: Optional constraint for the beta weight.
            gamma_constraint: Optional constraint for the gamma weight.
            alpha_constraint: Optional constraint for the alpha weight.
            use_dropout: Whether to use dropout layers in between.
            rate: Float between 0 and 1. Fraction of the input units to drop.
            noise_shape: 1D integer tensor representing the shape of the
                binary dropout mask that will be multiplied with the input.
                For instance, if your inputs have shape`(batch_size, timesteps, features)` and
                you want the dropout mask to be the same for all timesteps,
                you can use `noise_shape=(batch_size, 1, features)`.
            seed: A Python integer to use as random seed.
        """
        super(_MLPBase, self).__init__(**kwargs)
        local_kw = locals()
        # List for groups of arguments.
        self._key_list_act = [
            "activation", "activity_regularizer"
        ]
        self._key_list_dense = [
            "units", "use_bias", "kernel_regularizer", "bias_regularizer", "kernel_initializer", "bias_initializer",
            "kernel_constraint", "bias_constraint"
        ]
        self._key_list_norm_all = [
            "axis", "momentum", "epsilon", "center", "scale", "beta_initializer", "gamma_initializer",
            "moving_mean_initializer", "moving_variance_initializer", "beta_regularizer",
            "gamma_regularizer", "beta_constraint", "gamma_constraint", "alpha_initializer", "alpha_regularizer",
            "alpha_constraint", "mean_shift"
        ]
        self._key_list_dropout = ["rate", "noise_shape", "seed"]
        self._key_list_use = ["use_dropout", "use_normalization", "normalization_technique"]
        self._key_list_init = [
            "kernel_initializer", "bias_initializer", "beta_initializer", "gamma_initializer",
            "moving_mean_initializer", "moving_variance_initializer", "alpha_initializer"
        ]
        self._key_list_reg = [
            "activity_regularizer", "kernel_regularizer", "bias_regularizer", "beta_regularizer", "gamma_regularizer",
            "alpha_regularizer"
        ]
        self._key_list_const = [
            "kernel_constraint", "bias_constraint", "beta_constraint", "gamma_constraint", "alpha_constraint"
        ]
        self._key_list_general = [
            "padded_disjoint"
        ]
        self._key_dict_norm = global_normalization_args

        # Summarize all arguments.
        self._key_list = []
        self._key_list += self._key_list_act + self._key_list_dense + self._key_list_norm_all + self._key_list_dropout
        self._key_list += self._key_list_use + self._key_list_general
        self._key_list = list(set(self._key_list))

        # Dictionary of kwargs for MLP.
        mlp_kwargs = {key: local_kw[key] for key in self._key_list}

        # Everything should be defined by units.
        if isinstance(units, int):
            mlp_kwargs["units"] = [units]
        if not isinstance(mlp_kwargs["units"], list):
            raise ValueError("Units must be a list or a single int for `MLP`.")
        self._depth = len(mlp_kwargs["units"])
        # Special case, if axis is supposed to be multiple axis, use tuple here.
        if not isinstance(axis, list):
            mlp_kwargs["axis"] = [axis for _ in range(self._depth)]
        # Special case, for shape, use tuple here.
        if not isinstance(noise_shape, list):
            mlp_kwargs["noise_shape"] = [noise_shape for _ in range(self._depth)]

        # Assert matching number of args
        def assert_args_is_list(args):
            if not isinstance(args, (list, tuple)):
                return [args for _ in range(self._depth)]
            return args

        # Make every argument to list.
        for key, value in mlp_kwargs.items():
            mlp_kwargs[key] = assert_args_is_list(value)

        # Check correct length for all arguments.
        for key, value in mlp_kwargs.items():
            if self._depth != len(value):
                raise ValueError(
                    "Provide matching list of units '%s' and '%s' or simply a single value." % (
                        mlp_kwargs["units"], key))

        # Deserialize initializer, regularizes, constraints and activation.
        for sl, sm in [
            (self._key_list_init, ks.initializers.get), (self._key_list_reg, ks.regularizers.get),
            (self._key_list_const, ks.constraints.get), (["activation"], ks.activations.get)
        ]:
            for key in sl:
                mlp_kwargs[key] = [sm(x) for x in mlp_kwargs[key]]

        # Fix synonyms for normalization layer.
        replace_norm_identifier = [
            ("batch", "BatchNormalization"), ("layer", "LayerNormalization"), ("group", "GroupNormalization"),
            ("graph", "GraphNormalization"), ("graph_instance", "GraphInstanceNormalization"),
            ("unit_norm", "UnitNormalization"), ("norm", "Normalization"), ("graph_layer", "GraphLayerNormalization"),
            ("graph_batch", "GraphBatchNormalization")
        ]
        for i, x in enumerate(mlp_kwargs["normalization_technique"]):
            for key_rep, key in replace_norm_identifier:
                if x == key_rep:
                    mlp_kwargs["normalization_technique"][i] = key

        # Assign separate '_conf_' for use keys.
        # All '_conf_' kwargs in '_conf_mlp_kwargs'.
        for key in self._key_list_use:
            setattr(self, "_conf_" + key, mlp_kwargs[key])
        self._conf_mlp_kwargs = mlp_kwargs

    def _get_conf_for_keys(self, key_list_to_fetch: list, name_postfix: str, i_layer: int):
        out_kwargs = {key: self._conf_mlp_kwargs[key][i_layer] for key in key_list_to_fetch}
        out_kwargs.update({"name": self.name + "_" + name_postfix + "_" + str(i_layer)})
        return out_kwargs

    def build(self, input_shape):
        """Build layer."""
        super(_MLPBase, self).build(input_shape)

    def get_config(self):
        """Update config."""
        config = super(_MLPBase, self).get_config()
        for key in self._key_list:
            config.update({key: self._conf_mlp_kwargs[key]})

        # Serialize initializer, regularizes, constraints and activation.
        for sl, sm in [
            (self._key_list_init, ks.initializers.serialize), (self._key_list_reg, ks.regularizers.serialize),
            (self._key_list_const, ks.constraints.serialize), (["activation"], ks.activations.serialize)
        ]:
            for key in sl:
                config.update({key: [sm(x) for x in self._conf_mlp_kwargs[key]]})
        return config


[docs]class MLP(_MLPBase): # noqa r"""Class for multilayer perceptron that consist of multiple feed-forward networks. The class contains arguments for :obj:`Dense` , :obj:`Dropout` and :obj:`BatchNormalization` or :obj:`LayerNormalization` or :obj:`GraphNormalization` since MLP is made up of stacked :obj:`Dense` layers with optional normalization and dropout to improve stability or regularization. Here, a list in place of arguments must be provided that applies to each layer. If not a list is given, then the single argument is used for each layer. The number of layers is determined by :obj:`units` argument, which should be list. This class holds arguments for batch-normalization which should be applied between kernel and activation. And dropout after the kernel output and before normalization. """ # If child classes want to replace layers. _supress_dense = False
[docs] def __init__(self, units, **kwargs): r"""Initialize with parameter for MLP layer that match :obj:`Dense` layer, including :obj:`Dropout` and :obj:`BatchNormalization` or :obj:`LayerNormalization` or :obj:`GraphNormalization` . Args: units: Positive integer, dimensionality of the output space. %s """ super(MLP, self).__init__(units=units, **kwargs) norm_classes = { "UnitNormalization": UnitNormalization, "BatchNormalization": BatchNormalization, "GroupNormalization": GroupNormalization, "LayerNormalization": LayerNormalization, "GraphNormalization": GraphNormalization, "GraphInstanceNormalization": GraphInstanceNormalization, "GraphLayerNormalization": GraphLayerNormalization, "GraphBatchNormalization": GraphBatchNormalization, } if not self._supress_dense: self.mlp_dense_layer_list = [ Dense(**self._get_conf_for_keys( self._key_list_dense, "dense", i)) for i in range(self._depth) ] self.mlp_activation_layer_list = [ Activation(**self._get_conf_for_keys( self._key_list_act, "act", i)) for i in range(self._depth) ] self.mlp_dropout_layer_list = [ Dropout(**self._get_conf_for_keys( self._key_list_dropout, "drop", i)) if self._conf_use_dropout[i] else None for i in range(self._depth) ] self.mlp_norm_layer_list = [ norm_classes[self._conf_normalization_technique[i]]( **self._get_conf_for_keys(self._key_dict_norm[self._conf_normalization_technique[i]], "norm", i) ) if self._conf_use_normalization[i] else None for i in range(self._depth) ] self.is_graph_norm_layer = [ "Graph" in self._conf_normalization_technique[i] if self._conf_use_normalization[i] else False for i in range(self._depth) ]
[docs] def build(self, input_shape): """Build layer.""" x_shape, x_graph = (input_shape[0], input_shape[1:]) if isinstance(input_shape, list) else (input_shape, []) for i in range(self._depth): self.mlp_dense_layer_list[i].build(x_shape) x_shape = self.mlp_dense_layer_list[i].compute_output_shape(x_shape) if self._conf_use_dropout[i]: self.mlp_dropout_layer_list[i].build(x_shape) if self._conf_use_normalization[i]: norm_shape = x_shape if not self.is_graph_norm_layer[i] else [x_shape] + x_graph self.mlp_norm_layer_list[i].build(norm_shape) self.mlp_activation_layer_list[i].build(x_shape) self.built = True
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs (Tensor): Input tensor with last dimension not `None` . Returns: Tensor: MLP forward pass. """ x, batch = (inputs[0], inputs[1:]) if isinstance(inputs, list) else (inputs, []) for i in range(self._depth): x = self.mlp_dense_layer_list[i](x, **kwargs) if self._conf_use_dropout[i]: x = self.mlp_dropout_layer_list[i](x, **kwargs) if self._conf_use_normalization[i]: if self.is_graph_norm_layer[i]: x = self.mlp_norm_layer_list[i]([x]+batch, **kwargs) else: x = self.mlp_norm_layer_list[i](x, **kwargs) x = self.mlp_activation_layer_list[i](x, **kwargs) out = x return out
[docs] def get_config(self): """Update config.""" config = super(MLP, self).get_config() return config
MLP.__init__.__doc__ = MLP.__init__.__doc__ % _MLPBase.__init__.__doc__ # Normal MLP can pass additional tensors for normalization. # Use as synonym here. GraphMLP = MLP
[docs]class RelationalMLP(MLP): r"""Relational MLP which behaves like the standard MLP but uses :obj:`RelationalDense` , which applies a specific kernel transformation based on the provided relation. """ _supress_dense = True
[docs] def __init__(self, units, num_relations: int, num_bases: int = None, num_blocks: int = None, **kwargs): """Initialize with parameter for MLP layer that match :obj:`Dense` layer, including :obj:`Dropout` and :obj:`BatchNormalization` or :obj:`LayerNormalization` or :obj:`GraphNormalization` . Args: units: Positive integer, dimensionality of the output space. num_relations: Number of relations expected to construct weights. num_bases: Number of kernel basis functions to construct relations. Default is None. num_blocks: Number of block-matrices to get for parameter reduction. Default is None. %s """ super(RelationalMLP, self).__init__(units=units, **kwargs) self._conf_num_relations = num_relations self._conf_num_bases = num_bases self._conf_num_blocks = num_blocks self._conf_relational_kwargs = { "num_relations": self._conf_num_relations, "num_bases": self._conf_num_bases, "num_blocks": self._conf_num_blocks } # Override dense list with RelationalDense layer. self.mlp_dense_layer_list = [RelationalDense( # **self._conf_mlp_dense_layer_kwargs[i], **self._get_conf_for_keys(self._key_list_dense, "dense", i), **self._conf_relational_kwargs) for i in range(self._depth)]
[docs] def build(self, input_shape): """Build layer.""" x_shape, r_shape, x_graph = ( input_shape[0], input_shape[1], input_shape[2:]) if len(input_shape) > 2 else ( input_shape[0], input_shape[1], []) for i in range(self._depth): self.mlp_dense_layer_list[i].build([x_shape, r_shape]) x_shape = self.mlp_dense_layer_list[i].compute_output_shape([x_shape, r_shape]) if self._conf_use_dropout[i]: self.mlp_dropout_layer_list[i].build(x_shape) if self._conf_use_normalization[i]: norm_shape = x_shape if not self.is_graph_norm_layer[i] else [x_shape] + x_graph self.mlp_norm_layer_list[i].build(norm_shape) self.mlp_activation_layer_list[i].build(x_shape) self.built = True
[docs] def call(self, inputs, **kwargs): r"""Forward pass. Args: inputs: [features, relation] - features (Tensor): Input tensor with last dimension not `None` e.g. `(..., N)` . - relation (Tensor): Input tensor with relation information of shape e.g. `(..., )` of type 'int'. Returns: Tensor: MLP forward pass. """ x, relations, batch = (inputs[0], inputs[1], inputs[2:]) if len(inputs) > 2 else (inputs[0], inputs[1], []) for i in range(self._depth): x = self.mlp_dense_layer_list[i]([x, relations], **kwargs) if self._conf_use_dropout[i]: x = self.mlp_dropout_layer_list[i](x, **kwargs) if self._conf_use_normalization[i]: if self.is_graph_norm_layer[i]: x = self.mlp_norm_layer_list[i]([x]+batch, **kwargs) else: x = self.mlp_norm_layer_list[i](x, **kwargs) x = self.mlp_activation_layer_list[i](x, **kwargs) out = x return out
[docs] def get_config(self): """Update config.""" config = super(RelationalMLP, self).get_config() config.update(self._conf_relational_kwargs) return config
RelationalMLP.__init__.__doc__ = RelationalMLP.__init__.__doc__ % _MLPBase.__init__.__doc__