Variational Sparse GP for Binary Classification

Here we fit the hyperparameters of a Gaussian Process by maximizing the (log) marginal likelihood. This is commonly referred to as empirical Bayes, or type-II maximum likelihood estimation.

import numpy as np


import tensorflow as tf
import tensorflow_probability as tfp

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from tensorflow.keras.layers import Layer, InputLayer
from tensorflow.keras.initializers import Identity, Constant

from etudes.datasets import make_classification_dataset
from etudes.plotting import fill_between_stddev
from etudes.utils import get_kl_weight

from collections import defaultdict
# shortcuts
tfd = tfp.distributions
kernels = tfp.math.psd_kernels


# constants
num_train = 2048  # nbr training points in synthetic dataset
num_test = 40
num_features = 1  # dimensionality
num_index_points = 256  # nbr of index points
num_samples = 25
quadrature_size = 20

num_inducing_points = 50
num_epochs = 1000
batch_size = 64
shuffle_buffer_size = 500

jitter = 1e-6

kernel_cls = kernels.MaternFiveHalves

seed = 8888  # set random seed for reproducibility
random_state = np.random.RandomState(seed)

x_min, x_max = -5.0, 5.0
y_min, y_max = -6.0, 4.0

# index points
X_q = np.linspace(x_min, x_max, num_index_points).reshape(-1, num_features)

golden_ratio = 0.5 * (1 + np.sqrt(5))

Synthetic dataset

p = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]),
    components_distribution=tfd.Normal(loc=[2.0, -3.0], scale=[1.0, 0.5]))
q = tfd.Normal(loc=0.0, scale=2.0)
def load_data(num_samples, rate=0.5, dtype="float64", seed=None):

    num_p = int(num_samples * rate)
    num_q = num_samples - num_p

    X_p = p.sample(sample_shape=(num_p, 1), seed=seed).numpy()
    X_q = q.sample(sample_shape=(num_q, 1), seed=seed).numpy()

    X, y = make_classification_dataset(X_p, X_q, dtype=dtype)

    return X, y
def logit(x):

    return p.log_prob(x) - q.log_prob(x)
def density_ratio(x):

    return tf.exp(logit(x))
def optimal_score(x):

    return tf.sigmoid(logit(x))

Probability densities

fig, ax = plt.subplots()

ax.plot(X_q, q.prob(X_q), label='$q(x)$')
ax.plot(X_q, p.prob(X_q), label='$p(x)$')

ax.set_xlabel('$x$')
ax.set_ylabel('density')

ax.legend()

plt.show()
plot sparse gp classification keras

Log density ratio, log-odds, or logits.

fig, ax = plt.subplots()

ax.plot(X_q, logit(X_q), c='k', label=r"$f(x) = \log p(x) - \log q(x)$")

ax.set_xlabel('$x$')
ax.set_ylabel('$f(x)$')

ax.legend()

plt.show()
plot sparse gp classification keras

Density ratio.

fig, ax = plt.subplots()

ax.plot(X_q, density_ratio(X_q), c='k', label=r"$r(x) = \exp f(x)$")

ax.set_xlabel('$x$')
ax.set_ylabel('$r(x)$')

ax.legend()

plt.show()
plot sparse gp classification keras

Create classification dataset.

X_train, y_train = load_data(num_train, seed=seed)
X_test, y_test = load_data(num_test, seed=seed)

Dataset visualized against the Bayes optimal classifier.

fig, ax = plt.subplots()

ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$")
ax.scatter(X_train, y_train, c=y_train, s=12.**2,
           marker='s', alpha=0.1, cmap="coolwarm_r")
ax.set_yticks([0, 1])
ax.set_yticklabels([r"$x_q \sim q(x)$", r"$x_p \sim p(x)$"])
ax.set_xlabel('$x$')

ax.legend()

plt.show()
plot sparse gp classification keras

Encapsulate Variational Gaussian Process (particular variable initialization) in a Keras / TensorFlow Probability Mixin Layer. Clean and simple if we restrict to single-output (event_shape = ()) and feature_ndim = 1 (i.e. inputs are simply vectors rather than matrices or tensors).

class VariationalGaussianProcess1D(tfp.layers.DistributionLambda):

    def __init__(self, kernel_wrapper, num_inducing_points,
                 inducing_index_points_initializer, mean_fn=None, jitter=1e-6,
                 convert_to_tensor_fn=tfd.Distribution.sample, **kwargs):

        def make_distribution(x):

            return VariationalGaussianProcess1D.new(
                x, kernel_wrapper=self.kernel_wrapper,
                inducing_index_points=self.inducing_index_points,
                variational_inducing_observations_loc=(
                    self.variational_inducing_observations_loc),
                variational_inducing_observations_scale=(
                    self.variational_inducing_observations_scale),
                mean_fn=self.mean_fn,
                observation_noise_variance=tf.exp(
                    self.log_observation_noise_variance),
                jitter=self.jitter)

        super(VariationalGaussianProcess1D, self).__init__(
            make_distribution_fn=make_distribution,
            convert_to_tensor_fn=convert_to_tensor_fn,
            dtype=kernel_wrapper.dtype)

        self.kernel_wrapper = kernel_wrapper
        self.inducing_index_points_initializer = inducing_index_points_initializer
        self.num_inducing_points = num_inducing_points
        self.mean_fn = mean_fn
        self.jitter = jitter

        self._dtype = self.kernel_wrapper.dtype

    def build(self, input_shape):

        input_dim = input_shape[-1]

        # TODO: Fix initialization!
        self.inducing_index_points = self.add_weight(
            name="inducing_index_points",
            shape=(self.num_inducing_points, input_dim),
            initializer=self.inducing_index_points_initializer,
            dtype=self.dtype)

        self.variational_inducing_observations_loc = self.add_weight(
            name="variational_inducing_observations_loc",
            shape=(self.num_inducing_points,),
            initializer="zeros", dtype=self.dtype)

        self.variational_inducing_observations_scale = self.add_weight(
            name="variational_inducing_observations_scale",
            shape=(self.num_inducing_points, self.num_inducing_points),
            initializer=Identity(gain=1.0), dtype=self.dtype)

        self.log_observation_noise_variance = self.add_weight(
            name="log_observation_noise_variance",
            initializer=Constant(-5.0), dtype=self.dtype)

    @staticmethod
    def new(x, kernel_wrapper, inducing_index_points, mean_fn,
            variational_inducing_observations_loc,
            variational_inducing_observations_scale,
            observation_noise_variance, jitter, name=None):

        # ind = tfd.Independent(base, reinterpreted_batch_ndims=1)
        # bijector = tfp.bijectors.Transpose(rightmost_transposed_ndims=2)
        # d = tfd.TransformedDistribution(ind, bijector=bijector)

        return tfd.VariationalGaussianProcess(
            kernel=kernel_wrapper.kernel, index_points=x,
            inducing_index_points=inducing_index_points,
            variational_inducing_observations_loc=(
                variational_inducing_observations_loc),
            variational_inducing_observations_scale=(
                variational_inducing_observations_scale),
            mean_fn=mean_fn,
            observation_noise_variance=observation_noise_variance,
            jitter=jitter)

