Source code for kgcnn.ops.scatter

import kgcnn.backend as kgcnn_backend
from keras import KerasTensor
from kgcnn.backend import any_symbolic_tensors
from keras import Operation


class _ScatterMax(Operation):
    def call(self, indices, values, shape):
        return kgcnn_backend.scatter_reduce_max(indices, values, shape)

    def compute_output_spec(self, indices, values, shape):
        return KerasTensor(shape, dtype=values.dtype)


[docs]def scatter_reduce_max(indices, values, shape): r"""Scatter values at indices into new tensor of shape. Args: indices (Tensor): 1D Indices of shape `(M, )` . values (Tensor): Vales of shape `(M, ...)` . shape (tuple): Target shape. Returns: Tensor: Scattered values of `shape` . """ if any_symbolic_tensors((indices, values, shape)): return _ScatterMax().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_max(indices, values, shape)
class _ScatterMin(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_min(indices, values, shape) def compute_output_spec(self, indices, values, shape): return KerasTensor(shape, dtype=values.dtype)
[docs]def scatter_reduce_min(indices, values, shape): r"""Scatter values at indices into new tensor of shape. Args: indices (Tensor): 1D Indices of shape `(M, )` . values (Tensor): Vales of shape `(M, ...)` . shape (tuple): Target shape. Returns: Tensor: Scattered values of `shape` . """ if any_symbolic_tensors((indices, values, shape)): return _ScatterMin().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_min(indices, values, shape)
class _ScatterMean(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_mean(indices, values, shape) def compute_output_spec(self, indices, values, shape): return KerasTensor(shape, dtype=values.dtype)
[docs]def scatter_reduce_mean(indices, values, shape): r"""Scatter values at indices into new tensor of shape. Args: indices (Tensor): 1D Indices of shape `(M, )` . values (Tensor): Vales of shape `(M, ...)` . shape (tuple): Target shape. Returns: Tensor: Scattered values of `shape` . """ if any_symbolic_tensors((indices, values, shape)): return _ScatterMean().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_mean(indices, values, shape)
class _ScatterSum(Operation): def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_sum(indices, values, shape) def compute_output_spec(self, indices, values, shape): return KerasTensor(shape, dtype=values.dtype)
[docs]def scatter_reduce_sum(indices, values, shape): r"""Scatter values at indices into new tensor of shape. Args: indices (Tensor): 1D Indices of shape `(M, )` . values (Tensor): Vales of shape `(M, ...)` . shape (tuple): Target shape. Returns: Tensor: Scattered values of `shape` . """ if any_symbolic_tensors((indices, values, shape)): return _ScatterSum().symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_sum(indices, values, shape)
class _ScatterSoftmax(Operation): def __init__(self, normalize: bool = False): super().__init__() self.normalize = normalize def call(self, indices, values, shape): return kgcnn_backend.scatter_reduce_softmax(indices, values, shape, normalize=self.normalize) def compute_output_spec(self, indices, values, shape): return KerasTensor(shape, dtype=values.dtype)
[docs]def scatter_reduce_softmax(indices, values, shape, normalize: bool = False): r"""Scatter values at indices to normalize values via softmax. Args: indices (Tensor): 1D Indices of shape `(M, )` . values (Tensor): Vales of shape `(M, ...)` . shape (tuple): Target shape of scattered tensor. Returns: Tensor: Values with softmax computed by grouping at indices. """ if any_symbolic_tensors((indices, values, shape)): return _ScatterSoftmax(normalize=normalize).symbolic_call(indices, values, shape) return kgcnn_backend.scatter_reduce_softmax(indices, values, shape, normalize=normalize)