Source code for etudes.datasets.base

"""Datasets module."""

import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_probability as tfp

import json

from pathlib import Path
from .decorators import binarize

SEED = 42

tfd = tfp.distributions


[docs]def read_fcnet_data(f, max_configs=None, num_seeds=4): frames = [] for config_num, config_str in enumerate(f.keys()): if max_configs is not None and config_num > max_configs: break config = json.loads(config_str) for seed in range(num_seeds): config["seed"] = seed for attr in f[config_str].keys(): config[attr] = f[config_str][attr][seed] frame = pd.DataFrame(config) frame.index.name = "epoch" frame.reset_index(inplace=True) frames.append(frame) return pd.concat(frames, axis="index", ignore_index=True, sort=True)
@binarize(positive_label=2, negative_label=7) def binary_mnist_load_data(): return tf.keras.datasets.mnist.load_data()
[docs]def get_sequence_path(sequence_num, base_dir="../datasets"): return Path(base_dir).joinpath("bee-dance", "zips", "data", f"sequence{sequence_num:d}", "btf")
bee_dance_filenames = dict( x="ximage.btf", y="yimage.btf", t="timage.btf", label="label0.btf", timestamp="timestamp.btf" )
[docs]def read_sequence_column(sequence_num, col_name, base_dir="../datasets"): sequence_path = get_sequence_path(sequence_num, base_dir=base_dir) return pd.read_csv(sequence_path / bee_dance_filenames[col_name], names=[col_name], header=None)
[docs]def read_sequence(sequence_num, base_dir="../datasets"): left = None for col_name in bee_dance_filenames: right = read_sequence_column(sequence_num, col_name, base_dir=base_dir) if left is None: left = right else: left = pd.merge(left, right, left_index=True, right_index=True) change_point = left.label != left.label.shift() phase = change_point.cumsum() return left.assign(change_point=change_point, phase=phase)
[docs]def load_bee_dance_dataframe(base_dir="../datasets"): sequences = [] for i in range(6): sequence_num = i + 1 sequence = read_sequence(sequence_num, base_dir=base_dir) \ .assign(sequence=sequence_num) sequences.append(sequence) return pd.concat(sequences, axis="index")
[docs]def coal_mining_disasters_load_data(base_dir="../datasets/"): """ Coal mining disasters dataset. Examples -------- .. plot:: :context: close-figs from etudes.datasets import coal_mining_disasters_load_data X, y = coal_mining_disasters_load_data() fig, ax = plt.subplots() ax.vlines(X.squeeze(), ymin=0, ymax=y, linewidth=0.5, alpha=0.8) ax.set_xlabel("days") ax.set_ylabel("incidents") plt.show() """ base = Path(base_dir).joinpath("coal-mining-disasters") data = pd.read_csv(base / "data.csv", names=["count", "days"], header=None) X = np.expand_dims(data["days"].values, axis=-1) y = data["count"].values return X, y
[docs]def mauna_loa_load_dataframe(base_dir="../datasets/"): """ Mauna Loa dataset. Examples -------- .. plot:: :context: close-figs import seaborn as sns from etudes.datasets import mauna_loa_load_dataframe data = mauna_loa_load_dataframe() g = sns.relplot(x='date', y='average', kind="line", data=data, height=5, aspect=1.5, alpha=0.8) g.set_ylabels(r"average $\mathrm{CO}_2$ (ppm)") """ base = Path(base_dir).joinpath("mauna-loa-co2") column_names = ["year", "month", "date", "average", "interpolated", "trend", "num_days"] data = pd.read_csv(base / "co2_mm_mlo.txt", names=column_names, comment="#", header=None, sep=r"\s+") data = data[data.average > 0] return data