Kernel wrapper layer

class KernelWrapper(Layer):

    # TODO: Support automatic relevance determination
    def __init__(self, kernel_cls=kernels.ExponentiatedQuadratic,
                 dtype=None, **kwargs):

        super(KernelWrapper, self).__init__(dtype=dtype, **kwargs)

        self.kernel_cls = kernel_cls

        self.log_amplitude = self.add_weight(
            name="log_amplitude",
            initializer="zeros", dtype=dtype)

        self.log_length_scale = self.add_weight(
            name="log_length_scale",
            initializer="zeros", dtype=dtype)

    def call(self, x):
        # Never called -- this is just a layer so it can hold variables
        # in a way Keras understands.
        return x

    @property
    def kernel(self):
        return self.kernel_cls(amplitude=tf.exp(self.log_amplitude),
                               length_scale=tf.exp(self.log_length_scale))

Bernoulli likelihood for binary classification.

def make_binary_classification_likelihood(f):

    return tfd.Independent(tfd.Bernoulli(logits=f),
                           reinterpreted_batch_ndims=1)
def log_likelihood(y, f):

    likelihood = make_binary_classification_likelihood(f)
    return likelihood.log_prob(y)

Helper Model factory method.

def build_model(input_dim, jitter=1e-6):

    inducing_index_points_initial = random_state.choice(X_train.squeeze(),
                                                        num_inducing_points) \
                                                .reshape(-1, num_features)

    inducing_index_points_initializer = (
        tf.constant_initializer(inducing_index_points_initial))

    return tf.keras.Sequential([
        InputLayer(input_shape=(input_dim,)),
        VariationalGaussianProcess1D(
            kernel_wrapper=KernelWrapper(kernel_cls=kernel_cls,
                                         dtype=tf.float64),
            num_inducing_points=num_inducing_points,
            inducing_index_points_initializer=inducing_index_points_initializer,
            jitter=jitter)
    ])
model = build_model(input_dim=num_features, jitter=jitter)
optimizer = tf.keras.optimizers.Adam()

Out:

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gaussian_process.py:311: UserWarning: Unable to detect statically whether the number of index_points is 1. As a result, defaulting to treating the marginal GP at `index_points` as a multivariate Gaussian. This makes some methods, like `cdf` unavailable.
  'Unable to detect statically whether the number of index_points is '
@tf.function
def nelbo(X_batch, y_batch):

    qf = model(X_batch)

    ell = qf.surrogate_posterior_expected_log_likelihood(
        observations=y_batch,
        log_likelihood_fn=log_likelihood,
        quadrature_size=quadrature_size)

    kl = qf.surrogate_posterior_kl_divergence_prior()
    kl_weight = get_kl_weight(num_train, batch_size)

    return - ell + kl_weight * kl
@tf.function
def train_step(X_batch, y_batch):

    with tf.GradientTape() as tape:
        loss = nelbo(X_batch, y_batch)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss
dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) \
                         .shuffle(seed=seed, buffer_size=shuffle_buffer_size) \
                         .batch(batch_size, drop_remainder=True)
keys = ["inducing_index_points",
        "variational_inducing_observations_loc",
        "variational_inducing_observations_scale",
        "log_observation_noise_variance",
        "log_amplitude", "log_length_scale"]
history = defaultdict(list)

for epoch in range(num_epochs):

    for step, (X_batch, y_batch) in enumerate(dataset):

        loss = train_step(X_batch, y_batch)

    print("epoch={epoch:04d}, loss={loss:.4f}"
          .format(epoch=epoch, loss=loss.numpy()))

    history["nelbo"].append(loss.numpy())

    for key, tensor in zip(keys, model.get_weights()):

        history[key].append(tensor)

Out:

