Note
Click here to download the full example code
Variational Sparse Log Cox Gaussian Process¶
Here we fit the hyperparameters of a Gaussian Process by maximizing the (log) marginal likelihood. This is commonly referred to as empirical Bayes, or type-II maximum likelihood estimation.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tensorflow.keras.layers import Layer, InputLayer
from tensorflow.keras.initializers import Identity, Constant
from sklearn.preprocessing import MinMaxScaler
from etudes.datasets import coal_mining_disasters_load_data
from etudes.plotting import fill_between_stddev
from etudes.utils import get_kl_weight
from collections import defaultdict
# shortcuts
tfd = tfp.distributions
kernels = tfp.math.psd_kernels
# constants
num_train = 2048 # nbr training points in synthetic dataset
num_test = 40
num_features = 1 # dimensionality
num_index_points = 256 # nbr of index points
num_samples = 25
quadrature_size = 20
num_inducing_points = 50
num_epochs = 2000
batch_size = 64
shuffle_buffer_size = 500
jitter = 1e-6
kernel_cls = kernels.MaternFiveHalves
seed = 8888 # set random seed for reproducibility
random_state = np.random.RandomState(seed)
x_min, x_max = 0.0, 1.0
y_min, y_max = -0.05, 0.7
# index points
X_q = np.linspace(x_min, x_max, num_index_points).reshape(-1, num_features)
Coal mining disasters dataset¶
scaler = MinMaxScaler()
Z, y = coal_mining_disasters_load_data(base_dir="../../datasets/")
X = scaler.fit_transform(Z)
y = y.astype(np.float64)
Probability densities
fig, ax = plt.subplots()
ax.vlines(Z.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y)
ax.set_ylim(-0.05, 0.8)
ax.set_xlabel("days")
ax.set_ylabel("incidents")
plt.show()
Encapsulate Variational Gaussian Process (particular variable initialization) in a Keras / TensorFlow Probability Mixin Layer. Clean and simple if we restrict to single-output (event_shape = ()) and feature_ndim = 1 (i.e. inputs are simply vectors rather than matrices or tensors).
class VariationalGaussianProcess1D(tfp.layers.DistributionLambda):
def __init__(self, kernel_wrapper, num_inducing_points,
inducing_index_points_initializer, mean_fn=None, jitter=1e-6,
convert_to_tensor_fn=tfd.Distribution.sample, **kwargs):
def make_distribution(x):
return VariationalGaussianProcess1D.new(
x, kernel_wrapper=self.kernel_wrapper,
inducing_index_points=self.inducing_index_points,
variational_inducing_observations_loc=(
self.variational_inducing_observations_loc),
variational_inducing_observations_scale=(
self.variational_inducing_observations_scale),
mean_fn=self.mean_fn,
observation_noise_variance=tf.exp(
self.log_observation_noise_variance),
jitter=self.jitter)
super(VariationalGaussianProcess1D, self).__init__(
make_distribution_fn=make_distribution,
convert_to_tensor_fn=convert_to_tensor_fn,
dtype=kernel_wrapper.dtype)
self.kernel_wrapper = kernel_wrapper
self.inducing_index_points_initializer = inducing_index_points_initializer
self.num_inducing_points = num_inducing_points
self.mean_fn = mean_fn
self.jitter = jitter
self._dtype = self.kernel_wrapper.dtype
def build(self, input_shape):
input_dim = input_shape[-1]
# TODO: Fix initialization!
self.inducing_index_points = self.add_weight(
name="inducing_index_points",
shape=(self.num_inducing_points, input_dim),
initializer=self.inducing_index_points_initializer,
dtype=self.dtype)
self.variational_inducing_observations_loc = self.add_weight(
name="variational_inducing_observations_loc",
shape=(self.num_inducing_points,),
initializer="zeros", dtype=self.dtype)
self.variational_inducing_observations_scale = self.add_weight(
name="variational_inducing_observations_scale",
shape=(self.num_inducing_points, self.num_inducing_points),
initializer=Identity(gain=1.0), dtype=self.dtype)
self.log_observation_noise_variance = self.add_weight(
name="log_observation_noise_variance",
initializer=Constant(-5.0), dtype=self.dtype)
@staticmethod
def new(x, kernel_wrapper, inducing_index_points, mean_fn,
variational_inducing_observations_loc,
variational_inducing_observations_scale,
observation_noise_variance, jitter, name=None):
# ind = tfd.Independent(base, reinterpreted_batch_ndims=1)
# bijector = tfp.bijectors.Transpose(rightmost_transposed_ndims=2)
# d = tfd.TransformedDistribution(ind, bijector=bijector)
return tfd.VariationalGaussianProcess(
kernel=kernel_wrapper.kernel, index_points=x,
inducing_index_points=inducing_index_points,
variational_inducing_observations_loc=(
variational_inducing_observations_loc),
variational_inducing_observations_scale=(
variational_inducing_observations_scale),
mean_fn=mean_fn,
observation_noise_variance=observation_noise_variance,
jitter=jitter)
Kernel wrapper layer
class KernelWrapper(Layer):
# TODO: Support automatic relevance determination
def __init__(self, kernel_cls=kernels.ExponentiatedQuadratic,
dtype=None, **kwargs):
super(KernelWrapper, self).__init__(dtype=dtype, **kwargs)
self.kernel_cls = kernel_cls
self.log_amplitude = self.add_weight(
name="log_amplitude",
initializer="zeros", dtype=dtype)
self.log_length_scale = self.add_weight(
name="log_length_scale",
initializer="zeros", dtype=dtype)
def call(self, x):
# Never called -- this is just a layer so it can hold variables
# in a way Keras understands.
return x
@property
def kernel(self):
return self.kernel_cls(amplitude=tf.exp(self.log_amplitude),
length_scale=tf.exp(self.log_length_scale))
Poisson likelihood.
def make_poisson_likelihood(f):
return tfd.Independent(tfd.Poisson(log_rate=f),
reinterpreted_batch_ndims=1)
def log_likelihood(y, f):
likelihood = make_poisson_likelihood(f)
return likelihood.log_prob(y)
Helper Model factory method.
def build_model(input_dim, jitter=1e-6):
inducing_index_points_initial = random_state.choice(X.squeeze(),
num_inducing_points) \
.reshape(-1, num_features)
inducing_index_points_initializer = (
tf.constant_initializer(inducing_index_points_initial))
return tf.keras.Sequential([
InputLayer(input_shape=(input_dim,)),
VariationalGaussianProcess1D(
kernel_wrapper=KernelWrapper(kernel_cls=kernel_cls,
dtype=tf.float64),
num_inducing_points=num_inducing_points,
inducing_index_points_initializer=inducing_index_points_initializer,
jitter=jitter)
])
model = build_model(input_dim=num_features, jitter=jitter)
optimizer = tf.keras.optimizers.Adam()
Out:
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/distributions/gaussian_process.py:311: UserWarning: Unable to detect statically whether the number of index_points is 1. As a result, defaulting to treating the marginal GP at `index_points` as a multivariate Gaussian. This makes some methods, like `cdf` unavailable.
'Unable to detect statically whether the number of index_points is '
@tf.function
def nelbo(X_batch, y_batch):
qf = model(X_batch)
ell = qf.surrogate_posterior_expected_log_likelihood(
observations=y_batch,
log_likelihood_fn=log_likelihood,
quadrature_size=quadrature_size)
kl = qf.surrogate_posterior_kl_divergence_prior()
kl_weight = get_kl_weight(num_train, batch_size)
return - ell + kl_weight * kl
@tf.function
def train_step(X_batch, y_batch):
with tf.GradientTape() as tape:
loss = nelbo(X_batch, y_batch)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
dataset = tf.data.Dataset.from_tensor_slices((X, y)) \
.shuffle(seed=seed, buffer_size=shuffle_buffer_size) \
.batch(batch_size, drop_remainder=True)
keys = ["inducing_index_points",
"variational_inducing_observations_loc",
"variational_inducing_observations_scale",
"log_observation_noise_variance",
"log_amplitude", "log_length_scale"]
history = defaultdict(list)
for epoch in range(num_epochs):
for step, (X_batch, y_batch) in enumerate(dataset):
loss = train_step(X_batch, y_batch)
print("epoch={epoch:04d}, loss={loss:.4f}"
.format(epoch=epoch, loss=loss.numpy()))
history["nelbo"].append(loss.numpy())
for key, tensor in zip(keys, model.get_weights()):
history[key].append(tensor)
Out:
epoch=0000, loss=503718.0016
epoch=0001, loss=475124.6687
epoch=0002, loss=449182.0495
epoch=0003, loss=424675.8714
epoch=0004, loss=400130.0089
epoch=0005, loss=377441.1128
epoch=0006, loss=355909.6056
epoch=0007, loss=335670.5052
epoch=0008, loss=316633.7488
epoch=0009, loss=298586.9228
epoch=0010, loss=281445.5673
epoch=0011, loss=265165.5582
epoch=0012, loss=249702.8881
epoch=0013, loss=235009.1906
epoch=0014, loss=221047.5870
epoch=0015, loss=207782.8945
epoch=0016, loss=195182.2574
epoch=0017, loss=183225.4848
epoch=0018, loss=171885.0519
epoch=0019, loss=161128.6349
epoch=0020, loss=150953.2861
epoch=0021, loss=141324.5595
epoch=0022, loss=132227.2517
epoch=0023, loss=123644.0793
epoch=0024, loss=115545.0756
epoch=0025, loss=107924.3415
epoch=0026, loss=100754.6539
epoch=0027, loss=94018.2818
epoch=0028, loss=87691.2610
epoch=0029, loss=81765.3992
epoch=0030, loss=76209.4874
epoch=0031, loss=71007.1965
epoch=0032, loss=66139.6203
epoch=0033, loss=61591.6754
epoch=0034, loss=57337.9688
epoch=0035, loss=53372.9955
epoch=0036, loss=49671.7292
epoch=0037, loss=46239.5650
epoch=0038, loss=43051.8211
epoch=0039, loss=40093.4649
epoch=0040, loss=37367.3047
epoch=0041, loss=34844.5363
epoch=0042, loss=32521.0955
epoch=0043, loss=30376.2433
epoch=0044, loss=28401.9147
epoch=0045, loss=26589.4699
epoch=0046, loss=24918.4878
epoch=0047, loss=23379.7795
epoch=0048, loss=21963.8164
epoch=0049, loss=20661.9284
epoch=0050, loss=19461.8934
epoch=0051, loss=18350.8827
epoch=0052, loss=17330.1188
epoch=0053, loss=16388.4615
epoch=0054, loss=15513.8981
epoch=0055, loss=14703.6937
epoch=0056, loss=13951.6997
epoch=0057, loss=13263.4179
epoch=0058, loss=12618.2228
epoch=0059, loss=12017.4802
epoch=0060, loss=11457.0539
epoch=0061, loss=10938.7167
epoch=0062, loss=10452.8857
epoch=0063, loss=9995.8923
epoch=0064, loss=9572.3354
epoch=0065, loss=9176.7417
epoch=0066, loss=8802.5129
epoch=0067, loss=8456.5162
epoch=0068, loss=8121.0805
epoch=0069, loss=7811.6885
epoch=0070, loss=7518.2041
epoch=0071, loss=7243.9737
epoch=0072, loss=6983.4403
epoch=0073, loss=6738.4308
epoch=0074, loss=6505.0204
epoch=0075, loss=6280.9981
epoch=0076, loss=6073.2072
epoch=0077, loss=5878.0479
epoch=0078, loss=5692.3076
epoch=0079, loss=5507.7018
epoch=0080, loss=5332.9441
epoch=0081, loss=5175.1431
epoch=0082, loss=5018.8195
epoch=0083, loss=4878.9884
epoch=0084, loss=4732.3058
epoch=0085, loss=4601.0597
epoch=0086, loss=4469.2106
epoch=0087, loss=4349.7136
epoch=0088, loss=4237.6983
epoch=0089, loss=4116.0311
epoch=0090, loss=4006.7375
epoch=0091, loss=3904.2931
epoch=0092, loss=3799.8214
epoch=0093, loss=3706.6880
epoch=0094, loss=3612.3212
epoch=0095, loss=3531.6876
epoch=0096, loss=3440.0799
epoch=0097, loss=3364.8440
epoch=0098, loss=3280.9135
epoch=0099, loss=3216.7467
epoch=0100, loss=3130.4951
epoch=0101, loss=3059.3164
epoch=0102, loss=2990.5427
epoch=0103, loss=2924.2639
epoch=0104, loss=2874.4743
epoch=0105, loss=2803.1734
epoch=0106, loss=2747.5532
epoch=0107, loss=2683.1805
epoch=0108, loss=2636.1819
epoch=0109, loss=2576.8733
epoch=0110, loss=2518.7064
epoch=0111, loss=2480.2028
epoch=0112, loss=2430.1288
epoch=0113, loss=2376.4723
epoch=0114, loss=2334.9129
epoch=0115, loss=2289.6594
epoch=0116, loss=2243.3222
epoch=0117, loss=2204.1842
epoch=0118, loss=2162.5270
epoch=0119, loss=2126.7130
epoch=0120, loss=2084.3528
epoch=0121, loss=2039.8391
epoch=0122, loss=2011.7281
epoch=0123, loss=1976.2229
epoch=0124, loss=1942.0064
epoch=0125, loss=1901.3520
epoch=0126, loss=1879.5601
epoch=0127, loss=1842.2692
epoch=0128, loss=1810.3070
epoch=0129, loss=1783.4963
epoch=0130, loss=1749.7177
epoch=0131, loss=1728.7218
epoch=0132, loss=1696.4940
epoch=0133, loss=1667.3415
epoch=0134, loss=1646.6158
epoch=0135, loss=1617.1468
epoch=0136, loss=1595.6100
epoch=0137, loss=1562.6681
epoch=0138, loss=1545.7072
epoch=0139, loss=1511.5130
epoch=0140, loss=1489.9823
epoch=0141, loss=1472.1741
epoch=0142, loss=1459.4709
epoch=0143, loss=1427.1879
epoch=0144, loss=1410.0700
epoch=0145, loss=1392.4305
epoch=0146, loss=1358.6774
epoch=0147, loss=1356.2028
epoch=0148, loss=1326.9513
epoch=0149, loss=1311.4088
epoch=0150, loss=1296.7058
epoch=0151, loss=1276.8852
epoch=0152, loss=1254.6550
epoch=0153, loss=1239.3621
epoch=0154, loss=1230.0371
epoch=0155, loss=1209.9762
epoch=0156, loss=1187.4667
epoch=0157, loss=1177.2099
epoch=0158, loss=1154.8790
epoch=0159, loss=1141.5911
epoch=0160, loss=1127.3600
epoch=0161, loss=1123.1947
epoch=0162, loss=1105.7010
epoch=0163, loss=1084.5280
epoch=0164, loss=1076.3502
epoch=0165, loss=1051.0518
epoch=0166, loss=1052.1513
epoch=0167, loss=1029.4153
epoch=0168, loss=1012.4560
epoch=0169, loss=1006.7754
epoch=0170, loss=994.1713
epoch=0171, loss=979.2590
epoch=0172, loss=965.2443
epoch=0173, loss=967.2261
epoch=0174, loss=947.9689
epoch=0175, loss=941.7749
epoch=0176, loss=930.2486
epoch=0177, loss=913.6552
epoch=0178, loss=905.7848
epoch=0179, loss=895.1102
epoch=0180, loss=882.7923
epoch=0181, loss=878.9853
epoch=0182, loss=855.4525
epoch=0183, loss=848.5580
epoch=0184, loss=854.4321
epoch=0185, loss=837.1378
epoch=0186, loss=820.3715
epoch=0187, loss=810.3195
epoch=0188, loss=813.9388
epoch=0189, loss=795.8079
epoch=0190, loss=797.2823
epoch=0191, loss=778.9154
epoch=0192, loss=767.1228
epoch=0193, loss=765.8746
epoch=0194, loss=751.7439
epoch=0195, loss=735.2551
epoch=0196, loss=736.6673
epoch=0197, loss=724.3290
epoch=0198, loss=734.2114
epoch=0199, loss=712.6057
epoch=0200, loss=712.9040
epoch=0201, loss=697.5288
epoch=0202, loss=704.0655
epoch=0203, loss=692.6003
epoch=0204, loss=685.3066
epoch=0205, loss=674.1011
epoch=0206, loss=665.7722
epoch=0207, loss=654.3170
epoch=0208, loss=654.4745
epoch=0209, loss=639.0899
epoch=0210, loss=638.7625
epoch=0211, loss=623.6700
epoch=0212, loss=634.3182
epoch=0213, loss=613.1934
epoch=0214, loss=615.6673
epoch=0215, loss=611.3304
epoch=0216, loss=604.6876
epoch=0217, loss=594.3960
epoch=0218, loss=594.2117
epoch=0219, loss=573.6360
epoch=0220, loss=577.7584
epoch=0221, loss=570.8481
epoch=0222, loss=560.8061
epoch=0223, loss=573.5160
epoch=0224, loss=552.8636
epoch=0225, loss=549.9146
epoch=0226, loss=546.9251
epoch=0227, loss=532.6548
epoch=0228, loss=539.0560
epoch=0229, loss=528.0483
epoch=0230, loss=522.8893
epoch=0231, loss=522.7825
epoch=0232, loss=510.8447
epoch=0233, loss=509.5694
epoch=0234, loss=516.3404
epoch=0235, loss=511.4087
epoch=0236, loss=494.5026
epoch=0237, loss=503.1661
epoch=0238, loss=492.0830
epoch=0239, loss=479.2352
epoch=0240, loss=475.4360
epoch=0241, loss=475.0072
epoch=0242, loss=465.4977
epoch=0243, loss=475.3879
epoch=0244, loss=461.6061
epoch=0245, loss=461.3810
epoch=0246, loss=448.1367
epoch=0247, loss=462.1360
epoch=0248, loss=452.7104
epoch=0249, loss=442.0678
epoch=0250, loss=446.5259
epoch=0251, loss=429.6645
epoch=0252, loss=432.3712
epoch=0253, loss=432.3731
epoch=0254, loss=428.3459
epoch=0255, loss=427.4889
epoch=0256, loss=415.5806
epoch=0257, loss=408.1475
epoch=0258, loss=398.6794
epoch=0259, loss=418.0139
epoch=0260, loss=407.5253
epoch=0261, loss=403.9192
epoch=0262, loss=393.8357
epoch=0263, loss=388.2981
epoch=0264, loss=392.3414
epoch=0265, loss=395.5534
epoch=0266, loss=390.2889
epoch=0267, loss=385.2252
epoch=0268, loss=378.7956
epoch=0269, loss=381.6068
epoch=0270, loss=369.0336
epoch=0271, loss=378.1374
epoch=0272, loss=370.4357
epoch=0273, loss=366.2946
epoch=0274, loss=352.2604
epoch=0275, loss=372.7704
epoch=0276, loss=356.5688
epoch=0277, loss=352.5583
epoch=0278, loss=350.4578
epoch=0279, loss=350.4543
epoch=0280, loss=342.7281
epoch=0281, loss=341.7187
epoch=0282, loss=337.6962
epoch=0283, loss=336.7313
epoch=0284, loss=348.5399
epoch=0285, loss=331.6232
epoch=0286, loss=320.0296
epoch=0287, loss=349.6291
epoch=0288, loss=330.3633
epoch=0289, loss=321.4751
epoch=0290, loss=308.3491
epoch=0291, loss=333.5266
epoch=0292, loss=316.8104
epoch=0293, loss=319.8715
epoch=0294, loss=320.6826
epoch=0295, loss=304.9669
epoch=0296, loss=301.5010
epoch=0297, loss=306.3954
epoch=0298, loss=303.1771
epoch=0299, loss=308.6068
epoch=0300, loss=303.4759
epoch=0301, loss=291.6490
epoch=0302, loss=294.1003
epoch=0303, loss=293.1295
epoch=0304, loss=288.0081
epoch=0305, loss=290.3008
epoch=0306, loss=289.7989
epoch=0307, loss=282.0794
epoch=0308, loss=284.0127
epoch=0309, loss=278.3577
epoch=0310, loss=283.5266
epoch=0311, loss=270.4285
epoch=0312, loss=292.5535
epoch=0313, loss=269.1271
epoch=0314, loss=267.3524
epoch=0315, loss=282.8746
epoch=0316, loss=267.9802
epoch=0317, loss=255.7895
epoch=0318, loss=263.4444
epoch=0319, loss=262.7765
epoch=0320, loss=269.5424
epoch=0321, loss=247.4822
epoch=0322, loss=245.4533
epoch=0323, loss=251.4737
epoch=0324, loss=256.4097
epoch=0325, loss=248.1391
epoch=0326, loss=245.7778
epoch=0327, loss=239.1502
epoch=0328, loss=231.2579
epoch=0329, loss=231.9704
epoch=0330, loss=230.7200
epoch=0331, loss=244.4598
epoch=0332, loss=228.0484
epoch=0333, loss=232.7435
epoch=0334, loss=231.0531
epoch=0335, loss=229.9803
epoch=0336, loss=237.8903
epoch=0337, loss=242.0139
epoch=0338, loss=240.8350
epoch=0339, loss=228.8982
epoch=0340, loss=227.9539
epoch=0341, loss=221.3632
epoch=0342, loss=226.5502
epoch=0343, loss=225.5825
epoch=0344, loss=221.7293
epoch=0345, loss=220.6527
epoch=0346, loss=224.6299
epoch=0347, loss=218.0827
epoch=0348, loss=207.6847
epoch=0349, loss=215.5031
epoch=0350, loss=206.6249
epoch=0351, loss=219.5271
epoch=0352, loss=209.6862
epoch=0353, loss=207.3360
epoch=0354, loss=197.6009
epoch=0355, loss=205.0703
epoch=0356, loss=206.5529
epoch=0357, loss=194.7957
epoch=0358, loss=197.0495
epoch=0359, loss=194.9799
epoch=0360, loss=193.0936
epoch=0361, loss=201.0962
epoch=0362, loss=194.8401
epoch=0363, loss=187.5098
epoch=0364, loss=196.7941
epoch=0365, loss=203.5231
epoch=0366, loss=196.1433
epoch=0367, loss=184.3774
epoch=0368, loss=190.3704
epoch=0369, loss=183.6483
epoch=0370, loss=187.4339
epoch=0371, loss=195.7501
epoch=0372, loss=175.1536
epoch=0373, loss=188.3401
epoch=0374, loss=188.6235
epoch=0375, loss=177.8545
epoch=0376, loss=188.3338
epoch=0377, loss=195.2243
epoch=0378, loss=176.8593
epoch=0379, loss=169.6181
epoch=0380, loss=190.8619
epoch=0381, loss=164.9291
epoch=0382, loss=182.7612
epoch=0383, loss=183.4093
epoch=0384, loss=185.6163
epoch=0385, loss=177.8770
epoch=0386, loss=166.0440
epoch=0387, loss=173.9915
epoch=0388, loss=174.3544
epoch=0389, loss=166.0017
epoch=0390, loss=152.6976
epoch=0391, loss=160.4803
epoch=0392, loss=172.6038
epoch=0393, loss=156.3226
epoch=0394, loss=159.6409
epoch=0395, loss=166.7436
epoch=0396, loss=159.3685
epoch=0397, loss=154.3180
epoch=0398, loss=167.5721
epoch=0399, loss=163.5109
epoch=0400, loss=171.3971
epoch=0401, loss=160.6082
epoch=0402, loss=145.6805
epoch=0403, loss=155.7041
epoch=0404, loss=156.1276
epoch=0405, loss=152.6160
epoch=0406, loss=162.6592
epoch=0407, loss=142.9762
epoch=0408, loss=159.5509
epoch=0409, loss=151.7371
epoch=0410, loss=152.2857
epoch=0411, loss=147.3106
epoch=0412, loss=154.5352
epoch=0413, loss=146.3088
epoch=0414, loss=152.1442
epoch=0415, loss=145.5159
epoch=0416, loss=156.1471
epoch=0417, loss=144.5401
epoch=0418, loss=143.0060
epoch=0419, loss=138.8478
epoch=0420, loss=144.5675
epoch=0421, loss=141.7152
epoch=0422, loss=152.1052
epoch=0423, loss=135.4647
epoch=0424, loss=144.8311
epoch=0425, loss=136.1836
epoch=0426, loss=136.0509
epoch=0427, loss=131.0169
epoch=0428, loss=130.8113
epoch=0429, loss=140.1047
epoch=0430, loss=135.6104
epoch=0431, loss=139.6994
epoch=0432, loss=135.0475
epoch=0433, loss=124.3657
epoch=0434, loss=128.8028
epoch=0435, loss=135.1245
epoch=0436, loss=126.3713
epoch=0437, loss=126.9186
epoch=0438, loss=127.0762
epoch=0439, loss=132.2666
epoch=0440, loss=117.0638
epoch=0441, loss=131.8005
epoch=0442, loss=121.6616
epoch=0443, loss=127.8722
epoch=0444, loss=132.6087
epoch=0445, loss=119.4095
epoch=0446, loss=120.3532
epoch=0447, loss=124.0202
epoch=0448, loss=119.2436
epoch=0449, loss=117.3810
epoch=0450, loss=125.5332
epoch=0451, loss=116.9169
epoch=0452, loss=117.7422
epoch=0453, loss=118.8801
epoch=0454, loss=125.7930
epoch=0455, loss=127.4539
epoch=0456, loss=110.2093
epoch=0457, loss=115.6084
epoch=0458, loss=139.7687
epoch=0459, loss=122.2508
epoch=0460, loss=114.1477
epoch=0461, loss=125.7758
epoch=0462, loss=122.3530
epoch=0463, loss=118.7015
epoch=0464, loss=118.7259
epoch=0465, loss=105.7424
epoch=0466, loss=128.8635
epoch=0467, loss=126.3125
epoch=0468, loss=132.5865
epoch=0469, loss=105.5933
epoch=0470, loss=107.0141
epoch=0471, loss=109.4016
epoch=0472, loss=104.4746
epoch=0473, loss=103.8465
epoch=0474, loss=95.5309
epoch=0475, loss=109.6063
epoch=0476, loss=113.3386
epoch=0477, loss=99.2031
epoch=0478, loss=111.3675
epoch=0479, loss=106.6856
epoch=0480, loss=121.9445
epoch=0481, loss=117.0523
epoch=0482, loss=119.6875
epoch=0483, loss=112.8985
epoch=0484, loss=110.3964
epoch=0485, loss=119.5072
epoch=0486, loss=104.1824
epoch=0487, loss=108.1442
epoch=0488, loss=93.3669
epoch=0489, loss=104.0165
epoch=0490, loss=101.3833
epoch=0491, loss=109.5744
epoch=0492, loss=100.7609
epoch=0493, loss=97.6215
epoch=0494, loss=107.8976
epoch=0495, loss=91.9219
epoch=0496, loss=103.6455
epoch=0497, loss=100.3046
epoch=0498, loss=87.8629
epoch=0499, loss=102.6298
epoch=0500, loss=97.4680
epoch=0501, loss=93.2479
epoch=0502, loss=102.4228
epoch=0503, loss=99.8412
epoch=0504, loss=95.0297
epoch=0505, loss=110.5724
epoch=0506, loss=92.3500
epoch=0507, loss=93.5608
epoch=0508, loss=104.3107
epoch=0509, loss=90.2806
epoch=0510, loss=99.3023
epoch=0511, loss=89.7588
epoch=0512, loss=88.0880
epoch=0513, loss=102.7510
epoch=0514, loss=85.3267
epoch=0515, loss=83.1560
epoch=0516, loss=86.8379
epoch=0517, loss=84.2890
epoch=0518, loss=84.6089
epoch=0519, loss=90.2681
epoch=0520, loss=98.1327
epoch=0521, loss=94.0007
epoch=0522, loss=89.3772
epoch=0523, loss=82.4729
epoch=0524, loss=86.4087
epoch=0525, loss=91.0402
epoch=0526, loss=87.6877
epoch=0527, loss=87.4722
epoch=0528, loss=81.0174
epoch=0529, loss=82.1088
epoch=0530, loss=88.7508
epoch=0531, loss=102.3531
epoch=0532, loss=84.8533
epoch=0533, loss=94.7518
epoch=0534, loss=86.7407
epoch=0535, loss=83.1817
epoch=0536, loss=79.8289
epoch=0537, loss=94.1217
epoch=0538, loss=82.4183
epoch=0539, loss=88.1212
epoch=0540, loss=79.6697
epoch=0541, loss=82.3980
epoch=0542, loss=85.3741
epoch=0543, loss=75.5057
epoch=0544, loss=79.6781
epoch=0545, loss=68.5751
epoch=0546, loss=86.9400
epoch=0547, loss=91.1889
epoch=0548, loss=82.6861
epoch=0549, loss=78.8321
epoch=0550, loss=86.1787
epoch=0551, loss=88.3150
epoch=0552, loss=82.9925
epoch=0553, loss=81.7292
epoch=0554, loss=66.1165
epoch=0555, loss=84.6366
epoch=0556, loss=92.2174
epoch=0557, loss=83.1902
epoch=0558, loss=74.3282
epoch=0559, loss=89.9474
epoch=0560, loss=89.7887
epoch=0561, loss=86.8922
epoch=0562, loss=76.5947
epoch=0563, loss=71.9305
epoch=0564, loss=84.6352
epoch=0565, loss=86.1365
epoch=0566, loss=89.1902
epoch=0567, loss=87.1468
epoch=0568, loss=66.7683
epoch=0569, loss=68.3575
epoch=0570, loss=73.0146
epoch=0571, loss=67.2949
epoch=0572, loss=78.2781
epoch=0573, loss=74.8641
epoch=0574, loss=89.5783
epoch=0575, loss=83.2646
epoch=0576, loss=75.5487
epoch=0577, loss=71.8853
epoch=0578, loss=62.4595
epoch=0579, loss=75.7681
epoch=0580, loss=75.3446
epoch=0581, loss=60.4439
epoch=0582, loss=80.5004
epoch=0583, loss=76.4536
epoch=0584, loss=64.6321
epoch=0585, loss=62.7810
epoch=0586, loss=69.3928
epoch=0587, loss=66.0748
epoch=0588, loss=67.1042
epoch=0589, loss=67.3756
epoch=0590, loss=62.8932
epoch=0591, loss=71.8032
epoch=0592, loss=70.5042
epoch=0593, loss=63.6476
epoch=0594, loss=57.6988
epoch=0595, loss=59.8144
epoch=0596, loss=74.8017
epoch=0597, loss=81.8788
epoch=0598, loss=77.6706
epoch=0599, loss=68.4422
epoch=0600, loss=64.2554
epoch=0601, loss=65.3775
epoch=0602, loss=64.8156
epoch=0603, loss=67.2893
epoch=0604, loss=62.5748
epoch=0605, loss=66.9849
epoch=0606, loss=65.8462
epoch=0607, loss=75.3093
epoch=0608, loss=79.2133
epoch=0609, loss=69.7162
epoch=0610, loss=59.6053
epoch=0611, loss=80.3660
epoch=0612, loss=63.0438
epoch=0613, loss=58.2483
epoch=0614, loss=59.0454
epoch=0615, loss=58.6386
epoch=0616, loss=58.9345
epoch=0617, loss=77.5252
epoch=0618, loss=79.6157
epoch=0619, loss=55.6701
epoch=0620, loss=79.2709
epoch=0621, loss=64.5792
epoch=0622, loss=71.3175
epoch=0623, loss=63.5462
epoch=0624, loss=71.6826
epoch=0625, loss=58.9421
epoch=0626, loss=66.4716
epoch=0627, loss=62.3693
epoch=0628, loss=63.5122
epoch=0629, loss=69.9511
epoch=0630, loss=47.5289
epoch=0631, loss=57.6975
epoch=0632, loss=67.8547
epoch=0633, loss=58.8649
epoch=0634, loss=56.6639
epoch=0635, loss=69.4698
epoch=0636, loss=53.4113
epoch=0637, loss=59.6261
epoch=0638, loss=64.1503
epoch=0639, loss=64.2238
epoch=0640, loss=67.4349
epoch=0641, loss=68.9233
epoch=0642, loss=58.3884
epoch=0643, loss=52.5806
epoch=0644, loss=67.4768
epoch=0645, loss=55.4877
epoch=0646, loss=53.7203
epoch=0647, loss=58.3251
epoch=0648, loss=59.4838
epoch=0649, loss=59.8746
epoch=0650, loss=58.7315
epoch=0651, loss=79.7500
epoch=0652, loss=58.5452
epoch=0653, loss=60.7954
epoch=0654, loss=54.9007
epoch=0655, loss=65.4627
epoch=0656, loss=61.3134
epoch=0657, loss=53.3750
epoch=0658, loss=54.0795
epoch=0659, loss=72.4057
epoch=0660, loss=53.6076
epoch=0661, loss=81.3088
epoch=0662, loss=58.5682
epoch=0663, loss=58.1461
epoch=0664, loss=57.2060
epoch=0665, loss=71.0399
epoch=0666, loss=65.0347
epoch=0667, loss=49.6852
epoch=0668, loss=53.1331
epoch=0669, loss=50.6020
epoch=0670, loss=60.3321
epoch=0671, loss=61.5717
epoch=0672, loss=53.9444
epoch=0673, loss=51.0058
epoch=0674, loss=49.6867
epoch=0675, loss=49.3432
epoch=0676, loss=47.9196
epoch=0677, loss=68.2624
epoch=0678, loss=57.4226
epoch=0679, loss=62.1165
epoch=0680, loss=46.3444
epoch=0681, loss=50.4172
epoch=0682, loss=49.5411
epoch=0683, loss=62.4633
epoch=0684, loss=50.0264
epoch=0685, loss=58.2416
epoch=0686, loss=50.6320
epoch=0687, loss=49.0191
epoch=0688, loss=48.2821
epoch=0689, loss=42.7373
epoch=0690, loss=58.3171
epoch=0691, loss=55.5937
epoch=0692, loss=44.9044
epoch=0693, loss=55.7095
epoch=0694, loss=56.4527
epoch=0695, loss=44.5670
epoch=0696, loss=51.8081
epoch=0697, loss=50.0144
epoch=0698, loss=58.4083
epoch=0699, loss=48.1501
epoch=0700, loss=54.2906
epoch=0701, loss=53.6182
epoch=0702, loss=39.3437
epoch=0703, loss=61.8583
epoch=0704, loss=58.8785
epoch=0705, loss=46.4308
epoch=0706, loss=62.9105
epoch=0707, loss=56.4528
epoch=0708, loss=55.8068
epoch=0709, loss=43.4011
epoch=0710, loss=44.8983
epoch=0711, loss=44.0255
epoch=0712, loss=54.1520
epoch=0713, loss=46.9686
epoch=0714, loss=45.5972
epoch=0715, loss=49.8532
epoch=0716, loss=43.7185
epoch=0717, loss=49.9352
epoch=0718, loss=60.2397
epoch=0719, loss=52.1619
epoch=0720, loss=46.1009
epoch=0721, loss=53.7403
epoch=0722, loss=43.8176
epoch=0723, loss=50.5373
epoch=0724, loss=56.5087
epoch=0725, loss=51.5442
epoch=0726, loss=55.1669
epoch=0727, loss=40.8989
epoch=0728, loss=48.9067
epoch=0729, loss=45.8273
epoch=0730, loss=44.7350
epoch=0731, loss=46.5871
epoch=0732, loss=41.1744
epoch=0733, loss=84.7797
epoch=0734, loss=48.4741
epoch=0735, loss=39.9638
epoch=0736, loss=41.6540
epoch=0737, loss=61.9457
epoch=0738, loss=46.0518
epoch=0739, loss=52.9721
epoch=0740, loss=40.0342
epoch=0741, loss=43.8131
epoch=0742, loss=44.1711
epoch=0743, loss=51.0378
epoch=0744, loss=47.2733
epoch=0745, loss=46.5394
epoch=0746, loss=35.1426
epoch=0747, loss=54.2852
epoch=0748, loss=51.4293
epoch=0749, loss=50.2110
epoch=0750, loss=49.1398
epoch=0751, loss=45.6933
epoch=0752, loss=57.4035
epoch=0753, loss=54.6156
epoch=0754, loss=44.5137
epoch=0755, loss=54.1820
epoch=0756, loss=39.1424
epoch=0757, loss=54.4399
epoch=0758, loss=36.3383
epoch=0759, loss=49.9852
epoch=0760, loss=53.9937
epoch=0761, loss=43.7791
epoch=0762, loss=46.2205
epoch=0763, loss=62.5270
epoch=0764, loss=59.7127
epoch=0765, loss=39.0664
epoch=0766, loss=44.7098
epoch=0767, loss=43.4895
epoch=0768, loss=49.4405
epoch=0769, loss=54.8502
epoch=0770, loss=42.0123
epoch=0771, loss=62.3534
epoch=0772, loss=38.1141
epoch=0773, loss=43.5459
epoch=0774, loss=41.6324
epoch=0775, loss=48.1968
epoch=0776, loss=46.0691
epoch=0777, loss=51.6692
epoch=0778, loss=46.6518
epoch=0779, loss=63.3883
epoch=0780, loss=52.5965
epoch=0781, loss=41.4272
epoch=0782, loss=48.3477
epoch=0783, loss=43.6215
epoch=0784, loss=44.7427
epoch=0785, loss=51.7438
epoch=0786, loss=32.6199
epoch=0787, loss=40.2166
epoch=0788, loss=62.7569
epoch=0789, loss=52.6731
epoch=0790, loss=46.1543
epoch=0791, loss=53.3033
epoch=0792, loss=44.2406
epoch=0793, loss=49.6839
epoch=0794, loss=37.8175
epoch=0795, loss=44.3829
epoch=0796, loss=48.2948
epoch=0797, loss=48.3759
epoch=0798, loss=47.0026
epoch=0799, loss=34.5370
epoch=0800, loss=49.3600
epoch=0801, loss=42.6382
epoch=0802, loss=31.4616
epoch=0803, loss=30.3515
epoch=0804, loss=38.8228
epoch=0805, loss=48.9507
epoch=0806, loss=46.5755
epoch=0807, loss=46.1173
epoch=0808, loss=48.8414
epoch=0809, loss=58.1899
epoch=0810, loss=66.8658
epoch=0811, loss=47.4377
epoch=0812, loss=47.3466
epoch=0813, loss=48.9937
epoch=0814, loss=39.4674
epoch=0815, loss=44.8714
epoch=0816, loss=43.2242
epoch=0817, loss=45.2122
epoch=0818, loss=38.3916
epoch=0819, loss=36.9446
epoch=0820, loss=54.9395
epoch=0821, loss=42.1571
epoch=0822, loss=46.8208
epoch=0823, loss=49.8680
epoch=0824, loss=31.9886
epoch=0825, loss=44.2595
epoch=0826, loss=47.7277
epoch=0827, loss=40.7836
epoch=0828, loss=50.3694
epoch=0829, loss=31.0777
epoch=0830, loss=51.3054
epoch=0831, loss=45.2598
epoch=0832, loss=45.5807
epoch=0833, loss=38.6982
epoch=0834, loss=54.1841
epoch=0835, loss=37.5695
epoch=0836, loss=50.5953
epoch=0837, loss=55.2999
epoch=0838, loss=61.0726
epoch=0839, loss=41.8046
epoch=0840, loss=57.9198
epoch=0841, loss=50.9733
epoch=0842, loss=52.0090
epoch=0843, loss=49.6375
epoch=0844, loss=33.1304
epoch=0845, loss=35.4482
epoch=0846, loss=40.9187
epoch=0847, loss=32.0274
epoch=0848, loss=54.8925
epoch=0849, loss=55.4495
epoch=0850, loss=39.9136
epoch=0851, loss=39.5183
epoch=0852, loss=44.2735
epoch=0853, loss=42.9308
epoch=0854, loss=33.3962
epoch=0855, loss=42.6717
epoch=0856, loss=33.2201
epoch=0857, loss=44.9332
epoch=0858, loss=42.4207
epoch=0859, loss=43.9101
epoch=0860, loss=42.7964
epoch=0861, loss=28.6696
epoch=0862, loss=42.3850
epoch=0863, loss=45.0092
epoch=0864, loss=38.7697
epoch=0865, loss=38.6989
epoch=0866, loss=42.1009
epoch=0867, loss=44.3899
epoch=0868, loss=55.5179
epoch=0869, loss=37.4072
epoch=0870, loss=49.8416
epoch=0871, loss=22.9407
epoch=0872, loss=53.3703
epoch=0873, loss=49.7300
epoch=0874, loss=41.9710
epoch=0875, loss=33.5178
epoch=0876, loss=40.1948
epoch=0877, loss=38.5868
epoch=0878, loss=48.0192
epoch=0879, loss=56.0664
epoch=0880, loss=29.9777
epoch=0881, loss=53.9331
epoch=0882, loss=35.7819
epoch=0883, loss=46.7845
epoch=0884, loss=51.1198
epoch=0885, loss=41.6742
epoch=0886, loss=30.7241
epoch=0887, loss=50.6927
epoch=0888, loss=38.3675
epoch=0889, loss=40.7640
epoch=0890, loss=44.0895
epoch=0891, loss=43.9981
epoch=0892, loss=31.3808
epoch=0893, loss=41.1945
epoch=0894, loss=30.9972
epoch=0895, loss=42.1725
epoch=0896, loss=36.7186
epoch=0897, loss=48.3122
epoch=0898, loss=36.7973
epoch=0899, loss=32.2036
epoch=0900, loss=31.9240
epoch=0901, loss=30.0664
epoch=0902, loss=39.6483
epoch=0903, loss=49.0898
epoch=0904, loss=43.9462
epoch=0905, loss=46.5315
epoch=0906, loss=35.1399
epoch=0907, loss=43.3344
epoch=0908, loss=39.6814
epoch=0909, loss=49.7920
epoch=0910, loss=42.8614
epoch=0911, loss=42.2364
epoch=0912, loss=29.1351
epoch=0913, loss=36.6785
epoch=0914, loss=42.9751
epoch=0915, loss=39.4084
epoch=0916, loss=22.3465
epoch=0917, loss=36.7195
epoch=0918, loss=39.6980
epoch=0919, loss=39.0278
epoch=0920, loss=29.5661
epoch=0921, loss=36.6211
epoch=0922, loss=39.0927
epoch=0923, loss=53.0615
epoch=0924, loss=54.3734
epoch=0925, loss=49.3880
epoch=0926, loss=44.8807
epoch=0927, loss=30.4433
epoch=0928, loss=30.8365
epoch=0929, loss=43.8575
epoch=0930, loss=40.3110
epoch=0931, loss=39.0222
epoch=0932, loss=35.6125
epoch=0933, loss=40.0691
epoch=0934, loss=37.4746
epoch=0935, loss=50.4772
epoch=0936, loss=35.9687
epoch=0937, loss=43.3802
epoch=0938, loss=35.6311
epoch=0939, loss=34.9448
epoch=0940, loss=39.2193
epoch=0941, loss=47.5750
epoch=0942, loss=40.5120
epoch=0943, loss=41.5842
epoch=0944, loss=36.9813
epoch=0945, loss=39.4889
epoch=0946, loss=40.2740
epoch=0947, loss=54.9169
epoch=0948, loss=38.1466
epoch=0949, loss=38.7406
epoch=0950, loss=39.3057
epoch=0951, loss=43.0153
epoch=0952, loss=29.6520
epoch=0953, loss=35.1728
epoch=0954, loss=29.8932
epoch=0955, loss=28.1536
epoch=0956, loss=38.3106
epoch=0957, loss=38.9041
epoch=0958, loss=44.8865
epoch=0959, loss=34.7237
epoch=0960, loss=35.5248
epoch=0961, loss=31.2927
epoch=0962, loss=45.6466
epoch=0963, loss=49.5796
epoch=0964, loss=38.5878
epoch=0965, loss=43.0185
epoch=0966, loss=41.3775
epoch=0967, loss=35.3072
epoch=0968, loss=40.4946
epoch=0969, loss=38.3144
epoch=0970, loss=37.1825
epoch=0971, loss=33.5633
epoch=0972, loss=44.8774
epoch=0973, loss=30.5078
epoch=0974, loss=54.0708
epoch=0975, loss=40.4760
epoch=0976, loss=44.6554
epoch=0977, loss=43.8926
epoch=0978, loss=35.7388
epoch=0979, loss=31.7607
epoch=0980, loss=41.0082
epoch=0981, loss=54.1730
epoch=0982, loss=35.5671
epoch=0983, loss=54.8984
epoch=0984, loss=32.4584
epoch=0985, loss=36.8033
epoch=0986, loss=28.5383
epoch=0987, loss=31.1247
epoch=0988, loss=44.6059
epoch=0989, loss=28.4877
epoch=0990, loss=47.7916
epoch=0991, loss=41.5727
epoch=0992, loss=25.4704
epoch=0993, loss=27.3044
epoch=0994, loss=40.2768
epoch=0995, loss=44.7830
epoch=0996, loss=25.4001
epoch=0997, loss=45.7636
epoch=0998, loss=24.7614
epoch=0999, loss=35.8817
epoch=1000, loss=31.5747
epoch=1001, loss=46.7773
epoch=1002, loss=38.6838
epoch=1003, loss=43.0371
epoch=1004, loss=41.6749
epoch=1005, loss=36.8893
epoch=1006, loss=41.3862
epoch=1007, loss=41.6232
epoch=1008, loss=29.2862
epoch=1009, loss=34.4208
epoch=1010, loss=50.0407
epoch=1011, loss=41.5424
epoch=1012, loss=32.5476
epoch=1013, loss=45.9401
epoch=1014, loss=43.9288
epoch=1015, loss=28.2446
epoch=1016, loss=37.2122
epoch=1017, loss=38.2516
epoch=1018, loss=25.0566
epoch=1019, loss=26.2989
epoch=1020, loss=46.4798
epoch=1021, loss=30.7958
epoch=1022, loss=38.4176
epoch=1023, loss=36.8788
epoch=1024, loss=33.6224
epoch=1025, loss=44.1153
epoch=1026, loss=49.2197
epoch=1027, loss=33.1222
epoch=1028, loss=42.2189
epoch=1029, loss=31.0787
epoch=1030, loss=47.6945
epoch=1031, loss=36.4322
epoch=1032, loss=44.9050
epoch=1033, loss=37.7846
epoch=1034, loss=55.0392
epoch=1035, loss=43.4899
epoch=1036, loss=32.6500
epoch=1037, loss=46.7317
epoch=1038, loss=26.9375
epoch=1039, loss=33.0094
epoch=1040, loss=60.7153
epoch=1041, loss=45.5859
epoch=1042, loss=42.2432
epoch=1043, loss=50.3978
epoch=1044, loss=36.0496
epoch=1045, loss=39.4132
epoch=1046, loss=38.4170
epoch=1047, loss=35.5436
epoch=1048, loss=34.5298
epoch=1049, loss=47.9556
epoch=1050, loss=29.9514
epoch=1051, loss=41.5883
epoch=1052, loss=40.4844
epoch=1053, loss=32.1343
epoch=1054, loss=28.9982
epoch=1055, loss=40.7955
epoch=1056, loss=33.8277
epoch=1057, loss=37.2696
epoch=1058, loss=32.4221
epoch=1059, loss=35.5141
epoch=1060, loss=46.4895
epoch=1061, loss=33.7861
epoch=1062, loss=37.9355
epoch=1063, loss=28.9689
epoch=1064, loss=37.1848
epoch=1065, loss=34.9959
epoch=1066, loss=34.3465
epoch=1067, loss=54.1800
epoch=1068, loss=33.4773
epoch=1069, loss=37.8777
epoch=1070, loss=29.8906
epoch=1071, loss=36.3135
epoch=1072, loss=32.1539
epoch=1073, loss=53.2417
epoch=1074, loss=35.5554
epoch=1075, loss=36.0650
epoch=1076, loss=35.3531
epoch=1077, loss=34.6156
epoch=1078, loss=50.0658
epoch=1079, loss=32.4242
epoch=1080, loss=43.8290
epoch=1081, loss=33.9779
epoch=1082, loss=45.4978
epoch=1083, loss=36.2925
epoch=1084, loss=50.7296
epoch=1085, loss=29.3651
epoch=1086, loss=41.1270
epoch=1087, loss=43.6603
epoch=1088, loss=26.5921
epoch=1089, loss=31.1718
epoch=1090, loss=45.6769
epoch=1091, loss=21.7876
epoch=1092, loss=40.9696
epoch=1093, loss=43.9557
epoch=1094, loss=34.7761
epoch=1095, loss=39.4589
epoch=1096, loss=50.8497
epoch=1097, loss=39.5860
epoch=1098, loss=33.2066
epoch=1099, loss=39.5427
epoch=1100, loss=32.6381
epoch=1101, loss=37.2633
epoch=1102, loss=32.2884
epoch=1103, loss=38.1605
epoch=1104, loss=30.7495
epoch=1105, loss=35.5737
epoch=1106, loss=30.1897
epoch=1107, loss=38.1585
epoch=1108, loss=28.4092
epoch=1109, loss=29.7329
epoch=1110, loss=37.1669
epoch=1111, loss=29.4979
epoch=1112, loss=35.8428
epoch=1113, loss=34.8263
epoch=1114, loss=47.5728
epoch=1115, loss=37.1924
epoch=1116, loss=46.1044
epoch=1117, loss=34.6153
epoch=1118, loss=61.9499
epoch=1119, loss=35.0889
epoch=1120, loss=22.9839
epoch=1121, loss=29.6082
epoch=1122, loss=34.8566
epoch=1123, loss=35.1996
epoch=1124, loss=32.4029
epoch=1125, loss=38.2590
epoch=1126, loss=34.7926
epoch=1127, loss=30.0309
epoch=1128, loss=33.1623
epoch=1129, loss=36.3337
epoch=1130, loss=48.0137
epoch=1131, loss=30.0274
epoch=1132, loss=28.3043
epoch=1133, loss=23.4344
epoch=1134, loss=23.2265
epoch=1135, loss=40.6957
epoch=1136, loss=33.3754
epoch=1137, loss=36.1965
epoch=1138, loss=41.8181
epoch=1139, loss=25.8868
epoch=1140, loss=28.9432
epoch=1141, loss=33.7793
epoch=1142, loss=35.4136
epoch=1143, loss=29.5980
epoch=1144, loss=42.2540
epoch=1145, loss=26.7558
epoch=1146, loss=37.7473
epoch=1147, loss=37.1238
epoch=1148, loss=32.0596
epoch=1149, loss=39.6587
epoch=1150, loss=28.4213
epoch=1151, loss=28.8102
epoch=1152, loss=39.1638
epoch=1153, loss=27.3062
epoch=1154, loss=42.6782
epoch=1155, loss=32.8738
epoch=1156, loss=25.7113
epoch=1157, loss=34.6483
epoch=1158, loss=26.3806
epoch=1159, loss=31.2331
epoch=1160, loss=50.5903
epoch=1161, loss=24.9333
epoch=1162, loss=36.7008
epoch=1163, loss=37.0426
epoch=1164, loss=32.0863
epoch=1165, loss=24.1840
epoch=1166, loss=38.1325
epoch=1167, loss=38.7330
epoch=1168, loss=48.2531
epoch=1169, loss=45.8834
epoch=1170, loss=32.5161
epoch=1171, loss=36.7577
epoch=1172, loss=36.6310
epoch=1173, loss=32.6013
epoch=1174, loss=29.5227
epoch=1175, loss=35.8207
epoch=1176, loss=35.8144
epoch=1177, loss=43.7553
epoch=1178, loss=31.8759
epoch=1179, loss=32.9536
epoch=1180, loss=34.0385
epoch=1181, loss=33.6179
epoch=1182, loss=48.9506
epoch=1183, loss=39.0245
epoch=1184, loss=44.1334
epoch=1185, loss=40.7664
epoch=1186, loss=33.6860
epoch=1187, loss=27.1803
epoch=1188, loss=22.2602
epoch=1189, loss=40.0744
epoch=1190, loss=34.6431
epoch=1191, loss=32.6041
epoch=1192, loss=37.8065
epoch=1193, loss=39.2743
epoch=1194, loss=28.5010
epoch=1195, loss=36.9890
epoch=1196, loss=38.7774
epoch=1197, loss=40.2295
epoch=1198, loss=45.9734
epoch=1199, loss=32.7846
epoch=1200, loss=35.4219
epoch=1201, loss=31.6757
epoch=1202, loss=20.8077
epoch=1203, loss=42.8190
epoch=1204, loss=35.4590
epoch=1205, loss=33.9555
epoch=1206, loss=38.9901
epoch=1207, loss=35.0332
epoch=1208, loss=27.4473
epoch=1209, loss=20.5175
epoch=1210, loss=34.2019
epoch=1211, loss=41.0577
epoch=1212, loss=34.4536
epoch=1213, loss=28.7619
epoch=1214, loss=35.7453
epoch=1215, loss=31.2096
epoch=1216, loss=35.4969
epoch=1217, loss=32.5666
epoch=1218, loss=40.1697
epoch=1219, loss=21.2583
epoch=1220, loss=36.9934
epoch=1221, loss=33.8079
epoch=1222, loss=30.8306
epoch=1223, loss=39.5826
epoch=1224, loss=40.4954
epoch=1225, loss=27.8841
epoch=1226, loss=32.9384
epoch=1227, loss=32.3499
epoch=1228, loss=26.3457
epoch=1229, loss=34.5355
epoch=1230, loss=45.5821
epoch=1231, loss=40.1966
epoch=1232, loss=19.0265
epoch=1233, loss=46.5220
epoch=1234, loss=43.6520
epoch=1235, loss=36.3505
epoch=1236, loss=34.3163
epoch=1237, loss=44.1153
epoch=1238, loss=39.0215
epoch=1239, loss=35.1996
epoch=1240, loss=40.9362
epoch=1241, loss=40.2023
epoch=1242, loss=28.3928
epoch=1243, loss=24.3768
epoch=1244, loss=40.4758
epoch=1245, loss=27.4730
epoch=1246, loss=24.2012
epoch=1247, loss=46.7035
epoch=1248, loss=43.2430
epoch=1249, loss=29.7535
epoch=1250, loss=21.3884
epoch=1251, loss=33.8545
epoch=1252, loss=25.7920
epoch=1253, loss=37.9737
epoch=1254, loss=38.1511
epoch=1255, loss=42.9844
epoch=1256, loss=34.4310
epoch=1257, loss=24.3918
epoch=1258, loss=22.3792
epoch=1259, loss=35.8306
epoch=1260, loss=41.1900
epoch=1261, loss=53.3297
epoch=1262, loss=39.2182
epoch=1263, loss=42.3259
epoch=1264, loss=38.3113
epoch=1265, loss=32.0263
epoch=1266, loss=44.0778
epoch=1267, loss=27.5000
epoch=1268, loss=43.9579
epoch=1269, loss=32.4862
epoch=1270, loss=34.7368
epoch=1271, loss=36.3911
epoch=1272, loss=41.4185
epoch=1273, loss=41.6621
epoch=1274, loss=48.0827
epoch=1275, loss=37.3628
epoch=1276, loss=32.3777
epoch=1277, loss=26.1825
epoch=1278, loss=23.6553
epoch=1279, loss=23.5205
epoch=1280, loss=37.0725
epoch=1281, loss=26.9658
epoch=1282, loss=36.7055
epoch=1283, loss=40.5629
epoch=1284, loss=34.3698
epoch=1285, loss=65.5482
epoch=1286, loss=25.3339
epoch=1287, loss=40.3472
epoch=1288, loss=35.8977
epoch=1289, loss=35.3034
epoch=1290, loss=24.6489
epoch=1291, loss=30.9076
epoch=1292, loss=32.2027
epoch=1293, loss=30.1449
epoch=1294, loss=38.2282
epoch=1295, loss=33.8271
epoch=1296, loss=36.1534
epoch=1297, loss=31.7768
epoch=1298, loss=31.2463
epoch=1299, loss=33.4478
epoch=1300, loss=32.8936
epoch=1301, loss=24.6013
epoch=1302, loss=29.7237
epoch=1303, loss=33.0355
epoch=1304, loss=29.6027
epoch=1305, loss=20.6213
epoch=1306, loss=39.2912
epoch=1307, loss=33.8585
epoch=1308, loss=27.9786
epoch=1309, loss=37.9520
epoch=1310, loss=27.3741
epoch=1311, loss=31.4076
epoch=1312, loss=43.2102
epoch=1313, loss=47.3487
epoch=1314, loss=23.7160
epoch=1315, loss=39.3521
epoch=1316, loss=34.5424
epoch=1317, loss=42.2042
epoch=1318, loss=48.8755
epoch=1319, loss=38.0471
epoch=1320, loss=45.6290
epoch=1321, loss=45.8212
epoch=1322, loss=25.6850
epoch=1323, loss=28.9088
epoch=1324, loss=24.9450
epoch=1325, loss=46.3202
epoch=1326, loss=38.0713
epoch=1327, loss=35.5920
epoch=1328, loss=36.7340
epoch=1329, loss=36.7376
epoch=1330, loss=33.7929
epoch=1331, loss=28.7036
epoch=1332, loss=34.7067
epoch=1333, loss=34.2267
epoch=1334, loss=28.4187
epoch=1335, loss=48.9013
epoch=1336, loss=39.2364
epoch=1337, loss=33.5179
epoch=1338, loss=27.4427
epoch=1339, loss=40.8810
epoch=1340, loss=32.8109
epoch=1341, loss=36.2374
epoch=1342, loss=34.9915
epoch=1343, loss=48.1464
epoch=1344, loss=35.1412
epoch=1345, loss=38.2447
epoch=1346, loss=32.2756
epoch=1347, loss=34.8562
epoch=1348, loss=32.7381
epoch=1349, loss=45.5488
epoch=1350, loss=42.3326
epoch=1351, loss=34.9011
epoch=1352, loss=27.3625
epoch=1353, loss=29.7940
epoch=1354, loss=30.0752
epoch=1355, loss=39.7541
epoch=1356, loss=21.0593
epoch=1357, loss=39.7403
epoch=1358, loss=25.7632
epoch=1359, loss=27.4025
epoch=1360, loss=42.4548
epoch=1361, loss=30.8795
epoch=1362, loss=34.3792
epoch=1363, loss=26.9746
epoch=1364, loss=46.9133
epoch=1365, loss=44.5003
epoch=1366, loss=34.8450
epoch=1367, loss=36.5696
epoch=1368, loss=28.2940
epoch=1369, loss=31.7687
epoch=1370, loss=31.1225
epoch=1371, loss=28.7888
epoch=1372, loss=41.3312
epoch=1373, loss=24.3718
epoch=1374, loss=30.1667
epoch=1375, loss=43.1038
epoch=1376, loss=35.0408
epoch=1377, loss=45.6852
epoch=1378, loss=23.8897
epoch=1379, loss=40.9500
epoch=1380, loss=27.0842
epoch=1381, loss=37.4781
epoch=1382, loss=27.9361
epoch=1383, loss=38.1767
epoch=1384, loss=30.1983
epoch=1385, loss=35.0150
epoch=1386, loss=29.0877
epoch=1387, loss=29.1462
epoch=1388, loss=38.7472
epoch=1389, loss=34.5128
epoch=1390, loss=26.4847
epoch=1391, loss=37.5297
epoch=1392, loss=46.5221
epoch=1393, loss=30.3338
epoch=1394, loss=43.5524
epoch=1395, loss=47.7259
epoch=1396, loss=38.2131
epoch=1397, loss=28.8958
epoch=1398, loss=38.0463
epoch=1399, loss=32.1844
epoch=1400, loss=31.8049
epoch=1401, loss=36.0688
epoch=1402, loss=36.3782
epoch=1403, loss=34.5230
epoch=1404, loss=31.9884
epoch=1405, loss=26.4623
epoch=1406, loss=28.9178
epoch=1407, loss=34.3496
epoch=1408, loss=42.2160
epoch=1409, loss=45.9547
epoch=1410, loss=28.7658
epoch=1411, loss=29.1827
epoch=1412, loss=38.7917
epoch=1413, loss=36.0446
epoch=1414, loss=28.3945
epoch=1415, loss=38.4445
epoch=1416, loss=48.1248
epoch=1417, loss=24.3291
epoch=1418, loss=32.1941
epoch=1419, loss=40.7814
epoch=1420, loss=34.9021
epoch=1421, loss=28.1727
epoch=1422, loss=33.6879
epoch=1423, loss=42.9682
epoch=1424, loss=28.9138
epoch=1425, loss=34.5130
epoch=1426, loss=28.2958
epoch=1427, loss=34.0243
epoch=1428, loss=52.9340
epoch=1429, loss=36.5021
epoch=1430, loss=29.9891
epoch=1431, loss=46.2166
epoch=1432, loss=37.6546
epoch=1433, loss=26.2562
epoch=1434, loss=29.1484
epoch=1435, loss=33.7357
epoch=1436, loss=35.7856
epoch=1437, loss=43.1520
epoch=1438, loss=23.5223
epoch=1439, loss=37.2715
epoch=1440, loss=37.2932
epoch=1441, loss=39.1635
epoch=1442, loss=39.8090
epoch=1443, loss=37.1292
epoch=1444, loss=42.5673
epoch=1445, loss=30.0301
epoch=1446, loss=40.1862
epoch=1447, loss=26.3388
epoch=1448, loss=26.1381
epoch=1449, loss=31.9983
epoch=1450, loss=35.8569
epoch=1451, loss=18.6316
epoch=1452, loss=35.4459
epoch=1453, loss=23.9700
epoch=1454, loss=28.1707
epoch=1455, loss=46.6944
epoch=1456, loss=36.3569
epoch=1457, loss=25.8404
epoch=1458, loss=37.0917
epoch=1459, loss=41.3789
epoch=1460, loss=31.4906
epoch=1461, loss=32.2165
epoch=1462, loss=35.4388
epoch=1463, loss=36.7359
epoch=1464, loss=54.2147
epoch=1465, loss=40.9424
epoch=1466, loss=35.2999
epoch=1467, loss=33.0847
epoch=1468, loss=37.4095
epoch=1469, loss=25.4988
epoch=1470, loss=37.7319
epoch=1471, loss=33.4607
epoch=1472, loss=29.9038
epoch=1473, loss=30.6572
epoch=1474, loss=32.1203
epoch=1475, loss=31.6620
epoch=1476, loss=32.0512
epoch=1477, loss=34.6661
epoch=1478, loss=26.7730
epoch=1479, loss=45.8692
epoch=1480, loss=34.8963
epoch=1481, loss=41.5471
epoch=1482, loss=55.0377
epoch=1483, loss=32.9687
epoch=1484, loss=36.4512
epoch=1485, loss=36.8798
epoch=1486, loss=28.7513
epoch=1487, loss=27.4820
epoch=1488, loss=33.4598
epoch=1489, loss=28.5997
epoch=1490, loss=34.9837
epoch=1491, loss=34.0515
epoch=1492, loss=34.0678
epoch=1493, loss=40.5191
epoch=1494, loss=40.1758
epoch=1495, loss=32.9277
epoch=1496, loss=35.6981
epoch=1497, loss=37.2294
epoch=1498, loss=25.2547
epoch=1499, loss=32.5763
epoch=1500, loss=19.6940
epoch=1501, loss=25.7013
epoch=1502, loss=33.3048
epoch=1503, loss=37.1762
epoch=1504, loss=28.7383
epoch=1505, loss=23.0517
epoch=1506, loss=44.1326
epoch=1507, loss=34.9054
epoch=1508, loss=45.7176
epoch=1509, loss=38.1900
epoch=1510, loss=39.1193
epoch=1511, loss=45.1884
epoch=1512, loss=31.4161
epoch=1513, loss=20.9701
epoch=1514, loss=26.8097
epoch=1515, loss=35.9537
epoch=1516, loss=40.7580
epoch=1517, loss=36.0942
epoch=1518, loss=36.3071
epoch=1519, loss=30.0816
epoch=1520, loss=37.9924
epoch=1521, loss=33.5203
epoch=1522, loss=28.2977
epoch=1523, loss=25.2874
epoch=1524, loss=34.9287
epoch=1525, loss=32.7701
epoch=1526, loss=29.0807
epoch=1527, loss=37.3032
epoch=1528, loss=30.2755
epoch=1529, loss=27.3569
epoch=1530, loss=36.5855
epoch=1531, loss=24.0321
epoch=1532, loss=47.2285
epoch=1533, loss=40.5877
epoch=1534, loss=39.0814
epoch=1535, loss=23.1268
epoch=1536, loss=29.5388
epoch=1537, loss=29.2759
epoch=1538, loss=33.0768
epoch=1539, loss=36.9392
epoch=1540, loss=44.2638
epoch=1541, loss=51.0828
epoch=1542, loss=31.2280
epoch=1543, loss=31.0411
epoch=1544, loss=29.3206
epoch=1545, loss=25.4009
epoch=1546, loss=29.0619
epoch=1547, loss=29.3662
epoch=1548, loss=43.8614
epoch=1549, loss=33.0834
epoch=1550, loss=38.1879
epoch=1551, loss=36.0901
epoch=1552, loss=33.1927
epoch=1553, loss=31.8870
epoch=1554, loss=41.4068
epoch=1555, loss=36.1541
epoch=1556, loss=31.7799
epoch=1557, loss=30.4052
epoch=1558, loss=27.6859
epoch=1559, loss=30.0883
epoch=1560, loss=37.6626
epoch=1561, loss=37.5391
epoch=1562, loss=37.2008
epoch=1563, loss=35.7859
epoch=1564, loss=31.0485
epoch=1565, loss=42.8876
epoch=1566, loss=33.3878
epoch=1567, loss=33.9048
epoch=1568, loss=28.5611
epoch=1569, loss=56.6481
epoch=1570, loss=37.7521
epoch=1571, loss=34.0539
epoch=1572, loss=38.2804
epoch=1573, loss=33.0363
epoch=1574, loss=49.9077
epoch=1575, loss=30.2550
epoch=1576, loss=34.8463
epoch=1577, loss=33.5766
epoch=1578, loss=29.2131
epoch=1579, loss=29.7863
epoch=1580, loss=29.3244
epoch=1581, loss=41.7618
epoch=1582, loss=31.1525
epoch=1583, loss=26.3711
epoch=1584, loss=31.2908
epoch=1585, loss=41.2849
epoch=1586, loss=30.1948
epoch=1587, loss=43.9187
epoch=1588, loss=30.0660
epoch=1589, loss=36.5963
epoch=1590, loss=39.6431
epoch=1591, loss=30.0314
epoch=1592, loss=33.8546
epoch=1593, loss=47.1560
epoch=1594, loss=34.4838
epoch=1595, loss=33.3091
epoch=1596, loss=36.6789
epoch=1597, loss=42.9379
epoch=1598, loss=37.5283
epoch=1599, loss=40.5685
epoch=1600, loss=28.9754
epoch=1601, loss=40.8618
epoch=1602, loss=40.8115
epoch=1603, loss=35.8314
epoch=1604, loss=29.7943
epoch=1605, loss=36.5864
epoch=1606, loss=34.8996
epoch=1607, loss=37.0557
epoch=1608, loss=25.0684
epoch=1609, loss=37.0592
epoch=1610, loss=31.3310
epoch=1611, loss=29.1245
epoch=1612, loss=45.6712
epoch=1613, loss=34.1055
epoch=1614, loss=43.7048
epoch=1615, loss=52.7717
epoch=1616, loss=32.6442
epoch=1617, loss=30.8302
epoch=1618, loss=38.3865
epoch=1619, loss=31.0302
epoch=1620, loss=36.1039
epoch=1621, loss=26.7771
epoch=1622, loss=23.1458
epoch=1623, loss=31.1163
epoch=1624, loss=35.4052
epoch=1625, loss=29.2810
epoch=1626, loss=26.3598
epoch=1627, loss=32.1139
epoch=1628, loss=30.6004
epoch=1629, loss=27.8413
epoch=1630, loss=32.1856
epoch=1631, loss=60.2945
epoch=1632, loss=32.8952
epoch=1633, loss=28.2247
epoch=1634, loss=40.7143
epoch=1635, loss=29.2681
epoch=1636, loss=34.8334
epoch=1637, loss=35.9697
epoch=1638, loss=25.3400
epoch=1639, loss=27.4441
epoch=1640, loss=53.5048
epoch=1641, loss=27.4171
epoch=1642, loss=28.0155
epoch=1643, loss=46.8164
epoch=1644, loss=28.9093
epoch=1645, loss=32.7978
epoch=1646, loss=50.6372
epoch=1647, loss=29.7534
epoch=1648, loss=30.3780
epoch=1649, loss=34.3235
epoch=1650, loss=29.1777
epoch=1651, loss=37.8256
epoch=1652, loss=34.4120
epoch=1653, loss=39.7161
epoch=1654, loss=36.0549
epoch=1655, loss=29.4830
epoch=1656, loss=35.1932
epoch=1657, loss=46.3637
epoch=1658, loss=22.3256
epoch=1659, loss=34.9382
epoch=1660, loss=29.0200
epoch=1661, loss=31.4598
epoch=1662, loss=40.1114
epoch=1663, loss=33.6175
epoch=1664, loss=30.2861
epoch=1665, loss=40.6111
epoch=1666, loss=58.0736
epoch=1667, loss=39.6817
epoch=1668, loss=36.3547
epoch=1669, loss=34.3968
epoch=1670, loss=45.9183
epoch=1671, loss=40.1887
epoch=1672, loss=43.8078
epoch=1673, loss=46.5437
epoch=1674, loss=23.7588
epoch=1675, loss=38.5523
epoch=1676, loss=40.9734
epoch=1677, loss=39.5176
epoch=1678, loss=31.7523
epoch=1679, loss=26.0209
epoch=1680, loss=28.3806
epoch=1681, loss=34.1911
epoch=1682, loss=35.0529
epoch=1683, loss=34.5082
epoch=1684, loss=31.7266
epoch=1685, loss=36.3510
epoch=1686, loss=50.4109
epoch=1687, loss=22.8710
epoch=1688, loss=58.3529
epoch=1689, loss=29.7358
epoch=1690, loss=39.1148
epoch=1691, loss=32.0888
epoch=1692, loss=25.9282
epoch=1693, loss=33.8056
epoch=1694, loss=31.5784
epoch=1695, loss=31.4391
epoch=1696, loss=30.5438
epoch=1697, loss=38.3535
epoch=1698, loss=34.4427
epoch=1699, loss=33.2623
epoch=1700, loss=31.7864
epoch=1701, loss=42.3128
epoch=1702, loss=38.4971
epoch=1703, loss=31.6685
epoch=1704, loss=34.9853
epoch=1705, loss=18.5508
epoch=1706, loss=40.1604
epoch=1707, loss=28.1061
epoch=1708, loss=37.5067
epoch=1709, loss=33.8658
epoch=1710, loss=30.8775
epoch=1711, loss=49.1571
epoch=1712, loss=28.7557
epoch=1713, loss=40.6502
epoch=1714, loss=39.2319
epoch=1715, loss=41.1720
epoch=1716, loss=28.6397
epoch=1717, loss=43.3470
epoch=1718, loss=31.4252
epoch=1719, loss=41.6643
epoch=1720, loss=31.6779
epoch=1721, loss=56.8492
epoch=1722, loss=36.2753
epoch=1723, loss=40.1690
epoch=1724, loss=42.5956
epoch=1725, loss=30.4588
epoch=1726, loss=52.1502
epoch=1727, loss=31.6561
epoch=1728, loss=32.3665
epoch=1729, loss=32.0891
epoch=1730, loss=29.9402
epoch=1731, loss=31.4298
epoch=1732, loss=45.5506
epoch=1733, loss=40.9761
epoch=1734, loss=27.5720
epoch=1735, loss=41.0700
epoch=1736, loss=40.5480
epoch=1737, loss=30.2543
epoch=1738, loss=45.0808
epoch=1739, loss=32.1093
epoch=1740, loss=52.5572
epoch=1741, loss=40.4674
epoch=1742, loss=43.8624
epoch=1743, loss=43.5267
epoch=1744, loss=25.0396
epoch=1745, loss=19.2728
epoch=1746, loss=38.8592
epoch=1747, loss=27.5246
epoch=1748, loss=38.7391
epoch=1749, loss=33.6800
epoch=1750, loss=33.0100
epoch=1751, loss=31.9837
epoch=1752, loss=43.6407
epoch=1753, loss=30.2375
epoch=1754, loss=43.9711
epoch=1755, loss=30.8167
epoch=1756, loss=32.8285
epoch=1757, loss=30.8861
epoch=1758, loss=34.9955
epoch=1759, loss=26.9971
epoch=1760, loss=46.0442
epoch=1761, loss=41.6299
epoch=1762, loss=37.3062
epoch=1763, loss=43.3887
epoch=1764, loss=28.2057
epoch=1765, loss=42.0531
epoch=1766, loss=42.6729
epoch=1767, loss=46.9954
epoch=1768, loss=34.0256
epoch=1769, loss=28.6126
epoch=1770, loss=36.1191
epoch=1771, loss=31.3382
epoch=1772, loss=32.0284
epoch=1773, loss=30.1594
epoch=1774, loss=35.4324
epoch=1775, loss=35.1956
epoch=1776, loss=31.3684
epoch=1777, loss=46.1910
epoch=1778, loss=27.7542
epoch=1779, loss=27.1898
epoch=1780, loss=21.9117
epoch=1781, loss=36.2531
epoch=1782, loss=43.3739
epoch=1783, loss=27.5341
epoch=1784, loss=24.9797
epoch=1785, loss=41.3120
epoch=1786, loss=34.2099
epoch=1787, loss=54.7183
epoch=1788, loss=30.1699
epoch=1789, loss=37.7671
epoch=1790, loss=47.7613
epoch=1791, loss=28.9987
epoch=1792, loss=28.2261
epoch=1793, loss=29.5456
epoch=1794, loss=23.8020
epoch=1795, loss=28.6554
epoch=1796, loss=27.5830
epoch=1797, loss=39.7724
epoch=1798, loss=37.2909
epoch=1799, loss=32.4331
epoch=1800, loss=28.8356
epoch=1801, loss=39.7177
epoch=1802, loss=32.5468
epoch=1803, loss=29.4262
epoch=1804, loss=26.1700
epoch=1805, loss=29.5171
epoch=1806, loss=51.8968
epoch=1807, loss=34.4221
epoch=1808, loss=33.7871
epoch=1809, loss=32.0702
epoch=1810, loss=31.3400
epoch=1811, loss=27.9518
epoch=1812, loss=39.6344
epoch=1813, loss=35.9245
epoch=1814, loss=42.7269
epoch=1815, loss=37.4334
epoch=1816, loss=35.6348
epoch=1817, loss=40.1131
epoch=1818, loss=38.9279
epoch=1819, loss=35.3195
epoch=1820, loss=48.8344
epoch=1821, loss=30.6186
epoch=1822, loss=36.9466
epoch=1823, loss=28.7751
epoch=1824, loss=27.8019
epoch=1825, loss=39.9999
epoch=1826, loss=29.6051
epoch=1827, loss=35.7877
epoch=1828, loss=35.4209
epoch=1829, loss=39.5553
epoch=1830, loss=37.8764
epoch=1831, loss=42.5659
epoch=1832, loss=30.1320
epoch=1833, loss=46.6477
epoch=1834, loss=38.4669
epoch=1835, loss=25.2448
epoch=1836, loss=33.5906
epoch=1837, loss=31.8231
epoch=1838, loss=43.5205
epoch=1839, loss=30.8798
epoch=1840, loss=32.5020
epoch=1841, loss=34.9660
epoch=1842, loss=32.1129
epoch=1843, loss=35.1489
epoch=1844, loss=35.3684
epoch=1845, loss=34.0432
epoch=1846, loss=47.0925
epoch=1847, loss=36.6252
epoch=1848, loss=35.7320
epoch=1849, loss=37.8639
epoch=1850, loss=32.5404
epoch=1851, loss=46.5566
epoch=1852, loss=31.8804
epoch=1853, loss=31.8740
epoch=1854, loss=28.8555
epoch=1855, loss=39.6632
epoch=1856, loss=35.9543
epoch=1857, loss=33.4105
epoch=1858, loss=42.2088
epoch=1859, loss=30.4224
epoch=1860, loss=24.3301
epoch=1861, loss=37.5396
epoch=1862, loss=24.0663
epoch=1863, loss=46.6297
epoch=1864, loss=25.2938
epoch=1865, loss=52.2606
epoch=1866, loss=33.0870
epoch=1867, loss=29.6333
epoch=1868, loss=34.0967
epoch=1869, loss=46.1038
epoch=1870, loss=44.0608
epoch=1871, loss=42.7982
epoch=1872, loss=30.8490
epoch=1873, loss=35.0194
epoch=1874, loss=25.9710
epoch=1875, loss=35.7427
epoch=1876, loss=35.8160
epoch=1877, loss=43.2334
epoch=1878, loss=26.2880
epoch=1879, loss=32.7721
epoch=1880, loss=34.7454
epoch=1881, loss=36.5345
epoch=1882, loss=39.8898
epoch=1883, loss=34.8573
epoch=1884, loss=22.9588
epoch=1885, loss=28.2082
epoch=1886, loss=47.4568
epoch=1887, loss=36.7360
epoch=1888, loss=32.4228
epoch=1889, loss=27.7855
epoch=1890, loss=26.3805
epoch=1891, loss=31.8914
epoch=1892, loss=45.6464
epoch=1893, loss=30.4516
epoch=1894, loss=26.3911
epoch=1895, loss=28.8419
epoch=1896, loss=37.7788
epoch=1897, loss=37.7130
epoch=1898, loss=40.2391
epoch=1899, loss=49.8949
epoch=1900, loss=33.5475
epoch=1901, loss=36.0506
epoch=1902, loss=38.4378
epoch=1903, loss=32.5662
epoch=1904, loss=33.0551
epoch=1905, loss=32.6487
epoch=1906, loss=32.2130
epoch=1907, loss=44.5138
epoch=1908, loss=25.4023
epoch=1909, loss=37.0736
epoch=1910, loss=39.4626
epoch=1911, loss=34.6423
epoch=1912, loss=30.6426
epoch=1913, loss=46.5408
epoch=1914, loss=31.7788
epoch=1915, loss=36.8644
epoch=1916, loss=40.8513
epoch=1917, loss=31.2167
epoch=1918, loss=29.0489
epoch=1919, loss=36.5178
epoch=1920, loss=26.8812
epoch=1921, loss=44.3916
epoch=1922, loss=33.4103
epoch=1923, loss=43.9588
epoch=1924, loss=43.0485
epoch=1925, loss=32.9726
epoch=1926, loss=39.2526
epoch=1927, loss=33.8366
epoch=1928, loss=29.3006
epoch=1929, loss=25.8917
epoch=1930, loss=46.2496
epoch=1931, loss=35.0287
epoch=1932, loss=39.5581
epoch=1933, loss=43.1998
epoch=1934, loss=43.7446
epoch=1935, loss=50.5408
epoch=1936, loss=38.4567
epoch=1937, loss=33.6322
epoch=1938, loss=40.0772
epoch=1939, loss=32.5416
epoch=1940, loss=28.5272
epoch=1941, loss=19.7723
epoch=1942, loss=43.1280
epoch=1943, loss=24.6360
epoch=1944, loss=27.3806
epoch=1945, loss=43.2488
epoch=1946, loss=32.7586
epoch=1947, loss=32.0554
epoch=1948, loss=34.7635
epoch=1949, loss=34.7666
epoch=1950, loss=26.2582
epoch=1951, loss=39.6517
epoch=1952, loss=31.0023
epoch=1953, loss=41.1355
epoch=1954, loss=24.0179
epoch=1955, loss=35.2110
epoch=1956, loss=41.1189
epoch=1957, loss=36.2977
epoch=1958, loss=43.1321
epoch=1959, loss=38.5753
epoch=1960, loss=38.8746
epoch=1961, loss=36.1261
epoch=1962, loss=28.0169
epoch=1963, loss=26.2852
epoch=1964, loss=41.9416
epoch=1965, loss=40.1778
epoch=1966, loss=40.0545
epoch=1967, loss=37.0752
epoch=1968, loss=27.1719
epoch=1969, loss=27.8567
epoch=1970, loss=33.7178
epoch=1971, loss=35.1780
epoch=1972, loss=33.4344
epoch=1973, loss=30.2588
epoch=1974, loss=41.6152
epoch=1975, loss=31.6988
epoch=1976, loss=40.2418
epoch=1977, loss=40.8906
epoch=1978, loss=25.8763
epoch=1979, loss=28.6479
epoch=1980, loss=26.2841
epoch=1981, loss=37.0122
epoch=1982, loss=36.8730
epoch=1983, loss=39.5089
epoch=1984, loss=19.0150
epoch=1985, loss=37.0378
epoch=1986, loss=45.1676
epoch=1987, loss=34.5720
epoch=1988, loss=46.5605
epoch=1989, loss=44.1837
epoch=1990, loss=42.8904
epoch=1991, loss=47.9173
epoch=1992, loss=37.9717
epoch=1993, loss=32.5989
epoch=1994, loss=27.9967
epoch=1995, loss=21.6590
epoch=1996, loss=32.9677
epoch=1997, loss=34.2821
epoch=1998, loss=29.9108
epoch=1999, loss=32.4362
inducing_index_points_history = history.pop("inducing_index_points")
variational_inducing_observations_loc_history = (
history.pop("variational_inducing_observations_loc"))
inducing_index_points = inducing_index_points_history[-1]
variational_inducing_observations_loc = (
variational_inducing_observations_loc_history[-1])
Log density ratio, log-odds, or logits.
fig, ax = plt.subplots()
ax.plot(X_q, model(X_q).mean().numpy().T,
label="posterior mean")
fill_between_stddev(X_q.squeeze(),
model(X_q).mean().numpy().squeeze(),
model(X_q).stddev().numpy().squeeze(), alpha=0.1,
label="posterior std dev", ax=ax)
ax.scatter(inducing_index_points, np.full_like(inducing_index_points, -3.5),
marker='^', c="tab:gray", label="inducing inputs", alpha=0.4)
ax.scatter(inducing_index_points, variational_inducing_observations_loc,
marker='+', c="tab:blue", label="inducing variable mean")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$\log \lambda(x)$")
ax.legend()
plt.show()
Z_q = scaler.inverse_transform(X_q)
d = tfd.Independent(tfd.LogNormal(loc=model(X_q).mean(),
scale=model(X_q).stddev()),
reinterpreted_batch_ndims=1)
Density ratio.
fig, ax = plt.subplots()
ax.plot(X_q, d.mean().numpy().T, label="transformed posterior mean")
fill_between_stddev(X_q.squeeze(),
d.mean().numpy().squeeze(),
d.stddev().numpy().squeeze(), alpha=0.1,
label="transformed posterior std dev", ax=ax)
ax.vlines(X.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y)
ax.set_xlabel('$x$')
ax.set_ylim(y_min, y_max)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$\lambda(x)$")
ax.legend()
plt.show()
Predictive mean samples.
posterior_predictive = tf.keras.Sequential([
model, tfp.layers.IndependentPoisson(event_shape=(num_index_points,))
])
fig, ax = plt.subplots()
ax.plot(X_q, posterior_predictive(X_q).mean())
ax.vlines(X.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y)
ax.set_xlabel('$x$')
ax.set_ylim(y_min, y_max)
# ax.legend()
plt.show()
def make_posterior_predictive(num_samples=None, seed=None):
def posterior_predictive(x):
f_samples = model(x).sample(num_samples, seed=seed)
return make_poisson_likelihood(f=f_samples)
return posterior_predictive
posterior_predictive = make_posterior_predictive(num_samples, seed=seed)
fig, ax = plt.subplots()
ax.plot(X_q, posterior_predictive(X_q).mean().numpy().T, color="tab:blue",
linewidth=0.8, alpha=0.6)
ax.vlines(X.squeeze(), ymin=-0.025, ymax=0.0, linewidth=0.6 * y)
ax.set_xlabel('$x$')
ax.set_ylim(y_min, y_max)
# ax.legend()
plt.show()
def get_inducing_index_points_data(inducing_index_points):
df = pd.DataFrame(np.hstack(inducing_index_points).T)
df.index.name = "epoch"
df.columns.name = "inducing index points"
s = df.stack()
s.name = 'x'
return s.reset_index()
data = get_inducing_index_points_data(inducing_index_points_history)
fig, ax = plt.subplots()
sns.lineplot(x='x', y="epoch", hue="inducing index points", palette="viridis",
sort=False, data=data, alpha=0.8, ax=ax)
ax.set_xlabel(r'$x$')
plt.show()
variational_inducing_observations_scale_history = (
history.pop("variational_inducing_observations_scale"))
fig, (ax1, ax2) = plt.subplots(ncols=2, sharex=True, sharey=True)
im1 = ax1.imshow(variational_inducing_observations_scale_history[0],
vmin=-0.1, vmax=1.1)
im2 = ax2.imshow(variational_inducing_observations_scale_history[-1],
vmin=-0.1, vmax=1.1)
fig.colorbar(im2, ax=[ax1, ax2], extend="both", orientation="horizontal")
ax1.set_xlabel(r"$i$")
ax1.set_ylabel(r"$j$")
ax2.set_xlabel(r"$i$")
plt.show()
history_df = pd.DataFrame(history)
history_df.index.name = "epoch"
history_df.reset_index(inplace=True)
fig, ax = plt.subplots()
sns.lineplot(x="epoch", y="nelbo", data=history_df, alpha=0.8, ax=ax)
ax.set_yscale("log")
plt.show()
parameters_df = history_df.drop(columns="nelbo") \
.rename(columns=lambda s: s.replace('_', ' '))
g = sns.PairGrid(parameters_df, hue="epoch", palette="RdYlBu", corner=True)
g = g.map_lower(plt.scatter, facecolor="none", alpha=0.6)
Total running time of the script: ( 3 minutes 42.131 seconds)