Source code for kgcnn.losses.losses

import keras as ks
from keras import ops
from keras.losses import Loss
from keras.losses import mean_absolute_error
import keras.saving


[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='MeanAbsoluteError') class MeanAbsoluteError(Loss): def __init__(self, reduction="sum_over_batch_size", name="mean_absolute_error", dtype=None): super(MeanAbsoluteError, self).__init__(reduction=reduction, name=name, dtype=dtype)
[docs] def call(self, y_true, y_pred): return mean_absolute_error(y_true, y_pred)
[docs] def get_config(self): config = super(MeanAbsoluteError, self).get_config() return config
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='ForceMeanAbsoluteError') class ForceMeanAbsoluteError(Loss): def __init__(self, reduction="sum_over_batch_size", name="force_mean_absolute_error", squeeze_states: bool = True, find_padded_atoms: bool = True, dtype=None): super(ForceMeanAbsoluteError, self).__init__(reduction=reduction, name=name, dtype=dtype) self.squeeze_states = squeeze_states self.find_padded_atoms = find_padded_atoms
[docs] def call(self, y_true, y_pred): # Shape: (batch, N, 3, S) if self.find_padded_atoms: check_nonzero = ops.cast(ops.logical_not( ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=2)), dtype="int32") row_count = ops.sum(check_nonzero, axis=1) row_count = ops.where(row_count < 1, 1, row_count) norm = 1/ops.cast(row_count, dtype=y_true.dtype) else: norm = 1/ops.shape(y_true)[1] diff = ops.abs(y_true-y_pred) out = ops.sum(ops.mean(diff, axis=2), axis=1)*norm if not self.squeeze_states: out = ops.mean(out, axis=-1) return out
[docs] def get_config(self): config = super(ForceMeanAbsoluteError, self).get_config() config.update({"find_padded_atoms": self.find_padded_atoms, "squeeze_states": self.squeeze_states}) return config
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='BinaryCrossentropyNoNaN') class BinaryCrossentropyNoNaN(ks.losses.BinaryCrossentropy): def __init__(self, *args, **kwargs): super(BinaryCrossentropyNoNaN, self).__init__(*args, **kwargs)
[docs] def call(self, y_true, y_pred): is_nan = ops.isnan(y_true) y_pred = ops.where(is_nan, ops.zeros_like(y_pred), y_pred) y_true = ops.where(is_nan, ops.zeros_like(y_true), y_true) return super(BinaryCrossentropyNoNaN, self).call(y_true, y_pred)