Source code for kgcnn.backend

import tree
from keras import KerasTensor
from keras.backend import backend


[docs]def any_symbolic_tensors(args=None, kwargs=None): args = args or () kwargs = kwargs or {} for x in tree.flatten((args, kwargs)): if isinstance(x, KerasTensor): return True return False
# Import backend functions. if backend() == "tensorflow": from kgcnn.backend._tensorflow import * elif backend() == "jax": from kgcnn.backend._jax import * elif backend() == "torch": from kgcnn.backend._torch import * elif backend() == "numpy": from kgcnn.backend._numpy import * else: raise ValueError(f"Unable to import backend : {backend()}")