Source code for kgcnn.layers.scale

import keras as ks
from typing import Union
from kgcnn.layers.pooling import PoolingNodes
import numpy as np
from keras import ops


[docs]class StandardLabelScaler(ks.layers.Layer): # noqa
[docs] def __init__(self, scaling_shape: tuple = None, dtype_scale: str = "float64", trainable: bool = False, name="StandardLabelScaler", **kwargs): r"""Initialize layer instance of :obj:`StandardLabelScaler` . Args: scaling_shape (tuple): Shape """ super(StandardLabelScaler, self).__init__(**kwargs) self._scaling_shape = scaling_shape self.name = name self._weights_trainable = trainable self.dtype_scale = dtype_scale self.extensive = False if self._scaling_shape is not None: self._add_weights_for_scaling()
def _add_weights_for_scaling(self): self.scale_ = self.add_weight( shape=self._scaling_shape, initializer="ones", trainable=self._weights_trainable, dtype=self.dtype_scale ) self.mean_ = self.add_weight( shape=self._scaling_shape, initializer="zeros", trainable=self._weights_trainable, dtype=self.dtype_scale )
[docs] def build(self, input_shape): if self._scaling_shape is None: if input_shape is None: raise ValueError("Can not build scale and mean weights if `input_shape` not known.") self._scaling_shape = tuple([1 if i is None else i for i in input_shape]) self._add_weights_for_scaling() self.built = True
[docs] def compute_output_shape(self, input_shape): return input_shape
[docs] def call(self, inputs, **kwargs): return ops.cast(inputs, dtype=self.dtype_scale)*self.scale_ + self.mean_
[docs] def get_config(self): config = super(StandardLabelScaler, self).get_config() config.update({}) return config
[docs] def set_scale(self, scaler): self.set_weights([scaler.get_scaling(), scaler.get_mean_shift()])
[docs]class ExtensiveMolecularLabelScaler(ks.layers.Layer): # noqa max_atomic_number = 95
[docs] def __init__(self, scaling_shape: tuple = None, dtype_scale: str = "float64", trainable: bool = False, name="ExtensiveMolecularLabelScaler", **kwargs): r"""Initialize layer instance of :obj:`StandardLabelScaler` . Args: scaling_shape (tuple): Shape """ super(ExtensiveMolecularLabelScaler, self).__init__(**kwargs) self._scaling_shape = scaling_shape self.name = name self._weights_trainable = trainable self.dtype_scale = dtype_scale self.extensive = True self.layer_pool = PoolingNodes(pooling_method="scatter_sum") self._fit_atom_selection_mask = self.add_weight( shape=(self.max_atomic_number, ), trainable=False, dtype="bool", initializer="zeros") if self._scaling_shape is not None: self._add_weights_for_scaling()
def _add_weights_for_scaling(self): self.scale_ = self.add_weight( shape=self._scaling_shape, initializer="ones", trainable=self._weights_trainable, dtype=self.dtype_scale ) self.ridge_kernel_ = self.add_weight( shape=tuple([self.max_atomic_number] + list(self._scaling_shape[1:])), initializer="zeros", trainable=self._weights_trainable, dtype=self.dtype_scale )
[docs] def build(self, input_shape): if self._scaling_shape is None: if input_shape is None: raise ValueError("Can not build scale and mean weights if `input_shape` and `scaling_shape` not known.") self._scaling_shape = tuple([1 if i is None else i for i in input_shape[0]]) self._add_weights_for_scaling() self.built = True
[docs] def compute_output_shape(self, input_shape): return input_shape[0]
[docs] def call(self, inputs, **kwargs): graph, nodes, batch_id = inputs energy_per_node = ops.take(self.ridge_kernel_, nodes, axis=0) extensive_energies = self.layer_pool([graph, energy_per_node, batch_id]) return ops.cast(graph, dtype=self.dtype_scale)*self.scale_ + extensive_energies
[docs] def get_config(self): config = super(ExtensiveMolecularLabelScaler, self).get_config() config.update({}) return config
[docs] def set_scale(self, scaler): ridge_kernel = np.transpose(np.array(scaler.ridge.coef_)) pos = np.sort(np.array(scaler._fit_atom_selection)) mask = np.array(scaler._fit_atom_selection_mask) shape = tuple([int(self.max_atomic_number)] + list(ridge_kernel.shape[1:])) layer_kernel = np.zeros(shape) layer_kernel[pos] = ridge_kernel layer_kernel[0] = 0. # Make sure 0 is always 0. self.set_weights([mask, scaler.get_scaling(), layer_kernel])
[docs]class QMGraphLabelScaler(ks.layers.Layer): # noqa max_atomic_number = 95
[docs] def __init__(self, scaler_list: list = None, name="QMGraphLabelScaler", **kwargs): r"""Initialize layer instance of :obj:`StandardLabelScaler` . Args: scaler_list (list): List of scaler """ super(QMGraphLabelScaler, self).__init__(**kwargs) self._scaler_list = scaler_list self.name = name self.extensive = True
[docs] def build(self, input_shape): for scaler in self._scaler_list: if scaler.extensive: scaler.build(input_shape) else: scaler.build(input_shape[0]) self.built = True
[docs] def compute_output_shape(self, input_shape): return input_shape[0]
[docs] def call(self, inputs, **kwargs): return inputs
[docs] def get_config(self): config = super(QMGraphLabelScaler, self).get_config() config.update({}) return config
[docs] def set_scale(self, scaler): for s in self._scaler_list: s.set_scale(scaler)
[docs]def get(scale_name: str): scaler_reference = { "StandardLabelScaler": StandardLabelScaler, "ExtensiveMolecularLabelScaler": ExtensiveMolecularLabelScaler } return scaler_reference[scale_name]