epoch=0000, loss=83688.7371
epoch=0001, loss=58054.7560
epoch=0002, loss=42126.1083
epoch=0003, loss=29983.7028
epoch=0004, loss=21101.8620
epoch=0005, loss=15167.4498
epoch=0006, loss=11310.4198
epoch=0007, loss=8797.5941
epoch=0008, loss=7048.9826
epoch=0009, loss=5827.5657
epoch=0010, loss=4929.9037
epoch=0011, loss=4249.9116
epoch=0012, loss=3700.4717
epoch=0013, loss=3270.4793
epoch=0014, loss=2917.3544
epoch=0015, loss=2618.1428
epoch=0016, loss=2370.7512
epoch=0017, loss=2152.5240
epoch=0018, loss=1970.7316
epoch=0019, loss=1810.3165
epoch=0020, loss=1673.9554
epoch=0021, loss=1546.4203
epoch=0022, loss=1445.2123
epoch=0023, loss=1346.2835
epoch=0024, loss=1259.0919
epoch=0025, loss=1178.9194
epoch=0026, loss=1110.1653
epoch=0027, loss=1043.3796
epoch=0028, loss=985.8623
epoch=0029, loss=935.3818
epoch=0030, loss=883.3387
epoch=0031, loss=846.6758
epoch=0032, loss=801.0132
epoch=0033, loss=761.0347
epoch=0034, loss=725.8216
epoch=0035, loss=692.4462
epoch=0036, loss=664.7324
epoch=0037, loss=636.0242
epoch=0038, loss=612.9439
epoch=0039, loss=583.1268
epoch=0040, loss=559.6087
epoch=0041, loss=541.4257
epoch=0042, loss=517.6200
epoch=0043, loss=498.7617
epoch=0044, loss=484.0547
epoch=0045, loss=469.5395
epoch=0046, loss=448.1870
epoch=0047, loss=433.6379
epoch=0048, loss=418.5844
epoch=0049, loss=404.7501
epoch=0050, loss=396.1823
epoch=0051, loss=383.1807
epoch=0052, loss=369.7613
epoch=0053, loss=356.8915
epoch=0054, loss=349.1055
epoch=0055, loss=337.6789
epoch=0056, loss=326.7187
epoch=0057, loss=321.0192
epoch=0058, loss=310.2008
epoch=0059, loss=301.8602
epoch=0060, loss=290.9317
epoch=0061, loss=285.6259
epoch=0062, loss=282.6673
epoch=0063, loss=268.9425
epoch=0064, loss=263.4309
epoch=0065, loss=259.8918
epoch=0066, loss=256.2367
epoch=0067, loss=245.6845
epoch=0068, loss=239.3562
epoch=0069, loss=233.7618
epoch=0070, loss=229.4667
epoch=0071, loss=223.0711
epoch=0072, loss=217.4910
epoch=0073, loss=213.8025
epoch=0074, loss=207.8142
epoch=0075, loss=213.1301
epoch=0076, loss=199.0433
epoch=0077, loss=196.4339
epoch=0078, loss=195.7780
epoch=0079, loss=191.0600
epoch=0080, loss=181.2908
epoch=0081, loss=180.3356
epoch=0082, loss=177.3560
epoch=0083, loss=175.0715
epoch=0084, loss=170.5514
epoch=0085, loss=165.8833
epoch=0086, loss=163.8111
epoch=0087, loss=157.9501
epoch=0088, loss=158.0125
epoch=0089, loss=155.8078
epoch=0090, loss=154.0804
epoch=0091, loss=148.0916
epoch=0092, loss=151.4859
epoch=0093, loss=144.5506
epoch=0094, loss=140.4436
epoch=0095, loss=140.8091
epoch=0096, loss=136.7062
epoch=0097, loss=133.6574
epoch=0098, loss=128.8100
epoch=0099, loss=129.7225
epoch=0100, loss=120.3902
epoch=0101, loss=124.8359
epoch=0102, loss=120.4163
epoch=0103, loss=120.3068
epoch=0104, loss=116.1298
epoch=0105, loss=115.4547
epoch=0106, loss=118.7081
epoch=0107, loss=114.6724
epoch=0108, loss=110.4259
epoch=0109, loss=115.2138
epoch=0110, loss=106.8103
epoch=0111, loss=108.8230
epoch=0112, loss=105.5960
epoch=0113, loss=104.1928
epoch=0114, loss=103.6297
epoch=0115, loss=107.9894
epoch=0116, loss=101.7562
epoch=0117, loss=100.6693
epoch=0118, loss=99.2496
epoch=0119, loss=96.7935
epoch=0120, loss=93.9617
epoch=0121, loss=92.8867
epoch=0122, loss=91.7722
epoch=0123, loss=93.2940
epoch=0124, loss=90.0586
epoch=0125, loss=89.2878
epoch=0126, loss=87.5611
epoch=0127, loss=85.0880
epoch=0128, loss=84.8172
epoch=0129, loss=88.2420
epoch=0130, loss=85.9797
epoch=0131, loss=79.7830
epoch=0132, loss=76.0578
epoch=0133, loss=85.2679
epoch=0134, loss=79.1043
epoch=0135, loss=77.3490
epoch=0136, loss=76.7453
epoch=0137, loss=81.7491
epoch=0138, loss=78.8741
epoch=0139, loss=79.0682
epoch=0140, loss=71.8925
epoch=0141, loss=68.2520
epoch=0142, loss=68.2822
epoch=0143, loss=71.6652
epoch=0144, loss=65.2787
epoch=0145, loss=72.3872
epoch=0146, loss=74.1184
epoch=0147, loss=65.5071
epoch=0148, loss=61.3433
epoch=0149, loss=68.5294
epoch=0150, loss=65.9604
epoch=0151, loss=67.6910
epoch=0152, loss=71.8157
epoch=0153, loss=59.9562
epoch=0154, loss=67.7029
epoch=0155, loss=60.3049
epoch=0156, loss=61.8187
epoch=0157, loss=65.2218
epoch=0158, loss=65.0102
epoch=0159, loss=59.5548
epoch=0160, loss=59.7366
epoch=0161, loss=58.5197
epoch=0162, loss=49.5367
epoch=0163, loss=53.2743
epoch=0164, loss=54.8077
epoch=0165, loss=51.4587
epoch=0166, loss=52.5808
epoch=0167, loss=58.1869
epoch=0168, loss=54.1942
epoch=0169, loss=59.7821
epoch=0170, loss=49.5684
epoch=0171, loss=52.9581
epoch=0172, loss=57.4065
epoch=0173, loss=50.8347
epoch=0174, loss=49.2007
epoch=0175, loss=55.5615
epoch=0176, loss=53.7488
epoch=0177, loss=58.8213
epoch=0178, loss=56.4820
epoch=0179, loss=58.3914
epoch=0180, loss=48.2207
epoch=0181, loss=46.4462
epoch=0182, loss=48.3633
epoch=0183, loss=51.8320
epoch=0184, loss=50.4845
epoch=0185, loss=48.7852
epoch=0186, loss=46.9084
epoch=0187, loss=41.9072
epoch=0188, loss=49.3036
epoch=0189, loss=49.3320
epoch=0190, loss=51.7469
epoch=0191, loss=48.3297
epoch=0192, loss=44.3409
epoch=0193, loss=42.1941
epoch=0194, loss=50.0154
epoch=0195, loss=46.3124
epoch=0196, loss=50.2283
epoch=0197, loss=45.3228
epoch=0198, loss=52.0803
epoch=0199, loss=43.0025
epoch=0200, loss=44.9862
epoch=0201, loss=40.3760
epoch=0202, loss=44.1481
epoch=0203, loss=50.6961
epoch=0204, loss=46.7251
epoch=0205, loss=45.1247
epoch=0206, loss=43.0848
epoch=0207, loss=46.8332
epoch=0208, loss=33.5887
epoch=0209, loss=36.3205
epoch=0210, loss=38.5064
epoch=0211, loss=40.7160
epoch=0212, loss=44.2748
epoch=0213, loss=46.7911
epoch=0214, loss=45.7236
epoch=0215, loss=46.1145
epoch=0216, loss=37.8869
epoch=0217, loss=47.6046
epoch=0218, loss=41.1444
epoch=0219, loss=39.6064
epoch=0220, loss=41.7515
epoch=0221, loss=47.2257
epoch=0222, loss=44.0899
epoch=0223, loss=36.8878
epoch=0224, loss=42.9716
epoch=0225, loss=41.7179
epoch=0226, loss=38.1457
epoch=0227, loss=35.3032
epoch=0228, loss=37.0043
epoch=0229, loss=42.6077
epoch=0230, loss=40.3539
epoch=0231, loss=39.6932
epoch=0232, loss=40.3788
epoch=0233, loss=40.9728
epoch=0234, loss=43.0455
epoch=0235, loss=34.4279
epoch=0236, loss=39.5351
epoch=0237, loss=36.6007
epoch=0238, loss=38.0377
epoch=0239, loss=38.3556
epoch=0240, loss=38.0829
epoch=0241, loss=40.0411
epoch=0242, loss=44.2893
epoch=0243, loss=47.8129
epoch=0244, loss=40.0995
epoch=0245, loss=34.5026
epoch=0246, loss=33.8639
epoch=0247, loss=27.7273
epoch=0248, loss=37.1499
epoch=0249, loss=35.1662
epoch=0250, loss=36.6678
epoch=0251, loss=43.0815
epoch=0252, loss=34.6902
epoch=0253, loss=31.2057
epoch=0254, loss=32.8320
epoch=0255, loss=33.0954
epoch=0256, loss=36.8544
epoch=0257, loss=40.3444
epoch=0258, loss=33.1079
epoch=0259, loss=35.5396
epoch=0260, loss=33.6233
epoch=0261, loss=34.1449
epoch=0262, loss=38.5085
epoch=0263, loss=36.7990
epoch=0264, loss=33.6773
epoch=0265, loss=37.3847
epoch=0266, loss=35.2265
epoch=0267, loss=37.9969
epoch=0268, loss=31.5359
epoch=0269, loss=34.2021
epoch=0270, loss=32.6648
epoch=0271, loss=30.7869
epoch=0272, loss=34.8629
epoch=0273, loss=37.1946
epoch=0274, loss=38.3552
epoch=0275, loss=35.3011
epoch=0276, loss=34.1053
epoch=0277, loss=41.5445
epoch=0278, loss=38.9380
epoch=0279, loss=29.0395
epoch=0280, loss=38.3291
epoch=0281, loss=31.3095
epoch=0282, loss=37.1072
epoch=0283, loss=31.1259
epoch=0284, loss=32.1109
epoch=0285, loss=30.8005
epoch=0286, loss=35.4642
epoch=0287, loss=38.1415
epoch=0288, loss=31.6439
epoch=0289, loss=29.5798
epoch=0290, loss=36.0997
epoch=0291, loss=32.4649
epoch=0292, loss=33.2411
epoch=0293, loss=30.5523
epoch=0294, loss=23.8734
epoch=0295, loss=27.2925
epoch=0296, loss=36.2676
epoch=0297, loss=37.2896
epoch=0298, loss=34.1515
epoch=0299, loss=33.6274
epoch=0300, loss=36.6911
epoch=0301, loss=30.2100
epoch=0302, loss=29.0880
epoch=0303, loss=36.4233
epoch=0304, loss=35.4147
epoch=0305, loss=34.4470
epoch=0306, loss=33.3925
epoch=0307, loss=33.6304
epoch=0308, loss=34.1828
epoch=0309, loss=32.4087
epoch=0310, loss=31.3047
epoch=0311, loss=30.2233
epoch=0312, loss=26.5028
epoch=0313, loss=31.8921
epoch=0314, loss=37.1131
epoch=0315, loss=33.3578
epoch=0316, loss=31.3482
epoch=0317, loss=34.8639
epoch=0318, loss=36.3711
epoch=0319, loss=36.1821
epoch=0320, loss=30.3500
epoch=0321, loss=30.7098
epoch=0322, loss=25.4003
epoch=0323, loss=34.0975
epoch=0324, loss=29.7466
epoch=0325, loss=36.4988
epoch=0326, loss=29.9902
epoch=0327, loss=25.2689
epoch=0328, loss=35.8094
epoch=0329, loss=39.8636
epoch=0330, loss=29.5660
epoch=0331, loss=34.1648
epoch=0332, loss=33.9564
epoch=0333, loss=32.2371
epoch=0334, loss=32.0714
epoch=0335, loss=29.1359
epoch=0336, loss=27.9533
epoch=0337, loss=30.0461
epoch=0338, loss=27.1369
epoch=0339, loss=31.6953
epoch=0340, loss=30.7300
epoch=0341, loss=32.9143
epoch=0342, loss=31.5164
epoch=0343, loss=29.7537
epoch=0344, loss=30.9467
epoch=0345, loss=35.6777
epoch=0346, loss=34.4547
epoch=0347, loss=29.6856
epoch=0348, loss=31.2083
epoch=0349, loss=40.4076
epoch=0350, loss=27.9872
epoch=0351, loss=29.7655
epoch=0352, loss=26.4337
epoch=0353, loss=33.4726
epoch=0354, loss=33.3741
epoch=0355, loss=34.5716
epoch=0356, loss=26.9817
epoch=0357, loss=28.4823
epoch=0358, loss=35.5149
epoch=0359, loss=31.9134
epoch=0360, loss=22.4320
epoch=0361, loss=26.6444
epoch=0362, loss=36.2608
epoch=0363, loss=37.0368
epoch=0364, loss=37.8973
epoch=0365, loss=38.1220
epoch=0366, loss=30.1391
epoch=0367, loss=34.4883
epoch=0368, loss=29.2819
epoch=0369, loss=32.3203
epoch=0370, loss=33.3812
epoch=0371, loss=27.8247
epoch=0372, loss=40.8224
epoch=0373, loss=24.7654
epoch=0374, loss=29.6273
epoch=0375, loss=30.0148
epoch=0376, loss=25.0442
epoch=0377, loss=29.4838
epoch=0378, loss=30.1864
epoch=0379, loss=38.5230
epoch=0380, loss=25.7349
epoch=0381, loss=27.2724
epoch=0382, loss=29.4689
epoch=0383, loss=30.5501
epoch=0384, loss=43.1093
epoch=0385, loss=22.1201
epoch=0386, loss=39.9249
epoch=0387, loss=32.1775
epoch=0388, loss=30.8506
epoch=0389, loss=33.3784
epoch=0390, loss=28.2816
epoch=0391, loss=33.0714
epoch=0392, loss=37.8322
epoch=0393, loss=35.1359
epoch=0394, loss=38.3259
epoch=0395, loss=34.6878
epoch=0396, loss=35.9231
epoch=0397, loss=28.1549
epoch=0398, loss=28.7206
epoch=0399, loss=25.5446
epoch=0400, loss=33.4617
epoch=0401, loss=30.3626
epoch=0402, loss=36.2228
epoch=0403, loss=30.8657
epoch=0404, loss=32.7538
epoch=0405, loss=24.6781
epoch=0406, loss=36.1495
epoch=0407, loss=32.8352
epoch=0408, loss=38.9192
epoch=0409, loss=36.4745
epoch=0410, loss=30.1079
epoch=0411, loss=35.7376
epoch=0412, loss=26.5665
epoch=0413, loss=28.3328
epoch=0414, loss=30.4497
epoch=0415, loss=30.3573
epoch=0416, loss=29.6778
epoch=0417, loss=26.6646
epoch=0418, loss=33.5846
epoch=0419, loss=28.4296
epoch=0420, loss=40.5777
epoch=0421, loss=31.9579
epoch=0422, loss=24.8175
epoch=0423, loss=29.9179
epoch=0424, loss=33.4627
epoch=0425, loss=28.0348
epoch=0426, loss=39.8553
epoch=0427, loss=32.8669
epoch=0428, loss=30.5330
epoch=0429, loss=32.0896
epoch=0430, loss=26.3181
epoch=0431, loss=29.8527
epoch=0432, loss=21.1602
epoch=0433, loss=30.1509
epoch=0434, loss=30.5755
epoch=0435, loss=27.7651
epoch=0436, loss=36.7598
epoch=0437, loss=30.1598
epoch=0438, loss=31.5458
epoch=0439, loss=29.3167
epoch=0440, loss=25.5753
epoch=0441, loss=33.2169
epoch=0442, loss=29.3014
epoch=0443, loss=30.6140
epoch=0444, loss=39.8261
epoch=0445, loss=18.8122
epoch=0446, loss=38.6028
epoch=0447, loss=28.4286
epoch=0448, loss=32.6361
epoch=0449, loss=21.1802
epoch=0450, loss=23.9867
epoch=0451, loss=31.4449
epoch=0452, loss=27.0395
epoch=0453, loss=29.5076
epoch=0454, loss=28.4418
epoch=0455, loss=20.1761
epoch=0456, loss=26.2913
epoch=0457, loss=24.5854
epoch=0458, loss=24.4506
epoch=0459, loss=37.6108
epoch=0460, loss=27.6040
epoch=0461, loss=17.9364
epoch=0462, loss=23.2051
epoch=0463, loss=29.2800
epoch=0464, loss=32.4361
epoch=0465, loss=25.0572
epoch=0466, loss=24.9538
epoch=0467, loss=31.6631
epoch=0468, loss=27.2633
epoch=0469, loss=33.8013
epoch=0470, loss=33.9849
epoch=0471, loss=29.2885
epoch=0472, loss=25.6215
epoch=0473, loss=34.4632
epoch=0474, loss=34.7711
epoch=0475, loss=32.0922
epoch=0476, loss=31.9866
epoch=0477, loss=34.8060
epoch=0478, loss=34.8214
epoch=0479, loss=38.7582
epoch=0480, loss=27.0959
epoch=0481, loss=28.6286
epoch=0482, loss=26.5750
epoch=0483, loss=27.2744
epoch=0484, loss=31.6943
epoch=0485, loss=27.1727
epoch=0486, loss=35.6702
epoch=0487, loss=33.3673
epoch=0488, loss=26.5954
epoch=0489, loss=22.7762
epoch=0490, loss=31.5625
epoch=0491, loss=34.8244
epoch=0492, loss=30.4988
epoch=0493, loss=31.0691
epoch=0494, loss=24.4411
epoch=0495, loss=31.1617
epoch=0496, loss=26.1924
epoch=0497, loss=26.2127
epoch=0498, loss=33.2170
epoch=0499, loss=32.1868
epoch=0500, loss=32.5412
epoch=0501, loss=32.1374
epoch=0502, loss=28.3969
epoch=0503, loss=28.5351
epoch=0504, loss=32.0663
epoch=0505, loss=27.6775
epoch=0506, loss=34.6730
epoch=0507, loss=27.8387
epoch=0508, loss=34.9139
epoch=0509, loss=30.9424
epoch=0510, loss=21.6705
epoch=0511, loss=36.3246
epoch=0512, loss=27.2223
epoch=0513, loss=26.8271
epoch=0514, loss=26.4202
epoch=0515, loss=33.3292
epoch=0516, loss=27.2856
epoch=0517, loss=32.8151
epoch=0518, loss=35.4326
epoch=0519, loss=32.4986
epoch=0520, loss=27.4436
epoch=0521, loss=27.2153
epoch=0522, loss=28.2741
epoch=0523, loss=31.3475
epoch=0524, loss=25.6270
epoch=0525, loss=32.7321
epoch=0526, loss=24.4846
epoch=0527, loss=27.5130
epoch=0528, loss=34.2964
epoch=0529, loss=30.5890
epoch=0530, loss=22.2985
epoch=0531, loss=31.6849
epoch=0532, loss=32.2470
epoch=0533, loss=36.0476
epoch=0534, loss=29.5351
epoch=0535, loss=32.0569
epoch=0536, loss=25.4508
epoch=0537, loss=23.9327
epoch=0538, loss=30.7464
epoch=0539, loss=26.1415
epoch=0540, loss=26.3222
epoch=0541, loss=26.5467
epoch=0542, loss=33.1505
epoch=0543, loss=25.3844
epoch=0544, loss=28.9720
epoch=0545, loss=25.4541
epoch=0546, loss=28.6486
epoch=0547, loss=31.9400
epoch=0548, loss=28.5519
epoch=0549, loss=26.4234
epoch=0550, loss=32.8854
epoch=0551, loss=41.4835
epoch=0552, loss=32.2330
epoch=0553, loss=27.0701
epoch=0554, loss=27.3950
epoch=0555, loss=26.9011
epoch=0556, loss=27.6771
epoch=0557, loss=28.6061
epoch=0558, loss=27.9106
epoch=0559, loss=31.6639
epoch=0560, loss=23.9012
epoch=0561, loss=27.4612
epoch=0562, loss=32.3914
epoch=0563, loss=27.6704
epoch=0564, loss=30.3432
epoch=0565, loss=35.6091
epoch=0566, loss=24.0133
epoch=0567, loss=33.4245
epoch=0568, loss=26.0107
epoch=0569, loss=36.5293
epoch=0570, loss=30.9391
epoch=0571, loss=26.8901
epoch=0572, loss=31.3761
epoch=0573, loss=22.5602
epoch=0574, loss=29.3992
epoch=0575, loss=44.5518
epoch=0576, loss=30.0017
epoch=0577, loss=28.7804
epoch=0578, loss=33.2926
epoch=0579, loss=33.4243
epoch=0580, loss=33.7876
epoch=0581, loss=26.9251
epoch=0582, loss=32.3470
epoch=0583, loss=27.8890
epoch=0584, loss=35.7873
epoch=0585, loss=30.1516
epoch=0586, loss=34.0975
epoch=0587, loss=33.8867
epoch=0588, loss=28.5880
epoch=0589, loss=32.1572
epoch=0590, loss=27.0512
epoch=0591, loss=27.2561
epoch=0592, loss=28.8458
epoch=0593, loss=34.0907
epoch=0594, loss=31.9035
epoch=0595, loss=36.1714
epoch=0596, loss=25.7288
epoch=0597, loss=33.0792
epoch=0598, loss=37.0893
epoch=0599, loss=27.7460
epoch=0600, loss=33.2528
epoch=0601, loss=33.5350
epoch=0602, loss=32.1877
epoch=0603, loss=37.4387
epoch=0604, loss=32.8316
epoch=0605, loss=35.4501
epoch=0606, loss=29.1970
epoch=0607, loss=29.0604
epoch=0608, loss=37.4593
epoch=0609, loss=29.2131
epoch=0610, loss=29.2925
epoch=0611, loss=29.2586
epoch=0612, loss=24.8837
epoch=0613, loss=33.3018
epoch=0614, loss=22.7285
epoch=0615, loss=28.3767
epoch=0616, loss=25.0335
epoch=0617, loss=30.4135
epoch=0618, loss=27.1625
epoch=0619, loss=38.0184
epoch=0620, loss=30.7108
epoch=0621, loss=32.1971
epoch=0622, loss=29.4523
epoch=0623, loss=27.5280
epoch=0624, loss=30.6639
epoch=0625, loss=34.0091
epoch=0626, loss=33.3729
epoch=0627, loss=22.9472
epoch=0628, loss=26.4402
epoch=0629, loss=35.6636
epoch=0630, loss=26.3517
epoch=0631, loss=30.3457
epoch=0632, loss=28.7809
epoch=0633, loss=37.0943
epoch=0634, loss=32.6953
epoch=0635, loss=24.2005
epoch=0636, loss=30.0030
epoch=0637, loss=22.4156
epoch=0638, loss=27.1513
epoch=0639, loss=30.1138
epoch=0640, loss=31.4409
epoch=0641, loss=33.5067
epoch=0642, loss=32.0439
epoch=0643, loss=32.3260
epoch=0644, loss=28.7807
epoch=0645, loss=27.8976
epoch=0646, loss=29.1612
epoch=0647, loss=31.5936
epoch=0648, loss=25.6046
epoch=0649, loss=31.3582
epoch=0650, loss=29.6808
epoch=0651, loss=30.2134
epoch=0652, loss=31.0002
epoch=0653, loss=25.0183
epoch=0654, loss=27.5577
epoch=0655, loss=33.7742
epoch=0656, loss=35.4759
epoch=0657, loss=34.6184
epoch=0658, loss=33.0030
epoch=0659, loss=30.3795
epoch=0660, loss=25.1782
epoch=0661, loss=29.3501
epoch=0662, loss=32.5010
epoch=0663, loss=27.8927
epoch=0664, loss=36.6008
epoch=0665, loss=26.7161
epoch=0666, loss=23.4961
epoch=0667, loss=38.7474
epoch=0668, loss=25.8379
epoch=0669, loss=29.2888
epoch=0670, loss=29.2714
epoch=0671, loss=26.5862
epoch=0672, loss=30.6372
epoch=0673, loss=22.8480
epoch=0674, loss=25.0936
epoch=0675, loss=27.0401
epoch=0676, loss=33.4890
epoch=0677, loss=29.9665
epoch=0678, loss=24.2670
epoch=0679, loss=37.0768
epoch=0680, loss=35.3357
epoch=0681, loss=26.8299
epoch=0682, loss=24.6501
epoch=0683, loss=21.9212
epoch=0684, loss=28.7373
epoch=0685, loss=23.0443
epoch=0686, loss=30.7751
epoch=0687, loss=34.6857
epoch=0688, loss=29.4374
epoch=0689, loss=26.9692
epoch=0690, loss=30.4589
epoch=0691, loss=30.6934
epoch=0692, loss=27.2462
epoch=0693, loss=25.4722
epoch=0694, loss=33.4113
epoch=0695, loss=28.2299
epoch=0696, loss=27.9278
epoch=0697, loss=28.2355
epoch=0698, loss=24.5265
epoch=0699, loss=38.6529
epoch=0700, loss=22.5823
epoch=0701, loss=28.8159
epoch=0702, loss=30.6030
epoch=0703, loss=28.8008
epoch=0704, loss=29.7304
epoch=0705, loss=35.3726
epoch=0706, loss=28.0533
epoch=0707, loss=32.3399
epoch=0708, loss=33.1600
epoch=0709, loss=30.6655
epoch=0710, loss=24.1939
epoch=0711, loss=41.2005
epoch=0712, loss=21.0756
epoch=0713, loss=25.7090
epoch=0714, loss=27.7683
epoch=0715, loss=22.7833
epoch=0716, loss=32.5791
epoch=0717, loss=32.7902
epoch=0718, loss=30.8091
epoch=0719, loss=27.2202
epoch=0720, loss=29.6537
epoch=0721, loss=30.1864
epoch=0722, loss=28.9362
epoch=0723, loss=27.7320
epoch=0724, loss=26.9590
epoch=0725, loss=29.2898
epoch=0726, loss=25.8810
epoch=0727, loss=36.0706
epoch=0728, loss=22.0726
epoch=0729, loss=25.4777
epoch=0730, loss=32.1601
epoch=0731, loss=25.8966
epoch=0732, loss=25.2156
epoch=0733, loss=28.2761
epoch=0734, loss=24.4889
epoch=0735, loss=33.2960
epoch=0736, loss=28.2445
epoch=0737, loss=28.1270
epoch=0738, loss=23.2748
epoch=0739, loss=26.3298
epoch=0740, loss=25.7911
epoch=0741, loss=34.5597
epoch=0742, loss=19.8165
epoch=0743, loss=23.8939
epoch=0744, loss=25.0206
epoch=0745, loss=23.7858
epoch=0746, loss=27.3312
epoch=0747, loss=30.3652
epoch=0748, loss=29.9311
epoch=0749, loss=30.0926
epoch=0750, loss=27.3512
epoch=0751, loss=23.1224
epoch=0752, loss=28.3635
epoch=0753, loss=32.5327
epoch=0754, loss=27.6659
epoch=0755, loss=22.9115
epoch=0756, loss=29.4907
epoch=0757, loss=27.0478
epoch=0758, loss=33.5934
epoch=0759, loss=34.2653
epoch=0760, loss=27.2891
epoch=0761, loss=24.4841
epoch=0762, loss=28.4861
epoch=0763, loss=21.7775
epoch=0764, loss=29.7306
epoch=0765, loss=26.3709
epoch=0766, loss=39.8319
epoch=0767, loss=37.2172
epoch=0768, loss=36.7648
epoch=0769, loss=25.7041
epoch=0770, loss=28.0844
epoch=0771, loss=27.7131
epoch=0772, loss=27.0056
epoch=0773, loss=30.9536
epoch=0774, loss=32.0072
epoch=0775, loss=19.4594
epoch=0776, loss=36.1352
epoch=0777, loss=25.2046
epoch=0778, loss=28.7529
epoch=0779, loss=22.1777
epoch=0780, loss=35.4953
epoch=0781, loss=26.4261
epoch=0782, loss=27.4257
epoch=0783, loss=29.5290
epoch=0784, loss=31.7643
epoch=0785, loss=25.4182
epoch=0786, loss=24.5568
epoch=0787, loss=27.9693
epoch=0788, loss=25.7238
epoch=0789, loss=23.0230
epoch=0790, loss=27.3617
epoch=0791, loss=28.5302
epoch=0792, loss=34.3995
epoch=0793, loss=39.3020
epoch=0794, loss=39.4359
epoch=0795, loss=32.3973
epoch=0796, loss=23.5648
epoch=0797, loss=29.0293
epoch=0798, loss=29.9338
epoch=0799, loss=27.4944
epoch=0800, loss=29.2912
epoch=0801, loss=26.5246
epoch=0802, loss=33.3934
epoch=0803, loss=22.2256
epoch=0804, loss=25.9302
epoch=0805, loss=34.1813
epoch=0806, loss=23.4320
epoch=0807, loss=29.1830
epoch=0808, loss=35.9738
epoch=0809, loss=26.6954
epoch=0810, loss=27.0675
epoch=0811, loss=29.5133
epoch=0812, loss=36.9584
epoch=0813, loss=35.6409
epoch=0814, loss=28.8566
epoch=0815, loss=26.5897
epoch=0816, loss=29.5652
epoch=0817, loss=27.5943
epoch=0818, loss=23.2622
epoch=0819, loss=34.6450
epoch=0820, loss=34.1187
epoch=0821, loss=20.6681
epoch=0822, loss=36.6363
epoch=0823, loss=30.0078
epoch=0824, loss=30.6091
epoch=0825, loss=27.6910
epoch=0826, loss=30.6062
epoch=0827, loss=35.9787
epoch=0828, loss=29.0582
epoch=0829, loss=31.9850
epoch=0830, loss=30.3401
epoch=0831, loss=24.6613
epoch=0832, loss=20.2683
epoch=0833, loss=31.3213
epoch=0834, loss=28.3993
epoch=0835, loss=32.9188
epoch=0836, loss=29.6197
epoch=0837, loss=29.8607
epoch=0838, loss=36.3002
epoch=0839, loss=28.7022
epoch=0840, loss=31.4818
epoch=0841, loss=32.7796
epoch=0842, loss=26.1925
epoch=0843, loss=28.6750
epoch=0844, loss=25.8718
epoch=0845, loss=27.1945
epoch=0846, loss=32.2494
epoch=0847, loss=30.9900
epoch=0848, loss=23.8044
epoch=0849, loss=27.8261
epoch=0850, loss=34.8777
epoch=0851, loss=27.9111
epoch=0852, loss=24.5413
epoch=0853, loss=31.9728
epoch=0854, loss=32.5463
epoch=0855, loss=23.8105
epoch=0856, loss=23.9962
epoch=0857, loss=28.0755
epoch=0858, loss=30.4817
epoch=0859, loss=31.6064
epoch=0860, loss=25.3999
epoch=0861, loss=32.3398
epoch=0862, loss=32.8042
epoch=0863, loss=33.4416
epoch=0864, loss=25.8737
epoch=0865, loss=27.0345
epoch=0866, loss=28.4645
epoch=0867, loss=26.0198
epoch=0868, loss=22.0006
epoch=0869, loss=21.8023
epoch=0870, loss=26.9796
epoch=0871, loss=24.7600
epoch=0872, loss=21.8620
epoch=0873, loss=34.5223
epoch=0874, loss=32.3790
epoch=0875, loss=23.9359
epoch=0876, loss=27.6214
epoch=0877, loss=25.6304
epoch=0878, loss=27.9532
epoch=0879, loss=22.9013
epoch=0880, loss=29.7355
epoch=0881, loss=23.0440
epoch=0882, loss=31.5084
epoch=0883, loss=35.5792
epoch=0884, loss=31.2261
epoch=0885, loss=35.8895
epoch=0886, loss=29.0580
epoch=0887, loss=35.2827
epoch=0888, loss=27.2780
epoch=0889, loss=36.7490
epoch=0890, loss=26.4794
epoch=0891, loss=22.9410
epoch=0892, loss=35.5281
epoch=0893, loss=32.9946
epoch=0894, loss=25.7193
epoch=0895, loss=20.1898
epoch=0896, loss=28.7137
epoch=0897, loss=29.8084
epoch=0898, loss=22.0095
epoch=0899, loss=27.7857
epoch=0900, loss=29.1246
epoch=0901, loss=29.2783
epoch=0902, loss=30.5743
epoch=0903, loss=24.0386
epoch=0904, loss=33.8666
epoch=0905, loss=27.0704
epoch=0906, loss=28.9136
epoch=0907, loss=31.2436
epoch=0908, loss=28.4699
epoch=0909, loss=30.7925
epoch=0910, loss=29.6884
epoch=0911, loss=26.3277
epoch=0912, loss=37.9164
epoch=0913, loss=31.8990
epoch=0914, loss=28.7760
epoch=0915, loss=30.7502
epoch=0916, loss=26.4530
epoch=0917, loss=39.3759
epoch=0918, loss=26.1209
epoch=0919, loss=29.4831
epoch=0920, loss=29.1794
epoch=0921, loss=27.2733
epoch=0922, loss=37.4330
epoch=0923, loss=38.1942
epoch=0924, loss=28.4737
epoch=0925, loss=25.0199
epoch=0926, loss=25.8582
epoch=0927, loss=28.8155
epoch=0928, loss=27.3282
epoch=0929, loss=30.1471
epoch=0930, loss=28.7374
epoch=0931, loss=30.4134
epoch=0932, loss=30.1350
epoch=0933, loss=26.7751
epoch=0934, loss=31.4003
epoch=0935, loss=29.2187
epoch=0936, loss=32.5177
epoch=0937, loss=26.2173
epoch=0938, loss=35.1512
epoch=0939, loss=38.3276
epoch=0940, loss=37.5435
epoch=0941, loss=27.8002
epoch=0942, loss=25.4741
epoch=0943, loss=30.0411
epoch=0944, loss=30.6629
epoch=0945, loss=33.6017
epoch=0946, loss=27.0888
epoch=0947, loss=24.7382
epoch=0948, loss=24.0262
epoch=0949, loss=26.1200
epoch=0950, loss=29.0273
epoch=0951, loss=24.8881
epoch=0952, loss=31.3502
epoch=0953, loss=35.3200
epoch=0954, loss=29.1447
epoch=0955, loss=32.3892
epoch=0956, loss=28.6891
epoch=0957, loss=30.1271
epoch=0958, loss=31.7796
epoch=0959, loss=28.3434
epoch=0960, loss=29.5188
epoch=0961, loss=29.6104
epoch=0962, loss=29.3695
epoch=0963, loss=33.3180
epoch=0964, loss=30.5522
epoch=0965, loss=30.8621
epoch=0966, loss=21.5595
epoch=0967, loss=30.0009
epoch=0968, loss=21.4720
epoch=0969, loss=25.7199
epoch=0970, loss=22.2154
epoch=0971, loss=35.4136
epoch=0972, loss=36.0910
epoch=0973, loss=27.7546
epoch=0974, loss=27.8656
epoch=0975, loss=29.3763
epoch=0976, loss=27.2434
epoch=0977, loss=25.0235
epoch=0978, loss=30.6160
epoch=0979, loss=30.6950
epoch=0980, loss=28.1493
epoch=0981, loss=26.5703
epoch=0982, loss=28.4495
epoch=0983, loss=29.6340
epoch=0984, loss=28.7667
epoch=0985, loss=26.2836
epoch=0986, loss=32.8532
epoch=0987, loss=26.9663
epoch=0988, loss=31.8039
epoch=0989, loss=26.7345
epoch=0990, loss=25.3842
epoch=0991, loss=35.4913
epoch=0992, loss=26.6371
epoch=0993, loss=27.0929
epoch=0994, loss=26.5950
epoch=0995, loss=28.8731
epoch=0996, loss=22.7450
epoch=0997, loss=26.1217
epoch=0998, loss=29.9866
epoch=0999, loss=24.5195

