Source code for kgcnn.models.force

import keras as ks
import keras.saving
from typing import Union
from kgcnn.models.utils import get_model_class
from keras.saving import deserialize_keras_object, serialize_keras_object
from keras.backend import backend


# In keras 3.0.0 there is no `ops.gradient()` function yet.
# Backend specific gradient implementation in the following.
if backend() == "tensorflow":
    import tensorflow as tf
elif backend() == "torch":
    import torch
else:
    raise NotImplementedError("Backend '%s' not supported for force model." % backend())


[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='EnergyForceModel') class EnergyForceModel(ks.models.Model): r"""Outer model to wrap a normal invariant GNN to predict forces from energy predictions via partial derivatives. The Force :math:`\vec{F_i}` on Atom :math:`i` is given by .. math:: \vec{F_i} = - \vec{\nabla}_i E_{\text{total}} Note that the model simply returns the tensor type of the coordinate input for forces. No casting is done by this class. This means that the model returns a ragged, disjoint or padded tensor depending on the tensor type of the coordinates. """
[docs] def __init__(self, inputs: Union[dict, list] = None, model_energy=None, coordinate_input: Union[int, str] = 1, output_as_dict: bool = True, ragged_validate: bool = False, output_to_tensor: bool = True, output_squeeze_states: bool = False, nested_model_config: bool = True, is_physical_force: bool = True, use_batch_jacobian: bool = None, name: str = None, outputs: Union[dict, list] = None ): """Initialize Force model with an energy model. Args: inputs (list): List of inputs as dictionary kwargs of keras input layers. model_energy (ks.models.Model, dict): Keras model os deserialization dictionary for a keras model. coordinate_input (int): Position of the coordinate input. output_as_dict (bool, tuple): Names for the output if a dictionary should be returned. Or simply a bool which will use the names "energy" and "force". ragged_validate (bool): Whether to validate ragged or jagged tensors. output_to_tensor: Deprecated. output_squeeze_states (bool): Whether to squeeze state/energy dimension for forces in case of a single energy value. nested_model_config (bool): Whether there is a config for the energy model. is_physical_force (bool): Whether to return the physical force, e.g. the negative gradient of the energy. use_batch_jacobian: Deprecated. name (str): Name of the model. outputs: List of outputs as dictionary kwargs similar to inputs. Not used by the model but can be passed for external use. """ super().__init__() if model_energy is None: raise ValueError("Require valid model in `model_energy` for force prediction.") # Input for model_energy. self._model_energy = model_energy self.name = name if isinstance(model_energy, ks.models.Model): # Ignoring module_name and class_name. self.energy_model = model_energy elif isinstance(model_energy, dict): if "module_name" not in model_energy: self.energy_model = deserialize_keras_object(model_energy) else: self.energy_model_class = get_model_class(model_energy["module_name"], model_energy["class_name"]) self.energy_model = self.energy_model_class(**model_energy["config"]) else: raise TypeError("Input `model_energy` must be dict or `ks.models.Model` . Can not deserialize model.") # Additional parameters of io and behavior of this class. self.ragged_validate = ragged_validate self.coordinate_input = coordinate_input # self.output_to_tensor = output_to_tensor self.output_squeeze_states = output_squeeze_states self.is_physical_force = is_physical_force self.nested_model_config = nested_model_config self._force_outputs = outputs # self.use_batch_jacobian = use_batch_jacobian self.output_as_dict = output_as_dict if isinstance(output_as_dict, bool): self.output_as_dict_use = output_as_dict self.output_as_dict_names = ("energy", "force") elif isinstance(output_as_dict, (list, tuple)): self.output_as_dict_use = True self.output_as_dict_names = (output_as_dict[0], output_as_dict[1]) else: self.output_as_dict_use = False # We can try to infer the model inputs from energy model, if not given explicit. self._inputs_to_force_model = inputs if self._inputs_to_force_model is None: if self.nested_model_config and isinstance(model_energy, dict): self._inputs_to_force_model = model_energy["config"]["inputs"] if backend() == "tensorflow": self._call_grad_backend = self._call_grad_tf elif backend() == "torch": self._call_grad_backend = self._call_grad_torch else: raise NotImplementedError("Backend '%s' not supported for force model." % backend())
[docs] def build(self, input_shape): self.energy_model.build(input_shape) self.built = True
def _call_grad_tf(self, inputs, training=False, **kwargs): x_in = inputs[self.coordinate_input] with tf.GradientTape(persistent=True) as tape: if isinstance(x_in, tf.RaggedTensor): x, splits = x_in.values, x_in.row_splits else: x, splits = x_in, None tape.watch(x) eng = self.energy_model(inputs, training=training, **kwargs) eng_sum = tf.reduce_sum(eng, axis=0, keepdims=False) e_grad = [eng_sum[i] for i in range(eng_sum.shape[-1])] e_grad = [tf.expand_dims(tape.gradient(e_i, x), axis=-1) for e_i in e_grad] e_grad = tf.concat(e_grad, axis=-1) if self.output_squeeze_states: e_grad = tf.squeeze(e_grad, axis=-1) if isinstance(x_in, tf.RaggedTensor): e_grad = tf.RaggedTensor.from_row_splits(e_grad, splits, validate=self.ragged_validate) return eng, e_grad def _call_grad_torch(self, inputs, training=False, **kwargs): x = inputs[self.coordinate_input] with torch.enable_grad(): x.requires_grad = True eng = self.energy_model.call(inputs, training=training, **kwargs) eng_sum = eng.sum(dim=0) e_grad = torch.cat([ torch.unsqueeze(torch.autograd.grad(eng_sum[i], x, create_graph=True)[0], dim=-1) for i in range(eng.shape[-1])], dim=-1) if self.output_squeeze_states: e_grad = torch.squeeze(e_grad, dim=-1) return eng, e_grad
[docs] def call(self, inputs, training=False, **kwargs): eng, e_grad = self._call_grad_backend(inputs, training=training, **kwargs) if self.is_physical_force: e_grad = -e_grad if self.output_as_dict_use: return {self.output_as_dict_names[0]: eng, self.output_as_dict_names[1]: e_grad} else: return eng, e_grad
[docs] def get_config(self): """Get config.""" # Keras model does not provide config from base class. # conf = super(EnergyForceModel, self).get_config() conf = {} # Serialize class if _model_energy is not dict. if isinstance(self._model_energy, dict): model_energy = self._model_energy else: model_energy = serialize_keras_object(self._model_energy) conf.update({ "model_energy": model_energy, "coordinate_input": self.coordinate_input, "output_as_dict": self.output_as_dict, "ragged_validate": self.ragged_validate, # "output_to_tensor": self.output_to_tensor, "output_squeeze_states": self.output_squeeze_states, "nested_model_config": self.nested_model_config, # "use_batch_jacobian": self.use_batch_jacobian, "inputs": self._inputs_to_force_model, "outputs": self._force_outputs }) return conf