Source code for kgcnn.data.crystal

import os
import numpy as np
from collections import defaultdict
from typing import Dict, Callable, List, Union
import pandas as pd
import pymatgen
import pymatgen.io.cif
import pymatgen.core.structure
import pymatgen.symmetry.structure
from kgcnn.utils.serial import deserialize
from kgcnn.data.base import MemoryGraphDataset
from kgcnn.data.utils import save_json_file, load_json_file
from kgcnn.crystal.base import CrystalPreprocessor
from kgcnn.graph.base import GraphDict


[docs]class CrystalDataset(MemoryGraphDataset): r"""Class for making graph dataset from periodic structures such as crystals. The dataset class requires a :obj:`data_directory` to store a table '.csv' file containing labels and information of the structures stored in multiple (CIF, POSCAR, ...) files in :obj:`file_directory` . The file names must be included in the '.csv' table. The table file must have one line of header with column names! .. code-block:: console ├── data_directory ├── file_directory │ ├── *.cif │ ├── *.cif │ └── ... ├── file_name.csv ├── file_name.pymatgen.json └── dataset_name.kgcnn.pickle This class uses :obj:`pymatgen.core.structure.Structure` and therefore requires :obj:`pymatgen` to be installed. A '.pymatgen.json' serialized file is generated to store a list of structures from single '.cif' files via :obj:`prepare_data()` . Consequently, a 'file_name.pymatgen.json' can be directly stored in :obj:`data_directory`. In this, case :obj:`prepare_data()` does not have to be used. Additionally, a table file 'file_name.csv' that lists the single file names and possible labels or classification targets is required. .. code-block:: python from kgcnn.data.crystal import CrystalDataset dataset = CrystalDataset( data_directory="data_directory/", dataset_name="ExampleCrystal", file_name="file_name.csv", file_directory="file_directory") dataset.prepare_data(file_column_name="file_name", overwrite=True) dataset.read_in_memory(label_column_name="label") """ _default_loop_update_info = 5000
[docs] def __init__(self, data_directory: str = None, dataset_name: str = None, file_name: str = None, file_directory: str = None, file_name_pymatgen_json: str = None, verbose: int = 10): r"""Initialize a base class of :obj:`CrystalDataset`. Args: data_directory (str): Full path to directory of the dataset. Default is None. file_name (str): Filename for dataset to read into memory. This is a table file. The '.csv' should contain file names that are expected to be CIF-files in :obj:`file_directory`. Default is None. file_directory (str): Name or relative path from :obj:`data_directory` to a directory containing sorted 'cif' files. Default is None. file_name_pymatgen_json (str): This class will generate a 'json' file with pymatgen structures. You can specify the file name of that file with this argument. By default, it will be named from :obj:`file_name` when passed None. dataset_name (str): Name of the dataset. Important for naming and saving files. Default is None. verbose (int): Logging level. Default is 10. """ super(CrystalDataset, self).__init__( data_directory=data_directory, dataset_name=dataset_name, file_name=file_name, verbose=verbose, file_directory=file_directory) self._structs = None self.file_name_pymatgen_json = file_name_pymatgen_json self.label_units = None self.label_names = None
@property def pymatgen_json_file_path(self): """Internal file name for the pymatgen serialization information to store to disk.""" self._verify_data_directory() if self.file_name_pymatgen_json is None: file_name = os.path.splitext(self.file_name)[0] + ".pymatgen.json" else: file_name = self.file_name_pymatgen_json return os.path.join(self.data_directory, file_name) @staticmethod def _pymatgen_serialize_structs(structs: List) -> List[dict]: dicts = [] for s in structs: d = s.as_dict() # Module information should be already obtained from as_dict(). # d["@module"] = type(s).__module__ # d["@class"] = type(s).__name__ dicts.append(d) return dicts @staticmethod def _pymatgen_deserialize_dicts(dicts: List[dict], to_unit_cell: bool = False) -> list: structs = [] for x in dicts: # TODO: We could check symmetry or @module, @class items in dict. s = pymatgen.core.structure.Structure.from_dict(x) structs.append(s) if to_unit_cell: for site in s.sites: site.to_unit_cell(in_place=True) return structs
[docs] def save_structures_to_json_file(self, structs: list, file_path: str = None): """Save a list of pymatgen structures to file. Args: structs (list): List of pymatgen structures. file_path (str): File path to store structures to disk, uses class-default. Default is None. Returns: None. """ if file_path is None: file_path = self.pymatgen_json_file_path self.info("Exporting as dict for pymatgen ...") dicts = self._pymatgen_serialize_structs(structs) self.info("Saving structures as .json ...") save_json_file(dicts, file_path)
@staticmethod def _pymatgen_parse_file_to_structure(cif_file: str): # TODO: We can add flexible parsing to include other than just CIF from file here. structures = pymatgen.io.cif.CifParser(cif_file).get_structures() return structures
[docs] def prepare_data(self, file_column_name: str = None, overwrite: bool = False): r"""Default preparation for crystal datasets. Try to load all crystal structures from single files and save them as a pymatgen json serialization. Can load multiple CIF files from a table that keeps file names and possible labels or additional information. Args: file_column_name (str): Name of the column that has file names found in file_directory. Default is None. overwrite (bool): Whether to rerun the data extraction. Default is False. Returns: self """ if os.path.exists(self.pymatgen_json_file_path) and not overwrite: self.info("Pickled pymatgen structures already exist. Do nothing.") return self self.info("Searching for structure files in '%s'" % self.file_directory_path) structs = self.collect_files_in_file_directory( file_column_name=file_column_name, table_file_path=None, read_method_file=self._pymatgen_parse_file_to_structure, update_counter=self._default_loop_update_info, append_file_content=True, read_method_return_list=True ) self.save_structures_to_json_file(structs) return self
[docs] def get_structures_from_json_file(self, file_path: str = None) -> List: """Load pymatgen serialized json-file into memory. Structures are not added to :obj:`CrystalDataset` but returned by this function. Args: file_path (str): File path to json-file, uses class default. Default is None. Returns: list: List of pymatgen structures. """ if file_path is None: file_path = self.pymatgen_json_file_path if not os.path.exists(file_path): raise FileNotFoundError("Cannot find .json file for `CrystalDataset`. Please `prepare_data()`.") self.info("Reading structures from .json ...") return self._pymatgen_deserialize_dicts(load_json_file(file_path))
[docs] def _map_callbacks(self, structs: list, data: pd.Series, callbacks: Dict[ str, Callable[[pymatgen.core.structure.Structure, pd.Series], Union[np.ndarray, None]]], assign_to_self: bool = True) -> dict: """Map callbacks on a data series object plus structure list. Args: structs (list): List of pymatgen structures. data (pd.Series, pd.DataFrame): Data Frame matching the structure list. callbacks (dict): Dictionary of callbacks that take a data object plus pymatgen structure as argument. assign_to_self (bool): Whether to already assign the output of callbacks to this class. Returns: dict: Values of callbacks. """ # The dictionaries values are lists, one for each attribute defines in "callbacks" and each value in those # lists corresponds to one structure in the dataset. value_lists = defaultdict(list) for index, st in enumerate(structs): for name, callback in callbacks.items(): if st is None: value_lists[name].append(None) else: data_dict = data.loc[index] value = callback(st, data_dict) value_lists[name].append(value) if index % self._default_loop_update_info == 0: self.info(" ... read structures {0} from {1}".format(index, len(structs))) # The string key names of the original "callbacks" dict are also used as the names of the properties which are # assigned if assign_to_self: for name, values in value_lists.items(): self.assign_property(name, values) return value_lists
[docs] def read_in_memory(self, label_column_name: str = None, additional_callbacks: Dict[ str, Callable[[pymatgen.core.structure.Structure, pd.Series], None]] = None ): """Read structures from pymatgen json serialization and convert them into graph information. Args: label_column_name (str): Columns of labels for graph in table file. Default is None. additional_callbacks (dict): Callbacks to add during read into memory. Returns: self """ if additional_callbacks is None: additional_callbacks = {} self.info("Making node features from structure...") callbacks = {"graph_labels": lambda st, ds: ds[label_column_name] if label_column_name is not None else None, "node_coordinates": lambda st, ds: np.array(st.cart_coords, dtype="float"), "node_frac_coordinates": lambda st, ds: np.array(st.frac_coords, dtype="float"), "graph_lattice": lambda st, ds: np.ascontiguousarray(np.array(st.lattice.matrix), dtype="float"), "abc": lambda st, ds: np.array(st.lattice.abc), "charge": lambda st, ds: np.array([st.charge], dtype="float"), "volume": lambda st, ds: np.array([st.lattice.volume], dtype="float"), "node_number": lambda st, ds: np.array(st.atomic_numbers, dtype="int"), **additional_callbacks } self._map_callbacks(structs=self.get_structures_from_json_file(), data=self.read_in_table_file(file_path=self.file_path).data_frame, callbacks=callbacks) return self
[docs] def set_representation(self, pre_processor: Union[CrystalPreprocessor, dict], reset_graphs: bool = False): r"""Build a graph representation for this dataset using :obj:`kgcnn.crystal` . Args: pre_processor (CrystalPreprocessor): Crystal preprocessor to use. reset_graphs (bool): Whether to reset the graph information. Default is False. Returns: """ if reset_graphs: self.clear() if isinstance(pre_processor, dict): pre_processor = deserialize(pre_processor) # Read pymatgen JSON file from file. structs = self.get_structures_from_json_file() if reset_graphs: self.empty(len(structs)) pre_processor.output_graph_as_dict = True for index, s in enumerate(structs): g = pre_processor(s) self[index].update(g) if index % self._default_loop_update_info == 0: self.info(" ... preprocess structures {0} from {1}".format(index, len(structs))) return self