etudes.datasets package¶
Submodules¶
etudes.datasets.base module¶
Datasets module.
-
etudes.datasets.base.
coal_mining_disasters_load_data
(base_dir='../datasets/')[source]¶ Coal mining disasters dataset.
Examples
from etudes.datasets import coal_mining_disasters_load_data X, y = coal_mining_disasters_load_data() fig, ax = plt.subplots() ax.vlines(X.squeeze(), ymin=0, ymax=y, linewidth=0.5, alpha=0.8) ax.set_xlabel("days") ax.set_ylabel("incidents") plt.show()
(Source code, png, hires.png, pdf)
-
etudes.datasets.base.
mauna_loa_load_dataframe
(base_dir='../datasets/')[source]¶ Mauna Loa dataset.
Examples
import seaborn as sns from etudes.datasets import mauna_loa_load_dataframe data = mauna_loa_load_dataframe() g = sns.relplot(x='date', y='average', kind="line", data=data, height=5, aspect=1.5, alpha=0.8) g.set_ylabels(r"average $\mathrm{CO}_2$ (ppm)")
(Source code, png, hires.png, pdf)
etudes.datasets.decorators module¶
-
etudes.datasets.decorators.
binarize
(positive_label=3, negative_label=5)[source]¶ MNIST binary classification.
Examples
import tensorflow as tf from etudes.datasets import binarize from etudes.plotting import plot_image_grid @binarize(positive_label=2, negative_label=7) def binary_mnist_load_data(): return tf.keras.datasets.mnist.load_data() (X_train, Y_train), (X_test, Y_test) = binary_mnist_load_data() num_train, img_rows, img_cols = X_train.shape num_test, img_rows, img_cols = X_test.shape fig, (ax1, ax2) = plt.subplots(ncols=2) plot_image_grid(ax1, X_train[Y_train == 0], shape=(img_rows, img_cols), nrows=10, cmap="cividis") plot_image_grid(ax2, X_train[Y_train == 1], shape=(img_rows, img_cols), nrows=10, cmap="cividis") plt.show()
etudes.datasets.networks module¶
etudes.datasets.synthetic module¶
-
etudes.datasets.synthetic.
make_classification_dataset
(X_pos, X_neg, shuffle=False, dtype='float64', random_state=None)[source]¶
-
etudes.datasets.synthetic.
make_regression_dataset
(latent_fn=<function synthetic_sinusoidal>)[source]¶ Make synthetic dataset.
Examples
Test
from etudes.datasets import synthetic_sinusoidal, make_regression_dataset num_train = 64 # nbr training points in synthetic dataset num_index_points = 256 num_features = 1 observation_noise_variance = 1e-1 f = synthetic_sinusoidal X_pred = np.linspace(-0.6, 0.6, num_index_points).reshape(-1, num_features) load_data = make_regression_dataset(f) X_train, Y_train = load_data(num_train, num_features, observation_noise_variance, x_min=-0.5, x_max=0.5) fig, ax = plt.subplots() ax.plot(X_pred, f(X_pred), label="true") ax.scatter(X_train, Y_train, marker='x', color='k', label="noisy observations") ax.legend() ax.set_xlabel(r'$x$') ax.set_ylabel(r'$y$') plt.show()
Module contents¶
-
etudes.datasets.
make_regression_dataset
(latent_fn=<function synthetic_sinusoidal>)[source]¶ Make synthetic dataset.
Examples
Test
from etudes.datasets import synthetic_sinusoidal, make_regression_dataset num_train = 64 # nbr training points in synthetic dataset num_index_points = 256 num_features = 1 observation_noise_variance = 1e-1 f = synthetic_sinusoidal X_pred = np.linspace(-0.6, 0.6, num_index_points).reshape(-1, num_features) load_data = make_regression_dataset(f) X_train, Y_train = load_data(num_train, num_features, observation_noise_variance, x_min=-0.5, x_max=0.5) fig, ax = plt.subplots() ax.plot(X_pred, f(X_pred), label="true") ax.scatter(X_train, Y_train, marker='x', color='k', label="noisy observations") ax.legend() ax.set_xlabel(r'$x$') ax.set_ylabel(r'$y$') plt.show()
-
etudes.datasets.
make_classification_dataset
(X_pos, X_neg, shuffle=False, dtype='float64', random_state=None)[source]¶
-
etudes.datasets.
mauna_loa_load_dataframe
(base_dir='../datasets/')[source]¶ Mauna Loa dataset.
Examples
import seaborn as sns from etudes.datasets import mauna_loa_load_dataframe data = mauna_loa_load_dataframe() g = sns.relplot(x='date', y='average', kind="line", data=data, height=5, aspect=1.5, alpha=0.8) g.set_ylabels(r"average $\mathrm{CO}_2$ (ppm)")
(Source code, png, hires.png, pdf)
-
etudes.datasets.
coal_mining_disasters_load_data
(base_dir='../datasets/')[source]¶ Coal mining disasters dataset.
Examples
from etudes.datasets import coal_mining_disasters_load_data X, y = coal_mining_disasters_load_data() fig, ax = plt.subplots() ax.vlines(X.squeeze(), ymin=0, ymax=y, linewidth=0.5, alpha=0.8) ax.set_xlabel("days") ax.set_ylabel("incidents") plt.show()
(Source code, png, hires.png, pdf)
-
etudes.datasets.
binary_mnist_load_data
()¶