Source code for kgcnn.graph.serial

import importlib
from typing import Union
from kgcnn.utils.serial import deserialize


[docs]def get_preprocessor(name: Union[str, dict], **kwargs): """Get a preprocessor. Args: name (str, dict): Serialization dictionary of class. This can also be a name of former graph functions for backward compatibility that now coincides with the processor's default name. kwargs: Kwargs for processor initialization, if :obj:`name` is string. Returns: GraphPreProcessorBase: Instance of graph preprocessor. """ preprocessor_identifier = { "make_undirected_edges": "MakeUndirectedEdges", "add_edge_self_loops": "AddEdgeSelfLoops", "sort_edge_indices": "SortEdgeIndices", "set_edge_indices_reverse": "SetEdgeIndicesReverse", "pad_property": "PadProperty", "set_edge_weights_uniform": "SetEdgeWeightsUniform", "normalize_edge_weights_sym": "NormalizeEdgeWeightsSymmetric", "set_range_from_edges": "SetRangeFromEdges", "set_range": "SetRange", "set_angle": "SetAngle", "set_range_periodic": "SetRangePeriodic", "expand_distance_gaussian_basis": "ExpandDistanceGaussianBasis", "atomic_charge_representation": "AtomicChargesRepresentation", "principal_moments_of_inertia": "PrincipalMomentsOfInertia", "count_nodes_and_edges": "CountNodesAndEdges", "make_dense_adjacency_matrix": "MakeDenseAdjacencyMatrix", "make_mask": "MakeMask" } if isinstance(name, dict): return deserialize(name) # if given as string name. Lookup identifier. obj_class = getattr(importlib.import_module(str("kgcnn.graph.preprocessor")), str(preprocessor_identifier[name])) return obj_class(**kwargs)
[docs]def get_postprocessor(name: Union[str, dict], **kwargs): r"""Get a postprocessor. Args: name (str, dict): Serialization dictionary of class. This can also be a name of former graph functions for backward compatibility that now coincides with the processor's default name. kwargs: Kwargs for processor initialization, if :obj:`name` is string. Returns: GraphPostProcessorBase: Instance of graph postprocessor. """ preprocessor_identifier = { "extensive_energy_force_scaler": "ExtensiveEnergyForceScalerPostprocessor", } if isinstance(name, dict): return deserialize(name) # if given as string. Lookup identifier. obj_class = getattr(importlib.import_module(str("kgcnn.graph.postprocessor")), str(preprocessor_identifier[name])) return obj_class(**kwargs)