.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/gaussian_processes/plot_sparse_log_cox_gaussian_process_keras.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_gaussian_processes_plot_sparse_log_cox_gaussian_process_keras.py: Variational Sparse Log Cox Gaussian Process =========================================== Here we fit the hyperparameters of a Gaussian Process by maximizing the (log) marginal likelihood. This is commonly referred to as empirical Bayes, or type-II maximum likelihood estimation. .. GENERATED FROM PYTHON SOURCE LINES 10-32 .. code-block:: default import numpy as np import tensorflow as tf import tensorflow_probability as tfp import matplotlib.pyplot as plt import seaborn as sns import pandas as pd from tensorflow.keras.layers import Layer, InputLayer from tensorflow.keras.initializers import Identity, Constant from sklearn.preprocessing import MinMaxScaler from etudes.datasets import coal_mining_disasters_load_data from etudes.plotting import fill_between_stddev from etudes.utils import get_kl_weight from collections import defaultdict .. GENERATED FROM PYTHON SOURCE LINES 34-66 .. code-block:: default # shortcuts tfd = tfp.distributions kernels = tfp.math.psd_kernels # constants num_train = 2048 # nbr training points in synthetic dataset num_test = 40 num_features = 1 # dimensionality num_index_points = 256 # nbr of index points num_samples = 25 quadrature_size = 20 num_inducing_points = 50 num_epochs = 2000 batch_size = 64 shuffle_buffer_size = 500 jitter = 1e-6 kernel_cls = kernels.MaternFiveHalves seed = 8888 # set random seed for reproducibility random_state = np.random.RandomState(seed) x_min, x_max = 0.0, 1.0 y_min, y_max = -0.05, 0.7 # index points X_q = np.linspace(x_min, x_max, num_index_points).reshape(-1, num_features) .. GENERATED FROM PYTHON SOURCE LINES 67-69 Coal mining disasters dataset ----------------------------- .. GENERATED FROM PYTHON SOURCE LINES 69-75 .. code-block:: default scaler = MinMaxScaler() Z, y = coal_mining_disasters_load_data(base_dir="../../datasets/") X = scaler.fit_transform(Z) y = y.astype(np.float64) .. GENERATED FROM PYTHON SOURCE LINES 76-77 Probability densities .. GENERATED FROM PYTHON SOURCE LINES 77-89 .. code-block:: default fig, ax = plt.subplots() ax.vlines(Z.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y) ax.set_ylim(-0.05, 0.8) ax.set_xlabel("days") ax.set_ylabel("incidents") plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_001.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 90-95 Encapsulate Variational Gaussian Process (particular variable initialization) in a Keras / TensorFlow Probability Mixin Layer. Clean and simple if we restrict to single-output (`event_shape = ()`) and `feature_ndim = 1` (i.e. inputs are simply vectors rather than matrices or tensors). .. GENERATED FROM PYTHON SOURCE LINES 95-176 .. code-block:: default class VariationalGaussianProcess1D(tfp.layers.DistributionLambda): def __init__(self, kernel_wrapper, num_inducing_points, inducing_index_points_initializer, mean_fn=None, jitter=1e-6, convert_to_tensor_fn=tfd.Distribution.sample, **kwargs): def make_distribution(x): return VariationalGaussianProcess1D.new( x, kernel_wrapper=self.kernel_wrapper, inducing_index_points=self.inducing_index_points, variational_inducing_observations_loc=( self.variational_inducing_observations_loc), variational_inducing_observations_scale=( self.variational_inducing_observations_scale), mean_fn=self.mean_fn, observation_noise_variance=tf.exp( self.log_observation_noise_variance), jitter=self.jitter) super(VariationalGaussianProcess1D, self).__init__( make_distribution_fn=make_distribution, convert_to_tensor_fn=convert_to_tensor_fn, dtype=kernel_wrapper.dtype) self.kernel_wrapper = kernel_wrapper self.inducing_index_points_initializer = inducing_index_points_initializer self.num_inducing_points = num_inducing_points self.mean_fn = mean_fn self.jitter = jitter self._dtype = self.kernel_wrapper.dtype def build(self, input_shape): input_dim = input_shape[-1] # TODO: Fix initialization! self.inducing_index_points = self.add_weight( name="inducing_index_points", shape=(self.num_inducing_points, input_dim), initializer=self.inducing_index_points_initializer, dtype=self.dtype) self.variational_inducing_observations_loc = self.add_weight( name="variational_inducing_observations_loc", shape=(self.num_inducing_points,), initializer="zeros", dtype=self.dtype) self.variational_inducing_observations_scale = self.add_weight( name="variational_inducing_observations_scale", shape=(self.num_inducing_points, self.num_inducing_points), initializer=Identity(gain=1.0), dtype=self.dtype) self.log_observation_noise_variance = self.add_weight( name="log_observation_noise_variance", initializer=Constant(-5.0), dtype=self.dtype) @staticmethod def new(x, kernel_wrapper, inducing_index_points, mean_fn, variational_inducing_observations_loc, variational_inducing_observations_scale, observation_noise_variance, jitter, name=None): # ind = tfd.Independent(base, reinterpreted_batch_ndims=1) # bijector = tfp.bijectors.Transpose(rightmost_transposed_ndims=2) # d = tfd.TransformedDistribution(ind, bijector=bijector) return tfd.VariationalGaussianProcess( kernel=kernel_wrapper.kernel, index_points=x, inducing_index_points=inducing_index_points, variational_inducing_observations_loc=( variational_inducing_observations_loc), variational_inducing_observations_scale=( variational_inducing_observations_scale), mean_fn=mean_fn, observation_noise_variance=observation_noise_variance, jitter=jitter) .. GENERATED FROM PYTHON SOURCE LINES 177-178 Kernel wrapper layer .. GENERATED FROM PYTHON SOURCE LINES 178-208 .. code-block:: default class KernelWrapper(Layer): # TODO: Support automatic relevance determination def __init__(self, kernel_cls=kernels.ExponentiatedQuadratic, dtype=None, **kwargs): super(KernelWrapper, self).__init__(dtype=dtype, **kwargs) self.kernel_cls = kernel_cls self.log_amplitude = self.add_weight( name="log_amplitude", initializer="zeros", dtype=dtype) self.log_length_scale = self.add_weight( name="log_length_scale", initializer="zeros", dtype=dtype) def call(self, x): # Never called -- this is just a layer so it can hold variables # in a way Keras understands. return x @property def kernel(self): return self.kernel_cls(amplitude=tf.exp(self.log_amplitude), length_scale=tf.exp(self.log_length_scale)) .. GENERATED FROM PYTHON SOURCE LINES 209-210 Poisson likelihood. .. GENERATED FROM PYTHON SOURCE LINES 210-217 .. code-block:: default def make_poisson_likelihood(f): return tfd.Independent(tfd.Poisson(log_rate=f), reinterpreted_batch_ndims=1) .. GENERATED FROM PYTHON SOURCE LINES 218-225 .. code-block:: default def log_likelihood(y, f): likelihood = make_poisson_likelihood(f) return likelihood.log_prob(y) .. GENERATED FROM PYTHON SOURCE LINES 226-227 Helper Model factory method. .. GENERATED FROM PYTHON SOURCE LINES 227-247 .. code-block:: default def build_model(input_dim, jitter=1e-6): inducing_index_points_initial = random_state.choice(X.squeeze(), num_inducing_points) \ .reshape(-1, num_features) inducing_index_points_initializer = ( tf.constant_initializer(inducing_index_points_initial)) return tf.keras.Sequential([ InputLayer(input_shape=(input_dim,)), VariationalGaussianProcess1D( kernel_wrapper=KernelWrapper(kernel_cls=kernel_cls, dtype=tf.float64), num_inducing_points=num_inducing_points, inducing_index_points_initializer=inducing_index_points_initializer, jitter=jitter) ]) .. GENERATED FROM PYTHON SOURCE LINES 248-252 .. code-block:: default model = build_model(input_dim=num_features, jitter=jitter) optimizer = tf.keras.optimizers.Adam() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gaussian_process.py:311: UserWarning: Unable to detect statically whether the number of index_points is 1. As a result, defaulting to treating the marginal GP at `index_points` as a multivariate Gaussian. This makes some methods, like `cdf` unavailable. 'Unable to detect statically whether the number of index_points is ' .. GENERATED FROM PYTHON SOURCE LINES 253-269 .. code-block:: default @tf.function def nelbo(X_batch, y_batch): qf = model(X_batch) ell = qf.surrogate_posterior_expected_log_likelihood( observations=y_batch, log_likelihood_fn=log_likelihood, quadrature_size=quadrature_size) kl = qf.surrogate_posterior_kl_divergence_prior() kl_weight = get_kl_weight(num_train, batch_size) return - ell + kl_weight * kl .. GENERATED FROM PYTHON SOURCE LINES 270-282 .. code-block:: default @tf.function def train_step(X_batch, y_batch): with tf.GradientTape() as tape: loss = nelbo(X_batch, y_batch) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss .. GENERATED FROM PYTHON SOURCE LINES 283-288 .. code-block:: default dataset = tf.data.Dataset.from_tensor_slices((X, y)) \ .shuffle(seed=seed, buffer_size=shuffle_buffer_size) \ .batch(batch_size, drop_remainder=True) .. GENERATED FROM PYTHON SOURCE LINES 289-295 .. code-block:: default keys = ["inducing_index_points", "variational_inducing_observations_loc", "variational_inducing_observations_scale", "log_observation_noise_variance", "log_amplitude", "log_length_scale"] .. GENERATED FROM PYTHON SOURCE LINES 296-314 .. code-block:: default history = defaultdict(list) for epoch in range(num_epochs): for step, (X_batch, y_batch) in enumerate(dataset): loss = train_step(X_batch, y_batch) print("epoch={epoch:04d}, loss={loss:.4f}" .format(epoch=epoch, loss=loss.numpy())) history["nelbo"].append(loss.numpy()) for key, tensor in zip(keys, model.get_weights()): history[key].append(tensor) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none epoch=0000, loss=503718.0016 epoch=0001, loss=475124.6687 epoch=0002, loss=449182.0495 epoch=0003, loss=424675.8714 epoch=0004, loss=400130.0089 epoch=0005, loss=377441.1128 epoch=0006, loss=355909.6056 epoch=0007, loss=335670.5052 epoch=0008, loss=316633.7488 epoch=0009, loss=298586.9228 epoch=0010, loss=281445.5673 epoch=0011, loss=265165.5582 epoch=0012, loss=249702.8881 epoch=0013, loss=235009.1906 epoch=0014, loss=221047.5870 epoch=0015, loss=207782.8945 epoch=0016, loss=195182.2574 epoch=0017, loss=183225.4848 epoch=0018, loss=171885.0519 epoch=0019, loss=161128.6349 epoch=0020, loss=150953.2861 epoch=0021, loss=141324.5595 epoch=0022, loss=132227.2517 epoch=0023, loss=123644.0793 epoch=0024, loss=115545.0756 epoch=0025, loss=107924.3415 epoch=0026, loss=100754.6539 epoch=0027, loss=94018.2818 epoch=0028, loss=87691.2610 epoch=0029, loss=81765.3992 epoch=0030, loss=76209.4874 epoch=0031, loss=71007.1965 epoch=0032, loss=66139.6203 epoch=0033, loss=61591.6754 epoch=0034, loss=57337.9688 epoch=0035, loss=53372.9955 epoch=0036, loss=49671.7292 epoch=0037, loss=46239.5650 epoch=0038, loss=43051.8211 epoch=0039, loss=40093.4649 epoch=0040, loss=37367.3047 epoch=0041, loss=34844.5363 epoch=0042, loss=32521.0955 epoch=0043, loss=30376.2433 epoch=0044, loss=28401.9147 epoch=0045, loss=26589.4699 epoch=0046, loss=24918.4878 epoch=0047, loss=23379.7795 epoch=0048, loss=21963.8164 epoch=0049, loss=20661.9284 epoch=0050, loss=19461.8934 epoch=0051, loss=18350.8827 epoch=0052, loss=17330.1188 epoch=0053, loss=16388.4615 epoch=0054, loss=15513.8981 epoch=0055, loss=14703.6937 epoch=0056, loss=13951.6997 epoch=0057, loss=13263.4179 epoch=0058, loss=12618.2228 epoch=0059, loss=12017.4802 epoch=0060, loss=11457.0539 epoch=0061, loss=10938.7167 epoch=0062, loss=10452.8857 epoch=0063, loss=9995.8923 epoch=0064, loss=9572.3354 epoch=0065, loss=9176.7417 epoch=0066, loss=8802.5129 epoch=0067, loss=8456.5162 epoch=0068, loss=8121.0805 epoch=0069, loss=7811.6885 epoch=0070, loss=7518.2041 epoch=0071, loss=7243.9737 epoch=0072, loss=6983.4403 epoch=0073, loss=6738.4308 epoch=0074, loss=6505.0204 epoch=0075, loss=6280.9981 epoch=0076, loss=6073.2072 epoch=0077, loss=5878.0479 epoch=0078, loss=5692.3076 epoch=0079, loss=5507.7018 epoch=0080, loss=5332.9441 epoch=0081, loss=5175.1431 epoch=0082, loss=5018.8195 epoch=0083, loss=4878.9884 epoch=0084, loss=4732.3058 epoch=0085, loss=4601.0597 epoch=0086, loss=4469.2106 epoch=0087, loss=4349.7136 epoch=0088, loss=4237.6983 epoch=0089, loss=4116.0311 epoch=0090, loss=4006.7375 epoch=0091, loss=3904.2931 epoch=0092, loss=3799.8214 epoch=0093, loss=3706.6880 epoch=0094, loss=3612.3212 epoch=0095, loss=3531.6876 epoch=0096, loss=3440.0799 epoch=0097, loss=3364.8440 epoch=0098, loss=3280.9135 epoch=0099, loss=3216.7467 epoch=0100, loss=3130.4951 epoch=0101, loss=3059.3164 epoch=0102, loss=2990.5427 epoch=0103, loss=2924.2639 epoch=0104, loss=2874.4743 epoch=0105, loss=2803.1734 epoch=0106, loss=2747.5532 epoch=0107, loss=2683.1805 epoch=0108, loss=2636.1819 epoch=0109, loss=2576.8733 epoch=0110, loss=2518.7064 epoch=0111, loss=2480.2028 epoch=0112, loss=2430.1288 epoch=0113, loss=2376.4723 epoch=0114, loss=2334.9129 epoch=0115, loss=2289.6594 epoch=0116, loss=2243.3222 epoch=0117, loss=2204.1842 epoch=0118, loss=2162.5270 epoch=0119, loss=2126.7130 epoch=0120, loss=2084.3528 epoch=0121, loss=2039.8391 epoch=0122, loss=2011.7281 epoch=0123, loss=1976.2229 epoch=0124, loss=1942.0064 epoch=0125, loss=1901.3520 epoch=0126, loss=1879.5601 epoch=0127, loss=1842.2692 epoch=0128, loss=1810.3070 epoch=0129, loss=1783.4963 epoch=0130, loss=1749.7177 epoch=0131, loss=1728.7218 epoch=0132, loss=1696.4940 epoch=0133, loss=1667.3415 epoch=0134, loss=1646.6158 epoch=0135, loss=1617.1468 epoch=0136, loss=1595.6100 epoch=0137, loss=1562.6681 epoch=0138, loss=1545.7072 epoch=0139, loss=1511.5130 epoch=0140, loss=1489.9823 epoch=0141, loss=1472.1741 epoch=0142, loss=1459.4709 epoch=0143, loss=1427.1879 epoch=0144, loss=1410.0700 epoch=0145, loss=1392.4305 epoch=0146, loss=1358.6774 epoch=0147, loss=1356.2028 epoch=0148, loss=1326.9513 epoch=0149, loss=1311.4088 epoch=0150, loss=1296.7058 epoch=0151, loss=1276.8852 epoch=0152, loss=1254.6550 epoch=0153, loss=1239.3621 epoch=0154, loss=1230.0371 epoch=0155, loss=1209.9762 epoch=0156, loss=1187.4667 epoch=0157, loss=1177.2099 epoch=0158, loss=1154.8790 epoch=0159, loss=1141.5911 epoch=0160, loss=1127.3600 epoch=0161, loss=1123.1947 epoch=0162, loss=1105.7010 epoch=0163, loss=1084.5280 epoch=0164, loss=1076.3502 epoch=0165, loss=1051.0518 epoch=0166, loss=1052.1513 epoch=0167, loss=1029.4153 epoch=0168, loss=1012.4560 epoch=0169, loss=1006.7754 epoch=0170, loss=994.1713 epoch=0171, loss=979.2590 epoch=0172, loss=965.2443 epoch=0173, loss=967.2261 epoch=0174, loss=947.9689 epoch=0175, loss=941.7749 epoch=0176, loss=930.2486 epoch=0177, loss=913.6552 epoch=0178, loss=905.7848 epoch=0179, loss=895.1102 epoch=0180, loss=882.7923 epoch=0181, loss=878.9853 epoch=0182, loss=855.4525 epoch=0183, loss=848.5580 epoch=0184, loss=854.4321 epoch=0185, loss=837.1378 epoch=0186, loss=820.3715 epoch=0187, loss=810.3195 epoch=0188, loss=813.9388 epoch=0189, loss=795.8079 epoch=0190, loss=797.2823 epoch=0191, loss=778.9154 epoch=0192, loss=767.1228 epoch=0193, loss=765.8746 epoch=0194, loss=751.7439 epoch=0195, loss=735.2551 epoch=0196, loss=736.6673 epoch=0197, loss=724.3290 epoch=0198, loss=734.2114 epoch=0199, loss=712.6057 epoch=0200, loss=712.9040 epoch=0201, loss=697.5288 epoch=0202, loss=704.0655 epoch=0203, loss=692.6003 epoch=0204, loss=685.3066 epoch=0205, loss=674.1011 epoch=0206, loss=665.7722 epoch=0207, loss=654.3170 epoch=0208, loss=654.4745 epoch=0209, loss=639.0899 epoch=0210, loss=638.7625 epoch=0211, loss=623.6700 epoch=0212, loss=634.3182 epoch=0213, loss=613.1934 epoch=0214, loss=615.6673 epoch=0215, loss=611.3304 epoch=0216, loss=604.6876 epoch=0217, loss=594.3960 epoch=0218, loss=594.2117 epoch=0219, loss=573.6360 epoch=0220, loss=577.7584 epoch=0221, loss=570.8481 epoch=0222, loss=560.8061 epoch=0223, loss=573.5160 epoch=0224, loss=552.8636 epoch=0225, loss=549.9146 epoch=0226, loss=546.9251 epoch=0227, loss=532.6548 epoch=0228, loss=539.0560 epoch=0229, loss=528.0483 epoch=0230, loss=522.8893 epoch=0231, loss=522.7825 epoch=0232, loss=510.8447 epoch=0233, loss=509.5694 epoch=0234, loss=516.3404 epoch=0235, loss=511.4087 epoch=0236, loss=494.5026 epoch=0237, loss=503.1661 epoch=0238, loss=492.0830 epoch=0239, loss=479.2352 epoch=0240, loss=475.4360 epoch=0241, loss=475.0072 epoch=0242, loss=465.4977 epoch=0243, loss=475.3879 epoch=0244, loss=461.6061 epoch=0245, loss=461.3810 epoch=0246, loss=448.1367 epoch=0247, loss=462.1360 epoch=0248, loss=452.7104 epoch=0249, loss=442.0678 epoch=0250, loss=446.5259 epoch=0251, loss=429.6645 epoch=0252, loss=432.3712 epoch=0253, loss=432.3731 epoch=0254, loss=428.3459 epoch=0255, loss=427.4889 epoch=0256, loss=415.5806 epoch=0257, loss=408.1475 epoch=0258, loss=398.6794 epoch=0259, loss=418.0139 epoch=0260, loss=407.5253 epoch=0261, loss=403.9192 epoch=0262, loss=393.8357 epoch=0263, loss=388.2981 epoch=0264, loss=392.3414 epoch=0265, loss=395.5534 epoch=0266, loss=390.2889 epoch=0267, loss=385.2252 epoch=0268, loss=378.7956 epoch=0269, loss=381.6068 epoch=0270, loss=369.0336 epoch=0271, loss=378.1374 epoch=0272, loss=370.4357 epoch=0273, loss=366.2946 epoch=0274, loss=352.2604 epoch=0275, loss=372.7704 epoch=0276, loss=356.5688 epoch=0277, loss=352.5583 epoch=0278, loss=350.4578 epoch=0279, loss=350.4543 epoch=0280, loss=342.7281 epoch=0281, loss=341.7187 epoch=0282, loss=337.6962 epoch=0283, loss=336.7313 epoch=0284, loss=348.5399 epoch=0285, loss=331.6232 epoch=0286, loss=320.0296 epoch=0287, loss=349.6291 epoch=0288, loss=330.3633 epoch=0289, loss=321.4751 epoch=0290, loss=308.3491 epoch=0291, loss=333.5266 epoch=0292, loss=316.8104 epoch=0293, loss=319.8715 epoch=0294, loss=320.6826 epoch=0295, loss=304.9669 epoch=0296, loss=301.5010 epoch=0297, loss=306.3954 epoch=0298, loss=303.1771 epoch=0299, loss=308.6068 epoch=0300, loss=303.4759 epoch=0301, loss=291.6490 epoch=0302, loss=294.1003 epoch=0303, loss=293.1295 epoch=0304, loss=288.0081 epoch=0305, loss=290.3008 epoch=0306, loss=289.7989 epoch=0307, loss=282.0794 epoch=0308, loss=284.0127 epoch=0309, loss=278.3577 epoch=0310, loss=283.5266 epoch=0311, loss=270.4285 epoch=0312, loss=292.5535 epoch=0313, loss=269.1271 epoch=0314, loss=267.3524 epoch=0315, loss=282.8746 epoch=0316, loss=267.9802 epoch=0317, loss=255.7895 epoch=0318, loss=263.4444 epoch=0319, loss=262.7765 epoch=0320, loss=269.5424 epoch=0321, loss=247.4822 epoch=0322, loss=245.4533 epoch=0323, loss=251.4737 epoch=0324, loss=256.4097 epoch=0325, loss=248.1391 epoch=0326, loss=245.7778 epoch=0327, loss=239.1502 epoch=0328, loss=231.2579 epoch=0329, loss=231.9704 epoch=0330, loss=230.7200 epoch=0331, loss=244.4598 epoch=0332, loss=228.0484 epoch=0333, loss=232.7435 epoch=0334, loss=231.0531 epoch=0335, loss=229.9803 epoch=0336, loss=237.8903 epoch=0337, loss=242.0139 epoch=0338, loss=240.8350 epoch=0339, loss=228.8982 epoch=0340, loss=227.9539 epoch=0341, loss=221.3632 epoch=0342, loss=226.5502 epoch=0343, loss=225.5825 epoch=0344, loss=221.7293 epoch=0345, loss=220.6527 epoch=0346, loss=224.6299 epoch=0347, loss=218.0827 epoch=0348, loss=207.6847 epoch=0349, loss=215.5031 epoch=0350, loss=206.6249 epoch=0351, loss=219.5271 epoch=0352, loss=209.6862 epoch=0353, loss=207.3360 epoch=0354, loss=197.6009 epoch=0355, loss=205.0703 epoch=0356, loss=206.5529 epoch=0357, loss=194.7957 epoch=0358, loss=197.0495 epoch=0359, loss=194.9799 epoch=0360, loss=193.0936 epoch=0361, loss=201.0962 epoch=0362, loss=194.8401 epoch=0363, loss=187.5098 epoch=0364, loss=196.7941 epoch=0365, loss=203.5231 epoch=0366, loss=196.1433 epoch=0367, loss=184.3774 epoch=0368, loss=190.3704 epoch=0369, loss=183.6483 epoch=0370, loss=187.4339 epoch=0371, loss=195.7501 epoch=0372, loss=175.1536 epoch=0373, loss=188.3401 epoch=0374, loss=188.6235 epoch=0375, loss=177.8545 epoch=0376, loss=188.3338 epoch=0377, loss=195.2243 epoch=0378, loss=176.8593 epoch=0379, loss=169.6181 epoch=0380, loss=190.8619 epoch=0381, loss=164.9291 epoch=0382, loss=182.7612 epoch=0383, loss=183.4093 epoch=0384, loss=185.6163 epoch=0385, loss=177.8770 epoch=0386, loss=166.0440 epoch=0387, loss=173.9915 epoch=0388, loss=174.3544 epoch=0389, loss=166.0017 epoch=0390, loss=152.6976 epoch=0391, loss=160.4803 epoch=0392, loss=172.6038 epoch=0393, loss=156.3226 epoch=0394, loss=159.6409 epoch=0395, loss=166.7436 epoch=0396, loss=159.3685 epoch=0397, loss=154.3180 epoch=0398, loss=167.5721 epoch=0399, loss=163.5109 epoch=0400, loss=171.3971 epoch=0401, loss=160.6082 epoch=0402, loss=145.6805 epoch=0403, loss=155.7041 epoch=0404, loss=156.1276 epoch=0405, loss=152.6160 epoch=0406, loss=162.6592 epoch=0407, loss=142.9762 epoch=0408, loss=159.5509 epoch=0409, loss=151.7371 epoch=0410, loss=152.2857 epoch=0411, loss=147.3106 epoch=0412, loss=154.5352 epoch=0413, loss=146.3088 epoch=0414, loss=152.1442 epoch=0415, loss=145.5159 epoch=0416, loss=156.1471 epoch=0417, loss=144.5401 epoch=0418, loss=143.0060 epoch=0419, loss=138.8478 epoch=0420, loss=144.5675 epoch=0421, loss=141.7152 epoch=0422, loss=152.1052 epoch=0423, loss=135.4647 epoch=0424, loss=144.8311 epoch=0425, loss=136.1836 epoch=0426, loss=136.0509 epoch=0427, loss=131.0169 epoch=0428, loss=130.8113 epoch=0429, loss=140.1047 epoch=0430, loss=135.6104 epoch=0431, loss=139.6994 epoch=0432, loss=135.0475 epoch=0433, loss=124.3657 epoch=0434, loss=128.8028 epoch=0435, loss=135.1245 epoch=0436, loss=126.3713 epoch=0437, loss=126.9186 epoch=0438, loss=127.0762 epoch=0439, loss=132.2666 epoch=0440, loss=117.0638 epoch=0441, loss=131.8005 epoch=0442, loss=121.6616 epoch=0443, loss=127.8722 epoch=0444, loss=132.6087 epoch=0445, loss=119.4095 epoch=0446, loss=120.3532 epoch=0447, loss=124.0202 epoch=0448, loss=119.2436 epoch=0449, loss=117.3810 epoch=0450, loss=125.5332 epoch=0451, loss=116.9169 epoch=0452, loss=117.7422 epoch=0453, loss=118.8801 epoch=0454, loss=125.7930 epoch=0455, loss=127.4539 epoch=0456, loss=110.2093 epoch=0457, loss=115.6084 epoch=0458, loss=139.7687 epoch=0459, loss=122.2508 epoch=0460, loss=114.1477 epoch=0461, loss=125.7758 epoch=0462, loss=122.3530 epoch=0463, loss=118.7015 epoch=0464, loss=118.7259 epoch=0465, loss=105.7424 epoch=0466, loss=128.8635 epoch=0467, loss=126.3125 epoch=0468, loss=132.5865 epoch=0469, loss=105.5933 epoch=0470, loss=107.0141 epoch=0471, loss=109.4016 epoch=0472, loss=104.4746 epoch=0473, loss=103.8465 epoch=0474, loss=95.5309 epoch=0475, loss=109.6063 epoch=0476, loss=113.3386 epoch=0477, loss=99.2031 epoch=0478, loss=111.3675 epoch=0479, loss=106.6856 epoch=0480, loss=121.9445 epoch=0481, loss=117.0523 epoch=0482, loss=119.6875 epoch=0483, loss=112.8985 epoch=0484, loss=110.3964 epoch=0485, loss=119.5072 epoch=0486, loss=104.1824 epoch=0487, loss=108.1442 epoch=0488, loss=93.3669 epoch=0489, loss=104.0165 epoch=0490, loss=101.3833 epoch=0491, loss=109.5744 epoch=0492, loss=100.7609 epoch=0493, loss=97.6215 epoch=0494, loss=107.8976 epoch=0495, loss=91.9219 epoch=0496, loss=103.6455 epoch=0497, loss=100.3046 epoch=0498, loss=87.8629 epoch=0499, loss=102.6298 epoch=0500, loss=97.4680 epoch=0501, loss=93.2479 epoch=0502, loss=102.4228 epoch=0503, loss=99.8412 epoch=0504, loss=95.0297 epoch=0505, loss=110.5724 epoch=0506, loss=92.3500 epoch=0507, loss=93.5608 epoch=0508, loss=104.3107 epoch=0509, loss=90.2806 epoch=0510, loss=99.3023 epoch=0511, loss=89.7588 epoch=0512, loss=88.0880 epoch=0513, loss=102.7510 epoch=0514, loss=85.3267 epoch=0515, loss=83.1560 epoch=0516, loss=86.8379 epoch=0517, loss=84.2890 epoch=0518, loss=84.6089 epoch=0519, loss=90.2681 epoch=0520, loss=98.1327 epoch=0521, loss=94.0007 epoch=0522, loss=89.3772 epoch=0523, loss=82.4729 epoch=0524, loss=86.4087 epoch=0525, loss=91.0402 epoch=0526, loss=87.6877 epoch=0527, loss=87.4722 epoch=0528, loss=81.0174 epoch=0529, loss=82.1088 epoch=0530, loss=88.7508 epoch=0531, loss=102.3531 epoch=0532, loss=84.8533 epoch=0533, loss=94.7518 epoch=0534, loss=86.7407 epoch=0535, loss=83.1817 epoch=0536, loss=79.8289 epoch=0537, loss=94.1217 epoch=0538, loss=82.4183 epoch=0539, loss=88.1212 epoch=0540, loss=79.6697 epoch=0541, loss=82.3980 epoch=0542, loss=85.3741 epoch=0543, loss=75.5057 epoch=0544, loss=79.6781 epoch=0545, loss=68.5751 epoch=0546, loss=86.9400 epoch=0547, loss=91.1889 epoch=0548, loss=82.6861 epoch=0549, loss=78.8321 epoch=0550, loss=86.1787 epoch=0551, loss=88.3150 epoch=0552, loss=82.9925 epoch=0553, loss=81.7292 epoch=0554, loss=66.1165 epoch=0555, loss=84.6366 epoch=0556, loss=92.2174 epoch=0557, loss=83.1902 epoch=0558, loss=74.3282 epoch=0559, loss=89.9474 epoch=0560, loss=89.7887 epoch=0561, loss=86.8922 epoch=0562, loss=76.5947 epoch=0563, loss=71.9305 epoch=0564, loss=84.6352 epoch=0565, loss=86.1365 epoch=0566, loss=89.1902 epoch=0567, loss=87.1468 epoch=0568, loss=66.7683 epoch=0569, loss=68.3575 epoch=0570, loss=73.0146 epoch=0571, loss=67.2949 epoch=0572, loss=78.2781 epoch=0573, loss=74.8641 epoch=0574, loss=89.5783 epoch=0575, loss=83.2646 epoch=0576, loss=75.5487 epoch=0577, loss=71.8853 epoch=0578, loss=62.4595 epoch=0579, loss=75.7681 epoch=0580, loss=75.3446 epoch=0581, loss=60.4439 epoch=0582, loss=80.5004 epoch=0583, loss=76.4536 epoch=0584, loss=64.6321 epoch=0585, loss=62.7810 epoch=0586, loss=69.3928 epoch=0587, loss=66.0748 epoch=0588, loss=67.1042 epoch=0589, loss=67.3756 epoch=0590, loss=62.8932 epoch=0591, loss=71.8032 epoch=0592, loss=70.5042 epoch=0593, loss=63.6476 epoch=0594, loss=57.6988 epoch=0595, loss=59.8144 epoch=0596, loss=74.8017 epoch=0597, loss=81.8788 epoch=0598, loss=77.6706 epoch=0599, loss=68.4422 epoch=0600, loss=64.2554 epoch=0601, loss=65.3775 epoch=0602, loss=64.8156 epoch=0603, loss=67.2893 epoch=0604, loss=62.5748 epoch=0605, loss=66.9849 epoch=0606, loss=65.8462 epoch=0607, loss=75.3093 epoch=0608, loss=79.2133 epoch=0609, loss=69.7162 epoch=0610, loss=59.6053 epoch=0611, loss=80.3660 epoch=0612, loss=63.0438 epoch=0613, loss=58.2483 epoch=0614, loss=59.0454 epoch=0615, loss=58.6386 epoch=0616, loss=58.9345 epoch=0617, loss=77.5252 epoch=0618, loss=79.6157 epoch=0619, loss=55.6701 epoch=0620, loss=79.2709 epoch=0621, loss=64.5792 epoch=0622, loss=71.3175 epoch=0623, loss=63.5462 epoch=0624, loss=71.6826 epoch=0625, loss=58.9421 epoch=0626, loss=66.4716 epoch=0627, loss=62.3693 epoch=0628, loss=63.5122 epoch=0629, loss=69.9511 epoch=0630, loss=47.5289 epoch=0631, loss=57.6975 epoch=0632, loss=67.8547 epoch=0633, loss=58.8649 epoch=0634, loss=56.6639 epoch=0635, loss=69.4698 epoch=0636, loss=53.4113 epoch=0637, loss=59.6261 epoch=0638, loss=64.1503 epoch=0639, loss=64.2238 epoch=0640, loss=67.4349 epoch=0641, loss=68.9233 epoch=0642, loss=58.3884 epoch=0643, loss=52.5806 epoch=0644, loss=67.4768 epoch=0645, loss=55.4877 epoch=0646, loss=53.7203 epoch=0647, loss=58.3251 epoch=0648, loss=59.4838 epoch=0649, loss=59.8746 epoch=0650, loss=58.7315 epoch=0651, loss=79.7500 epoch=0652, loss=58.5452 epoch=0653, loss=60.7954 epoch=0654, loss=54.9007 epoch=0655, loss=65.4627 epoch=0656, loss=61.3134 epoch=0657, loss=53.3750 epoch=0658, loss=54.0795 epoch=0659, loss=72.4057 epoch=0660, loss=53.6076 epoch=0661, loss=81.3088 epoch=0662, loss=58.5682 epoch=0663, loss=58.1461 epoch=0664, loss=57.2060 epoch=0665, loss=71.0399 epoch=0666, loss=65.0347 epoch=0667, loss=49.6852 epoch=0668, loss=53.1331 epoch=0669, loss=50.6020 epoch=0670, loss=60.3321 epoch=0671, loss=61.5717 epoch=0672, loss=53.9444 epoch=0673, loss=51.0058 epoch=0674, loss=49.6867 epoch=0675, loss=49.3432 epoch=0676, loss=47.9196 epoch=0677, loss=68.2624 epoch=0678, loss=57.4226 epoch=0679, loss=62.1165 epoch=0680, loss=46.3444 epoch=0681, loss=50.4172 epoch=0682, loss=49.5411 epoch=0683, loss=62.4633 epoch=0684, loss=50.0264 epoch=0685, loss=58.2416 epoch=0686, loss=50.6320 epoch=0687, loss=49.0191 epoch=0688, loss=48.2821 epoch=0689, loss=42.7373 epoch=0690, loss=58.3171 epoch=0691, loss=55.5937 epoch=0692, loss=44.9044 epoch=0693, loss=55.7095 epoch=0694, loss=56.4527 epoch=0695, loss=44.5670 epoch=0696, loss=51.8081 epoch=0697, loss=50.0144 epoch=0698, loss=58.4083 epoch=0699, loss=48.1501 epoch=0700, loss=54.2906 epoch=0701, loss=53.6182 epoch=0702, loss=39.3437 epoch=0703, loss=61.8583 epoch=0704, loss=58.8785 epoch=0705, loss=46.4308 epoch=0706, loss=62.9105 epoch=0707, loss=56.4528 epoch=0708, loss=55.8068 epoch=0709, loss=43.4011 epoch=0710, loss=44.8983 epoch=0711, loss=44.0255 epoch=0712, loss=54.1520 epoch=0713, loss=46.9686 epoch=0714, loss=45.5972 epoch=0715, loss=49.8532 epoch=0716, loss=43.7185 epoch=0717, loss=49.9352 epoch=0718, loss=60.2397 epoch=0719, loss=52.1619 epoch=0720, loss=46.1009 epoch=0721, loss=53.7403 epoch=0722, loss=43.8176 epoch=0723, loss=50.5373 epoch=0724, loss=56.5087 epoch=0725, loss=51.5442 epoch=0726, loss=55.1669 epoch=0727, loss=40.8989 epoch=0728, loss=48.9067 epoch=0729, loss=45.8273 epoch=0730, loss=44.7350 epoch=0731, loss=46.5871 epoch=0732, loss=41.1744 epoch=0733, loss=84.7797 epoch=0734, loss=48.4741 epoch=0735, loss=39.9638 epoch=0736, loss=41.6540 epoch=0737, loss=61.9457 epoch=0738, loss=46.0518 epoch=0739, loss=52.9721 epoch=0740, loss=40.0342 epoch=0741, loss=43.8131 epoch=0742, loss=44.1711 epoch=0743, loss=51.0378 epoch=0744, loss=47.2733 epoch=0745, loss=46.5394 epoch=0746, loss=35.1426 epoch=0747, loss=54.2852 epoch=0748, loss=51.4293 epoch=0749, loss=50.2110 epoch=0750, loss=49.1398 epoch=0751, loss=45.6933 epoch=0752, loss=57.4035 epoch=0753, loss=54.6156 epoch=0754, loss=44.5137 epoch=0755, loss=54.1820 epoch=0756, loss=39.1424 epoch=0757, loss=54.4399 epoch=0758, loss=36.3383 epoch=0759, loss=49.9852 epoch=0760, loss=53.9937 epoch=0761, loss=43.7791 epoch=0762, loss=46.2205 epoch=0763, loss=62.5270 epoch=0764, loss=59.7127 epoch=0765, loss=39.0664 epoch=0766, loss=44.7098 epoch=0767, loss=43.4895 epoch=0768, loss=49.4405 epoch=0769, loss=54.8502 epoch=0770, loss=42.0123 epoch=0771, loss=62.3534 epoch=0772, loss=38.1141 epoch=0773, loss=43.5459 epoch=0774, loss=41.6324 epoch=0775, loss=48.1968 epoch=0776, loss=46.0691 epoch=0777, loss=51.6692 epoch=0778, loss=46.6518 epoch=0779, loss=63.3883 epoch=0780, loss=52.5965 epoch=0781, loss=41.4272 epoch=0782, loss=48.3477 epoch=0783, loss=43.6215 epoch=0784, loss=44.7427 epoch=0785, loss=51.7438 epoch=0786, loss=32.6199 epoch=0787, loss=40.2166 epoch=0788, loss=62.7569 epoch=0789, loss=52.6731 epoch=0790, loss=46.1543 epoch=0791, loss=53.3033 epoch=0792, loss=44.2406 epoch=0793, loss=49.6839 epoch=0794, loss=37.8175 epoch=0795, loss=44.3829 epoch=0796, loss=48.2948 epoch=0797, loss=48.3759 epoch=0798, loss=47.0026 epoch=0799, loss=34.5370 epoch=0800, loss=49.3600 epoch=0801, loss=42.6382 epoch=0802, loss=31.4616 epoch=0803, loss=30.3515 epoch=0804, loss=38.8228 epoch=0805, loss=48.9507 epoch=0806, loss=46.5755 epoch=0807, loss=46.1173 epoch=0808, loss=48.8414 epoch=0809, loss=58.1899 epoch=0810, loss=66.8658 epoch=0811, loss=47.4377 epoch=0812, loss=47.3466 epoch=0813, loss=48.9937 epoch=0814, loss=39.4674 epoch=0815, loss=44.8714 epoch=0816, loss=43.2242 epoch=0817, loss=45.2122 epoch=0818, loss=38.3916 epoch=0819, loss=36.9446 epoch=0820, loss=54.9395 epoch=0821, loss=42.1571 epoch=0822, loss=46.8208 epoch=0823, loss=49.8680 epoch=0824, loss=31.9886 epoch=0825, loss=44.2595 epoch=0826, loss=47.7277 epoch=0827, loss=40.7836 epoch=0828, loss=50.3694 epoch=0829, loss=31.0777 epoch=0830, loss=51.3054 epoch=0831, loss=45.2598 epoch=0832, loss=45.5807 epoch=0833, loss=38.6982 epoch=0834, loss=54.1841 epoch=0835, loss=37.5695 epoch=0836, loss=50.5953 epoch=0837, loss=55.2999 epoch=0838, loss=61.0726 epoch=0839, loss=41.8046 epoch=0840, loss=57.9198 epoch=0841, loss=50.9733 epoch=0842, loss=52.0090 epoch=0843, loss=49.6375 epoch=0844, loss=33.1304 epoch=0845, loss=35.4482 epoch=0846, loss=40.9187 epoch=0847, loss=32.0274 epoch=0848, loss=54.8925 epoch=0849, loss=55.4495 epoch=0850, loss=39.9136 epoch=0851, loss=39.5183 epoch=0852, loss=44.2735 epoch=0853, loss=42.9308 epoch=0854, loss=33.3962 epoch=0855, loss=42.6717 epoch=0856, loss=33.2201 epoch=0857, loss=44.9332 epoch=0858, loss=42.4207 epoch=0859, loss=43.9101 epoch=0860, loss=42.7964 epoch=0861, loss=28.6696 epoch=0862, loss=42.3850 epoch=0863, loss=45.0092 epoch=0864, loss=38.7697 epoch=0865, loss=38.6989 epoch=0866, loss=42.1009 epoch=0867, loss=44.3899 epoch=0868, loss=55.5179 epoch=0869, loss=37.4072 epoch=0870, loss=49.8416 epoch=0871, loss=22.9407 epoch=0872, loss=53.3703 epoch=0873, loss=49.7300 epoch=0874, loss=41.9710 epoch=0875, loss=33.5178 epoch=0876, loss=40.1948 epoch=0877, loss=38.5868 epoch=0878, loss=48.0192 epoch=0879, loss=56.0664 epoch=0880, loss=29.9777 epoch=0881, loss=53.9331 epoch=0882, loss=35.7819 epoch=0883, loss=46.7845 epoch=0884, loss=51.1198 epoch=0885, loss=41.6742 epoch=0886, loss=30.7241 epoch=0887, loss=50.6927 epoch=0888, loss=38.3675 epoch=0889, loss=40.7640 epoch=0890, loss=44.0895 epoch=0891, loss=43.9981 epoch=0892, loss=31.3808 epoch=0893, loss=41.1945 epoch=0894, loss=30.9972 epoch=0895, loss=42.1725 epoch=0896, loss=36.7186 epoch=0897, loss=48.3122 epoch=0898, loss=36.7973 epoch=0899, loss=32.2036 epoch=0900, loss=31.9240 epoch=0901, loss=30.0664 epoch=0902, loss=39.6483 epoch=0903, loss=49.0898 epoch=0904, loss=43.9462 epoch=0905, loss=46.5315 epoch=0906, loss=35.1399 epoch=0907, loss=43.3344 epoch=0908, loss=39.6814 epoch=0909, loss=49.7920 epoch=0910, loss=42.8614 epoch=0911, loss=42.2364 epoch=0912, loss=29.1351 epoch=0913, loss=36.6785 epoch=0914, loss=42.9751 epoch=0915, loss=39.4084 epoch=0916, loss=22.3465 epoch=0917, loss=36.7195 epoch=0918, loss=39.6980 epoch=0919, loss=39.0278 epoch=0920, loss=29.5661 epoch=0921, loss=36.6211 epoch=0922, loss=39.0927 epoch=0923, loss=53.0615 epoch=0924, loss=54.3734 epoch=0925, loss=49.3880 epoch=0926, loss=44.8807 epoch=0927, loss=30.4433 epoch=0928, loss=30.8365 epoch=0929, loss=43.8575 epoch=0930, loss=40.3110 epoch=0931, loss=39.0222 epoch=0932, loss=35.6125 epoch=0933, loss=40.0691 epoch=0934, loss=37.4746 epoch=0935, loss=50.4772 epoch=0936, loss=35.9687 epoch=0937, loss=43.3802 epoch=0938, loss=35.6311 epoch=0939, loss=34.9448 epoch=0940, loss=39.2193 epoch=0941, loss=47.5750 epoch=0942, loss=40.5120 epoch=0943, loss=41.5842 epoch=0944, loss=36.9813 epoch=0945, loss=39.4889 epoch=0946, loss=40.2740 epoch=0947, loss=54.9169 epoch=0948, loss=38.1466 epoch=0949, loss=38.7406 epoch=0950, loss=39.3057 epoch=0951, loss=43.0153 epoch=0952, loss=29.6520 epoch=0953, loss=35.1728 epoch=0954, loss=29.8932 epoch=0955, loss=28.1536 epoch=0956, loss=38.3106 epoch=0957, loss=38.9041 epoch=0958, loss=44.8865 epoch=0959, loss=34.7237 epoch=0960, loss=35.5248 epoch=0961, loss=31.2927 epoch=0962, loss=45.6466 epoch=0963, loss=49.5796 epoch=0964, loss=38.5878 epoch=0965, loss=43.0185 epoch=0966, loss=41.3775 epoch=0967, loss=35.3072 epoch=0968, loss=40.4946 epoch=0969, loss=38.3144 epoch=0970, loss=37.1825 epoch=0971, loss=33.5633 epoch=0972, loss=44.8774 epoch=0973, loss=30.5078 epoch=0974, loss=54.0708 epoch=0975, loss=40.4760 epoch=0976, loss=44.6554 epoch=0977, loss=43.8926 epoch=0978, loss=35.7388 epoch=0979, loss=31.7607 epoch=0980, loss=41.0082 epoch=0981, loss=54.1730 epoch=0982, loss=35.5671 epoch=0983, loss=54.8984 epoch=0984, loss=32.4584 epoch=0985, loss=36.8033 epoch=0986, loss=28.5383 epoch=0987, loss=31.1247 epoch=0988, loss=44.6059 epoch=0989, loss=28.4877 epoch=0990, loss=47.7916 epoch=0991, loss=41.5727 epoch=0992, loss=25.4704 epoch=0993, loss=27.3044 epoch=0994, loss=40.2768 epoch=0995, loss=44.7830 epoch=0996, loss=25.4001 epoch=0997, loss=45.7636 epoch=0998, loss=24.7614 epoch=0999, loss=35.8817 epoch=1000, loss=31.5747 epoch=1001, loss=46.7773 epoch=1002, loss=38.6838 epoch=1003, loss=43.0371 epoch=1004, loss=41.6749 epoch=1005, loss=36.8893 epoch=1006, loss=41.3862 epoch=1007, loss=41.6232 epoch=1008, loss=29.2862 epoch=1009, loss=34.4208 epoch=1010, loss=50.0407 epoch=1011, loss=41.5424 epoch=1012, loss=32.5476 epoch=1013, loss=45.9401 epoch=1014, loss=43.9288 epoch=1015, loss=28.2446 epoch=1016, loss=37.2122 epoch=1017, loss=38.2516 epoch=1018, loss=25.0566 epoch=1019, loss=26.2989 epoch=1020, loss=46.4798 epoch=1021, loss=30.7958 epoch=1022, loss=38.4176 epoch=1023, loss=36.8788 epoch=1024, loss=33.6224 epoch=1025, loss=44.1153 epoch=1026, loss=49.2197 epoch=1027, loss=33.1222 epoch=1028, loss=42.2189 epoch=1029, loss=31.0787 epoch=1030, loss=47.6945 epoch=1031, loss=36.4322 epoch=1032, loss=44.9050 epoch=1033, loss=37.7846 epoch=1034, loss=55.0392 epoch=1035, loss=43.4899 epoch=1036, loss=32.6500 epoch=1037, loss=46.7317 epoch=1038, loss=26.9375 epoch=1039, loss=33.0094 epoch=1040, loss=60.7153 epoch=1041, loss=45.5859 epoch=1042, loss=42.2432 epoch=1043, loss=50.3978 epoch=1044, loss=36.0496 epoch=1045, loss=39.4132 epoch=1046, loss=38.4170 epoch=1047, loss=35.5436 epoch=1048, loss=34.5298 epoch=1049, loss=47.9556 epoch=1050, loss=29.9514 epoch=1051, loss=41.5883 epoch=1052, loss=40.4844 epoch=1053, loss=32.1343 epoch=1054, loss=28.9982 epoch=1055, loss=40.7955 epoch=1056, loss=33.8277 epoch=1057, loss=37.2696 epoch=1058, loss=32.4221 epoch=1059, loss=35.5141 epoch=1060, loss=46.4895 epoch=1061, loss=33.7861 epoch=1062, loss=37.9355 epoch=1063, loss=28.9689 epoch=1064, loss=37.1848 epoch=1065, loss=34.9959 epoch=1066, loss=34.3465 epoch=1067, loss=54.1800 epoch=1068, loss=33.4773 epoch=1069, loss=37.8777 epoch=1070, loss=29.8906 epoch=1071, loss=36.3135 epoch=1072, loss=32.1539 epoch=1073, loss=53.2417 epoch=1074, loss=35.5554 epoch=1075, loss=36.0650 epoch=1076, loss=35.3531 epoch=1077, loss=34.6156 epoch=1078, loss=50.0658 epoch=1079, loss=32.4242 epoch=1080, loss=43.8290 epoch=1081, loss=33.9779 epoch=1082, loss=45.4978 epoch=1083, loss=36.2925 epoch=1084, loss=50.7296 epoch=1085, loss=29.3651 epoch=1086, loss=41.1270 epoch=1087, loss=43.6603 epoch=1088, loss=26.5921 epoch=1089, loss=31.1718 epoch=1090, loss=45.6769 epoch=1091, loss=21.7876 epoch=1092, loss=40.9696 epoch=1093, loss=43.9557 epoch=1094, loss=34.7761 epoch=1095, loss=39.4589 epoch=1096, loss=50.8497 epoch=1097, loss=39.5860 epoch=1098, loss=33.2066 epoch=1099, loss=39.5427 epoch=1100, loss=32.6381 epoch=1101, loss=37.2633 epoch=1102, loss=32.2884 epoch=1103, loss=38.1605 epoch=1104, loss=30.7495 epoch=1105, loss=35.5737 epoch=1106, loss=30.1897 epoch=1107, loss=38.1585 epoch=1108, loss=28.4092 epoch=1109, loss=29.7329 epoch=1110, loss=37.1669 epoch=1111, loss=29.4979 epoch=1112, loss=35.8428 epoch=1113, loss=34.8263 epoch=1114, loss=47.5728 epoch=1115, loss=37.1924 epoch=1116, loss=46.1044 epoch=1117, loss=34.6153 epoch=1118, loss=61.9499 epoch=1119, loss=35.0889 epoch=1120, loss=22.9839 epoch=1121, loss=29.6082 epoch=1122, loss=34.8566 epoch=1123, loss=35.1996 epoch=1124, loss=32.4029 epoch=1125, loss=38.2590 epoch=1126, loss=34.7926 epoch=1127, loss=30.0309 epoch=1128, loss=33.1623 epoch=1129, loss=36.3337 epoch=1130, loss=48.0137 epoch=1131, loss=30.0274 epoch=1132, loss=28.3043 epoch=1133, loss=23.4344 epoch=1134, loss=23.2265 epoch=1135, loss=40.6957 epoch=1136, loss=33.3754 epoch=1137, loss=36.1965 epoch=1138, loss=41.8181 epoch=1139, loss=25.8868 epoch=1140, loss=28.9432 epoch=1141, loss=33.7793 epoch=1142, loss=35.4136 epoch=1143, loss=29.5980 epoch=1144, loss=42.2540 epoch=1145, loss=26.7558 epoch=1146, loss=37.7473 epoch=1147, loss=37.1238 epoch=1148, loss=32.0596 epoch=1149, loss=39.6587 epoch=1150, loss=28.4213 epoch=1151, loss=28.8102 epoch=1152, loss=39.1638 epoch=1153, loss=27.3062 epoch=1154, loss=42.6782 epoch=1155, loss=32.8738 epoch=1156, loss=25.7113 epoch=1157, loss=34.6483 epoch=1158, loss=26.3806 epoch=1159, loss=31.2331 epoch=1160, loss=50.5903 epoch=1161, loss=24.9333 epoch=1162, loss=36.7008 epoch=1163, loss=37.0426 epoch=1164, loss=32.0863 epoch=1165, loss=24.1840 epoch=1166, loss=38.1325 epoch=1167, loss=38.7330 epoch=1168, loss=48.2531 epoch=1169, loss=45.8834 epoch=1170, loss=32.5161 epoch=1171, loss=36.7577 epoch=1172, loss=36.6310 epoch=1173, loss=32.6013 epoch=1174, loss=29.5227 epoch=1175, loss=35.8207 epoch=1176, loss=35.8144 epoch=1177, loss=43.7553 epoch=1178, loss=31.8759 epoch=1179, loss=32.9536 epoch=1180, loss=34.0385 epoch=1181, loss=33.6179 epoch=1182, loss=48.9506 epoch=1183, loss=39.0245 epoch=1184, loss=44.1334 epoch=1185, loss=40.7664 epoch=1186, loss=33.6860 epoch=1187, loss=27.1803 epoch=1188, loss=22.2602 epoch=1189, loss=40.0744 epoch=1190, loss=34.6431 epoch=1191, loss=32.6041 epoch=1192, loss=37.8065 epoch=1193, loss=39.2743 epoch=1194, loss=28.5010 epoch=1195, loss=36.9890 epoch=1196, loss=38.7774 epoch=1197, loss=40.2295 epoch=1198, loss=45.9734 epoch=1199, loss=32.7846 epoch=1200, loss=35.4219 epoch=1201, loss=31.6757 epoch=1202, loss=20.8077 epoch=1203, loss=42.8190 epoch=1204, loss=35.4590 epoch=1205, loss=33.9555 epoch=1206, loss=38.9901 epoch=1207, loss=35.0332 epoch=1208, loss=27.4473 epoch=1209, loss=20.5175 epoch=1210, loss=34.2019 epoch=1211, loss=41.0577 epoch=1212, loss=34.4536 epoch=1213, loss=28.7619 epoch=1214, loss=35.7453 epoch=1215, loss=31.2096 epoch=1216, loss=35.4969 epoch=1217, loss=32.5666 epoch=1218, loss=40.1697 epoch=1219, loss=21.2583 epoch=1220, loss=36.9934 epoch=1221, loss=33.8079 epoch=1222, loss=30.8306 epoch=1223, loss=39.5826 epoch=1224, loss=40.4954 epoch=1225, loss=27.8841 epoch=1226, loss=32.9384 epoch=1227, loss=32.3499 epoch=1228, loss=26.3457 epoch=1229, loss=34.5355 epoch=1230, loss=45.5821 epoch=1231, loss=40.1966 epoch=1232, loss=19.0265 epoch=1233, loss=46.5220 epoch=1234, loss=43.6520 epoch=1235, loss=36.3505 epoch=1236, loss=34.3163 epoch=1237, loss=44.1153 epoch=1238, loss=39.0215 epoch=1239, loss=35.1996 epoch=1240, loss=40.9362 epoch=1241, loss=40.2023 epoch=1242, loss=28.3928 epoch=1243, loss=24.3768 epoch=1244, loss=40.4758 epoch=1245, loss=27.4730 epoch=1246, loss=24.2012 epoch=1247, loss=46.7035 epoch=1248, loss=43.2430 epoch=1249, loss=29.7535 epoch=1250, loss=21.3884 epoch=1251, loss=33.8545 epoch=1252, loss=25.7920 epoch=1253, loss=37.9737 epoch=1254, loss=38.1511 epoch=1255, loss=42.9844 epoch=1256, loss=34.4310 epoch=1257, loss=24.3918 epoch=1258, loss=22.3792 epoch=1259, loss=35.8306 epoch=1260, loss=41.1900 epoch=1261, loss=53.3297 epoch=1262, loss=39.2182 epoch=1263, loss=42.3259 epoch=1264, loss=38.3113 epoch=1265, loss=32.0263 epoch=1266, loss=44.0778 epoch=1267, loss=27.5000 epoch=1268, loss=43.9579 epoch=1269, loss=32.4862 epoch=1270, loss=34.7368 epoch=1271, loss=36.3911 epoch=1272, loss=41.4185 epoch=1273, loss=41.6621 epoch=1274, loss=48.0827 epoch=1275, loss=37.3628 epoch=1276, loss=32.3777 epoch=1277, loss=26.1825 epoch=1278, loss=23.6553 epoch=1279, loss=23.5205 epoch=1280, loss=37.0725 epoch=1281, loss=26.9658 epoch=1282, loss=36.7055 epoch=1283, loss=40.5629 epoch=1284, loss=34.3698 epoch=1285, loss=65.5482 epoch=1286, loss=25.3339 epoch=1287, loss=40.3472 epoch=1288, loss=35.8977 epoch=1289, loss=35.3034 epoch=1290, loss=24.6489 epoch=1291, loss=30.9076 epoch=1292, loss=32.2027 epoch=1293, loss=30.1449 epoch=1294, loss=38.2282 epoch=1295, loss=33.8271 epoch=1296, loss=36.1534 epoch=1297, loss=31.7768 epoch=1298, loss=31.2463 epoch=1299, loss=33.4478 epoch=1300, loss=32.8936 epoch=1301, loss=24.6013 epoch=1302, loss=29.7237 epoch=1303, loss=33.0355 epoch=1304, loss=29.6027 epoch=1305, loss=20.6213 epoch=1306, loss=39.2912 epoch=1307, loss=33.8585 epoch=1308, loss=27.9786 epoch=1309, loss=37.9520 epoch=1310, loss=27.3741 epoch=1311, loss=31.4076 epoch=1312, loss=43.2102 epoch=1313, loss=47.3487 epoch=1314, loss=23.7160 epoch=1315, loss=39.3521 epoch=1316, loss=34.5424 epoch=1317, loss=42.2042 epoch=1318, loss=48.8755 epoch=1319, loss=38.0471 epoch=1320, loss=45.6290 epoch=1321, loss=45.8212 epoch=1322, loss=25.6850 epoch=1323, loss=28.9088 epoch=1324, loss=24.9450 epoch=1325, loss=46.3202 epoch=1326, loss=38.0713 epoch=1327, loss=35.5920 epoch=1328, loss=36.7340 epoch=1329, loss=36.7376 epoch=1330, loss=33.7929 epoch=1331, loss=28.7036 epoch=1332, loss=34.7067 epoch=1333, loss=34.2267 epoch=1334, loss=28.4187 epoch=1335, loss=48.9013 epoch=1336, loss=39.2364 epoch=1337, loss=33.5179 epoch=1338, loss=27.4427 epoch=1339, loss=40.8810 epoch=1340, loss=32.8109 epoch=1341, loss=36.2374 epoch=1342, loss=34.9915 epoch=1343, loss=48.1464 epoch=1344, loss=35.1412 epoch=1345, loss=38.2447 epoch=1346, loss=32.2756 epoch=1347, loss=34.8562 epoch=1348, loss=32.7381 epoch=1349, loss=45.5488 epoch=1350, loss=42.3326 epoch=1351, loss=34.9011 epoch=1352, loss=27.3625 epoch=1353, loss=29.7940 epoch=1354, loss=30.0752 epoch=1355, loss=39.7541 epoch=1356, loss=21.0593 epoch=1357, loss=39.7403 epoch=1358, loss=25.7632 epoch=1359, loss=27.4025 epoch=1360, loss=42.4548 epoch=1361, loss=30.8795 epoch=1362, loss=34.3792 epoch=1363, loss=26.9746 epoch=1364, loss=46.9133 epoch=1365, loss=44.5003 epoch=1366, loss=34.8450 epoch=1367, loss=36.5696 epoch=1368, loss=28.2940 epoch=1369, loss=31.7687 epoch=1370, loss=31.1225 epoch=1371, loss=28.7888 epoch=1372, loss=41.3312 epoch=1373, loss=24.3718 epoch=1374, loss=30.1667 epoch=1375, loss=43.1038 epoch=1376, loss=35.0408 epoch=1377, loss=45.6852 epoch=1378, loss=23.8897 epoch=1379, loss=40.9500 epoch=1380, loss=27.0842 epoch=1381, loss=37.4781 epoch=1382, loss=27.9361 epoch=1383, loss=38.1767 epoch=1384, loss=30.1983 epoch=1385, loss=35.0150 epoch=1386, loss=29.0877 epoch=1387, loss=29.1462 epoch=1388, loss=38.7472 epoch=1389, loss=34.5128 epoch=1390, loss=26.4847 epoch=1391, loss=37.5297 epoch=1392, loss=46.5221 epoch=1393, loss=30.3338 epoch=1394, loss=43.5524 epoch=1395, loss=47.7259 epoch=1396, loss=38.2131 epoch=1397, loss=28.8958 epoch=1398, loss=38.0463 epoch=1399, loss=32.1844 epoch=1400, loss=31.8049 epoch=1401, loss=36.0688 epoch=1402, loss=36.3782 epoch=1403, loss=34.5230 epoch=1404, loss=31.9884 epoch=1405, loss=26.4623 epoch=1406, loss=28.9178 epoch=1407, loss=34.3496 epoch=1408, loss=42.2160 epoch=1409, loss=45.9547 epoch=1410, loss=28.7658 epoch=1411, loss=29.1827 epoch=1412, loss=38.7917 epoch=1413, loss=36.0446 epoch=1414, loss=28.3945 epoch=1415, loss=38.4445 epoch=1416, loss=48.1248 epoch=1417, loss=24.3291 epoch=1418, loss=32.1941 epoch=1419, loss=40.7814 epoch=1420, loss=34.9021 epoch=1421, loss=28.1727 epoch=1422, loss=33.6879 epoch=1423, loss=42.9682 epoch=1424, loss=28.9138 epoch=1425, loss=34.5130 epoch=1426, loss=28.2958 epoch=1427, loss=34.0243 epoch=1428, loss=52.9340 epoch=1429, loss=36.5021 epoch=1430, loss=29.9891 epoch=1431, loss=46.2166 epoch=1432, loss=37.6546 epoch=1433, loss=26.2562 epoch=1434, loss=29.1484 epoch=1435, loss=33.7357 epoch=1436, loss=35.7856 epoch=1437, loss=43.1520 epoch=1438, loss=23.5223 epoch=1439, loss=37.2715 epoch=1440, loss=37.2932 epoch=1441, loss=39.1635 epoch=1442, loss=39.8090 epoch=1443, loss=37.1292 epoch=1444, loss=42.5673 epoch=1445, loss=30.0301 epoch=1446, loss=40.1862 epoch=1447, loss=26.3388 epoch=1448, loss=26.1381 epoch=1449, loss=31.9983 epoch=1450, loss=35.8569 epoch=1451, loss=18.6316 epoch=1452, loss=35.4459 epoch=1453, loss=23.9700 epoch=1454, loss=28.1707 epoch=1455, loss=46.6944 epoch=1456, loss=36.3569 epoch=1457, loss=25.8404 epoch=1458, loss=37.0917 epoch=1459, loss=41.3789 epoch=1460, loss=31.4906 epoch=1461, loss=32.2165 epoch=1462, loss=35.4388 epoch=1463, loss=36.7359 epoch=1464, loss=54.2147 epoch=1465, loss=40.9424 epoch=1466, loss=35.2999 epoch=1467, loss=33.0847 epoch=1468, loss=37.4095 epoch=1469, loss=25.4988 epoch=1470, loss=37.7319 epoch=1471, loss=33.4607 epoch=1472, loss=29.9038 epoch=1473, loss=30.6572 epoch=1474, loss=32.1203 epoch=1475, loss=31.6620 epoch=1476, loss=32.0512 epoch=1477, loss=34.6661 epoch=1478, loss=26.7730 epoch=1479, loss=45.8692 epoch=1480, loss=34.8963 epoch=1481, loss=41.5471 epoch=1482, loss=55.0377 epoch=1483, loss=32.9687 epoch=1484, loss=36.4512 epoch=1485, loss=36.8798 epoch=1486, loss=28.7513 epoch=1487, loss=27.4820 epoch=1488, loss=33.4598 epoch=1489, loss=28.5997 epoch=1490, loss=34.9837 epoch=1491, loss=34.0515 epoch=1492, loss=34.0678 epoch=1493, loss=40.5191 epoch=1494, loss=40.1758 epoch=1495, loss=32.9277 epoch=1496, loss=35.6981 epoch=1497, loss=37.2294 epoch=1498, loss=25.2547 epoch=1499, loss=32.5763 epoch=1500, loss=19.6940 epoch=1501, loss=25.7013 epoch=1502, loss=33.3048 epoch=1503, loss=37.1762 epoch=1504, loss=28.7383 epoch=1505, loss=23.0517 epoch=1506, loss=44.1326 epoch=1507, loss=34.9054 epoch=1508, loss=45.7176 epoch=1509, loss=38.1900 epoch=1510, loss=39.1193 epoch=1511, loss=45.1884 epoch=1512, loss=31.4161 epoch=1513, loss=20.9701 epoch=1514, loss=26.8097 epoch=1515, loss=35.9537 epoch=1516, loss=40.7580 epoch=1517, loss=36.0942 epoch=1518, loss=36.3071 epoch=1519, loss=30.0816 epoch=1520, loss=37.9924 epoch=1521, loss=33.5203 epoch=1522, loss=28.2977 epoch=1523, loss=25.2874 epoch=1524, loss=34.9287 epoch=1525, loss=32.7701 epoch=1526, loss=29.0807 epoch=1527, loss=37.3032 epoch=1528, loss=30.2755 epoch=1529, loss=27.3569 epoch=1530, loss=36.5855 epoch=1531, loss=24.0321 epoch=1532, loss=47.2285 epoch=1533, loss=40.5877 epoch=1534, loss=39.0814 epoch=1535, loss=23.1268 epoch=1536, loss=29.5388 epoch=1537, loss=29.2759 epoch=1538, loss=33.0768 epoch=1539, loss=36.9392 epoch=1540, loss=44.2638 epoch=1541, loss=51.0828 epoch=1542, loss=31.2280 epoch=1543, loss=31.0411 epoch=1544, loss=29.3206 epoch=1545, loss=25.4009 epoch=1546, loss=29.0619 epoch=1547, loss=29.3662 epoch=1548, loss=43.8614 epoch=1549, loss=33.0834 epoch=1550, loss=38.1879 epoch=1551, loss=36.0901 epoch=1552, loss=33.1927 epoch=1553, loss=31.8870 epoch=1554, loss=41.4068 epoch=1555, loss=36.1541 epoch=1556, loss=31.7799 epoch=1557, loss=30.4052 epoch=1558, loss=27.6859 epoch=1559, loss=30.0883 epoch=1560, loss=37.6626 epoch=1561, loss=37.5391 epoch=1562, loss=37.2008 epoch=1563, loss=35.7859 epoch=1564, loss=31.0485 epoch=1565, loss=42.8876 epoch=1566, loss=33.3878 epoch=1567, loss=33.9048 epoch=1568, loss=28.5611 epoch=1569, loss=56.6481 epoch=1570, loss=37.7521 epoch=1571, loss=34.0539 epoch=1572, loss=38.2804 epoch=1573, loss=33.0363 epoch=1574, loss=49.9077 epoch=1575, loss=30.2550 epoch=1576, loss=34.8463 epoch=1577, loss=33.5766 epoch=1578, loss=29.2131 epoch=1579, loss=29.7863 epoch=1580, loss=29.3244 epoch=1581, loss=41.7618 epoch=1582, loss=31.1525 epoch=1583, loss=26.3711 epoch=1584, loss=31.2908 epoch=1585, loss=41.2849 epoch=1586, loss=30.1948 epoch=1587, loss=43.9187 epoch=1588, loss=30.0660 epoch=1589, loss=36.5963 epoch=1590, loss=39.6431 epoch=1591, loss=30.0314 epoch=1592, loss=33.8546 epoch=1593, loss=47.1560 epoch=1594, loss=34.4838 epoch=1595, loss=33.3091 epoch=1596, loss=36.6789 epoch=1597, loss=42.9379 epoch=1598, loss=37.5283 epoch=1599, loss=40.5685 epoch=1600, loss=28.9754 epoch=1601, loss=40.8618 epoch=1602, loss=40.8115 epoch=1603, loss=35.8314 epoch=1604, loss=29.7943 epoch=1605, loss=36.5864 epoch=1606, loss=34.8996 epoch=1607, loss=37.0557 epoch=1608, loss=25.0684 epoch=1609, loss=37.0592 epoch=1610, loss=31.3310 epoch=1611, loss=29.1245 epoch=1612, loss=45.6712 epoch=1613, loss=34.1055 epoch=1614, loss=43.7048 epoch=1615, loss=52.7717 epoch=1616, loss=32.6442 epoch=1617, loss=30.8302 epoch=1618, loss=38.3865 epoch=1619, loss=31.0302 epoch=1620, loss=36.1039 epoch=1621, loss=26.7771 epoch=1622, loss=23.1458 epoch=1623, loss=31.1163 epoch=1624, loss=35.4052 epoch=1625, loss=29.2810 epoch=1626, loss=26.3598 epoch=1627, loss=32.1139 epoch=1628, loss=30.6004 epoch=1629, loss=27.8413 epoch=1630, loss=32.1856 epoch=1631, loss=60.2945 epoch=1632, loss=32.8952 epoch=1633, loss=28.2247 epoch=1634, loss=40.7143 epoch=1635, loss=29.2681 epoch=1636, loss=34.8334 epoch=1637, loss=35.9697 epoch=1638, loss=25.3400 epoch=1639, loss=27.4441 epoch=1640, loss=53.5048 epoch=1641, loss=27.4171 epoch=1642, loss=28.0155 epoch=1643, loss=46.8164 epoch=1644, loss=28.9093 epoch=1645, loss=32.7978 epoch=1646, loss=50.6372 epoch=1647, loss=29.7534 epoch=1648, loss=30.3780 epoch=1649, loss=34.3235 epoch=1650, loss=29.1777 epoch=1651, loss=37.8256 epoch=1652, loss=34.4120 epoch=1653, loss=39.7161 epoch=1654, loss=36.0549 epoch=1655, loss=29.4830 epoch=1656, loss=35.1932 epoch=1657, loss=46.3637 epoch=1658, loss=22.3256 epoch=1659, loss=34.9382 epoch=1660, loss=29.0200 epoch=1661, loss=31.4598 epoch=1662, loss=40.1114 epoch=1663, loss=33.6175 epoch=1664, loss=30.2861 epoch=1665, loss=40.6111 epoch=1666, loss=58.0736 epoch=1667, loss=39.6817 epoch=1668, loss=36.3547 epoch=1669, loss=34.3968 epoch=1670, loss=45.9183 epoch=1671, loss=40.1887 epoch=1672, loss=43.8078 epoch=1673, loss=46.5437 epoch=1674, loss=23.7588 epoch=1675, loss=38.5523 epoch=1676, loss=40.9734 epoch=1677, loss=39.5176 epoch=1678, loss=31.7523 epoch=1679, loss=26.0209 epoch=1680, loss=28.3806 epoch=1681, loss=34.1911 epoch=1682, loss=35.0529 epoch=1683, loss=34.5082 epoch=1684, loss=31.7266 epoch=1685, loss=36.3510 epoch=1686, loss=50.4109 epoch=1687, loss=22.8710 epoch=1688, loss=58.3529 epoch=1689, loss=29.7358 epoch=1690, loss=39.1148 epoch=1691, loss=32.0888 epoch=1692, loss=25.9282 epoch=1693, loss=33.8056 epoch=1694, loss=31.5784 epoch=1695, loss=31.4391 epoch=1696, loss=30.5438 epoch=1697, loss=38.3535 epoch=1698, loss=34.4427 epoch=1699, loss=33.2623 epoch=1700, loss=31.7864 epoch=1701, loss=42.3128 epoch=1702, loss=38.4971 epoch=1703, loss=31.6685 epoch=1704, loss=34.9853 epoch=1705, loss=18.5508 epoch=1706, loss=40.1604 epoch=1707, loss=28.1061 epoch=1708, loss=37.5067 epoch=1709, loss=33.8658 epoch=1710, loss=30.8775 epoch=1711, loss=49.1571 epoch=1712, loss=28.7557 epoch=1713, loss=40.6502 epoch=1714, loss=39.2319 epoch=1715, loss=41.1720 epoch=1716, loss=28.6397 epoch=1717, loss=43.3470 epoch=1718, loss=31.4252 epoch=1719, loss=41.6643 epoch=1720, loss=31.6779 epoch=1721, loss=56.8492 epoch=1722, loss=36.2753 epoch=1723, loss=40.1690 epoch=1724, loss=42.5956 epoch=1725, loss=30.4588 epoch=1726, loss=52.1502 epoch=1727, loss=31.6561 epoch=1728, loss=32.3665 epoch=1729, loss=32.0891 epoch=1730, loss=29.9402 epoch=1731, loss=31.4298 epoch=1732, loss=45.5506 epoch=1733, loss=40.9761 epoch=1734, loss=27.5720 epoch=1735, loss=41.0700 epoch=1736, loss=40.5480 epoch=1737, loss=30.2543 epoch=1738, loss=45.0808 epoch=1739, loss=32.1093 epoch=1740, loss=52.5572 epoch=1741, loss=40.4674 epoch=1742, loss=43.8624 epoch=1743, loss=43.5267 epoch=1744, loss=25.0396 epoch=1745, loss=19.2728 epoch=1746, loss=38.8592 epoch=1747, loss=27.5246 epoch=1748, loss=38.7391 epoch=1749, loss=33.6800 epoch=1750, loss=33.0100 epoch=1751, loss=31.9837 epoch=1752, loss=43.6407 epoch=1753, loss=30.2375 epoch=1754, loss=43.9711 epoch=1755, loss=30.8167 epoch=1756, loss=32.8285 epoch=1757, loss=30.8861 epoch=1758, loss=34.9955 epoch=1759, loss=26.9971 epoch=1760, loss=46.0442 epoch=1761, loss=41.6299 epoch=1762, loss=37.3062 epoch=1763, loss=43.3887 epoch=1764, loss=28.2057 epoch=1765, loss=42.0531 epoch=1766, loss=42.6729 epoch=1767, loss=46.9954 epoch=1768, loss=34.0256 epoch=1769, loss=28.6126 epoch=1770, loss=36.1191 epoch=1771, loss=31.3382 epoch=1772, loss=32.0284 epoch=1773, loss=30.1594 epoch=1774, loss=35.4324 epoch=1775, loss=35.1956 epoch=1776, loss=31.3684 epoch=1777, loss=46.1910 epoch=1778, loss=27.7542 epoch=1779, loss=27.1898 epoch=1780, loss=21.9117 epoch=1781, loss=36.2531 epoch=1782, loss=43.3739 epoch=1783, loss=27.5341 epoch=1784, loss=24.9797 epoch=1785, loss=41.3120 epoch=1786, loss=34.2099 epoch=1787, loss=54.7183 epoch=1788, loss=30.1699 epoch=1789, loss=37.7671 epoch=1790, loss=47.7613 epoch=1791, loss=28.9987 epoch=1792, loss=28.2261 epoch=1793, loss=29.5456 epoch=1794, loss=23.8020 epoch=1795, loss=28.6554 epoch=1796, loss=27.5830 epoch=1797, loss=39.7724 epoch=1798, loss=37.2909 epoch=1799, loss=32.4331 epoch=1800, loss=28.8356 epoch=1801, loss=39.7177 epoch=1802, loss=32.5468 epoch=1803, loss=29.4262 epoch=1804, loss=26.1700 epoch=1805, loss=29.5171 epoch=1806, loss=51.8968 epoch=1807, loss=34.4221 epoch=1808, loss=33.7871 epoch=1809, loss=32.0702 epoch=1810, loss=31.3400 epoch=1811, loss=27.9518 epoch=1812, loss=39.6344 epoch=1813, loss=35.9245 epoch=1814, loss=42.7269 epoch=1815, loss=37.4334 epoch=1816, loss=35.6348 epoch=1817, loss=40.1131 epoch=1818, loss=38.9279 epoch=1819, loss=35.3195 epoch=1820, loss=48.8344 epoch=1821, loss=30.6186 epoch=1822, loss=36.9466 epoch=1823, loss=28.7751 epoch=1824, loss=27.8019 epoch=1825, loss=39.9999 epoch=1826, loss=29.6051 epoch=1827, loss=35.7877 epoch=1828, loss=35.4209 epoch=1829, loss=39.5553 epoch=1830, loss=37.8764 epoch=1831, loss=42.5659 epoch=1832, loss=30.1320 epoch=1833, loss=46.6477 epoch=1834, loss=38.4669 epoch=1835, loss=25.2448 epoch=1836, loss=33.5906 epoch=1837, loss=31.8231 epoch=1838, loss=43.5205 epoch=1839, loss=30.8798 epoch=1840, loss=32.5020 epoch=1841, loss=34.9660 epoch=1842, loss=32.1129 epoch=1843, loss=35.1489 epoch=1844, loss=35.3684 epoch=1845, loss=34.0432 epoch=1846, loss=47.0925 epoch=1847, loss=36.6252 epoch=1848, loss=35.7320 epoch=1849, loss=37.8639 epoch=1850, loss=32.5404 epoch=1851, loss=46.5566 epoch=1852, loss=31.8804 epoch=1853, loss=31.8740 epoch=1854, loss=28.8555 epoch=1855, loss=39.6632 epoch=1856, loss=35.9543 epoch=1857, loss=33.4105 epoch=1858, loss=42.2088 epoch=1859, loss=30.4224 epoch=1860, loss=24.3301 epoch=1861, loss=37.5396 epoch=1862, loss=24.0663 epoch=1863, loss=46.6297 epoch=1864, loss=25.2938 epoch=1865, loss=52.2606 epoch=1866, loss=33.0870 epoch=1867, loss=29.6333 epoch=1868, loss=34.0967 epoch=1869, loss=46.1038 epoch=1870, loss=44.0608 epoch=1871, loss=42.7982 epoch=1872, loss=30.8490 epoch=1873, loss=35.0194 epoch=1874, loss=25.9710 epoch=1875, loss=35.7427 epoch=1876, loss=35.8160 epoch=1877, loss=43.2334 epoch=1878, loss=26.2880 epoch=1879, loss=32.7721 epoch=1880, loss=34.7454 epoch=1881, loss=36.5345 epoch=1882, loss=39.8898 epoch=1883, loss=34.8573 epoch=1884, loss=22.9588 epoch=1885, loss=28.2082 epoch=1886, loss=47.4568 epoch=1887, loss=36.7360 epoch=1888, loss=32.4228 epoch=1889, loss=27.7855 epoch=1890, loss=26.3805 epoch=1891, loss=31.8914 epoch=1892, loss=45.6464 epoch=1893, loss=30.4516 epoch=1894, loss=26.3911 epoch=1895, loss=28.8419 epoch=1896, loss=37.7788 epoch=1897, loss=37.7130 epoch=1898, loss=40.2391 epoch=1899, loss=49.8949 epoch=1900, loss=33.5475 epoch=1901, loss=36.0506 epoch=1902, loss=38.4378 epoch=1903, loss=32.5662 epoch=1904, loss=33.0551 epoch=1905, loss=32.6487 epoch=1906, loss=32.2130 epoch=1907, loss=44.5138 epoch=1908, loss=25.4023 epoch=1909, loss=37.0736 epoch=1910, loss=39.4626 epoch=1911, loss=34.6423 epoch=1912, loss=30.6426 epoch=1913, loss=46.5408 epoch=1914, loss=31.7788 epoch=1915, loss=36.8644 epoch=1916, loss=40.8513 epoch=1917, loss=31.2167 epoch=1918, loss=29.0489 epoch=1919, loss=36.5178 epoch=1920, loss=26.8812 epoch=1921, loss=44.3916 epoch=1922, loss=33.4103 epoch=1923, loss=43.9588 epoch=1924, loss=43.0485 epoch=1925, loss=32.9726 epoch=1926, loss=39.2526 epoch=1927, loss=33.8366 epoch=1928, loss=29.3006 epoch=1929, loss=25.8917 epoch=1930, loss=46.2496 epoch=1931, loss=35.0287 epoch=1932, loss=39.5581 epoch=1933, loss=43.1998 epoch=1934, loss=43.7446 epoch=1935, loss=50.5408 epoch=1936, loss=38.4567 epoch=1937, loss=33.6322 epoch=1938, loss=40.0772 epoch=1939, loss=32.5416 epoch=1940, loss=28.5272 epoch=1941, loss=19.7723 epoch=1942, loss=43.1280 epoch=1943, loss=24.6360 epoch=1944, loss=27.3806 epoch=1945, loss=43.2488 epoch=1946, loss=32.7586 epoch=1947, loss=32.0554 epoch=1948, loss=34.7635 epoch=1949, loss=34.7666 epoch=1950, loss=26.2582 epoch=1951, loss=39.6517 epoch=1952, loss=31.0023 epoch=1953, loss=41.1355 epoch=1954, loss=24.0179 epoch=1955, loss=35.2110 epoch=1956, loss=41.1189 epoch=1957, loss=36.2977 epoch=1958, loss=43.1321 epoch=1959, loss=38.5753 epoch=1960, loss=38.8746 epoch=1961, loss=36.1261 epoch=1962, loss=28.0169 epoch=1963, loss=26.2852 epoch=1964, loss=41.9416 epoch=1965, loss=40.1778 epoch=1966, loss=40.0545 epoch=1967, loss=37.0752 epoch=1968, loss=27.1719 epoch=1969, loss=27.8567 epoch=1970, loss=33.7178 epoch=1971, loss=35.1780 epoch=1972, loss=33.4344 epoch=1973, loss=30.2588 epoch=1974, loss=41.6152 epoch=1975, loss=31.6988 epoch=1976, loss=40.2418 epoch=1977, loss=40.8906 epoch=1978, loss=25.8763 epoch=1979, loss=28.6479 epoch=1980, loss=26.2841 epoch=1981, loss=37.0122 epoch=1982, loss=36.8730 epoch=1983, loss=39.5089 epoch=1984, loss=19.0150 epoch=1985, loss=37.0378 epoch=1986, loss=45.1676 epoch=1987, loss=34.5720 epoch=1988, loss=46.5605 epoch=1989, loss=44.1837 epoch=1990, loss=42.8904 epoch=1991, loss=47.9173 epoch=1992, loss=37.9717 epoch=1993, loss=32.5989 epoch=1994, loss=27.9967 epoch=1995, loss=21.6590 epoch=1996, loss=32.9677 epoch=1997, loss=34.2821 epoch=1998, loss=29.9108 epoch=1999, loss=32.4362 .. GENERATED FROM PYTHON SOURCE LINES 315-324 .. code-block:: default inducing_index_points_history = history.pop("inducing_index_points") variational_inducing_observations_loc_history = ( history.pop("variational_inducing_observations_loc")) inducing_index_points = inducing_index_points_history[-1] variational_inducing_observations_loc = ( variational_inducing_observations_loc_history[-1]) .. GENERATED FROM PYTHON SOURCE LINES 325-326 Log density ratio, log-odds, or logits. .. GENERATED FROM PYTHON SOURCE LINES 326-348 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, model(X_q).mean().numpy().T, label="posterior mean") fill_between_stddev(X_q.squeeze(), model(X_q).mean().numpy().squeeze(), model(X_q).stddev().numpy().squeeze(), alpha=0.1, label="posterior std dev", ax=ax) ax.scatter(inducing_index_points, np.full_like(inducing_index_points, -3.5), marker='^', c="tab:gray", label="inducing inputs", alpha=0.4) ax.scatter(inducing_index_points, variational_inducing_observations_loc, marker='+', c="tab:blue", label="inducing variable mean") ax.set_xlabel(r"$x$") ax.set_ylabel(r"$\log \lambda(x)$") ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_002.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 349-352 .. code-block:: default Z_q = scaler.inverse_transform(X_q) .. GENERATED FROM PYTHON SOURCE LINES 353-357 .. code-block:: default d = tfd.Independent(tfd.LogNormal(loc=model(X_q).mean(), scale=model(X_q).stddev()), reinterpreted_batch_ndims=1) .. GENERATED FROM PYTHON SOURCE LINES 358-359 Density ratio. .. GENERATED FROM PYTHON SOURCE LINES 359-380 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, d.mean().numpy().T, label="transformed posterior mean") fill_between_stddev(X_q.squeeze(), d.mean().numpy().squeeze(), d.stddev().numpy().squeeze(), alpha=0.1, label="transformed posterior std dev", ax=ax) ax.vlines(X.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y) ax.set_xlabel('$x$') ax.set_ylim(y_min, y_max) ax.set_xlabel(r"$x$") ax.set_ylabel(r"$\lambda(x)$") ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_003.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 381-382 Predictive mean samples. .. GENERATED FROM PYTHON SOURCE LINES 382-386 .. code-block:: default posterior_predictive = tf.keras.Sequential([ model, tfp.layers.IndependentPoisson(event_shape=(num_index_points,)) ]) .. GENERATED FROM PYTHON SOURCE LINES 387-400 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, posterior_predictive(X_q).mean()) ax.vlines(X.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y) ax.set_xlabel('$x$') ax.set_ylim(y_min, y_max) # ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_004.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 401-413 .. code-block:: default def make_posterior_predictive(num_samples=None, seed=None): def posterior_predictive(x): f_samples = model(x).sample(num_samples, seed=seed) return make_poisson_likelihood(f=f_samples) return posterior_predictive .. GENERATED FROM PYTHON SOURCE LINES 414-418 .. code-block:: default posterior_predictive = make_posterior_predictive(num_samples, seed=seed) .. GENERATED FROM PYTHON SOURCE LINES 419-433 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_q, posterior_predictive(X_q).mean().numpy().T, color="tab:blue", linewidth=0.8, alpha=0.6) ax.vlines(X.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y) ax.set_xlabel('$x$') ax.set_ylim(y_min, y_max) # ax.legend() plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_005.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 434-447 .. code-block:: default def get_inducing_index_points_data(inducing_index_points): df = pd.DataFrame(np.hstack(inducing_index_points).T) df.index.name = "epoch" df.columns.name = "inducing index points" s = df.stack() s.name = 'x' return s.reset_index() .. GENERATED FROM PYTHON SOURCE LINES 448-452 .. code-block:: default data = get_inducing_index_points_data(inducing_index_points_history) .. GENERATED FROM PYTHON SOURCE LINES 453-464 .. code-block:: default fig, ax = plt.subplots() sns.lineplot(x='x', y="epoch", hue="inducing index points", palette="viridis", sort=False, data=data, alpha=0.8, ax=ax) ax.set_xlabel(r'$x$') plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_006.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 465-469 .. code-block:: default variational_inducing_observations_scale_history = ( history.pop("variational_inducing_observations_scale")) .. GENERATED FROM PYTHON SOURCE LINES 470-487 .. code-block:: default fig, (ax1, ax2) = plt.subplots(ncols=2, sharex=True, sharey=True) im1 = ax1.imshow(variational_inducing_observations_scale_history[0], vmin=-0.1, vmax=1.1) im2 = ax2.imshow(variational_inducing_observations_scale_history[-1], vmin=-0.1, vmax=1.1) fig.colorbar(im2, ax=[ax1, ax2], extend="both", orientation="horizontal") ax1.set_xlabel(r"$i$") ax1.set_ylabel(r"$j$") ax2.set_xlabel(r"$i$") plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_007.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 488-493 .. code-block:: default history_df = pd.DataFrame(history) history_df.index.name = "epoch" history_df.reset_index(inplace=True) .. GENERATED FROM PYTHON SOURCE LINES 494-502 .. code-block:: default fig, ax = plt.subplots() sns.lineplot(x="epoch", y="nelbo", data=history_df, alpha=0.8, ax=ax) ax.set_yscale("log") plt.show() .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_008.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 503-507 .. code-block:: default parameters_df = history_df.drop(columns="nelbo") \ .rename(columns=lambda s: s.replace('_', ' ')) .. GENERATED FROM PYTHON SOURCE LINES 508-511 .. code-block:: default g = sns.PairGrid(parameters_df, hue="epoch", palette="RdYlBu", corner=True) g = g.map_lower(plt.scatter, facecolor="none", alpha=0.6) .. image:: /auto_examples/gaussian_processes/images/sphx_glr_plot_sparse_log_cox_gaussian_process_keras_009.png :alt: plot sparse log cox gaussian process keras :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 42.131 seconds) .. _sphx_glr_download_auto_examples_gaussian_processes_plot_sparse_log_cox_gaussian_process_keras.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sparse_log_cox_gaussian_process_keras.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sparse_log_cox_gaussian_process_keras.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_