import gzip
import os
import shutil
import tarfile
import zipfile
import requests
import logging
logging.basicConfig()
module_logger = logging.getLogger(__name__)
module_logger.setLevel(logging.INFO)
[docs]class DownloadDataset:
r"""Download class for datasets.
Provides static-methods and functions for download and unzip of the data.
They are intentionally kept general and could also be used without this class definition.
Downloading is handled by :obj:`download_dataset_to_disk` already in :obj:`init` by default.
Dataset-specific functions like :obj:`prepare_data` must be implemented in subclasses.
.. note::
Note that :obj:`DownloadDataset` uses a main directory located at '~/.kgcnn/datasets' for downloading datasets
as default.
Classes in :obj:`kgcnn.data.datasets` inherit from this class, but :obj:`DownloadDataset` can also be
used as a member via composition.
.. warning::
Downloads are not checked for safety or malware. Use with caution!
Example on how to use :obj:`DownloadDataset` standalone:
.. code-block:: python
from kgcnn.data.download import DownloadDataset
download = DownloadDataset(
download_url="https://github.com/aimat-lab/gcnn_keras/blob/master/README.md",
data_main_dir="./",
data_directory_name="",
download_file_name="README.html",
reload=True,
execute_download_dataset_on_init=False
)
download.download_dataset_to_disk()
"""
[docs] def __init__(self,
dataset_name: str = None,
download_file_name: str = None,
data_directory_name: str = None,
unpack_directory_name: str = None,
extract_file_name: str = None,
download_url: str = None,
unpack_tar: bool = False,
unpack_zip: bool = False,
extract_gz: bool = False,
reload: bool = False,
verbose: int = 10,
data_main_dir: str = os.path.join(os.path.expanduser("~"), ".kgcnn", "datasets"),
execute_download_dataset_on_init: bool = True
):
r"""Base initialization function for downloading and extracting the data. The arguments to the constructor
determine what to download and whether to unpack the download. The main function :obj:`download_dataset_to_disk`
is already called in the constructor.
Args:
dataset_name (str): Name of the dataset to download (optional). Default is None.
download_file_name (str): Name of the file that the source is downloaded to. Default is None.
data_directory_name (str): Name of the data directory in data_main_dir the file is saved. Default is None.
unpack_directory_name (str): The name of a new directory in data_directory_name to unpack archive.
Default is None.
extract_file_name (str): Name of a gz-file to extract. Default is None.
download_url (str): Url for file to download. Default is None.
unpack_tar (bool): Whether to unpack a tar-archive. Default is False.
unpack_zip (bool): Whether to unpack a zip-archive. Default is False.
extract_gz (bool): Whether to unpack a gz-archive. Default is False.
reload (bool): Whether to reload the data and make new dataset. Default is False.
verbose (int): Logging level. Default is 10.
execute_download_dataset_on_init (bool): Whether to start download on class construction.
"""
self.data_main_dir = data_main_dir
self.dataset_name = dataset_name
self.download_file_name = download_file_name
self.data_directory_name = data_directory_name
self.unpack_directory_name = unpack_directory_name
self.extract_file_name = extract_file_name
self.download_url = download_url
self.unpack_tar = unpack_tar
self.unpack_zip = unpack_zip
self.extract_gz = extract_gz
self.download_reload = reload
self.logger_download = module_logger
self.logger_download.setLevel(verbose)
self.execute_download_dataset_on_init = bool(execute_download_dataset_on_init)
# Make the download already in init.
if self.execute_download_dataset_on_init:
self.download_dataset_to_disk()
[docs] def download_dataset_to_disk(self):
"""Main download function to run the download and unpack of the dataset. Defined by attributes in self."""
# Some datasets do not offer all information or require multiple files.
self.logger_download.info("Checking and possibly downloading dataset with name %s" % str(self.dataset_name))
# Default functions to load a dataset.
if self.data_directory_name is not None:
self.setup_dataset_main(self.data_main_dir, logger=self.logger_download)
self.setup_dataset_dir(self.data_main_dir, self.data_directory_name, logger=self.logger_download)
if self.download_url is not None:
self.download_database(
os.path.join(self.data_main_dir, self.data_directory_name), self.download_url,
self.download_file_name, overwrite=self.download_reload, logger=self.logger_download)
if self.unpack_tar:
self.unpack_tar_file(
os.path.join(self.data_main_dir, self.data_directory_name), self.download_file_name,
self.unpack_directory_name, overwrite=self.download_reload, logger=self.logger_download)
if self.unpack_zip:
self.unpack_zip_file(
os.path.join(self.data_main_dir, self.data_directory_name), self.download_file_name,
self.unpack_directory_name, overwrite=self.download_reload, logger=self.logger_download)
if self.extract_gz:
self.extract_gz_file(
os.path.join(self.data_main_dir, self.data_directory_name), self.download_file_name,
self.extract_file_name, overwrite=self.download_reload, logger=self.logger_download)
[docs] @staticmethod
def setup_dataset_main(data_main_dir, logger=None):
"""Make the main-directory for all datasets to store data.
Args:
data_main_dir (str): Path to create directory.
logger: Logger to print information or warnings.
"""
os.makedirs(data_main_dir, exist_ok=True)
if logger is not None:
logger.info("Dataset directory located at %s" % data_main_dir)
[docs] @staticmethod
def setup_dataset_dir(data_main_dir: str, data_directory: str, logger=None):
"""Make directory for each dataset.
Args:
data_main_dir (str): Path-location of the directory for all datasets.
data_directory (str): Path of the directory for specific dataset to create.
logger: Logger to print information or warnings.
"""
if not os.path.exists(os.path.join(data_main_dir, data_directory)):
os.mkdir(os.path.join(data_main_dir, data_directory))
else:
if logger is not None:
logger.info("Dataset directory found. Done.")
[docs] @staticmethod
def download_database(path: str, download_url: str, filename: str, overwrite: bool = False, logger=None):
"""Download dataset file.
Args:
path (str): Target filepath to store file (without filename).
download_url (str): String of the download url to catch database from.
filename (str): Name the dataset is downloaded to.
overwrite (bool): Overwrite existing database. Default is False.
logger: Logger to print information or warnings.
Returns:
os.path: Filepath of downloaded file.
"""
def logg_info(msg):
if logger is not None:
logger.info(msg)
if os.path.exists(os.path.join(path, filename)) is False or overwrite:
logg_info("Downloading dataset... ")
r = requests.get(download_url, allow_redirects=True)
open(os.path.join(path, filename), 'wb').write(r.content)
else:
logg_info("Dataset found. Done.")
return os.path.join(path, filename)
[docs] @staticmethod
def unpack_tar_file(path: str, filename: str, unpack_directory: str, overwrite: bool = False, logger=None):
"""Extract tar-file.
Args:
path (str): Filepath where the tar-file to extract is located.
filename (str): Name of the dataset to extract.
unpack_directory (str): Directory to extract data to.
overwrite (bool): Overwrite existing database. Default is False.
logger: Logger to print information or warnings.
Returns:
os.path: Filepath of the extracted dataset folder.
"""
def logg_info(msg):
if logger is not None:
logger.info(msg)
if not os.path.exists(os.path.join(path, unpack_directory)):
logg_info("Creating directory... ")
os.mkdir(os.path.join(path, unpack_directory))
else:
logg_info("Directory for extraction exists. Done.")
if not overwrite:
logg_info("Not extracting tar File. Stopped.")
return os.path.join(path, unpack_directory) # Stop extracting here
logg_info("Read tar file... ")
archive = tarfile.open(os.path.join(path, filename), "r")
# Filelistnames = archive.getnames()
logg_info("Extracting tar file... ")
archive.extractall(os.path.join(path, unpack_directory))
archive.close()
return os.path.join(path, unpack_directory)
[docs] @staticmethod
def unpack_zip_file(path: str, filename: str, unpack_directory: str, overwrite: bool = False, logger=None):
"""Extract zip-file.
Args:
path (str): Filepath where the zip-file to extract is located.
filename (str): Name of the dataset to extract.
unpack_directory (str): Directory to extract data to.
overwrite (bool): Overwrite existing database. Default is False.
logger: Logger to print information or warnings.
Returns:
os.path: Filepath of the extracted dataset folder.
"""
def logg_info(msg):
if logger is not None:
logger.info(msg)
if os.path.exists(os.path.join(path, unpack_directory)):
logg_info("Directory for extraction exists. Done.")
if not overwrite:
logg_info("Not extracting zip file. Stopped.")
return os.path.join(path, unpack_directory)
logg_info("Read zip file ... ")
archive = zipfile.ZipFile(os.path.join(path, filename), "r")
# Filelistnames = archive.getnames()
logg_info("Extracting zip file...")
archive.extractall(os.path.join(path, unpack_directory))
archive.close()
return os.path.join(path, unpack_directory)