.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/gaussian_processes/plot_sparse_gp_classification_keras.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_gaussian_processes_plot_sparse_gp_classification_keras.py: Variational Sparse GP for Binary Classification =============================================== Here we fit the hyperparameters of a Gaussian Process by maximizing the (log) marginal likelihood. This is commonly referred to as empirical Bayes, or type-II maximum likelihood estimation. .. GENERATED FROM PYTHON SOURCE LINES 10-30 .. code-block:: default import numpy as np import tensorflow as tf import tensorflow_probability as tfp import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from tensorflow.keras.layers import Layer, InputLayer from tensorflow.keras.initializers import Identity, Constant from etudes.datasets import make_classification_dataset from etudes.plotting import fill_between_stddev from etudes.utils import get_kl_weight from collections import defaultdict .. GENERATED FROM PYTHON SOURCE LINES 32-66 .. code-block:: default # shortcuts tfd = tfp.distributions kernels = tfp.math.psd_kernels # constants num_train = 2048 # nbr training points in synthetic dataset num_test = 40 num_features = 1 # dimensionality num_index_points = 256 # nbr of index points num_samples = 25 quadrature_size = 20 num_inducing_points = 50 num_epochs = 1000 batch_size = 64 shuffle_buffer_size = 500 jitter = 1e-6 kernel_cls = kernels.MaternFiveHalves seed = 8888 # set random seed for reproducibility random_state = np.random.RandomState(seed) x_min, x_max = -5.0, 5.0 y_min, y_max = -6.0, 4.0 # index points X_q = np.linspace(x_min, x_max, num_index_points).reshape(-1, num_features) golden_ratio = 0.5 * (1 + np.sqrt(5)) .. GENERATED FROM PYTHON SOURCE LINES 67-69 Synthetic dataset ----------------- .. GENERATED FROM PYTHON SOURCE LINES 69-75 .. code-block:: default p = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=[0.3, 0.7]), components_distribution=tfd.Normal(loc=[2.0, -3.0], scale=[1.0, 0.5])) q = tfd.Normal(loc=0.0, scale=2.0) .. GENERATED FROM PYTHON SOURCE LINES 76-90 .. code-block:: default def load_data(num_samples, rate=0.5, dtype="float64", seed=None): num_p = int(num_samples * rate) num_q = num_samples - num_p X_p = p.sample(sample_shape=(num_p, 1), seed=seed).numpy() X_q = q.sample(sample_shape=(num_q, 1), seed=seed).numpy() X, y = make_classification_dataset(X_p, X_q, dtype=dtype) return X, y .. GENERATED FROM PYTHON SOURCE LINES 91-97 .. code-block:: default def logit(x): return p.log_prob(x) - q.log_prob(x) .. GENERATED FROM PYTHON SOURCE LINES 98-104 .. code-block:: default def density_ratio(x): return tf.exp(logit(x)) .. GENERATED FROM PYTHON SOURCE LINES 105-111 .. code-block:: default def optimal_score(x): return tf.sigmoid(logit(x)) .. GENERATED FROM PYTHON SOURCE LINES 112-113 Probability densities .. GENERATED FROM PYTHON SOURCE LINES 113-127 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, q.prob(X_q), label='$q(x)$') ax.plot(X_q, p.prob(X_q), label='$p(x)$') ax.set_xlabel('$x$') ax.set_ylabel('density') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_001.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 128-129 Log density ratio, log-odds, or logits. .. GENERATED FROM PYTHON SOURCE LINES 129-141 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, logit(X_q), c='k', label=r"$f(x) = \log p(x) - \log q(x)$") ax.set_xlabel('$x$') ax.set_ylabel('$f(x)$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_002.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 142-143 Density ratio. .. GENERATED FROM PYTHON SOURCE LINES 143-155 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, density_ratio(X_q), c='k', label=r"$r(x) = \exp f(x)$") ax.set_xlabel('$x$') ax.set_ylabel('$r(x)$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_003.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 156-157 Create classification dataset. .. GENERATED FROM PYTHON SOURCE LINES 157-161 .. code-block:: default X_train, y_train = load_data(num_train, seed=seed) X_test, y_test = load_data(num_test, seed=seed) .. GENERATED FROM PYTHON SOURCE LINES 162-163 Dataset visualized against the Bayes optimal classifier. .. GENERATED FROM PYTHON SOURCE LINES 163-177 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$") ax.scatter(X_train, y_train, c=y_train, s=12.**2, marker='s', alpha=0.1, cmap="coolwarm_r") ax.set_yticks([0, 1]) ax.set_yticklabels([r"$x_q \sim q(x)$", r"$x_p \sim p(x)$"]) ax.set_xlabel('$x$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_004.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 178-183 Encapsulate Variational Gaussian Process (particular variable initialization) in a Keras / TensorFlow Probability Mixin Layer. Clean and simple if we restrict to single-output (`event_shape = ()`) and `feature_ndim = 1` (i.e. inputs are simply vectors rather than matrices or tensors). .. GENERATED FROM PYTHON SOURCE LINES 183-264 .. code-block:: default class VariationalGaussianProcess1D(tfp.layers.DistributionLambda): def __init__(self, kernel_wrapper, num_inducing_points, inducing_index_points_initializer, mean_fn=None, jitter=1e-6, convert_to_tensor_fn=tfd.Distribution.sample, **kwargs): def make_distribution(x): return VariationalGaussianProcess1D.new( x, kernel_wrapper=self.kernel_wrapper, inducing_index_points=self.inducing_index_points, variational_inducing_observations_loc=( self.variational_inducing_observations_loc), variational_inducing_observations_scale=( self.variational_inducing_observations_scale), mean_fn=self.mean_fn, observation_noise_variance=tf.exp( self.log_observation_noise_variance), jitter=self.jitter) super(VariationalGaussianProcess1D, self).__init__( make_distribution_fn=make_distribution, convert_to_tensor_fn=convert_to_tensor_fn, dtype=kernel_wrapper.dtype) self.kernel_wrapper = kernel_wrapper self.inducing_index_points_initializer = inducing_index_points_initializer self.num_inducing_points = num_inducing_points self.mean_fn = mean_fn self.jitter = jitter self._dtype = self.kernel_wrapper.dtype def build(self, input_shape): input_dim = input_shape[-1] # TODO: Fix initialization! self.inducing_index_points = self.add_weight( name="inducing_index_points", shape=(self.num_inducing_points, input_dim), initializer=self.inducing_index_points_initializer, dtype=self.dtype) self.variational_inducing_observations_loc = self.add_weight( name="variational_inducing_observations_loc", shape=(self.num_inducing_points,), initializer="zeros", dtype=self.dtype) self.variational_inducing_observations_scale = self.add_weight( name="variational_inducing_observations_scale", shape=(self.num_inducing_points, self.num_inducing_points), initializer=Identity(gain=1.0), dtype=self.dtype) self.log_observation_noise_variance = self.add_weight( name="log_observation_noise_variance", initializer=Constant(-5.0), dtype=self.dtype) @staticmethod def new(x, kernel_wrapper, inducing_index_points, mean_fn, variational_inducing_observations_loc, variational_inducing_observations_scale, observation_noise_variance, jitter, name=None): # ind = tfd.Independent(base, reinterpreted_batch_ndims=1) # bijector = tfp.bijectors.Transpose(rightmost_transposed_ndims=2) # d = tfd.TransformedDistribution(ind, bijector=bijector) return tfd.VariationalGaussianProcess( kernel=kernel_wrapper.kernel, index_points=x, inducing_index_points=inducing_index_points, variational_inducing_observations_loc=( variational_inducing_observations_loc), variational_inducing_observations_scale=( variational_inducing_observations_scale), mean_fn=mean_fn, observation_noise_variance=observation_noise_variance, jitter=jitter) .. GENERATED FROM PYTHON SOURCE LINES 265-266 Kernel wrapper layer .. GENERATED FROM PYTHON SOURCE LINES 266-296 .. code-block:: default class KernelWrapper(Layer): # TODO: Support automatic relevance determination def __init__(self, kernel_cls=kernels.ExponentiatedQuadratic, dtype=None, **kwargs): super(KernelWrapper, self).__init__(dtype=dtype, **kwargs) self.kernel_cls = kernel_cls self.log_amplitude = self.add_weight( name="log_amplitude", initializer="zeros", dtype=dtype) self.log_length_scale = self.add_weight( name="log_length_scale", initializer="zeros", dtype=dtype) def call(self, x): # Never called -- this is just a layer so it can hold variables # in a way Keras understands. return x @property def kernel(self): return self.kernel_cls(amplitude=tf.exp(self.log_amplitude), length_scale=tf.exp(self.log_length_scale)) .. GENERATED FROM PYTHON SOURCE LINES 297-298 Bernoulli likelihood for binary classification. .. GENERATED FROM PYTHON SOURCE LINES 298-304 .. code-block:: default def make_binary_classification_likelihood(f): return tfd.Independent(tfd.Bernoulli(logits=f), reinterpreted_batch_ndims=1) .. GENERATED FROM PYTHON SOURCE LINES 305-312 .. code-block:: default def log_likelihood(y, f): likelihood = make_binary_classification_likelihood(f) return likelihood.log_prob(y) .. GENERATED FROM PYTHON SOURCE LINES 313-314 Helper Model factory method. .. GENERATED FROM PYTHON SOURCE LINES 314-334 .. code-block:: default def build_model(input_dim, jitter=1e-6): inducing_index_points_initial = random_state.choice(X_train.squeeze(), num_inducing_points) \ .reshape(-1, num_features) inducing_index_points_initializer = ( tf.constant_initializer(inducing_index_points_initial)) return tf.keras.Sequential([ InputLayer(input_shape=(input_dim,)), VariationalGaussianProcess1D( kernel_wrapper=KernelWrapper(kernel_cls=kernel_cls, dtype=tf.float64), num_inducing_points=num_inducing_points, inducing_index_points_initializer=inducing_index_points_initializer, jitter=jitter) ]) .. GENERATED FROM PYTHON SOURCE LINES 335-339 .. code-block:: default model = build_model(input_dim=num_features, jitter=jitter) optimizer = tf.keras.optimizers.Adam() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gaussian_process.py:311: UserWarning: Unable to detect statically whether the number of index_points is 1. As a result, defaulting to treating the marginal GP at `index_points` as a multivariate Gaussian. This makes some methods, like `cdf` unavailable. 'Unable to detect statically whether the number of index_points is ' .. GENERATED FROM PYTHON SOURCE LINES 340-356 .. code-block:: default @tf.function def nelbo(X_batch, y_batch): qf = model(X_batch) ell = qf.surrogate_posterior_expected_log_likelihood( observations=y_batch, log_likelihood_fn=log_likelihood, quadrature_size=quadrature_size) kl = qf.surrogate_posterior_kl_divergence_prior() kl_weight = get_kl_weight(num_train, batch_size) return - ell + kl_weight * kl .. GENERATED FROM PYTHON SOURCE LINES 357-369 .. code-block:: default @tf.function def train_step(X_batch, y_batch): with tf.GradientTape() as tape: loss = nelbo(X_batch, y_batch) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss .. GENERATED FROM PYTHON SOURCE LINES 370-375 .. code-block:: default dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) \ .shuffle(seed=seed, buffer_size=shuffle_buffer_size) \ .batch(batch_size, drop_remainder=True) .. GENERATED FROM PYTHON SOURCE LINES 376-382 .. code-block:: default keys = ["inducing_index_points", "variational_inducing_observations_loc", "variational_inducing_observations_scale", "log_observation_noise_variance", "log_amplitude", "log_length_scale"] .. GENERATED FROM PYTHON SOURCE LINES 383-402 .. code-block:: default history = defaultdict(list) for epoch in range(num_epochs): for step, (X_batch, y_batch) in enumerate(dataset): loss = train_step(X_batch, y_batch) print("epoch={epoch:04d}, loss={loss:.4f}" .format(epoch=epoch, loss=loss.numpy())) history["nelbo"].append(loss.numpy()) for key, tensor in zip(keys, model.get_weights()): history[key].append(tensor) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none epoch=0000, loss=83688.7371 epoch=0001, loss=58054.7560 epoch=0002, loss=42126.1083 epoch=0003, loss=29983.7028 epoch=0004, loss=21101.8620 epoch=0005, loss=15167.4498 epoch=0006, loss=11310.4198 epoch=0007, loss=8797.5941 epoch=0008, loss=7048.9826 epoch=0009, loss=5827.5657 epoch=0010, loss=4929.9037 epoch=0011, loss=4249.9116 epoch=0012, loss=3700.4717 epoch=0013, loss=3270.4793 epoch=0014, loss=2917.3544 epoch=0015, loss=2618.1428 epoch=0016, loss=2370.7512 epoch=0017, loss=2152.5240 epoch=0018, loss=1970.7316 epoch=0019, loss=1810.3165 epoch=0020, loss=1673.9554 epoch=0021, loss=1546.4203 epoch=0022, loss=1445.2123 epoch=0023, loss=1346.2835 epoch=0024, loss=1259.0919 epoch=0025, loss=1178.9194 epoch=0026, loss=1110.1653 epoch=0027, loss=1043.3796 epoch=0028, loss=985.8623 epoch=0029, loss=935.3818 epoch=0030, loss=883.3387 epoch=0031, loss=846.6758 epoch=0032, loss=801.0132 epoch=0033, loss=761.0347 epoch=0034, loss=725.8216 epoch=0035, loss=692.4462 epoch=0036, loss=664.7324 epoch=0037, loss=636.0242 epoch=0038, loss=612.9439 epoch=0039, loss=583.1268 epoch=0040, loss=559.6087 epoch=0041, loss=541.4257 epoch=0042, loss=517.6200 epoch=0043, loss=498.7617 epoch=0044, loss=484.0547 epoch=0045, loss=469.5395 epoch=0046, loss=448.1870 epoch=0047, loss=433.6379 epoch=0048, loss=418.5844 epoch=0049, loss=404.7501 epoch=0050, loss=396.1823 epoch=0051, loss=383.1807 epoch=0052, loss=369.7613 epoch=0053, loss=356.8915 epoch=0054, loss=349.1055 epoch=0055, loss=337.6789 epoch=0056, loss=326.7187 epoch=0057, loss=321.0192 epoch=0058, loss=310.2008 epoch=0059, loss=301.8602 epoch=0060, loss=290.9317 epoch=0061, loss=285.6259 epoch=0062, loss=282.6673 epoch=0063, loss=268.9425 epoch=0064, loss=263.4309 epoch=0065, loss=259.8918 epoch=0066, loss=256.2367 epoch=0067, loss=245.6845 epoch=0068, loss=239.3562 epoch=0069, loss=233.7618 epoch=0070, loss=229.4667 epoch=0071, loss=223.0711 epoch=0072, loss=217.4910 epoch=0073, loss=213.8025 epoch=0074, loss=207.8142 epoch=0075, loss=213.1301 epoch=0076, loss=199.0433 epoch=0077, loss=196.4339 epoch=0078, loss=195.7780 epoch=0079, loss=191.0600 epoch=0080, loss=181.2908 epoch=0081, loss=180.3356 epoch=0082, loss=177.3560 epoch=0083, loss=175.0715 epoch=0084, loss=170.5514 epoch=0085, loss=165.8833 epoch=0086, loss=163.8111 epoch=0087, loss=157.9501 epoch=0088, loss=158.0125 epoch=0089, loss=155.8078 epoch=0090, loss=154.0804 epoch=0091, loss=148.0916 epoch=0092, loss=151.4859 epoch=0093, loss=144.5506 epoch=0094, loss=140.4436 epoch=0095, loss=140.8091 epoch=0096, loss=136.7062 epoch=0097, loss=133.6574 epoch=0098, loss=128.8100 epoch=0099, loss=129.7225 epoch=0100, loss=120.3902 epoch=0101, loss=124.8359 epoch=0102, loss=120.4163 epoch=0103, loss=120.3068 epoch=0104, loss=116.1298 epoch=0105, loss=115.4547 epoch=0106, loss=118.7081 epoch=0107, loss=114.6724 epoch=0108, loss=110.4259 epoch=0109, loss=115.2138 epoch=0110, loss=106.8103 epoch=0111, loss=108.8230 epoch=0112, loss=105.5960 epoch=0113, loss=104.1928 epoch=0114, loss=103.6297 epoch=0115, loss=107.9894 epoch=0116, loss=101.7562 epoch=0117, loss=100.6693 epoch=0118, loss=99.2496 epoch=0119, loss=96.7935 epoch=0120, loss=93.9617 epoch=0121, loss=92.8867 epoch=0122, loss=91.7722 epoch=0123, loss=93.2940 epoch=0124, loss=90.0586 epoch=0125, loss=89.2878 epoch=0126, loss=87.5611 epoch=0127, loss=85.0880 epoch=0128, loss=84.8172 epoch=0129, loss=88.2420 epoch=0130, loss=85.9797 epoch=0131, loss=79.7830 epoch=0132, loss=76.0578 epoch=0133, loss=85.2679 epoch=0134, loss=79.1043 epoch=0135, loss=77.3490 epoch=0136, loss=76.7453 epoch=0137, loss=81.7491 epoch=0138, loss=78.8741 epoch=0139, loss=79.0682 epoch=0140, loss=71.8925 epoch=0141, loss=68.2520 epoch=0142, loss=68.2822 epoch=0143, loss=71.6652 epoch=0144, loss=65.2787 epoch=0145, loss=72.3872 epoch=0146, loss=74.1184 epoch=0147, loss=65.5071 epoch=0148, loss=61.3433 epoch=0149, loss=68.5294 epoch=0150, loss=65.9604 epoch=0151, loss=67.6910 epoch=0152, loss=71.8157 epoch=0153, loss=59.9562 epoch=0154, loss=67.7029 epoch=0155, loss=60.3049 epoch=0156, loss=61.8187 epoch=0157, loss=65.2218 epoch=0158, loss=65.0102 epoch=0159, loss=59.5548 epoch=0160, loss=59.7366 epoch=0161, loss=58.5197 epoch=0162, loss=49.5367 epoch=0163, loss=53.2743 epoch=0164, loss=54.8077 epoch=0165, loss=51.4587 epoch=0166, loss=52.5808 epoch=0167, loss=58.1869 epoch=0168, loss=54.1942 epoch=0169, loss=59.7821 epoch=0170, loss=49.5684 epoch=0171, loss=52.9581 epoch=0172, loss=57.4065 epoch=0173, loss=50.8347 epoch=0174, loss=49.2007 epoch=0175, loss=55.5615 epoch=0176, loss=53.7488 epoch=0177, loss=58.8213 epoch=0178, loss=56.4820 epoch=0179, loss=58.3914 epoch=0180, loss=48.2207 epoch=0181, loss=46.4462 epoch=0182, loss=48.3633 epoch=0183, loss=51.8320 epoch=0184, loss=50.4845 epoch=0185, loss=48.7852 epoch=0186, loss=46.9084 epoch=0187, loss=41.9072 epoch=0188, loss=49.3036 epoch=0189, loss=49.3320 epoch=0190, loss=51.7469 epoch=0191, loss=48.3297 epoch=0192, loss=44.3409 epoch=0193, loss=42.1941 epoch=0194, loss=50.0154 epoch=0195, loss=46.3124 epoch=0196, loss=50.2283 epoch=0197, loss=45.3228 epoch=0198, loss=52.0803 epoch=0199, loss=43.0025 epoch=0200, loss=44.9862 epoch=0201, loss=40.3760 epoch=0202, loss=44.1481 epoch=0203, loss=50.6961 epoch=0204, loss=46.7251 epoch=0205, loss=45.1247 epoch=0206, loss=43.0848 epoch=0207, loss=46.8332 epoch=0208, loss=33.5887 epoch=0209, loss=36.3205 epoch=0210, loss=38.5064 epoch=0211, loss=40.7160 epoch=0212, loss=44.2748 epoch=0213, loss=46.7911 epoch=0214, loss=45.7236 epoch=0215, loss=46.1145 epoch=0216, loss=37.8869 epoch=0217, loss=47.6046 epoch=0218, loss=41.1444 epoch=0219, loss=39.6064 epoch=0220, loss=41.7515 epoch=0221, loss=47.2257 epoch=0222, loss=44.0899 epoch=0223, loss=36.8878 epoch=0224, loss=42.9716 epoch=0225, loss=41.7179 epoch=0226, loss=38.1457 epoch=0227, loss=35.3032 epoch=0228, loss=37.0043 epoch=0229, loss=42.6077 epoch=0230, loss=40.3539 epoch=0231, loss=39.6932 epoch=0232, loss=40.3788 epoch=0233, loss=40.9728 epoch=0234, loss=43.0455 epoch=0235, loss=34.4279 epoch=0236, loss=39.5351 epoch=0237, loss=36.6007 epoch=0238, loss=38.0377 epoch=0239, loss=38.3556 epoch=0240, loss=38.0829 epoch=0241, loss=40.0411 epoch=0242, loss=44.2893 epoch=0243, loss=47.8129 epoch=0244, loss=40.0995 epoch=0245, loss=34.5026 epoch=0246, loss=33.8639 epoch=0247, loss=27.7273 epoch=0248, loss=37.1499 epoch=0249, loss=35.1662 epoch=0250, loss=36.6678 epoch=0251, loss=43.0815 epoch=0252, loss=34.6902 epoch=0253, loss=31.2057 epoch=0254, loss=32.8320 epoch=0255, loss=33.0954 epoch=0256, loss=36.8544 epoch=0257, loss=40.3444 epoch=0258, loss=33.1079 epoch=0259, loss=35.5396 epoch=0260, loss=33.6233 epoch=0261, loss=34.1449 epoch=0262, loss=38.5085 epoch=0263, loss=36.7990 epoch=0264, loss=33.6773 epoch=0265, loss=37.3847 epoch=0266, loss=35.2265 epoch=0267, loss=37.9969 epoch=0268, loss=31.5359 epoch=0269, loss=34.2021 epoch=0270, loss=32.6648 epoch=0271, loss=30.7869 epoch=0272, loss=34.8629 epoch=0273, loss=37.1946 epoch=0274, loss=38.3552 epoch=0275, loss=35.3011 epoch=0276, loss=34.1053 epoch=0277, loss=41.5445 epoch=0278, loss=38.9380 epoch=0279, loss=29.0395 epoch=0280, loss=38.3291 epoch=0281, loss=31.3095 epoch=0282, loss=37.1072 epoch=0283, loss=31.1259 epoch=0284, loss=32.1109 epoch=0285, loss=30.8005 epoch=0286, loss=35.4642 epoch=0287, loss=38.1415 epoch=0288, loss=31.6439 epoch=0289, loss=29.5798 epoch=0290, loss=36.0997 epoch=0291, loss=32.4649 epoch=0292, loss=33.2411 epoch=0293, loss=30.5523 epoch=0294, loss=23.8734 epoch=0295, loss=27.2925 epoch=0296, loss=36.2676 epoch=0297, loss=37.2896 epoch=0298, loss=34.1515 epoch=0299, loss=33.6274 epoch=0300, loss=36.6911 epoch=0301, loss=30.2100 epoch=0302, loss=29.0880 epoch=0303, loss=36.4233 epoch=0304, loss=35.4147 epoch=0305, loss=34.4470 epoch=0306, loss=33.3925 epoch=0307, loss=33.6304 epoch=0308, loss=34.1828 epoch=0309, loss=32.4087 epoch=0310, loss=31.3047 epoch=0311, loss=30.2233 epoch=0312, loss=26.5028 epoch=0313, loss=31.8921 epoch=0314, loss=37.1131 epoch=0315, loss=33.3578 epoch=0316, loss=31.3482 epoch=0317, loss=34.8639 epoch=0318, loss=36.3711 epoch=0319, loss=36.1821 epoch=0320, loss=30.3500 epoch=0321, loss=30.7098 epoch=0322, loss=25.4003 epoch=0323, loss=34.0975 epoch=0324, loss=29.7466 epoch=0325, loss=36.4988 epoch=0326, loss=29.9902 epoch=0327, loss=25.2689 epoch=0328, loss=35.8094 epoch=0329, loss=39.8636 epoch=0330, loss=29.5660 epoch=0331, loss=34.1648 epoch=0332, loss=33.9564 epoch=0333, loss=32.2371 epoch=0334, loss=32.0714 epoch=0335, loss=29.1359 epoch=0336, loss=27.9533 epoch=0337, loss=30.0461 epoch=0338, loss=27.1369 epoch=0339, loss=31.6953 epoch=0340, loss=30.7300 epoch=0341, loss=32.9143 epoch=0342, loss=31.5164 epoch=0343, loss=29.7537 epoch=0344, loss=30.9467 epoch=0345, loss=35.6777 epoch=0346, loss=34.4547 epoch=0347, loss=29.6856 epoch=0348, loss=31.2083 epoch=0349, loss=40.4076 epoch=0350, loss=27.9872 epoch=0351, loss=29.7655 epoch=0352, loss=26.4337 epoch=0353, loss=33.4726 epoch=0354, loss=33.3741 epoch=0355, loss=34.5716 epoch=0356, loss=26.9817 epoch=0357, loss=28.4823 epoch=0358, loss=35.5149 epoch=0359, loss=31.9134 epoch=0360, loss=22.4320 epoch=0361, loss=26.6444 epoch=0362, loss=36.2608 epoch=0363, loss=37.0368 epoch=0364, loss=37.8973 epoch=0365, loss=38.1220 epoch=0366, loss=30.1391 epoch=0367, loss=34.4883 epoch=0368, loss=29.2819 epoch=0369, loss=32.3203 epoch=0370, loss=33.3812 epoch=0371, loss=27.8247 epoch=0372, loss=40.8224 epoch=0373, loss=24.7654 epoch=0374, loss=29.6273 epoch=0375, loss=30.0148 epoch=0376, loss=25.0442 epoch=0377, loss=29.4838 epoch=0378, loss=30.1864 epoch=0379, loss=38.5230 epoch=0380, loss=25.7349 epoch=0381, loss=27.2724 epoch=0382, loss=29.4689 epoch=0383, loss=30.5501 epoch=0384, loss=43.1093 epoch=0385, loss=22.1201 epoch=0386, loss=39.9249 epoch=0387, loss=32.1775 epoch=0388, loss=30.8506 epoch=0389, loss=33.3784 epoch=0390, loss=28.2816 epoch=0391, loss=33.0714 epoch=0392, loss=37.8322 epoch=0393, loss=35.1359 epoch=0394, loss=38.3259 epoch=0395, loss=34.6878 epoch=0396, loss=35.9231 epoch=0397, loss=28.1549 epoch=0398, loss=28.7206 epoch=0399, loss=25.5446 epoch=0400, loss=33.4617 epoch=0401, loss=30.3626 epoch=0402, loss=36.2228 epoch=0403, loss=30.8657 epoch=0404, loss=32.7538 epoch=0405, loss=24.6781 epoch=0406, loss=36.1495 epoch=0407, loss=32.8352 epoch=0408, loss=38.9192 epoch=0409, loss=36.4745 epoch=0410, loss=30.1079 epoch=0411, loss=35.7376 epoch=0412, loss=26.5665 epoch=0413, loss=28.3328 epoch=0414, loss=30.4497 epoch=0415, loss=30.3573 epoch=0416, loss=29.6778 epoch=0417, loss=26.6646 epoch=0418, loss=33.5846 epoch=0419, loss=28.4296 epoch=0420, loss=40.5777 epoch=0421, loss=31.9579 epoch=0422, loss=24.8175 epoch=0423, loss=29.9179 epoch=0424, loss=33.4627 epoch=0425, loss=28.0348 epoch=0426, loss=39.8553 epoch=0427, loss=32.8669 epoch=0428, loss=30.5330 epoch=0429, loss=32.0896 epoch=0430, loss=26.3181 epoch=0431, loss=29.8527 epoch=0432, loss=21.1602 epoch=0433, loss=30.1509 epoch=0434, loss=30.5755 epoch=0435, loss=27.7651 epoch=0436, loss=36.7598 epoch=0437, loss=30.1598 epoch=0438, loss=31.5458 epoch=0439, loss=29.3167 epoch=0440, loss=25.5753 epoch=0441, loss=33.2169 epoch=0442, loss=29.3014 epoch=0443, loss=30.6140 epoch=0444, loss=39.8261 epoch=0445, loss=18.8122 epoch=0446, loss=38.6028 epoch=0447, loss=28.4286 epoch=0448, loss=32.6361 epoch=0449, loss=21.1802 epoch=0450, loss=23.9867 epoch=0451, loss=31.4449 epoch=0452, loss=27.0395 epoch=0453, loss=29.5076 epoch=0454, loss=28.4418 epoch=0455, loss=20.1761 epoch=0456, loss=26.2913 epoch=0457, loss=24.5854 epoch=0458, loss=24.4506 epoch=0459, loss=37.6108 epoch=0460, loss=27.6040 epoch=0461, loss=17.9364 epoch=0462, loss=23.2051 epoch=0463, loss=29.2800 epoch=0464, loss=32.4361 epoch=0465, loss=25.0572 epoch=0466, loss=24.9538 epoch=0467, loss=31.6631 epoch=0468, loss=27.2633 epoch=0469, loss=33.8013 epoch=0470, loss=33.9849 epoch=0471, loss=29.2885 epoch=0472, loss=25.6215 epoch=0473, loss=34.4632 epoch=0474, loss=34.7711 epoch=0475, loss=32.0922 epoch=0476, loss=31.9866 epoch=0477, loss=34.8060 epoch=0478, loss=34.8214 epoch=0479, loss=38.7582 epoch=0480, loss=27.0959 epoch=0481, loss=28.6286 epoch=0482, loss=26.5750 epoch=0483, loss=27.2744 epoch=0484, loss=31.6943 epoch=0485, loss=27.1727 epoch=0486, loss=35.6702 epoch=0487, loss=33.3673 epoch=0488, loss=26.5954 epoch=0489, loss=22.7762 epoch=0490, loss=31.5625 epoch=0491, loss=34.8244 epoch=0492, loss=30.4988 epoch=0493, loss=31.0691 epoch=0494, loss=24.4411 epoch=0495, loss=31.1617 epoch=0496, loss=26.1924 epoch=0497, loss=26.2127 epoch=0498, loss=33.2170 epoch=0499, loss=32.1868 epoch=0500, loss=32.5412 epoch=0501, loss=32.1374 epoch=0502, loss=28.3969 epoch=0503, loss=28.5351 epoch=0504, loss=32.0663 epoch=0505, loss=27.6775 epoch=0506, loss=34.6730 epoch=0507, loss=27.8387 epoch=0508, loss=34.9139 epoch=0509, loss=30.9424 epoch=0510, loss=21.6705 epoch=0511, loss=36.3246 epoch=0512, loss=27.2223 epoch=0513, loss=26.8271 epoch=0514, loss=26.4202 epoch=0515, loss=33.3292 epoch=0516, loss=27.2856 epoch=0517, loss=32.8151 epoch=0518, loss=35.4326 epoch=0519, loss=32.4986 epoch=0520, loss=27.4436 epoch=0521, loss=27.2153 epoch=0522, loss=28.2741 epoch=0523, loss=31.3475 epoch=0524, loss=25.6270 epoch=0525, loss=32.7321 epoch=0526, loss=24.4846 epoch=0527, loss=27.5130 epoch=0528, loss=34.2964 epoch=0529, loss=30.5890 epoch=0530, loss=22.2985 epoch=0531, loss=31.6849 epoch=0532, loss=32.2470 epoch=0533, loss=36.0476 epoch=0534, loss=29.5351 epoch=0535, loss=32.0569 epoch=0536, loss=25.4508 epoch=0537, loss=23.9327 epoch=0538, loss=30.7464 epoch=0539, loss=26.1415 epoch=0540, loss=26.3222 epoch=0541, loss=26.5467 epoch=0542, loss=33.1505 epoch=0543, loss=25.3844 epoch=0544, loss=28.9720 epoch=0545, loss=25.4541 epoch=0546, loss=28.6486 epoch=0547, loss=31.9400 epoch=0548, loss=28.5519 epoch=0549, loss=26.4234 epoch=0550, loss=32.8854 epoch=0551, loss=41.4835 epoch=0552, loss=32.2330 epoch=0553, loss=27.0701 epoch=0554, loss=27.3950 epoch=0555, loss=26.9011 epoch=0556, loss=27.6771 epoch=0557, loss=28.6061 epoch=0558, loss=27.9106 epoch=0559, loss=31.6639 epoch=0560, loss=23.9012 epoch=0561, loss=27.4612 epoch=0562, loss=32.3914 epoch=0563, loss=27.6704 epoch=0564, loss=30.3432 epoch=0565, loss=35.6091 epoch=0566, loss=24.0133 epoch=0567, loss=33.4245 epoch=0568, loss=26.0107 epoch=0569, loss=36.5293 epoch=0570, loss=30.9391 epoch=0571, loss=26.8901 epoch=0572, loss=31.3761 epoch=0573, loss=22.5602 epoch=0574, loss=29.3992 epoch=0575, loss=44.5518 epoch=0576, loss=30.0017 epoch=0577, loss=28.7804 epoch=0578, loss=33.2926 epoch=0579, loss=33.4243 epoch=0580, loss=33.7876 epoch=0581, loss=26.9251 epoch=0582, loss=32.3470 epoch=0583, loss=27.8890 epoch=0584, loss=35.7873 epoch=0585, loss=30.1516 epoch=0586, loss=34.0975 epoch=0587, loss=33.8867 epoch=0588, loss=28.5880 epoch=0589, loss=32.1572 epoch=0590, loss=27.0512 epoch=0591, loss=27.2561 epoch=0592, loss=28.8458 epoch=0593, loss=34.0907 epoch=0594, loss=31.9035 epoch=0595, loss=36.1714 epoch=0596, loss=25.7288 epoch=0597, loss=33.0792 epoch=0598, loss=37.0893 epoch=0599, loss=27.7460 epoch=0600, loss=33.2528 epoch=0601, loss=33.5350 epoch=0602, loss=32.1877 epoch=0603, loss=37.4387 epoch=0604, loss=32.8316 epoch=0605, loss=35.4501 epoch=0606, loss=29.1970 epoch=0607, loss=29.0604 epoch=0608, loss=37.4593 epoch=0609, loss=29.2131 epoch=0610, loss=29.2925 epoch=0611, loss=29.2586 epoch=0612, loss=24.8837 epoch=0613, loss=33.3018 epoch=0614, loss=22.7285 epoch=0615, loss=28.3767 epoch=0616, loss=25.0335 epoch=0617, loss=30.4135 epoch=0618, loss=27.1625 epoch=0619, loss=38.0184 epoch=0620, loss=30.7108 epoch=0621, loss=32.1971 epoch=0622, loss=29.4523 epoch=0623, loss=27.5280 epoch=0624, loss=30.6639 epoch=0625, loss=34.0091 epoch=0626, loss=33.3729 epoch=0627, loss=22.9472 epoch=0628, loss=26.4402 epoch=0629, loss=35.6636 epoch=0630, loss=26.3517 epoch=0631, loss=30.3457 epoch=0632, loss=28.7809 epoch=0633, loss=37.0943 epoch=0634, loss=32.6953 epoch=0635, loss=24.2005 epoch=0636, loss=30.0030 epoch=0637, loss=22.4156 epoch=0638, loss=27.1513 epoch=0639, loss=30.1138 epoch=0640, loss=31.4409 epoch=0641, loss=33.5067 epoch=0642, loss=32.0439 epoch=0643, loss=32.3260 epoch=0644, loss=28.7807 epoch=0645, loss=27.8976 epoch=0646, loss=29.1612 epoch=0647, loss=31.5936 epoch=0648, loss=25.6046 epoch=0649, loss=31.3582 epoch=0650, loss=29.6808 epoch=0651, loss=30.2134 epoch=0652, loss=31.0002 epoch=0653, loss=25.0183 epoch=0654, loss=27.5577 epoch=0655, loss=33.7742 epoch=0656, loss=35.4759 epoch=0657, loss=34.6184 epoch=0658, loss=33.0030 epoch=0659, loss=30.3795 epoch=0660, loss=25.1782 epoch=0661, loss=29.3501 epoch=0662, loss=32.5010 epoch=0663, loss=27.8927 epoch=0664, loss=36.6008 epoch=0665, loss=26.7161 epoch=0666, loss=23.4961 epoch=0667, loss=38.7474 epoch=0668, loss=25.8379 epoch=0669, loss=29.2888 epoch=0670, loss=29.2714 epoch=0671, loss=26.5862 epoch=0672, loss=30.6372 epoch=0673, loss=22.8480 epoch=0674, loss=25.0936 epoch=0675, loss=27.0401 epoch=0676, loss=33.4890 epoch=0677, loss=29.9665 epoch=0678, loss=24.2670 epoch=0679, loss=37.0768 epoch=0680, loss=35.3357 epoch=0681, loss=26.8299 epoch=0682, loss=24.6501 epoch=0683, loss=21.9212 epoch=0684, loss=28.7373 epoch=0685, loss=23.0443 epoch=0686, loss=30.7751 epoch=0687, loss=34.6857 epoch=0688, loss=29.4374 epoch=0689, loss=26.9692 epoch=0690, loss=30.4589 epoch=0691, loss=30.6934 epoch=0692, loss=27.2462 epoch=0693, loss=25.4722 epoch=0694, loss=33.4113 epoch=0695, loss=28.2299 epoch=0696, loss=27.9278 epoch=0697, loss=28.2355 epoch=0698, loss=24.5265 epoch=0699, loss=38.6529 epoch=0700, loss=22.5823 epoch=0701, loss=28.8159 epoch=0702, loss=30.6030 epoch=0703, loss=28.8008 epoch=0704, loss=29.7304 epoch=0705, loss=35.3726 epoch=0706, loss=28.0533 epoch=0707, loss=32.3399 epoch=0708, loss=33.1600 epoch=0709, loss=30.6655 epoch=0710, loss=24.1939 epoch=0711, loss=41.2005 epoch=0712, loss=21.0756 epoch=0713, loss=25.7090 epoch=0714, loss=27.7683 epoch=0715, loss=22.7833 epoch=0716, loss=32.5791 epoch=0717, loss=32.7902 epoch=0718, loss=30.8091 epoch=0719, loss=27.2202 epoch=0720, loss=29.6537 epoch=0721, loss=30.1864 epoch=0722, loss=28.9362 epoch=0723, loss=27.7320 epoch=0724, loss=26.9590 epoch=0725, loss=29.2898 epoch=0726, loss=25.8810 epoch=0727, loss=36.0706 epoch=0728, loss=22.0726 epoch=0729, loss=25.4777 epoch=0730, loss=32.1601 epoch=0731, loss=25.8966 epoch=0732, loss=25.2156 epoch=0733, loss=28.2761 epoch=0734, loss=24.4889 epoch=0735, loss=33.2960 epoch=0736, loss=28.2445 epoch=0737, loss=28.1270 epoch=0738, loss=23.2748 epoch=0739, loss=26.3298 epoch=0740, loss=25.7911 epoch=0741, loss=34.5597 epoch=0742, loss=19.8165 epoch=0743, loss=23.8939 epoch=0744, loss=25.0206 epoch=0745, loss=23.7858 epoch=0746, loss=27.3312 epoch=0747, loss=30.3652 epoch=0748, loss=29.9311 epoch=0749, loss=30.0926 epoch=0750, loss=27.3512 epoch=0751, loss=23.1224 epoch=0752, loss=28.3635 epoch=0753, loss=32.5327 epoch=0754, loss=27.6659 epoch=0755, loss=22.9115 epoch=0756, loss=29.4907 epoch=0757, loss=27.0478 epoch=0758, loss=33.5934 epoch=0759, loss=34.2653 epoch=0760, loss=27.2891 epoch=0761, loss=24.4841 epoch=0762, loss=28.4861 epoch=0763, loss=21.7775 epoch=0764, loss=29.7306 epoch=0765, loss=26.3709 epoch=0766, loss=39.8319 epoch=0767, loss=37.2172 epoch=0768, loss=36.7648 epoch=0769, loss=25.7041 epoch=0770, loss=28.0844 epoch=0771, loss=27.7131 epoch=0772, loss=27.0056 epoch=0773, loss=30.9536 epoch=0774, loss=32.0072 epoch=0775, loss=19.4594 epoch=0776, loss=36.1352 epoch=0777, loss=25.2046 epoch=0778, loss=28.7529 epoch=0779, loss=22.1777 epoch=0780, loss=35.4953 epoch=0781, loss=26.4261 epoch=0782, loss=27.4257 epoch=0783, loss=29.5290 epoch=0784, loss=31.7643 epoch=0785, loss=25.4182 epoch=0786, loss=24.5568 epoch=0787, loss=27.9693 epoch=0788, loss=25.7238 epoch=0789, loss=23.0230 epoch=0790, loss=27.3617 epoch=0791, loss=28.5302 epoch=0792, loss=34.3995 epoch=0793, loss=39.3020 epoch=0794, loss=39.4359 epoch=0795, loss=32.3973 epoch=0796, loss=23.5648 epoch=0797, loss=29.0293 epoch=0798, loss=29.9338 epoch=0799, loss=27.4944 epoch=0800, loss=29.2912 epoch=0801, loss=26.5246 epoch=0802, loss=33.3934 epoch=0803, loss=22.2256 epoch=0804, loss=25.9302 epoch=0805, loss=34.1813 epoch=0806, loss=23.4320 epoch=0807, loss=29.1830 epoch=0808, loss=35.9738 epoch=0809, loss=26.6954 epoch=0810, loss=27.0675 epoch=0811, loss=29.5133 epoch=0812, loss=36.9584 epoch=0813, loss=35.6409 epoch=0814, loss=28.8566 epoch=0815, loss=26.5897 epoch=0816, loss=29.5652 epoch=0817, loss=27.5943 epoch=0818, loss=23.2622 epoch=0819, loss=34.6450 epoch=0820, loss=34.1187 epoch=0821, loss=20.6681 epoch=0822, loss=36.6363 epoch=0823, loss=30.0078 epoch=0824, loss=30.6091 epoch=0825, loss=27.6910 epoch=0826, loss=30.6062 epoch=0827, loss=35.9787 epoch=0828, loss=29.0582 epoch=0829, loss=31.9850 epoch=0830, loss=30.3401 epoch=0831, loss=24.6613 epoch=0832, loss=20.2683 epoch=0833, loss=31.3213 epoch=0834, loss=28.3993 epoch=0835, loss=32.9188 epoch=0836, loss=29.6197 epoch=0837, loss=29.8607 epoch=0838, loss=36.3002 epoch=0839, loss=28.7022 epoch=0840, loss=31.4818 epoch=0841, loss=32.7796 epoch=0842, loss=26.1925 epoch=0843, loss=28.6750 epoch=0844, loss=25.8718 epoch=0845, loss=27.1945 epoch=0846, loss=32.2494 epoch=0847, loss=30.9900 epoch=0848, loss=23.8044 epoch=0849, loss=27.8261 epoch=0850, loss=34.8777 epoch=0851, loss=27.9111 epoch=0852, loss=24.5413 epoch=0853, loss=31.9728 epoch=0854, loss=32.5463 epoch=0855, loss=23.8105 epoch=0856, loss=23.9962 epoch=0857, loss=28.0755 epoch=0858, loss=30.4817 epoch=0859, loss=31.6064 epoch=0860, loss=25.3999 epoch=0861, loss=32.3398 epoch=0862, loss=32.8042 epoch=0863, loss=33.4416 epoch=0864, loss=25.8737 epoch=0865, loss=27.0345 epoch=0866, loss=28.4645 epoch=0867, loss=26.0198 epoch=0868, loss=22.0006 epoch=0869, loss=21.8023 epoch=0870, loss=26.9796 epoch=0871, loss=24.7600 epoch=0872, loss=21.8620 epoch=0873, loss=34.5223 epoch=0874, loss=32.3790 epoch=0875, loss=23.9359 epoch=0876, loss=27.6214 epoch=0877, loss=25.6304 epoch=0878, loss=27.9532 epoch=0879, loss=22.9013 epoch=0880, loss=29.7355 epoch=0881, loss=23.0440 epoch=0882, loss=31.5084 epoch=0883, loss=35.5792 epoch=0884, loss=31.2261 epoch=0885, loss=35.8895 epoch=0886, loss=29.0580 epoch=0887, loss=35.2827 epoch=0888, loss=27.2780 epoch=0889, loss=36.7490 epoch=0890, loss=26.4794 epoch=0891, loss=22.9410 epoch=0892, loss=35.5281 epoch=0893, loss=32.9946 epoch=0894, loss=25.7193 epoch=0895, loss=20.1898 epoch=0896, loss=28.7137 epoch=0897, loss=29.8084 epoch=0898, loss=22.0095 epoch=0899, loss=27.7857 epoch=0900, loss=29.1246 epoch=0901, loss=29.2783 epoch=0902, loss=30.5743 epoch=0903, loss=24.0386 epoch=0904, loss=33.8666 epoch=0905, loss=27.0704 epoch=0906, loss=28.9136 epoch=0907, loss=31.2436 epoch=0908, loss=28.4699 epoch=0909, loss=30.7925 epoch=0910, loss=29.6884 epoch=0911, loss=26.3277 epoch=0912, loss=37.9164 epoch=0913, loss=31.8990 epoch=0914, loss=28.7760 epoch=0915, loss=30.7502 epoch=0916, loss=26.4530 epoch=0917, loss=39.3759 epoch=0918, loss=26.1209 epoch=0919, loss=29.4831 epoch=0920, loss=29.1794 epoch=0921, loss=27.2733 epoch=0922, loss=37.4330 epoch=0923, loss=38.1942 epoch=0924, loss=28.4737 epoch=0925, loss=25.0199 epoch=0926, loss=25.8582 epoch=0927, loss=28.8155 epoch=0928, loss=27.3282 epoch=0929, loss=30.1471 epoch=0930, loss=28.7374 epoch=0931, loss=30.4134 epoch=0932, loss=30.1350 epoch=0933, loss=26.7751 epoch=0934, loss=31.4003 epoch=0935, loss=29.2187 epoch=0936, loss=32.5177 epoch=0937, loss=26.2173 epoch=0938, loss=35.1512 epoch=0939, loss=38.3276 epoch=0940, loss=37.5435 epoch=0941, loss=27.8002 epoch=0942, loss=25.4741 epoch=0943, loss=30.0411 epoch=0944, loss=30.6629 epoch=0945, loss=33.6017 epoch=0946, loss=27.0888 epoch=0947, loss=24.7382 epoch=0948, loss=24.0262 epoch=0949, loss=26.1200 epoch=0950, loss=29.0273 epoch=0951, loss=24.8881 epoch=0952, loss=31.3502 epoch=0953, loss=35.3200 epoch=0954, loss=29.1447 epoch=0955, loss=32.3892 epoch=0956, loss=28.6891 epoch=0957, loss=30.1271 epoch=0958, loss=31.7796 epoch=0959, loss=28.3434 epoch=0960, loss=29.5188 epoch=0961, loss=29.6104 epoch=0962, loss=29.3695 epoch=0963, loss=33.3180 epoch=0964, loss=30.5522 epoch=0965, loss=30.8621 epoch=0966, loss=21.5595 epoch=0967, loss=30.0009 epoch=0968, loss=21.4720 epoch=0969, loss=25.7199 epoch=0970, loss=22.2154 epoch=0971, loss=35.4136 epoch=0972, loss=36.0910 epoch=0973, loss=27.7546 epoch=0974, loss=27.8656 epoch=0975, loss=29.3763 epoch=0976, loss=27.2434 epoch=0977, loss=25.0235 epoch=0978, loss=30.6160 epoch=0979, loss=30.6950 epoch=0980, loss=28.1493 epoch=0981, loss=26.5703 epoch=0982, loss=28.4495 epoch=0983, loss=29.6340 epoch=0984, loss=28.7667 epoch=0985, loss=26.2836 epoch=0986, loss=32.8532 epoch=0987, loss=26.9663 epoch=0988, loss=31.8039 epoch=0989, loss=26.7345 epoch=0990, loss=25.3842 epoch=0991, loss=35.4913 epoch=0992, loss=26.6371 epoch=0993, loss=27.0929 epoch=0994, loss=26.5950 epoch=0995, loss=28.8731 epoch=0996, loss=22.7450 epoch=0997, loss=26.1217 epoch=0998, loss=29.9866 epoch=0999, loss=24.5195 .. GENERATED FROM PYTHON SOURCE LINES 403-405 Create test-time model with higher `jitter` to be robust to unseen test inputs. .. GENERATED FROM PYTHON SOURCE LINES 405-408 .. code-block:: default test_model = build_model(input_dim=num_features, jitter=2e-5) test_model.set_weights(model.get_weights()) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gaussian_process.py:311: UserWarning: Unable to detect statically whether the number of index_points is 1. As a result, defaulting to treating the marginal GP at `index_points` as a multivariate Gaussian. This makes some methods, like `cdf` unavailable. 'Unable to detect statically whether the number of index_points is ' .. GENERATED FROM PYTHON SOURCE LINES 409-418 .. code-block:: default inducing_index_points_history = history.pop("inducing_index_points") variational_inducing_observations_loc_history = ( history.pop("variational_inducing_observations_loc")) inducing_index_points = inducing_index_points_history[-1] variational_inducing_observations_loc = ( variational_inducing_observations_loc_history[-1]) .. GENERATED FROM PYTHON SOURCE LINES 419-420 Log density ratio, log-odds, or logits. .. GENERATED FROM PYTHON SOURCE LINES 420-445 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, logit(X_q), c='k', label=r"$f(x) = \log p(x) - \log q(x)$") ax.plot(X_q, test_model(X_q).mean().numpy().T, label="posterior mean") fill_between_stddev(X_q.squeeze(), test_model(X_q).mean().numpy().squeeze(), test_model(X_q).stddev().numpy().squeeze(), alpha=0.1, label="posterior std dev", ax=ax) ax.scatter(inducing_index_points, np.full_like(inducing_index_points, y_min), marker='^', c="tab:gray", label="inducing inputs", alpha=0.4) ax.scatter(inducing_index_points, variational_inducing_observations_loc, marker='+', c="tab:blue", label="inducing variable mean") ax.set_xlabel('$x$') ax.set_ylabel('$f(x)$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_005.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 446-447 Density ratio. .. GENERATED FROM PYTHON SOURCE LINES 447-468 .. code-block:: default d = tfd.Independent(tfd.LogNormal(loc=test_model(X_q).mean(), scale=test_model(X_q).stddev()), reinterpreted_batch_ndims=1) fig, ax = plt.subplots() ax.plot(X_q, density_ratio(X_q), c='k', label=r"$r(x) = \exp f(x)$") ax.plot(X_q, d.mean().numpy().T, label="transformed posterior mean") fill_between_stddev(X_q.squeeze(), d.mean().numpy().squeeze(), d.stddev().numpy().squeeze(), alpha=0.1, label="transformed posterior std dev", ax=ax) ax.set_xlabel('$x$') ax.set_ylabel('$r(x)$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_006.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 469-470 Predictive mean samples. .. GENERATED FROM PYTHON SOURCE LINES 470-477 .. code-block:: default posterior_predictive = tf.keras.Sequential([ test_model, tfp.layers.IndependentBernoulli(event_shape=(num_index_points,)) ]) .. GENERATED FROM PYTHON SOURCE LINES 478-493 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, posterior_predictive(X_q).mean()) ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$") ax.scatter(X_train, y_train, c=y_train, s=12.**2, marker='s', alpha=0.1, cmap="coolwarm_r") ax.set_yticks([0, 1]) ax.set_yticklabels([r"$x_q \sim q(x)$", r"$x_p \sim p(x)$"]) ax.set_xlabel('$x$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_007.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 494-509 .. code-block:: default def make_posterior_predictive(num_samples=None): def posterior_predictive(x): f_samples = test_model(x).sample(num_samples) return make_binary_classification_likelihood(f=f_samples) return posterior_predictive posterior_predictive = make_posterior_predictive(num_samples) .. GENERATED FROM PYTHON SOURCE LINES 510-526 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, posterior_predictive(X_q).mean().numpy().T, color="tab:blue", linewidth=0.4, alpha=0.6) ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$") ax.scatter(X_train, y_train, c=y_train, s=12.**2, marker='s', alpha=0.1, cmap="coolwarm_r") ax.set_yticks([0, 1]) ax.set_yticklabels([r"$x_q \sim q(x)$", r"$x_p \sim p(x)$"]) ax.set_xlabel('$x$') ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_008.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 527-531 .. code-block:: default y_scores = posterior_predictive(X_test).mean().numpy() y_score_min, y_score_max = np.percentile(y_scores, q=[5, 95], axis=0) .. GENERATED FROM PYTHON SOURCE LINES 532-553 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, optimal_score(X_q), c='k', label=r"$\pi(x) = \sigma(f(x))$") ax.scatter(X_test, np.median(y_scores, axis=0), c="tab:blue", label="median", alpha=0.8) ax.vlines(X_test, ymin=y_score_min, ymax=y_score_max, color="tab:blue", label="90\% confidence interval", alpha=0.8) ax.scatter(X_test, y_test, c=y_test, s=12.**2, marker='s', alpha=0.1, cmap="coolwarm_r") ax.set_ylabel(r"$\pi(x)$") ax.set_xlabel(r"$x$") ax.set_xlim(x_min, x_max) ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_009.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 554-567 .. code-block:: default def get_inducing_index_points_data(inducing_index_points): df = pd.DataFrame(np.hstack(inducing_index_points).T) df.index.name = "epoch" df.columns.name = "inducing index points" s = df.stack() s.name = 'x' return s.reset_index() .. GENERATED FROM PYTHON SOURCE LINES 568-572 .. code-block:: default data = get_inducing_index_points_data(inducing_index_points_history) .. GENERATED FROM PYTHON SOURCE LINES 573-584 .. code-block:: default fig, ax = plt.subplots() sns.lineplot(x='x', y="epoch", hue="inducing index points", palette="viridis", sort=False, data=data, alpha=0.8, ax=ax) ax.set_xlabel(r'$x$') plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_010.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 585-589 .. code-block:: default variational_inducing_observations_scale_history = ( history.pop("variational_inducing_observations_scale")) .. GENERATED FROM PYTHON SOURCE LINES 590-609 .. code-block:: default fig, (ax1, ax2) = plt.subplots(ncols=2, sharex=True, sharey=True) im2 = ax2.imshow(variational_inducing_observations_scale_history[-1]) vmin, vmax = im2.get_clim() im1 = ax1.imshow(variational_inducing_observations_scale_history[0], vmin=vmin, vmax=vmax) fig.colorbar(im2, ax=[ax1, ax2], orientation="horizontal") ax1.set_xlabel(r"$i$") ax1.set_ylabel(r"$j$") ax2.set_xlabel(r"$i$") plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_011.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 610-615 .. code-block:: default history_df = pd.DataFrame(history) history_df.index.name = "epoch" history_df.reset_index(inplace=True) .. GENERATED FROM PYTHON SOURCE LINES 616-624 .. code-block:: default fig, ax = plt.subplots() sns.lineplot(x="epoch", y="nelbo", data=history_df, alpha=0.8, ax=ax) ax.set_yscale("log") plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_012.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 625-629 .. code-block:: default parameters_df = history_df.drop(columns="nelbo") \ .rename(columns=lambda s: s.replace('_', ' ')) .. GENERATED FROM PYTHON SOURCE LINES 630-633 .. code-block:: default g = sns.PairGrid(parameters_df, hue="epoch", palette="RdYlBu", corner=True) g = g.map_lower(plt.scatter, facecolor="none", alpha=0.6) .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_gp_classification_keras_013.png :alt: plot sparse gp classification keras :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 13.267 seconds) .. _sphx_glr_download_auto_examples_gaussian_processes_plot_sparse_gp_classification_keras.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sparse_gp_classification_keras.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sparse_gp_classification_keras.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_