Create test-time model with higher jitter to be robust to unseen test inputs.

test_model = build_model(input_dim=num_features, jitter=2e-5)
test_model.set_weights(model.get_weights())

Out:

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gaussian_process.py:311: UserWarning: Unable to detect statically whether the number of index_points is 1. As a result, defaulting to treating the marginal GP at `index_points` as a multivariate Gaussian. This makes some methods, like `cdf` unavailable.
  'Unable to detect statically whether the number of index_points is '
inducing_index_points_history = history.pop("inducing_index_points")
variational_inducing_observations_loc_history = (
    history.pop("variational_inducing_observations_loc"))

inducing_index_points = inducing_index_points_history[-1]
variational_inducing_observations_loc = (
    variational_inducing_observations_loc_history[-1])

Log density ratio, log-odds, or logits.

fig, ax = plt.subplots()

ax.plot(X_q, logit(X_q), c='k',
        label=r"$f(x) = \log p(x) - \log q(x)$")

ax.plot(X_q, test_model(X_q).mean().numpy().T,
        label="posterior mean")
fill_between_stddev(X_q.squeeze(),
                    test_model(X_q).mean().numpy().squeeze(),
                    test_model(X_q).stddev().numpy().squeeze(), alpha=0.1,
                    label="posterior std dev", ax=ax)

