import keras as ks
import keras.metrics
import numpy as np
from keras import ops
import keras.saving
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='ScaledMeanAbsoluteError')
class ScaledMeanAbsoluteError(ks.metrics.MeanAbsoluteError):
"""Metric for a scaled mean absolute error (MAE), which can undo a pre-scaling of the targets. Only intended as
metric this allows to info the MAE with correct units or absolute values during fit."""
def __init__(self, scaling_shape=(), name='mean_absolute_error', dtype_scale: str = None, **kwargs):
super(ScaledMeanAbsoluteError, self).__init__(name=name, **kwargs)
self.scaling_shape = scaling_shape
self.dtype_scale = dtype_scale
self.scale = self.add_variable(
shape=scaling_shape,
initializer=ks.initializers.Ones(),
name='kgcnn_scale_mae',
dtype=self.dtype_scale if self.dtype_scale is not None else self.dtype
)
[docs] def reset_state(self):
for v in self.variables:
if 'kgcnn_scale_mae' not in v.name:
v.assign(ops.zeros(v.shape, dtype=v.dtype))
[docs] def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self.scale * ops.cast(y_true, dtype=self.scale.dtype)
y_pred = self.scale * ops.cast(y_pred, dtype=self.scale.dtype)
return super(ScaledMeanAbsoluteError, self).update_state(y_true, y_pred, sample_weight=sample_weight)
[docs] def get_config(self):
"""Returns the serializable config of the metric."""
conf = super(ScaledMeanAbsoluteError, self).get_config()
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale})
return conf
[docs] def set_scale(self, scale):
"""Set the scale from numpy array. Usually used with broadcasting."""
self.scale.assign(ops.cast(scale, dtype=scale.dtype))
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='ScaledRootMeanSquaredError')
class ScaledRootMeanSquaredError(ks.metrics.RootMeanSquaredError):
"""Metric for a scaled root mean squared error (RMSE), which can undo a pre-scaling of the targets.
Only intended as metric this allows to info the MAE with correct units or absolute values during fit."""
def __init__(self, scaling_shape=(), name='root_mean_squared_error', dtype_scale: str = None, **kwargs):
super(ScaledRootMeanSquaredError, self).__init__(name=name, **kwargs)
self.scaling_shape = scaling_shape
self.dtype_scale = dtype_scale
self.scale = self.add_variable(
shape=scaling_shape,
initializer=ks.initializers.Ones(),
name='kgcnn_scale_rmse',
dtype=self.dtype_scale if self.dtype_scale is not None else self.dtype
)
[docs] def reset_state(self):
for v in self.variables:
if 'kgcnn_scale_rmse' not in v.name:
v.assign(ops.zeros(v.shape, dtype=v.dtype))
[docs] def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self.scale * ops.cast(y_true, dtype=self.scale.dtype)
y_pred = self.scale * ops.cast(y_pred, dtype=self.scale.dtype)
return super(ScaledRootMeanSquaredError, self).update_state(y_true, y_pred, sample_weight=sample_weight)
[docs] def get_config(self):
"""Returns the serializable config of the metric."""
conf = super(ScaledRootMeanSquaredError, self).get_config()
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale})
return conf
[docs] def set_scale(self, scale):
"""Set the scale from numpy array. Usually used with broadcasting."""
self.scale.assign(ops.cast(scale, dtype=scale.dtype))
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='ScaledMeanAbsoluteError')
class ScaledForceMeanAbsoluteError(ks.metrics.MeanMetricWrapper):
"""Metric for a scaled mean absolute error (MAE), which can undo a pre-scaling of the targets. Only intended as
metric this allows to info the MAE with correct units or absolute values during fit."""
def __init__(self, scaling_shape=(1, 1), name='force_mean_absolute_error', dtype_scale: str = None,
squeeze_states: bool = True, find_padded_atoms: bool = True, **kwargs):
super(ScaledForceMeanAbsoluteError, self).__init__(fn=self.fn_force_mae, name=name, **kwargs)
self.scaling_shape = scaling_shape
self.dtype_scale = dtype_scale
self.squeeze_states = squeeze_states
self.find_padded_atoms = find_padded_atoms
if scaling_shape[-1] == 1 and squeeze_states and len(scaling_shape) > 1:
scaling_shape = scaling_shape[:-1]
scaling_shape = tuple(list(scaling_shape[:1]) + [1, 1] + list(scaling_shape[1:]))
self.scale = self.add_variable(
shape=scaling_shape,
initializer=ks.initializers.Ones(),
name='kgcnn_scale_mae',
dtype=self.dtype_scale if self.dtype_scale is not None else self.dtype
)
[docs] def fn_force_mae(self, y_true, y_pred):
# (batch, N, 3)
if self.find_padded_atoms:
check_nonzero = ops.cast(ops.logical_not(
ops.all(ops.isclose(y_true, ops.convert_to_tensor(0., dtype=y_true.dtype)), axis=2)), dtype="int32")
row_count = ops.sum(check_nonzero, axis=1)
row_count = ops.where(row_count < 1, 1, row_count)
norm = 1/ops.cast(row_count, dtype=self.scale.dtype)
else:
norm = 1/ops.shape(y_true)[1]
y_true = self.scale * ops.cast(y_true, dtype=self.scale.dtype)
y_pred = self.scale * ops.cast(y_pred, dtype=self.scale.dtype)
diff = ops.abs(y_true-y_pred)
out = ops.sum(ops.mean(diff, axis=2), axis=1)*norm
if not self.squeeze_states:
out = ops.mean(out, axis=-1)
return out
[docs] def reset_state(self):
for v in self.variables:
if 'kgcnn_scale_mae' not in v.name:
v.assign(ops.zeros(v.shape, dtype=v.dtype))
[docs] def get_config(self):
"""Returns the serializable config of the metric."""
# May not manage to deserialize `fn_force_mae`, set conf directly.
# conf = super(ScaledForceMeanAbsoluteError, self).get_config()
conf = {"name": self.name, "dtype": self.dtype}
conf.update({"scaling_shape": self.scaling_shape, "dtype_scale": self.dtype_scale,
"find_padded_atoms": self.find_padded_atoms, "squeeze_states": self.squeeze_states})
return conf
[docs] def set_scale(self, scale):
"""Set the scale from numpy array. Usually used with broadcasting."""
scaling_shape = scale.shape
if scaling_shape[-1] == 1 and self.squeeze_states and len(scaling_shape) > 1:
scale = np.squeeze(scale, axis=-1)
scale = np.expand_dims(np.expand_dims(scale, axis=1), axis=2)
self.scale.assign(ops.cast(scale, dtype=scale.dtype))
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='BinaryAccuracyNoNaN')
class BinaryAccuracyNoNaN(ks.metrics.MeanMetricWrapper):
def __init__(self, name="binary_accuracy_no_nan", dtype=None, threshold=0.5, **kwargs):
if threshold is not None and (threshold <= 0 or threshold >= 1):
raise ValueError(
"Invalid value for argument `threshold`. "
"Expected a value in interval (0, 1). "
f"Received: threshold={threshold}"
)
super().__init__(
fn=self._binary_accuracy_no_nan, name=name, dtype=dtype, threshold=threshold, **kwargs
)
self.threshold = threshold
@staticmethod
def _binary_accuracy_no_nan(y_true, y_pred, threshold=0.5):
y_true = ops.convert_to_tensor(y_true)
y_pred = ops.convert_to_tensor(y_pred)
is_not_nan = ops.cast(ops.logical_not(ops.isnan(y_true)), y_true.dtype)
threshold = ops.cast(threshold, y_pred.dtype)
y_pred = ops.cast(y_pred > threshold, y_true.dtype)
counts = ops.sum(ops.cast(
ops.equal(y_true, y_pred), dtype=ks.backend.floatx()), axis=-1)
norm = ops.sum(ops.cast(is_not_nan, dtype=ks.backend.floatx()), axis=-1)
return counts/norm
[docs] def get_config(self):
config = {"name": self.name, "dtype": self.dtype, "threshold": self.threshold}
return config
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='AUCNoNaN')
class AUCNoNaN(ks.metrics.AUC):
def __init__(self, name="AUC_no_nan", **kwargs):
super(AUCNoNaN, self).__init__(name=name, **kwargs)
[docs] def update_state(self, y_true, y_pred, sample_weight=None):
is_not_nan = ops.cast(ops.logical_not(ops.isnan(y_true)), y_true.dtype)
if sample_weight is not None:
sample_weight *= is_not_nan
else:
sample_weight = is_not_nan
return super(AUCNoNaN, self).update_state(y_true, y_pred, sample_weight=sample_weight)
[docs]@ks.saving.register_keras_serializable(package='kgcnn', name='BalancedBinaryAccuracyNoNaN')
class BalancedBinaryAccuracyNoNaN(ks.metrics.SensitivityAtSpecificity):
def __init__(self, name="balanced_binary_accuracy_no_nan", class_id=None, num_thresholds=1,
specificity=0.5, **kwargs):
super(BalancedBinaryAccuracyNoNaN, self).__init__(name=name, class_id=class_id, num_thresholds=num_thresholds,
specificity=specificity, **kwargs)
self._thresholds_distributed_evenly = False
[docs] def update_state(self, y_true, y_pred, sample_weight=None):
"""Update the state of the metric.
Args:
y_true: Ground truth label values. shape = `[batch_size, d0, .. dN-1]`
or shape = `[batch_size, d0, .. dN-1, 1]` .
y_pred: The predicted probability values. shape = `[batch_size, d0, .. dN]` .
sample_weight: Optional sample_weight acts as a coefficient for the metric.
"""
is_not_nan = ops.cast(ops.logical_not(ops.isnan(y_true)), y_true.dtype)
if sample_weight is not None:
sample_weight *= is_not_nan
else:
sample_weight = is_not_nan
return super(BalancedBinaryAccuracyNoNaN, self).update_state(
y_true=y_true, y_pred=y_pred,
sample_weight=sample_weight
)
[docs] def result(self):
sensitivities = ops.divide(
self.true_positives,
self.true_positives + self.false_negatives + ks.config.epsilon(),
)
specificities = ops.divide(
self.true_negatives,
self.true_negatives + self.false_positives + ks.config.epsilon(),
)
result = (sensitivities + specificities)/2
return result
[docs] def get_config(self):
config = super(BalancedBinaryAccuracyNoNaN, self).get_config()
return config