{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Simple RNN in Pytorch\n",
        "\n",
        "What included in this notebook:\n",
        "\n",
        "- Implementation of Elman Network (Simple RNN) with nn.RNNCell and nn.RNN"
      ],
      "metadata": {
        "id": "nmuIpU10YrsP"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Simple RNN networks\n",
        "\n",
        "To demonstrate how RNN works, we use [nn.RNNCell](https://pytorch.org/docs/stable/generated/torch.nn.RNNCell.html) in Pytorch."
      ],
      "metadata": {
        "id": "JdYitlMdcI3e"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "\n",
        "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
        "\n",
        "class ElmanRNN(nn.Module):\n",
        "    \"\"\" an Elman RNN built using RNNCell \"\"\"\n",
        "\n",
        "    def __init__(self, input_size, hidden_size, batch_first=False):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            input_size (int): size of the input vectors\n",
        "            hidden_size (int): size of the hidden state vectors\n",
        "            batch_first (bool): whether the 0th dimension is batch\n",
        "        \"\"\"\n",
        "        super(ElmanRNN, self).__init__()\n",
        "        self.rnn_cell = nn.RNNCell(input_size, hidden_size)\n",
        "\n",
        "        self.batch_first = batch_first\n",
        "        self.hidden_size = hidden_size\n",
        "\n",
        "    def _initialize_hidden(self, batch_size):\n",
        "        return torch.zeros((batch_size, self.hidden_size))\n",
        "\n",
        "    def forward(self, x_in, initial_hidden=None):\n",
        "        \"\"\"The forward pass of the ElmanRNN\n",
        "        Args:\n",
        "            x_in (torch.Tensor): an input data tensor.\n",
        "                If self.batch_first: x_in.shape = (batch_size, seq_size, feat_size)\n",
        "                Else: x_in.shape = (seq_size, batch_size, feat_size)\n",
        "            initial_hidden (torch.Tensor): the initial hidden state for the RNN\n",
        "        Returns:\n",
        "            hiddens (torch.Tensor): The outputs of the RNN at each time step.\n",
        "                If self.batch_first:\n",
        "                    hiddens.shape = (batch_size, seq_size, hidden_size)\n",
        "                Else: hiddens.shape = (seq_size, batch_size, hidden_size)\n",
        "        \"\"\"\n",
        "        if self.batch_first:\n",
        "            batch_size, seq_size, feat_size = x_in.size()\n",
        "            x_in = x_in.permute(1, 0, 2)   # https://pytorch.org/docs/stable/generated/torch.permute.html\n",
        "        else:\n",
        "            seq_size, batch_size, feat_size = x_in.size()\n",
        "\n",
        "        hiddens = []\n",
        "        if initial_hidden is None:\n",
        "            initial_hidden = self._initialize_hidden(batch_size)\n",
        "            initial_hidden = initial_hidden.to(x_in.device)\n",
        "\n",
        "        hidden_t = initial_hidden\n",
        "        for t in range(seq_size):\n",
        "            hidden_t = self.rnn_cell(x_in[t], hidden_t)\n",
        "            hiddens.append(hidden_t)\n",
        "\n",
        "        hiddens = torch.stack(hiddens)\n",
        "\n",
        "        if self.batch_first:\n",
        "            hiddens = hiddens.permute(1, 0, 2)\n",
        "\n",
        "        return hiddens"
      ],
      "metadata": {
        "id": "IgXWQ5-adVb4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let create a model"
      ],
      "metadata": {
        "id": "KZvXrFSjinfg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = ElmanRNN(100, 20, batch_first=False)\n",
        "# Create a 3 sequences, each sequence contain 3 vectors with 10 features\n",
        "input = torch.randn(4, 3, 100)\n",
        "\n",
        "# Output of the model is 6 sequences of hidden vectors. Each sequence contains\n",
        "# 3 hidden vectors of 10 dimension\n",
        "hiddens = model(input)\n",
        "print(hiddens.size())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Na1IrCIaiqbv",
        "outputId": "f639d924-13a1-4bc3-b154-fd5c41e34906"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([3, 100])\n",
            "torch.Size([3, 100])\n",
            "torch.Size([3, 100])\n",
            "torch.Size([3, 100])\n",
            "torch.Size([4, 3, 20])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can use [nn.RNN](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html) to demonstrate Elman network."
      ],
      "metadata": {
        "id": "n_XEglH_iMC8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "input = torch.randn(4, 3, 100)\n",
        "rnn = nn.RNN(100, 20, batch_first=True)\n",
        "h0 = torch.zeros((1, 4, 20))\n",
        "output, hn = rnn(input, h0)\n",
        "print(output.size())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VYcG8A2yijhM",
        "outputId": "c4be8d8b-4a11-4417-eec5-a7444f7526ed"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([4, 3, 20])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "hn.size()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W-IcrMb0kcA8",
        "outputId": "390f3f13-0dc8-43e6-c735-056e42c14862"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([1, 6, 20])"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    }
  ]
}