Source code for etudes.datasets.decorators

import numpy as np


[docs]def binarize(positive_label=3, negative_label=5): """ MNIST binary classification. Examples -------- .. plot:: :context: close-figs import tensorflow as tf from etudes.datasets import binarize from etudes.plotting import plot_image_grid @binarize(positive_label=2, negative_label=7) def binary_mnist_load_data(): return tf.keras.datasets.mnist.load_data() (X_train, Y_train), (X_test, Y_test) = binary_mnist_load_data() num_train, img_rows, img_cols = X_train.shape num_test, img_rows, img_cols = X_test.shape fig, (ax1, ax2) = plt.subplots(ncols=2) plot_image_grid(ax1, X_train[Y_train == 0], shape=(img_rows, img_cols), nrows=10, cmap="cividis") plot_image_grid(ax2, X_train[Y_train == 1], shape=(img_rows, img_cols), nrows=10, cmap="cividis") plt.show() """ # TODO: come up with remote descriptive name def d(X, y, label, new_label=1): X_val = X[y == label] y_val = np.full(len(X_val), new_label) return X_val, y_val def binarize_decorator(load_data_fn): def new_load_data_fn(): (X_train, Y_train), (X_test, Y_test) = load_data_fn() X_train_pos, Y_train_pos = d(X_train, Y_train, label=positive_label, new_label=1) X_train_neg, Y_train_neg = d(X_train, Y_train, label=negative_label, new_label=0) X_train_new = np.vstack([X_train_pos, X_train_neg]) Y_train_new = np.hstack([Y_train_pos, Y_train_neg]) X_test_pos, Y_test_pos = d(X_test, Y_test, label=positive_label, new_label=1) X_test_neg, Y_test_neg = d(X_test, Y_test, label=negative_label, new_label=0) X_test_new = np.vstack([X_test_pos, X_test_neg]) Y_test_new = np.hstack([Y_test_pos, Y_test_neg]) return (X_train_new, Y_train_new), (X_test_new, Y_test_new) return new_load_data_fn return binarize_decorator