.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/misc/plot_gauss_hermite_quadrature.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_misc_plot_gauss_hermite_quadrature.py: Divergence estimation with Gauss-Hermite Quadrature =================================================== Hello world .. GENERATED FROM PYTHON SOURCE LINES 8-31 .. code-block:: default import numpy as np import tensorflow as tf import tensorflow_probability as tfp import matplotlib.pyplot as plt import seaborn as sns import pandas as pd tfd = tfp.distributions max_size = 300 num_seeds = 10 x_min, x_max = -15.0, 15.0 num_query_points = 256 num_features = 1 # query index points X_pred = np.linspace(x_min, x_max, num_query_points) .. GENERATED FROM PYTHON SOURCE LINES 33-34 Example .. GENERATED FROM PYTHON SOURCE LINES 34-38 .. code-block:: default p = tfd.Normal(loc=1.0, scale=1.0) q = tfd.Normal(loc=0.0, scale=2.0) .. GENERATED FROM PYTHON SOURCE LINES 39-52 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_pred, p.prob(X_pred), label='$p(x)$') ax.plot(X_pred, q.prob(X_pred), label='$q(x)$') ax.set_xlabel('$x$') ax.set_ylabel('density') ax.legend() plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_gauss_hermite_quadrature_001.png :alt: plot gauss hermite quadrature :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 53-55 Exact KL divergence (analytical) -------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 55-60 .. code-block:: default kl_exact = tfd.kl_divergence(p, q).numpy() kl_exact .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0.44314718 .. GENERATED FROM PYTHON SOURCE LINES 61-63 Approximate KL divergence (Monte Carlo) --------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 63-67 .. code-block:: default sample_size = 25 seed = 8888 .. GENERATED FROM PYTHON SOURCE LINES 68-74 .. code-block:: default kl_monte_carlo = tfp.vi.monte_carlo_variational_loss( p.log_prob, q, sample_size=sample_size, discrepancy_fn=tfp.vi.kl_forward, seed=seed).numpy() kl_monte_carlo .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0.326479 .. GENERATED FROM PYTHON SOURCE LINES 75-78 .. code-block:: default x_samples = p.sample(sample_size, seed=seed) .. GENERATED FROM PYTHON SOURCE LINES 79-88 .. code-block:: default def log_ratio(x): return p.log_prob(x) - q.log_prob(x) def h(x): return tfp.vi.kl_forward(log_ratio(x)) .. GENERATED FROM PYTHON SOURCE LINES 89-101 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_pred, h(X_pred)) ax.scatter(x_samples, h(x_samples)) ax.set_xlabel(r'$x$') ax.set_ylabel(r'$h(x)$') plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_gauss_hermite_quadrature_002.png :alt: plot gauss hermite quadrature :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 102-131 .. code-block:: default def divergence_monte_carlo(p, q, sample_size, under_p=True, discrepancy_fn=tfp.vi.kl_forward, seed=None): def log_ratio(x): return p.log_prob(x) - q.log_prob(x) if under_p: # TODO: Raise exception if `p` is non-Gaussian. w = lambda x: tf.exp(-log_ratio(x)) dist = p else: # TODO: Raise exception if `q` is non-Gaussian. w = lambda x: 1.0 dist = q def fn(x): return w(x) * discrepancy_fn(log_ratio(x)) x_samples = dist.sample(sample_size, seed=seed) # same as: # return tfp.monte_carlo.expectation(f=fn, samples=x_samples) return tf.reduce_mean(fn(x_samples), axis=-1) .. GENERATED FROM PYTHON SOURCE LINES 132-136 .. code-block:: default divergence_monte_carlo(p, q, sample_size, under_p=False, seed=seed).numpy() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0.74457335 .. GENERATED FROM PYTHON SOURCE LINES 137-158 Approximate KL divergence (Gauss-Hermite Quadrature) ---------------------------------------------------- Consider a function :math:`f(x)` where the variable :math:`x` is normally distributed :math:`x \sim p(x) = \mathcal{N}(\mu, \sigma^2)`. Then, to evaluate the expectaton of $f$, we can apply the change-of-variables .. math:: z = \frac{x - \mu}{\sqrt{2}\sigma} \Leftrightarrow \sqrt{2}\sigma z + \mu, and use Gauss-Hermite quadrature, which leads to .. math:: \mathbb{E}_{p(x)}[f(x)] & = \int \frac{1}{\sigma \sqrt{2\pi}} \exp \left ( -\frac{(x - \mu)^2}{2\sigma^2} \right ) f(x) dx \\ & = \frac{1}{\sqrt{\pi}} \int \exp ( - z^2 ) f(\sqrt{2}\sigma z + \mu) dz \\ & \approx \frac{1}{\sqrt{\pi}} \sum_{i=1}^{m} w_i f(\sqrt{2}\sigma z_i + \mu) where we've used integration by substitution with :math:`dx = \sqrt{2} \sigma dz`. .. GENERATED FROM PYTHON SOURCE LINES 158-161 .. code-block:: default quadrature_size = 25 .. GENERATED FROM PYTHON SOURCE LINES 162-171 .. code-block:: default def transform(x, loc, scale): return np.sqrt(2) * scale * x + loc X_samples, weights = np.polynomial.hermite.hermgauss(quadrature_size) .. GENERATED FROM PYTHON SOURCE LINES 172-181 .. code-block:: default fig, ax = plt.subplots() ax.scatter(transform(X_samples, q.loc, q.scale), weights) ax.set_xlabel(r'$x_i$') ax.set_ylabel(r'$w_i$') plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_gauss_hermite_quadrature_003.png :alt: plot gauss hermite quadrature :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 182-195 .. code-block:: default fig, ax = plt.subplots() ax.plot(X_pred, h(X_pred)) ax.scatter(transform(X_samples, q.loc, q.scale), h(transform(X_samples, q.loc, q.scale)), c=weights, cmap="Blues") ax.set_xlabel(r'$x$') ax.set_ylabel(r'$h(x)$') plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_gauss_hermite_quadrature_004.png :alt: plot gauss hermite quadrature :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 196-237 .. code-block:: default def expectation_gauss_hermite(fn, normal, quadrature_size): x, weights = np.polynomial.hermite.hermgauss(quadrature_size) y = transform(x, normal.loc, normal.scale) return tf.reduce_sum(weights * fn(y), axis=-1) / tf.sqrt(np.pi) def divergence_gauss_hermite(p, q, quadrature_size, under_p=True, discrepancy_fn=tfp.vi.kl_forward): """ Compute D_f[p || q] = E_{q(x)}[f(p(x)/q(x))] = E_{p(x)}[r(x)^{-1} f(r(x))] -- r(x) = p(x)/q(x) = E_{p(x)}[exp(-log r(x)) g(log r(x))] -- g(.) = f(exp(.)) = E_{p(x)}[h(x)] -- h(x) = exp(-log r(x)) g(log r(x)) using Gauss-Hermite quadrature assuming p(x) is Gaussian. Note `discrepancy_fn` corresponds to function `g`. """ def log_ratio(x): return p.log_prob(x) - q.log_prob(x) if under_p: # TODO: Raise exception if `p` is non-Gaussian. w = lambda x: tf.exp(-log_ratio(x)) normal = p else: # TODO: Raise exception if `q` is non-Gaussian. w = lambda x: 1.0 normal = q def fn(x): return w(x) * discrepancy_fn(log_ratio(x)) return expectation_gauss_hermite(fn, normal, quadrature_size) .. GENERATED FROM PYTHON SOURCE LINES 238-242 .. code-block:: default divergence_gauss_hermite(p, q, quadrature_size, under_p=False).numpy() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none 0.44317514 .. GENERATED FROM PYTHON SOURCE LINES 243-244 Comparisons .. GENERATED FROM PYTHON SOURCE LINES 244-268 .. code-block:: default lst = [] for size in range(1, max_size+1): for under_p in range(2): under_p = bool(under_p) for seed in range(num_seeds): kl = divergence_monte_carlo(p, q, size, under_p=under_p, seed=seed).numpy() lst.append(dict(kl=kl, size=size, seed=seed, under="p" if under_p else "q", approximation="Monte Carlo")) kl = divergence_gauss_hermite(p, q, size, under_p=under_p).numpy() lst.append(dict(kl=kl, size=size, seed=0, under="p" if under_p else "q", approximation="Gauss-Hermite")) data = pd.DataFrame(lst) .. GENERATED FROM PYTHON SOURCE LINES 269-270 Results .. GENERATED FROM PYTHON SOURCE LINES 270-283 .. code-block:: default def axhline(*args, **kwargs): ax = plt.gca() ax.axhline(kl_exact, color="tab:red", label="Exact") g = sns.relplot(x="size", y="kl", hue="approximation", col="under", kind="line", data=data) g.map(axhline) g.set(xscale="log") g.set_axis_labels("size", "KL divergence") .. image:: /auto_examples/misc/images/sphx_glr_plot_gauss_hermite_quadrature_005.png :alt: under = q, under = p :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 28.852 seconds) .. _sphx_glr_download_auto_examples_misc_plot_gauss_hermite_quadrature.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gauss_hermite_quadrature.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gauss_hermite_quadrature.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_