ax.scatter(inducing_index_points, np.full_like(inducing_index_points, y_min),
           marker='^', c="tab:gray", label="inducing inputs", alpha=0.4)
ax.scatter(inducing_index_points, variational_inducing_observations_loc,
           marker='+', c="tab:blue", label="inducing variable mean")

ax.set_xlabel('$x$')
ax.set_ylabel('$f(x)$')

ax.legend()

plt.show()
plot sparse gp classification keras

Density ratio.

d = tfd.Independent(tfd.LogNormal(loc=test_model(X_q).mean(),
                                  scale=test_model(X_q).stddev()),
                    reinterpreted_batch_ndims=1)

fig, ax = plt.subplots()

ax.plot(X_q, density_ratio(X_q), c='k', label=r"$r(x) = \exp f(x)$")
ax.plot(X_q, d.mean().numpy().T, label="transformed posterior mean")
fill_between_stddev(X_q.squeeze(),
                    d.mean().numpy().squeeze(),
                    d.stddev().numpy().squeeze(), alpha=0.1,
                    label="transformed posterior std dev", ax=ax)

ax.set_xlabel('$x$')
ax.set_ylabel('$r(x)$')

ax.legend()

plt.show()
plot sparse gp classification keras

Predictive mean samples.

posterior_predictive = tf.keras.Sequential([
    test_model,
    tfp.layers.IndependentBernoulli(event_shape=(num_index_points,))
])
fig, ax = plt.subplots()

