.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/misc/plot_lstm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_misc_plot_lstm.py: Long short-term memory (LSTM) networks ====================================== Hello world .. GENERATED FROM PYTHON SOURCE LINES 8-15 .. code-block:: default import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from tensorflow.keras.models import Sequential from tensorflow.keras.layers import LSTM, Bidirectional, Dense, RepeatVector, TimeDistributed .. GENERATED FROM PYTHON SOURCE LINES 17-24 .. code-block:: default num_seqs = 5 seq_len = 25 num_features = 1 seed = 42 # set random seed for reproducibility random_state = np.random.RandomState(seed) .. GENERATED FROM PYTHON SOURCE LINES 25-27 Many-to-many LSTM ----------------- .. GENERATED FROM PYTHON SOURCE LINES 27-32 .. code-block:: default # generate random walks in Euclidean space inputs = np.cumsum(random_state.randn(num_seqs, seq_len, num_features), axis=1) lstm = LSTM(units=1, return_sequences=True) output = lstm(inputs) .. GENERATED FROM PYTHON SOURCE LINES 33-35 .. code-block:: default print(output.shape) .. rst-class:: sphx-glr-script-out Out: .. code-block:: none (5, 25, 1) .. GENERATED FROM PYTHON SOURCE LINES 36-45 .. code-block:: default fig, ax = plt.subplots() ax.plot(inputs[..., 0].T) ax.set_xlabel(r"$t$") ax.set_ylabel(r"$x(t)$") plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_lstm_001.png :alt: plot lstm :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 46-55 .. code-block:: default fig, ax = plt.subplots() ax.plot(output.numpy()[..., 0].T) ax.set_xlabel(r"$t$") ax.set_ylabel(r"$h(t)$") plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_lstm_002.png :alt: plot lstm :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 56-67 .. code-block:: default fig, ax = plt.subplots() ax.plot(inputs[..., 0].T, output.numpy()[..., 0].T) ax.set_xlabel(r"$x(t)$") ax.set_ylabel(r"$h(t)$") plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_lstm_003.png :alt: plot lstm :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 68-70 One-to-many LSTM ---------------- .. GENERATED FROM PYTHON SOURCE LINES 70-74 .. code-block:: default seq_len = 100 num_index_points = 128 xmin, xmax = -10.0, 10.0 X_grid = np.linspace(xmin, xmax, num_index_points).reshape(-1, num_features) .. GENERATED FROM PYTHON SOURCE LINES 75-82 .. code-block:: default model = Sequential([ RepeatVector(seq_len, input_shape=(num_features,)), Bidirectional(LSTM(units=32, return_sequences=True), merge_mode="concat"), TimeDistributed(Dense(1)) ]) model.summary() .. rst-class:: sphx-glr-script-out Out: .. code-block:: none Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= repeat_vector (RepeatVector) (None, 100, 1) 0 _________________________________________________________________ bidirectional (Bidirectional (None, 100, 64) 8704 _________________________________________________________________ time_distributed (TimeDistri (None, 100, 1) 65 ================================================================= Total params: 8,769 Trainable params: 8,769 Non-trainable params: 0 _________________________________________________________________ .. GENERATED FROM PYTHON SOURCE LINES 83-86 .. code-block:: default Z = model(X_grid) Z .. rst-class:: sphx-glr-script-out Out: .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 87-101 .. code-block:: default T_grid = np.arange(seq_len) fig, ax = plt.subplots(subplot_kw=dict(projection="3d")) ax.plot_wireframe(T_grid, X_grid, Z.numpy().squeeze(axis=-1), alpha=0.6) # ax.plot_surface(T_grid, X_grid, Z.numpy().squeeze(axis=-1), # edgecolor='k', linewidth=0.5, cmap="Spectral_r") ax.set_xlabel(r'$t$') ax.set_ylabel(r'$x$') ax.set_zlabel(r"$h(x, t)$") plt.show() .. image:: /auto_examples/misc/images/sphx_glr_plot_lstm_004.png :alt: plot lstm :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 4.398 seconds) .. _sphx_glr_download_auto_examples_misc_plot_lstm.py: .. only :: html .. container:: sphx-glr-footer :class: sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_lstm.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_lstm.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_