Source code for kgcnn.initializers.initializers

import keras as ks
from keras import ops


[docs]def _compute_fans(shape): """Computes the number of input and output units for a weight shape. Taken from original TensorFlow implementation and copied here for static reference. Args: shape: Integer shape tuple or tensor shape. Returns: A tuple of integer scalars (fan_in, fan_out). """ if len(shape) < 1: # Just to avoid errors for constants. fan_in = fan_out = 1 elif len(shape) == 1: fan_in = fan_out = shape[0] elif len(shape) == 2: fan_in = shape[0] fan_out = shape[1] else: # Assuming convolution kernels (2D, 3D, or more). # kernel shape: (..., input_depth, depth) receptive_field_size = 1 for dim in shape[:-2]: receptive_field_size *= dim fan_in = shape[-2] * receptive_field_size fan_out = shape[-1] * receptive_field_size return int(fan_in), int(fan_out)
[docs]@ks.utils.register_keras_serializable(package='kgcnn', name='glorot_orthogonal') class GlorotOrthogonal(ks.initializers.Orthogonal): r"""Combining Glorot variance and Orthogonal initializer. Generate a weight matrix with variance according to Glorot initialization. Based on a random (semi-) orthogonal matrix neural networks are expected to learn better when features are de-correlated. This is stated by e.g.: * "Reducing over-fitting in deep networks by de-correlating representations" by M. Cogswell et al. (2016) `<https://arxiv.org/abs/1511.06068>`_ . * "Dropout: a simple way to prevent neural networks from over-fitting" by N. Srivastava et al. (2014) `<https://dl.acm.org/doi/10.5555/2627435.2670313>`_ . * "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks" by A. M. Saxe et al. (2013) `<https://arxiv.org/abs/1312.6120>`_ . This implementation has been borrowed and slightly modified from `DimeNetPP <https://arxiv.org/abs/2011.14115>`__ . """ def __init__(self, gain=1.0, seed=None, scale=1.0, mode='fan_avg'): super(GlorotOrthogonal, self).__init__(gain=gain, seed=seed) self.scale = scale self.mode = mode def __call__(self, shape, dtype="float32", **kwargs): weight_kernel = super(GlorotOrthogonal, self).__call__(shape, dtype=dtype, **kwargs) # Original implementation from DimeNet. # assert len(shape) == 2 # W = self.orth_init(shape, dtype) # W *= tf.sqrt(self.scale / ((shape[0] + shape[1]) * tf.math.reduce_variance(W))) # scale = 2.0 # Adapted with mode and scale chosen by class. Default values should match original version, to be used # for DimeNet model implementation. scale = self.scale fan_in, fan_out = _compute_fans(shape) if self.mode == "fan_in": scale /= max(1., fan_in) elif self.mode == "fan_out": scale /= max(1., fan_out) else: scale /= max(1., (fan_in + fan_out) / 2.) stddev = ops.sqrt(scale/ops.var(weight_kernel)) weight_kernel *= stddev return weight_kernel
[docs] def get_config(self): """Get keras config.""" config = super(GlorotOrthogonal, self).get_config() config.update({"scale": self.scale, "mode": self.mode}) return config
[docs]@ks.utils.register_keras_serializable(package='kgcnn', name='he_orthogonal') class HeOrthogonal(ks.initializers.Orthogonal): """Combining He variance and Orthogonal initializer. Generate a weight matrix with variance according to He initialization. Based on a random (semi-)orthogonal matrix neural networks are expected to learn better when features are de-correlated. This is stated by e.g.: * "Reducing over-fitting in deep networks by de-correlating representations" by M. Cogswell et al. (2016) `<https://arxiv.org/abs/1511.06068>`_ . * "Dropout: a simple way to prevent neural networks from over-fitting" by N. Srivastava et al. (2014) `<https://dl.acm.org/doi/10.5555/2627435.2670313>`_ . * "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks" by A. M. Saxe et al. (2013) `<https://arxiv.org/abs/1312.6120>`_ . This implementation has been borrowed and slightly modified from `GemNet <https://arxiv.org/abs/2106.08903>`__ . """ def __init__(self, gain=1.0, seed=None, scale=1.0, mode='fan_in'): super(HeOrthogonal, self).__init__(gain=gain, seed=seed) self.scale = scale self.mode = mode def __call__(self, shape, dtype="float32", **kwargs): weight_kernel = super(HeOrthogonal, self).__call__(shape, dtype=dtype, **kwargs) # Original reference implementation was designed for kernel rank={2,3}. # fan_in = shape[0] # if len(shape) == 3: # fan_in = fan_in * shape[1] # Tried to generalize with keras _compute_fans that is extends to convolutional kernels. # Although, not really meaningful in standard GNN applications. # Optionally use other scales. scale = self.scale fan_in, fan_out = _compute_fans(shape) if self.mode == "fan_in": scale /= max(1., fan_in) elif self.mode == "fan_out": scale /= max(1., fan_out) else: scale /= max(1., (fan_in + fan_out) / 2.) weight_kernel = self._standardize(weight_kernel, shape) # Original reference implementation with 1/fan_in changed to scale=scale/fan_in # W *= tf.sqrt(1 / fan_in) # variance decrease is addressed in the dense layers weight_kernel *= ops.sqrt(scale) return weight_kernel
[docs] @staticmethod def _standardize(kernel, shape): r"""Standardize kernel over `fan_in` dimensions. Args: kernel: Kernel variable. shape: Shape of the kernel. """ # Original doc string: Makes sure that N*Var(W) = 1 and E[W] = 0 # From original implementation as comments. # eps = 1e-6 eps = ks.backend.epsilon() # if len(shape) == 3: # axis = [0, 1] # last dimension is output dimension if len(shape) == 0: # Constant does not really have variance. # Moreover, Orthogonal initializer should throw error. return kernel if len(shape) >= 3: axis = [i for i in range(len(shape)-1)] else: axis = 0 mean = ops.mean(kernel, axis=axis, keepdims=True) var = ops.var(kernel, axis=axis, keepdims=True) kernel = (kernel - mean) / ops.sqrt(var + eps) return kernel
[docs] def get_config(self): """Get keras config.""" config = super(HeOrthogonal, self).get_config() config.update({"scale": self.scale, "mode": self.mode}) return config