ax.plot(X_q, posterior_predictive(X_q).mean())
ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$")
ax.scatter(X_train, y_train, c=y_train, s=12.**2,
           marker='s', alpha=0.1, cmap="coolwarm_r")
ax.set_yticks([0, 1])
ax.set_yticklabels([r"$x_q \sim q(x)$", r"$x_p \sim p(x)$"])
ax.set_xlabel('$x$')

ax.legend()

plt.show()
plot sparse gp classification keras
def make_posterior_predictive(num_samples=None):

    def posterior_predictive(x):

        f_samples = test_model(x).sample(num_samples)

        return make_binary_classification_likelihood(f=f_samples)

    return posterior_predictive


posterior_predictive = make_posterior_predictive(num_samples)
fig, ax = plt.subplots()

ax.plot(X_q, posterior_predictive(X_q).mean().numpy().T, color="tab:blue",
        linewidth=0.4, alpha=0.6)
ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$")
ax.scatter(X_train, y_train, c=y_train, s=12.**2,
           marker='s', alpha=0.1, cmap="coolwarm_r")
ax.set_yticks([0, 1])
ax.set_yticklabels([r"$x_q \sim q(x)$", r"$x_p \sim p(x)$"])
ax.set_xlabel('$x$')

ax.legend()

plt.show()
plot sparse gp classification keras
y_scores = posterior_predictive(X_test).mean().numpy()
y_score_min, y_score_max = np.percentile(y_scores, q=[5, 95], axis=0)
fig, ax = plt.subplots()

ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$")

