Source code for etudes.gaussian_process

"""Main module."""
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow.keras.layers import Layer
from tensorflow.keras.initializers import Identity, Constant

# shortcuts
tfd = tfp.distributions
kernels = tfp.math.psd_kernels


[docs]def identity_initializer(shape, dtype=None): *batch_shape, num_rows, num_columns = shape return tf.eye(num_rows, num_columns, batch_shape=batch_shape, dtype=dtype)
[docs]class KernelWrapper(Layer): # TODO: Support automatic relevance determination def __init__(self, input_dim=1, kernel_cls=kernels.ExponentiatedQuadratic, dtype=None, **kwargs): super(KernelWrapper, self).__init__(dtype=dtype, **kwargs) self.kernel_cls = kernel_cls self.log_amplitude = self.add_weight( name="log_amplitude", initializer="zeros", dtype=dtype) self.log_length_scale = self.add_weight( name="log_length_scale", initializer="zeros", dtype=dtype) self.log_scale_diag = self.add_weight( name="log_scale_diag", shape=(input_dim,), initializer="zeros", dtype=dtype)
[docs] def call(self, x): # Never called -- this is just a layer so it can hold variables # in a way Keras understands. return x
@property def kernel(self): base_kernel = self.kernel_cls( amplitude=tf.exp(self.log_amplitude), length_scale=tf.exp(self.log_length_scale)) return kernels.FeatureScaled(base_kernel, scale_diag=tf.exp(self.log_scale_diag))
[docs]class VariationalGaussianProcessScalar(tfp.layers.DistributionLambda): def __init__(self, kernel_wrapper, num_inducing_points, inducing_index_points_initializer, mean_fn=None, jitter=1e-6, convert_to_tensor_fn=tfd.Distribution.sample, **kwargs): def make_distribution(x): return VariationalGaussianProcessScalar.new( x, kernel_wrapper=self.kernel_wrapper, inducing_index_points=self.inducing_index_points, variational_inducing_observations_loc=( self.variational_inducing_observations_loc), variational_inducing_observations_scale=( self.variational_inducing_observations_scale), mean_fn=self.mean_fn, observation_noise_variance=tf.exp( self.log_observation_noise_variance), jitter=self.jitter) super(VariationalGaussianProcessScalar, self).__init__( make_distribution_fn=make_distribution, convert_to_tensor_fn=convert_to_tensor_fn, dtype=kernel_wrapper.dtype) self.kernel_wrapper = kernel_wrapper self.inducing_index_points_initializer = inducing_index_points_initializer self.num_inducing_points = num_inducing_points self.mean_fn = mean_fn self.jitter = jitter self._dtype = self.kernel_wrapper.dtype
[docs] def build(self, input_shape): input_dim = input_shape[-1] # TODO: Fix initialization! self.inducing_index_points = self.add_weight( name="inducing_index_points", shape=(self.num_inducing_points, input_dim), initializer=self.inducing_index_points_initializer, dtype=self.dtype) self.variational_inducing_observations_loc = self.add_weight( name="variational_inducing_observations_loc", shape=(self.num_inducing_points,), initializer="zeros", dtype=self.dtype) self.variational_inducing_observations_scale = self.add_weight( name="variational_inducing_observations_scale", shape=(self.num_inducing_points, self.num_inducing_points), initializer=Identity(gain=1.0), dtype=self.dtype) self.log_observation_noise_variance = self.add_weight( name="log_observation_noise_variance", initializer=Constant(-5.0), dtype=self.dtype)
[docs] @staticmethod def new(x, kernel_wrapper, inducing_index_points, mean_fn, variational_inducing_observations_loc, variational_inducing_observations_scale, observation_noise_variance, jitter, name=None): # ind = tfd.Independent(base, reinterpreted_batch_ndims=1) # bijector = tfp.bijectors.Transpose(rightmost_transposed_ndims=2) # d = tfd.TransformedDistribution(ind, bijector=bijector) return tfd.VariationalGaussianProcess( kernel=kernel_wrapper.kernel, index_points=x, inducing_index_points=inducing_index_points, variational_inducing_observations_loc=( variational_inducing_observations_loc), variational_inducing_observations_scale=( variational_inducing_observations_scale), mean_fn=mean_fn, observation_noise_variance=observation_noise_variance, jitter=jitter)
[docs]class GaussianProcessLayer(Layer): def __init__(self, units, kernel_provider, num_inducing_points=64, mean_fn=None, jitter=1e-6, **kwargs): self.units = units # TODO: Maybe generalize to `event_shape`? self.num_inducing_points = num_inducing_points self.kernel_provider = kernel_provider self.mean_fn = mean_fn self.jitter = jitter super(GaussianProcessLayer, self).__init__(**kwargs)
[docs] def build(self, input_shape): input_dim = input_shape[-1] # TODO: Fix initialization! self.inducing_index_points = self.add_weight( name="inducing_index_points", shape=(self.units, self.num_inducing_points, input_dim), initializer=tf.keras.initializers.RandomUniform(-1, 1), trainable=True) self.variational_loc = self.add_weight( name="variational_inducing_observations_loc", shape=(self.units, self.num_inducing_points), initializer='zeros', trainable=True) self.variational_scale = self.add_weight( name="variational_inducing_observations_scale", shape=(self.units, self.num_inducing_points, self.num_inducing_points), initializer=identity_initializer, trainable=True) super(GaussianProcessLayer, self).build(input_shape)
[docs] def call(self, x): base = tfd.VariationalGaussianProcess( kernel=self.kernel_provider.kernel, index_points=x, inducing_index_points=self.inducing_index_points, variational_inducing_observations_loc=self.variational_loc, variational_inducing_observations_scale=self.variational_scale, mean_fn=self.mean_fn, predictive_noise_variance=1e-1, # TODO: what does this mean in the non-Gaussian likelihood context? Should keep it zero. jitter=self.jitter ) # sum KL divergence between `units` independent processes self.add_loss(tf.reduce_sum(base.surrogate_posterior_kl_divergence_prior())) bijector = tfp.bijectors.Transpose(rightmost_transposed_ndims=2) qf = tfd.TransformedDistribution( tfd.Independent(base, reinterpreted_batch_ndims=1), bijector=bijector) return qf.sample()
[docs] def compute_output_shape(self, input_shape): return (input_shape[0], self.units)
[docs]def gp_sample_custom(gp, n_samples, seed=None): gp_marginal = gp.get_marginal_distribution() base_samples = gp_marginal.distribution.sample(n_samples, seed=seed) gp_samples = gp_marginal.bijector.forward(base_samples) return gp_samples
[docs]def dataframe_from_gp_samples(gp_samples_arr, X_q, amplitude, length_scale, n_samples): names = ["sample", "amplitude", "length_scale", "index_point"] v = [list(map(r"$i={}$".format, range(n_samples))), amplitude.squeeze(), length_scale.squeeze(), X_q.squeeze()] index = pd.MultiIndex.from_product(v, names=names) d = pd.DataFrame(gp_samples_arr.ravel(), index=index, columns=["function_value"]) return d.reset_index()
[docs]def dataframe_from_gp_summary(gp_mean_arr, gp_stddev_arr, amplitude, length_scale, index_point): names = ["amplitude", "length_scale", "index_point"] v = [amplitude.squeeze(), length_scale.squeeze(), index_point.squeeze()] index = pd.MultiIndex.from_product(v, names=names) d1 = pd.DataFrame(gp_mean_arr.ravel(), index=index, columns=["mean"]) d2 = pd.DataFrame(gp_stddev_arr.ravel(), index=index, columns=["stddev"]) data = pd.merge(d1, d2, left_index=True, right_index=True) return data.reset_index()