{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "5ea6a2562fcb4047a9a1a99f5380cd74": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_a32fdd5eb95b449eb6cb5f5673290f33", "IPY_MODEL_0b98afdde33f4f74bf0743d7a676cebf", "IPY_MODEL_911a83d48a9240f19079a658c11353b2" ], "layout": "IPY_MODEL_fa61455e18aa488880c2fe8b7e524aaf" } }, "a32fdd5eb95b449eb6cb5f5673290f33": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9de4393c4a4045088df13adb69ccd8c5", "placeholder": "​", "style": "IPY_MODEL_41b993cde00e46d6bccb1d1e300330f7", "value": "Epoch: 100%" } }, "0b98afdde33f4f74bf0743d7a676cebf": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_8c7b555eec6b4008873f7eadacb251b3", "max": 20, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_7f2e91f9f1034d26a5fc73a7db7532c9", "value": 20 } }, "911a83d48a9240f19079a658c11353b2": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c5c35dbf2ca24f91a9f4f8c78d38b658", "placeholder": "​", "style": "IPY_MODEL_61e38aa7bbf246acbe083d08b2c2e54a", "value": " 20/20 [01:28<00:00,  4.36s/it]" } }, "fa61455e18aa488880c2fe8b7e524aaf": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9de4393c4a4045088df13adb69ccd8c5": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "41b993cde00e46d6bccb1d1e300330f7": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8c7b555eec6b4008873f7eadacb251b3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7f2e91f9f1034d26a5fc73a7db7532c9": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "c5c35dbf2ca24f91a9f4f8c78d38b658": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "61e38aa7bbf246acbe083d08b2c2e54a": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "da05486480fa4c3eb7b3633da055ba48": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_297618bcb32e4421920f8478e5a34601", "IPY_MODEL_4d11564806424942922aa2fcd269ab93", "IPY_MODEL_96eebb59849f487a8f00b6072fc76158" ], "layout": "IPY_MODEL_9a08fba58a894762b1bbecd98cb389a3" } }, "297618bcb32e4421920f8478e5a34601": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_cf66eab7ae164b388621e84f28efcee7", "placeholder": "​", "style": "IPY_MODEL_7a140f6db001409cbeac38270f96dbd2", "value": "Evaluating: 100%" } }, "4d11564806424942922aa2fcd269ab93": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_8044dc73f05a4ecc8d0a440726f0cb76", "max": 177, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_bd71e1dc7fe4479c835ce17c4d904f9d", "value": 177 } }, "96eebb59849f487a8f00b6072fc76158": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a849ad9a474b48deaf2b9cc82340b1e4", "placeholder": "​", "style": "IPY_MODEL_6066cd978b21404abf57e15511154578", "value": " 177/177 [00:00<00:00, 330.07it/s]" } }, "9a08fba58a894762b1bbecd98cb389a3": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "cf66eab7ae164b388621e84f28efcee7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "7a140f6db001409cbeac38270f96dbd2": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8044dc73f05a4ecc8d0a440726f0cb76": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "bd71e1dc7fe4479c835ce17c4d904f9d": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a849ad9a474b48deaf2b9cc82340b1e4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6066cd978b21404abf57e15511154578": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "# RNN for Sequence Classification\n", "\n", "What included in this notebook:\n", "\n", "- Implementation of RNN model for text classification task\n", "- Using pre-trained word embeddings to initialize weights for embedding layers" ], "metadata": { "id": "V2s5aqlKrYfK" } }, { "cell_type": "markdown", "source": [ "## Download the data" ], "metadata": { "id": "kEu675S9uyZe" } }, { "cell_type": "code", "source": [ "%%capture\n", "!rm -f titles-en-train.labeled\n", "!rm -f titles-en-test.labeled\n", "\n", "!wget https://raw.githubusercontent.com/neubig/nlptutorial/master/data/titles-en-train.labeled\n", "!wget https://raw.githubusercontent.com/neubig/nlptutorial/master/data/titles-en-test.labeled" ], "metadata": { "id": "ZX-xYOlqvt4F" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Load data\n", "\n", "We will load data into a list of sentences with their labels." ], "metadata": { "id": "LTk2z5cDvwak" } }, { "cell_type": "code", "source": [ "def load_data(file_path):\n", " data = []\n", " with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:\n", " for line in f:\n", " line = line.strip()\n", " if line == '':\n", " continue\n", " lb, text = line.split('\\t')\n", " data.append((text,int(lb)))\n", "\n", " return data" ], "metadata": { "id": "FaEDCdffv4Vu" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "train_data = load_data('./titles-en-train.labeled')\n", "test_data = load_data('./titles-en-test.labeled')\n", "\n", "train_docs, train_labels = zip(*train_data)\n", "test_docs, test_labels = zip(*test_data)" ], "metadata": { "id": "QMPwsd21v6rx" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "train_data[0]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sfzdVqs8rVJt", "outputId": "ebf833b4-5ee0-45b6-8780-a150e6a65e69" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('FUJIWARA no Chikamori ( year of birth and death unknown ) was a samurai and poet who lived at the end of the Heian period .',\n", " 1)" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "markdown", "source": [ "## Steps in building RNN model for text classification\n", "\n", "- Create Vocabulary, Vectorizer, Dataset\n", "- Implement model class\n", "- Training loop\n", "- Evaluation on the test data" ], "metadata": { "id": "WoduhTD2wHVj" } }, { "cell_type": "markdown", "source": [ "## Vocabulary, Vectorizer, Dataset\n", "\n", "For each sentence, we need to transform tokens in the sentence into integer indexes that correspond to indexes of words in a vocabulary.\n", "\n", "So we need:\n", "- Create a vocab from training data\n", "- Vectorize data into integer indexes\n", "- Transform data into Data objects" ], "metadata": { "id": "ZA_AckJwyQd6" } }, { "cell_type": "markdown", "source": [ "### Vocablary class" ], "metadata": { "id": "z0NxuWygyVgk" } }, { "cell_type": "code", "source": [ "from collections import defaultdict\n", "\n", "class Vocabulary:\n", " def __init__(self, token_to_idx=None):\n", " \"\"\"\n", " Args:\n", " token_to_idx (dict): a pre-existing map of tokens to indices\n", " \"\"\"\n", " if token_to_idx is None:\n", " token_to_idx = {}\n", " self._token_to_idx = token_to_idx\n", "\n", " self._idx_to_token = {idx: token\n", " for token, idx in self._token_to_idx.items()}\n", "\n", " self.pad_index = 0\n", " self.unk_index = 1\n", "\n", " def lookup_token(self, token):\n", " \"\"\"Retrieve the index associated with the token\n", " or the UNK index if token isn't present.\n", "\n", " Args:\n", " token (str): the token to look up\n", " Returns:\n", " index (int): the index corresponding to the token\n", " Notes:\n", " `unk_index` needs to be >=0 (having been added into the Vocabulary)\n", " for the UNK functionality\n", " \"\"\"\n", " if self.unk_index >= 0:\n", " return self._token_to_idx.get(token, self.unk_index)\n", " else:\n", " return self._token_to_idx[token]\n", "\n", " def lookup_index(self, index):\n", " \"\"\"Return the token associated with the index\n", "\n", " Args:\n", " index (int): the index to look up\n", " Returns:\n", " token (str): the token corresponding to the index\n", " Raises:\n", " KeyError: if the index is not in the Vocabulary\n", " \"\"\"\n", " if index not in self._idx_to_token:\n", " raise KeyError(\"the index (%d) is not in the Vocabulary\" % index)\n", " return self._idx_to_token[index]\n", "\n", " def add_token(self, token):\n", " \"\"\"Update mapping dicts based on the token.\n", "\n", " Args:\n", " token (str): the item to add into the Vocabulary\n", " Returns:\n", " index (int): the integer corresponding to the token\n", " \"\"\"\n", " if token in self._token_to_idx:\n", " index = self._token_to_idx[token]\n", " else:\n", " index = len(self._token_to_idx)\n", " self._token_to_idx[token] = index\n", " self._idx_to_token[index] = token\n", " return index\n", "\n", " @classmethod\n", " def build_vocab(cls, sentences):\n", " \"\"\"Build vocabulary from a list of sentences\n", "\n", " Arguments:\n", " ----------\n", " sentences (list): list of sentences, each sentence is a string\n", "\n", " Return:\n", " ----------\n", " vocab (Vocabulary): a Vocabulary object\n", " \"\"\"\n", " token_to_idx = {\"\": 0, \"\": 1}\n", " vocab = cls(token_to_idx)\n", "\n", " for s in sentences:\n", " for word in s.split():\n", " vocab.add_token(word)\n", " return vocab\n", "\n", " def __str__(self):\n", " return \"\" % len(self)\n", "\n", " def __len__(self):\n", " return len(self._token_to_idx)" ], "metadata": { "id": "xvHJTQlVrl_N" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's try to create a Vocabulary from the training data" ], "metadata": { "id": "ynIWZbu2R4KT" } }, { "cell_type": "code", "source": [ "vocab = Vocabulary.build_vocab(train_docs)\n", "print(vocab)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8BBqBBcNR8ws", "outputId": "eb7be335-a485-4afe-e9c9-1395b3de14fa" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n" ] } ] }, { "cell_type": "markdown", "source": [ "### Data Vectorizer function" ], "metadata": { "id": "FsTbAHfgqsoP" } }, { "cell_type": "code", "source": [ "import torch\n", "import numpy as np\n", "\n", "def vectorize(vocab, title):\n", " \"\"\"\n", " Args:\n", " vocab (Vocabulary)\n", " title (str): the string of characters\n", " max_length (int): an argument for forcing the length of index vector\n", " \"\"\"\n", " indices = [vocab.lookup_token(token) for token in title.split()]\n", "\n", " return torch.tensor(indices)" ], "metadata": { "id": "9JBgfQo_ePHW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(train_docs[0])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8gUOLxzbf1mx", "outputId": "8528d039-a53c-431c-cc7b-75793c53559e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "FUJIWARA no Chikamori ( year of birth and death unknown ) was a samurai and poet who lived at the end of the Heian period .\n" ] } ] }, { "cell_type": "code", "source": [ "print(vectorize(vocab, train_docs[0]))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3x28FE6af3ya", "outputId": "cf237116-c8ec-40a1-b990-a13bf600348b" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 9, 16, 17, 18,\n", " 19, 20, 21, 7, 20, 22, 23, 24])\n" ] } ] }, { "cell_type": "markdown", "source": [ "### Vectorize training data/test data" ], "metadata": { "id": "8cXDQaASgpua" } }, { "cell_type": "code", "source": [ "train_data = [vectorize(vocab, t) for t in train_docs]\n", "test_data = [vectorize(vocab, t) for t in test_docs]" ], "metadata": { "id": "C_TfqUo6nnqH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(train_data[0])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ttn9haLGsEhQ", "outputId": "d8e3cd59-d33b-48c9-8653-7b0c63945008" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 9, 16, 17, 18,\n", " 19, 20, 21, 7, 20, 22, 23, 24])\n" ] } ] }, { "cell_type": "markdown", "source": [ "### Label Mapping" ], "metadata": { "id": "X1Er1wOsvcyo" } }, { "cell_type": "code", "source": [ "label2idx = {\n", " -1: 0, 1: 1\n", "}\n", "train_y = [label2idx[lb] for lb in train_labels]\n", "test_y = [label2idx[lb] for lb in test_labels]" ], "metadata": { "id": "uZf8hJCLvejJ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Dataset class\n", "\n", "In order to put data into DataLoader, we need to implement a custom Dataset class that inherite [Dataset class](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)\n", "\n", "It is required to implement two functions `__len__` and `__getitem__`" ], "metadata": { "id": "mz4vx3T7CHcB" } }, { "cell_type": "code", "source": [ "from torch.utils.data import Dataset, DataLoader\n", "\n", "class TextDataset(Dataset):\n", "\n", " def __init__(self, sequences, labels):\n", " self.sequences = sequences\n", " self.labels = labels\n", "\n", " def __len__(self):\n", " return len(self.labels)\n", "\n", " def __getitem__(self, index):\n", " x = self.sequences[index]\n", " y = self.labels[index]\n", "\n", " return x, y" ], "metadata": { "id": "Pe8ThNhcCJjc" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Create train_dataset and test_dataset" ], "metadata": { "id": "AwKoCB-6J9kD" } }, { "cell_type": "code", "source": [ "train_dataset = TextDataset(train_data, train_y)\n", "test_dataset = TextDataset(test_data, test_y)" ], "metadata": { "id": "GSYvWJiKKANm" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print( train_dataset[0] )" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6vLikMEaCNp4", "outputId": "30044ab5-2cc1-4704-bf8f-585f4ae911e5" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "(tensor([ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 9, 16, 17, 18,\n", " 19, 20, 21, 7, 20, 22, 23, 24]), 1)\n" ] } ] }, { "cell_type": "markdown", "source": [ "### Create DataLoader\n", "\n", "We need to define function for processing batches generated by DataLoader" ], "metadata": { "id": "_Ln2o4bqGtBj" } }, { "cell_type": "code", "source": [ "from torch.nn.utils.rnn import pad_sequence\n", "\n", "def collate_batch(batch):\n", " \"\"\"Processing a batch generated by DataLoader\n", "\n", " Arguments:\n", " -----\n", " batch (torch.tensor): a tensor generated by DataLoader\n", " \"\"\"\n", " (x, y) = zip(*batch)\n", " x_lens = torch.tensor([len(x) for x in x])\n", " y = torch.tensor(y, dtype=torch.float32)\n", "\n", " x_pad = pad_sequence(x, batch_first=True, padding_value=0)\n", "\n", " return x_pad, x_lens, y" ], "metadata": { "id": "Z2mClJy-HL8e" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## RNN Model\n", "\n", "Our RNN model for text classification includes following layers:\n", "\n", "- Embedding layer ([nn.Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html))\n", "- RNN Layer ([nn.RNN](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html))\n", "- Linear layer with sigmoid ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html))" ], "metadata": { "id": "cB1appQUv2D2" } }, { "cell_type": "code", "source": [ "import torch.nn as nn\n", "from torch.nn.utils.rnn import pack_padded_sequence\n", "\n", "class TextClassifier(nn.Module):\n", "\n", " def __init__(self, vocab_size, embedding_size, rnn_hidden_size, num_classes,\n", " batch_first=True, padding_idx=0):\n", "\n", " super(TextClassifier, self).__init__()\n", "\n", " self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size,\n", " padding_idx=padding_idx)\n", " self.rnn = nn.LSTM(input_size=embedding_size, hidden_size=rnn_hidden_size,\n", " batch_first=batch_first)\n", " self.fc = nn.Linear(in_features=rnn_hidden_size, out_features=num_classes)\n", "\n", "\n", " def forward(self, x_in, x_lens):\n", " x_embed = self.emb(x_in)\n", " x_packed = pack_padded_sequence(x_embed, x_lens, batch_first=True, enforce_sorted=False)\n", " # If we use simple RNN, there are just two elements to unpack\n", " # _, hidden = self.rnn(x_packed)\n", " _, (hidden, _) = self.rnn(x_packed)\n", "\n", " logits = torch.sigmoid(self.fc(hidden))\n", " return logits" ], "metadata": { "id": "xR481z6Q2EAP" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Create an RNN model" ], "metadata": { "id": "DSRjbOIogupJ" } }, { "cell_type": "code", "source": [ "vocab_size = len(vocab) # 27192\n", "embedding_size = 200\n", "rnn_hidden_size = 128\n", "num_classes = 1\n", "batch_first = True\n", "\n", "model = TextClassifier(vocab_size=vocab_size,\n", " embedding_size=embedding_size,\n", " rnn_hidden_size=rnn_hidden_size,\n", " num_classes=num_classes,\n", " batch_first=batch_first)" ], "metadata": { "id": "fDhF3VkugxfG" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(model)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3Dq7CDJ6lIq4", "outputId": "7df73145-37ad-436e-a544-e6f88cf0ef94" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "TextClassifier(\n", " (emb): Embedding(27192, 200, padding_idx=0)\n", " (rnn): LSTM(200, 128, batch_first=True)\n", " (fc): Linear(in_features=128, out_features=1, bias=True)\n", ")\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Training Loop" ], "metadata": { "id": "3aCXXG7oghzC" } }, { "cell_type": "code", "source": [ "from tqdm.notebook import trange, tqdm\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "learning_rate = 1e-3\n", "batch_size = 16\n", "epochs = 20\n", "\n", "criterion = torch.nn.BCELoss()\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "model.to(device)\n", "\n", "def train():\n", " train_dataloader = DataLoader(\n", " train_dataset,\n", " collate_fn=collate_batch,\n", " batch_size=batch_size,\n", " )\n", " model.train()\n", " train_iterator = trange(int(epochs), desc=\"Epoch\")\n", "\n", " for _ in train_iterator:\n", " for x_in, x_lens, y in train_dataloader:\n", " x_in = x_in.to(device)\n", " y = y.to(device)\n", "\n", " optimizer.zero_grad()\n", " pred = model(x_in, x_lens).squeeze()\n", " loss = criterion(pred, y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "train()\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 69, "referenced_widgets": [ "5ea6a2562fcb4047a9a1a99f5380cd74", "a32fdd5eb95b449eb6cb5f5673290f33", "0b98afdde33f4f74bf0743d7a676cebf", "911a83d48a9240f19079a658c11353b2", "fa61455e18aa488880c2fe8b7e524aaf", "9de4393c4a4045088df13adb69ccd8c5", "41b993cde00e46d6bccb1d1e300330f7", "8c7b555eec6b4008873f7eadacb251b3", "7f2e91f9f1034d26a5fc73a7db7532c9", "c5c35dbf2ca24f91a9f4f8c78d38b658", "61e38aa7bbf246acbe083d08b2c2e54a" ] }, "id": "8AOV58Iqgj-2", "outputId": "e7770ca9-37bf-4652-c1b8-e574c4a49edd" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Epoch: 0%| | 0/20 [00:00=0.5).type(torch.long)\n", " preds += _preds.detach().cpu().numpy().tolist()\n", " true_labels += y.detach().cpu().numpy().tolist()\n", "\n", " print(metrics.classification_report(true_labels, preds))\n", "\n", "evaluate()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 236, "referenced_widgets": [ "da05486480fa4c3eb7b3633da055ba48", "297618bcb32e4421920f8478e5a34601", "4d11564806424942922aa2fcd269ab93", "96eebb59849f487a8f00b6072fc76158", "9a08fba58a894762b1bbecd98cb389a3", "cf66eab7ae164b388621e84f28efcee7", "7a140f6db001409cbeac38270f96dbd2", "8044dc73f05a4ecc8d0a440726f0cb76", "bd71e1dc7fe4479c835ce17c4d904f9d", "a849ad9a474b48deaf2b9cc82340b1e4", "6066cd978b21404abf57e15511154578" ] }, "id": "gB5DuSEcpqg_", "outputId": "ed92c8f8-0fc9-41da-aeab-0bf2d77c6ad9" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Evaluating: 0%| | 0/177 [00:00