ax.scatter(X_test, np.median(y_scores, axis=0), c="tab:blue", label="median",
           alpha=0.8)
ax.vlines(X_test, ymin=y_score_min, ymax=y_score_max, color="tab:blue",
          label="90\% confidence interval", alpha=0.8)

ax.scatter(X_test, y_test, c=y_test, s=12.**2,
           marker='s', alpha=0.1, cmap="coolwarm_r")

ax.set_ylabel(r"$\pi(x)$")
ax.set_xlabel(r"$x$")
ax.set_xlim(x_min, x_max)

ax.legend()

plt.show()
plot sparse gp classification keras
def get_inducing_index_points_data(inducing_index_points):

    df = pd.DataFrame(np.hstack(inducing_index_points).T)
    df.index.name = "epoch"
    df.columns.name = "inducing index points"

    s = df.stack()
    s.name = 'x'

    return s.reset_index()
data = get_inducing_index_points_data(inducing_index_points_history)
fig, ax = plt.subplots()

sns.lineplot(x='x', y="epoch", hue="inducing index points", palette="viridis",
             sort=False, data=data, alpha=0.8, ax=ax)

ax.set_xlabel(r'$x$')

plt.show()
plot sparse gp classification keras
variational_inducing_observations_scale_history = (
    history.pop("variational_inducing_observations_scale"))
fig, (ax1, ax2) = plt.subplots(ncols=2, sharex=True, sharey=True)

im2 = ax2.imshow(variational_inducing_observations_scale_history[-1])

vmin, vmax = im2.get_clim()
im1 = ax1.imshow(variational_inducing_observations_scale_history[0],
                 vmin=vmin, vmax=vmax)

fig.colorbar(im2, ax=[ax1, ax2],
             orientation="horizontal")

ax1.set_xlabel(r"$i$")
ax1.set_ylabel(r"$j$")

ax2.set_xlabel(r"$i$")

plt.show()
plot sparse gp classification keras
history_df = pd.DataFrame(history)
history_df.index.name = "epoch"
history_df.reset_index(inplace=True)
fig, ax = plt.subplots()

sns.lineplot(x="epoch", y="nelbo", data=history_df, alpha=0.8, ax=ax)
ax.set_yscale("log")

plt.show()
plot sparse gp classification keras
parameters_df = history_df.drop(columns="nelbo") \
                          .rename(columns=lambda s: s.replace('_', ' '))
g = sns.PairGrid(parameters_df, hue="epoch", palette="RdYlBu", corner=True)
g = g.map_lower(plt.scatter, facecolor="none", alpha=0.6)
plot sparse gp classification keras

Total running time of the script: ( 3 minutes 13.267 seconds)

Gallery generated by Sphinx-Gallery