Created
May 30, 2020 12:25
-
-
Save fernandocamargoai/0f9c0b390ec44e0239a835cff91ae85c to your computer and use it in GitHub Desktop.
Custom JsonDataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import os | |
import warnings | |
from glob import glob | |
from multiprocessing import Pool | |
from pathlib import Path | |
from typing import Optional, Iterator, List | |
import numpy as np | |
import ujson as json | |
from gluonts.dataset.common import DataEntry, Dataset, ProcessDataEntry, SourceContext | |
from gluonts.dataset.field_names import FieldName | |
from gluonts.transform import TransformedDataset | |
from tqdm import tqdm | |
class JsonFile(object): | |
""" | |
A type that draws data from a JSON file. | |
Parameters | |
---------- | |
path | |
Path of the file to load data from. This should be a valid | |
JSON file. | |
""" | |
def __init__(self, path: str, process: ProcessDataEntry, cache: bool = False, preload: bool = False, | |
check_data: bool = False) -> None: | |
self.path = path | |
self.process = process | |
self.cache = cache or preload | |
self.check_data = check_data | |
self._data_cache: Optional[dict] = None | |
if preload: | |
self.get_data() | |
def _check_data(self, data: DataEntry): | |
for key, value in data.items(): | |
if isinstance(value, np.ndarray): | |
if not np.isfinite(value).all(): | |
raise ValueError(f"The key {key} of the source {data['source']} contains an invalid value") | |
if key == FieldName.FEAT_DYNAMIC_REAL: | |
if value.max() > 1.0: | |
warnings.warn(f"The key {key} of the source {data['source']} contains a value above 1.0") | |
if value.min() < 0.0: | |
warnings.warn(f"The key {key} of the source {data['source']} contains a value bellow 0.0") | |
def get_data(self) -> dict: | |
if self._data_cache is not None: | |
return self._data_cache | |
else: | |
with open(self.path, "r") as f: | |
data = self.process(json.load(f)) | |
data["source"] = SourceContext( | |
source=self.path, row=0 | |
) | |
if self.check_data: | |
self._check_data(data) | |
if self.cache: | |
self._data_cache = data | |
return data | |
class JsonDataset(Dataset): | |
""" | |
Dataset that loads JSON files contained in a path. | |
Parameters | |
---------- | |
path | |
Path containing the dataset files. Each file should end with .json. | |
A file can be for | |
instance: {"start": "2014-09-07", "target": [0.1, 0.2]}. | |
freq | |
Frequency of the observation in the time series. | |
Must be a valid Pandas frequency. | |
one_dim_target | |
Whether to accept only univariate target time series. | |
cache | |
Indicates whether the dataset should be cached or not. | |
""" | |
def __init__( | |
self, | |
path: str, | |
freq: str, | |
one_dim_target: bool = True, | |
cache: bool = False, | |
preload: bool = False, | |
check_data: bool = False, | |
) -> None: | |
self.cache = cache or preload | |
self.check_data = check_data | |
self.path = path | |
process = ProcessDataEntry(freq, one_dim_target=one_dim_target) | |
files = self.files() | |
if preload: | |
with Pool(os.cpu_count()) as pool: | |
print("Preloading dataset...") | |
self._json_files = list(tqdm( | |
pool.map(functools.partial(JsonFile, process=process, cache=cache, preload=preload, | |
check_data=check_data), files), | |
total=len(files))) | |
else: | |
self._json_files = [JsonFile(path, process, cache) for path in files] | |
self._len = len(self._json_files) | |
if self._len == 0: | |
raise OSError(f"no valid file found in {path}") | |
def __iter__(self) -> Iterator[DataEntry]: | |
for json_file in self._json_files: | |
yield json_file.get_data() | |
def __len__(self): | |
return self._len | |
def files(self) -> List[str]: | |
""" | |
List the files that compose the dataset. | |
Returns | |
------- | |
List[Path] | |
List of the paths of all files composing the dataset. | |
""" | |
return glob(os.path.join(self.path, "*.json")) | |
class SameSizeTransformedDataset(TransformedDataset): | |
def __len__(self): | |
return len(self.base_dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment