Skip to content

Instantly share code, notes, and snippets.

@fernandocamargoai
Created May 30, 2020 12:25
Show Gist options
  • Save fernandocamargoai/0f9c0b390ec44e0239a835cff91ae85c to your computer and use it in GitHub Desktop.
Save fernandocamargoai/0f9c0b390ec44e0239a835cff91ae85c to your computer and use it in GitHub Desktop.
Custom JsonDataset
import functools
import os
import warnings
from glob import glob
from multiprocessing import Pool
from pathlib import Path
from typing import Optional, Iterator, List
import numpy as np
import ujson as json
from gluonts.dataset.common import DataEntry, Dataset, ProcessDataEntry, SourceContext
from gluonts.dataset.field_names import FieldName
from gluonts.transform import TransformedDataset
from tqdm import tqdm
class JsonFile(object):
"""
A type that draws data from a JSON file.
Parameters
----------
path
Path of the file to load data from. This should be a valid
JSON file.
"""
def __init__(self, path: str, process: ProcessDataEntry, cache: bool = False, preload: bool = False,
check_data: bool = False) -> None:
self.path = path
self.process = process
self.cache = cache or preload
self.check_data = check_data
self._data_cache: Optional[dict] = None
if preload:
self.get_data()
def _check_data(self, data: DataEntry):
for key, value in data.items():
if isinstance(value, np.ndarray):
if not np.isfinite(value).all():
raise ValueError(f"The key {key} of the source {data['source']} contains an invalid value")
if key == FieldName.FEAT_DYNAMIC_REAL:
if value.max() > 1.0:
warnings.warn(f"The key {key} of the source {data['source']} contains a value above 1.0")
if value.min() < 0.0:
warnings.warn(f"The key {key} of the source {data['source']} contains a value bellow 0.0")
def get_data(self) -> dict:
if self._data_cache is not None:
return self._data_cache
else:
with open(self.path, "r") as f:
data = self.process(json.load(f))
data["source"] = SourceContext(
source=self.path, row=0
)
if self.check_data:
self._check_data(data)
if self.cache:
self._data_cache = data
return data
class JsonDataset(Dataset):
"""
Dataset that loads JSON files contained in a path.
Parameters
----------
path
Path containing the dataset files. Each file should end with .json.
A file can be for
instance: {"start": "2014-09-07", "target": [0.1, 0.2]}.
freq
Frequency of the observation in the time series.
Must be a valid Pandas frequency.
one_dim_target
Whether to accept only univariate target time series.
cache
Indicates whether the dataset should be cached or not.
"""
def __init__(
self,
path: str,
freq: str,
one_dim_target: bool = True,
cache: bool = False,
preload: bool = False,
check_data: bool = False,
) -> None:
self.cache = cache or preload
self.check_data = check_data
self.path = path
process = ProcessDataEntry(freq, one_dim_target=one_dim_target)
files = self.files()
if preload:
with Pool(os.cpu_count()) as pool:
print("Preloading dataset...")
self._json_files = list(tqdm(
pool.map(functools.partial(JsonFile, process=process, cache=cache, preload=preload,
check_data=check_data), files),
total=len(files)))
else:
self._json_files = [JsonFile(path, process, cache) for path in files]
self._len = len(self._json_files)
if self._len == 0:
raise OSError(f"no valid file found in {path}")
def __iter__(self) -> Iterator[DataEntry]:
for json_file in self._json_files:
yield json_file.get_data()
def __len__(self):
return self._len
def files(self) -> List[str]:
"""
List the files that compose the dataset.
Returns
-------
List[Path]
List of the paths of all files composing the dataset.
"""
return glob(os.path.join(self.path, "*.json"))
class SameSizeTransformedDataset(TransformedDataset):
def __len__(self):
return len(self.base_dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment