{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard", "widgets": { "application/vnd.jupyter.widget-state+json": { "5b111d01ace6464a84e9cdb62f15347a": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b8f665ed9f8348cb9109977af001da42", "IPY_MODEL_49b62860994c4242bbe5b95e99ae4ef6", "IPY_MODEL_f3d1f1b01e1349f59be5be007c188ba5" ], "layout": "IPY_MODEL_279562ff2314475198b1f546b1fd9ab7" } }, "b8f665ed9f8348cb9109977af001da42": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a823d460d6ce4e05aa0d2ef70a0924bc", "placeholder": "​", "style": "IPY_MODEL_3e0a7036cd77412297fdd3756b617f69", "value": "Epoch: 100%" } }, "49b62860994c4242bbe5b95e99ae4ef6": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_f997fbb2f7d24296bd12031f3c17d21a", "max": 100, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_44d2fa68af5946b2ae133f442dcd30c6", "value": 100 } }, "f3d1f1b01e1349f59be5be007c188ba5": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9287f66ee21c4d748b2b8b854cbee65d", "placeholder": "​", "style": "IPY_MODEL_c4e2863ce2cc4b559ca4c4ad4bd614f7", "value": " 100/100 [02:33<00:00,  1.45s/it]" } }, "279562ff2314475198b1f546b1fd9ab7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a823d460d6ce4e05aa0d2ef70a0924bc": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3e0a7036cd77412297fdd3756b617f69": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f997fbb2f7d24296bd12031f3c17d21a": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "44d2fa68af5946b2ae133f442dcd30c6": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "9287f66ee21c4d748b2b8b854cbee65d": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "c4e2863ce2cc4b559ca4c4ad4bd614f7": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5686d310ce4948c1b08d0f906840af2c": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_ac8700369f854305a94393a8ce12a095", "IPY_MODEL_13495cd8b2b24f2dbf5a3899d25e48f1", "IPY_MODEL_244cb193fb1f4a14bbfedd8fa96a2982" ], "layout": "IPY_MODEL_45b3cb3e53264c46a2e09d56a4685524" } }, "ac8700369f854305a94393a8ce12a095": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_621600e26e4d40278a3dd81bd0ff46f6", "placeholder": "​", "style": "IPY_MODEL_478719e8557b41b58f823cbc7967cff9", "value": "Evaluating: 100%" } }, "13495cd8b2b24f2dbf5a3899d25e48f1": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_368bc735dbe848e1b3189ff3f9f70307", "max": 25, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_d7f2f734034a4f77b525a2dd2bd60366", "value": 25 } }, "244cb193fb1f4a14bbfedd8fa96a2982": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_be5c7d7437c142089a170510b3248305", "placeholder": "​", "style": "IPY_MODEL_2b95324497b846a0bc7c1c61d94f6265", "value": " 25/25 [00:00<00:00, 129.48it/s]" } }, "45b3cb3e53264c46a2e09d56a4685524": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "621600e26e4d40278a3dd81bd0ff46f6": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "478719e8557b41b58f823cbc7967cff9": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "368bc735dbe848e1b3189ff3f9f70307": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d7f2f734034a4f77b525a2dd2bd60366": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "be5c7d7437c142089a170510b3248305": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2b95324497b846a0bc7c1c61d94f6265": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "source": [ "# RNN for POS Tagging\n", "\n", "What included in the notebook:\n", "\n", "- Implementation of RNN model for POS Tagging" ], "metadata": { "id": "jM909rAz86iC" } }, { "cell_type": "code", "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "torch.manual_seed(1)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "T2BHCHqO4hEC", "outputId": "afc87696-f734-4dfb-b617-322b69f84ac1" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 1 } ] }, { "cell_type": "code", "source": [ "nltk.__version__" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "d0UY0nkfSXTc", "outputId": "52200a30-d0c0-4db9-889d-1a27da50cbfe" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'3.9.1'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 27 } ] }, { "cell_type": "markdown", "source": [ "## Dataset\n", "\n", "We will use the Treebak data obtained from nltk.\n" ], "metadata": { "id": "0D3at1Q5ACM3" } }, { "cell_type": "code", "source": [ "%%capture\n", "\n", "import nltk\n", "from nltk.corpus import treebank\n", "\n", "nltk.download('universal_tagset')\n", "nltk.download('treebank')" ], "metadata": { "id": "AvmzTdKNAGca" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Load data set" ], "metadata": { "id": "yexlfZRQAJh9" } }, { "cell_type": "code", "source": [ "tagged_sentences = treebank.tagged_sents(tagset='universal')\n", "tagged_sentences[0]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ejPjn4-pAQvg", "outputId": "ad023d6d-45a1-424d-dbc8-16abe849ca8b" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('Pierre', 'NOUN'),\n", " ('Vinken', 'NOUN'),\n", " (',', '.'),\n", " ('61', 'NUM'),\n", " ('years', 'NOUN'),\n", " ('old', 'ADJ'),\n", " (',', '.'),\n", " ('will', 'VERB'),\n", " ('join', 'VERB'),\n", " ('the', 'DET'),\n", " ('board', 'NOUN'),\n", " ('as', 'ADP'),\n", " ('a', 'DET'),\n", " ('nonexecutive', 'ADJ'),\n", " ('director', 'NOUN'),\n", " ('Nov.', 'NOUN'),\n", " ('29', 'NUM'),\n", " ('.', '.')]" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "markdown", "source": [ "## Create train/test/split" ], "metadata": { "id": "RC--NOf7Az-2" } }, { "cell_type": "code", "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train_tagged_sentences, test_tagged_sentences = train_test_split(tagged_sentences, test_size=0.2, random_state=42)" ], "metadata": { "id": "SiqKtYo5EFd3" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Seperate sentences and tag sequences" ], "metadata": { "id": "Z23hZnfXEIGm" } }, { "cell_type": "code", "source": [ "def make_x_y(tagged_sentences):\n", " \"\"\"Seperate sentences and tag sequences from tagged sentences\n", "\n", " Arguments\n", " ----------\n", " tagged_sentences\n", "\n", " Returns\n", " ----------\n", " sentences (list): list of sentences. Each sentence is a list of words\n", " tag_sequences\n", " \"\"\"\n", " sentences = []\n", " tag_sequences = []\n", " for s in tagged_sentences:\n", " words, tags = zip(*s)\n", " sentences.append(list(words))\n", " tag_sequences.append(list(tags))\n", " return sentences, tag_sequences" ], "metadata": { "id": "mW6DzBhEAjIS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "train_sentences, train_tag_sequences = make_x_y(train_tagged_sentences)\n", "test_sentences, test_tag_sequences = make_x_y(test_tagged_sentences)" ], "metadata": { "id": "xGhHIkYJ6vOG" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Steps in building RNN model for POS Tagging\n", "\n", "- Create Vocabulary, Vectorizer, Dataset\n", "- Implement model class\n", "- Training loop\n", "- Evaluation on the test data" ], "metadata": { "id": "SnIOM22QAX3H" } }, { "cell_type": "markdown", "source": [ "## Create Vocabulary\n", "\n", "We modified the Vocabulary class in the previous lecture.\n", "\n", "We need to convert tags into integer indeces, so we will create two vocabularies, one for words and one for tags." ], "metadata": { "id": "YKC5n_msAxLL" } }, { "cell_type": "code", "source": [ "from collections import defaultdict\n", "\n", "class Vocabulary:\n", " def __init__(self, token_to_idx=None, use_unk=True):\n", " \"\"\"\n", " Args:\n", " token_to_idx (dict): a pre-existing map of tokens to indices\n", " \"\"\"\n", " if token_to_idx is None:\n", " token_to_idx = {}\n", " self._token_to_idx = token_to_idx\n", "\n", " self._idx_to_token = {idx: token\n", " for token, idx in self._token_to_idx.items()}\n", "\n", " self.pad_index = 0\n", "\n", " if use_unk:\n", " self.unk_index = 1\n", " else:\n", " self.unk_index = -1\n", "\n", " def lookup_token(self, token):\n", " \"\"\"Retrieve the index associated with the token\n", " or the UNK index if token isn't present.\n", "\n", " Args:\n", " token (str): the token to look up\n", " Returns:\n", " index (int): the index corresponding to the token\n", " Notes:\n", " `unk_index` needs to be >=0 (having been added into the Vocabulary)\n", " for the UNK functionality\n", " \"\"\"\n", " if self.unk_index >= 0:\n", " return self._token_to_idx.get(token, self.unk_index)\n", " else:\n", " return self._token_to_idx[token]\n", "\n", " def lookup_index(self, index):\n", " \"\"\"Return the token associated with the index\n", "\n", " Args:\n", " index (int): the index to look up\n", " Returns:\n", " token (str): the token corresponding to the index\n", " Raises:\n", " KeyError: if the index is not in the Vocabulary\n", " \"\"\"\n", " if index not in self._idx_to_token:\n", " raise KeyError(\"the index (%d) is not in the Vocabulary\" % index)\n", " return self._idx_to_token[index]\n", "\n", " def add_token(self, token):\n", " \"\"\"Update mapping dicts based on the token.\n", "\n", " Args:\n", " token (str): the item to add into the Vocabulary\n", " Returns:\n", " index (int): the integer corresponding to the token\n", " \"\"\"\n", " if token in self._token_to_idx:\n", " index = self._token_to_idx[token]\n", " else:\n", " index = len(self._token_to_idx)\n", " self._token_to_idx[token] = index\n", " self._idx_to_token[index] = token\n", " return index\n", "\n", " @classmethod\n", " def build_vocab(cls, sequences, use_unk=True):\n", " \"\"\"Build vocabulary from a list of sequences\n", " A sequence may be a sequence of words or a sequence of tags.\n", "\n", " Arguments:\n", " ----------\n", " sequences (list): list of sequences, each sentence list of words\n", " or list of tags\n", "\n", " Return:\n", " ----------\n", " vocab (Vocabulary): a Vocabulary object\n", " \"\"\"\n", " if use_unk:\n", " token_to_idx = {\"\": 0, \"\": 1}\n", " else:\n", " token_to_idx = {\"\": 0}\n", "\n", " vocab = cls(token_to_idx, use_unk=use_unk)\n", " for s in sequences:\n", " for word in s:\n", " vocab.add_token(word)\n", " return vocab\n", "\n", " def __str__(self):\n", " return \"\" % len(self)\n", "\n", " def __len__(self):\n", " return len(self._token_to_idx)" ], "metadata": { "id": "11ojPoT_8he1" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Word vocabulary\n", "word_vocab = Vocabulary.build_vocab(train_sentences)\n", "print(word_vocab)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YO2HGYCKANK9", "outputId": "d6c1fbe1-cf2f-4b1a-e429-9618233d0fce" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n" ] } ] }, { "cell_type": "code", "source": [ "# Tag vocabulary\n", "tag_vocab = Vocabulary.build_vocab(train_tag_sequences, use_unk=False)\n", "print(tag_vocab._token_to_idx)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HaNIuMBwW0dh", "outputId": "e3e34887-d3ca-46e9-9fe8-5bc4fe7501a3" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'': 0, 'NOUN': 1, '.': 2, 'NUM': 3, 'ADJ': 4, 'VERB': 5, 'DET': 6, 'ADP': 7, 'CONJ': 8, 'PRON': 9, 'X': 10, 'ADV': 11, 'PRT': 12}\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Data Vectorizer" ], "metadata": { "id": "63wTCWWtaTHb" } }, { "cell_type": "code", "source": [ "import torch\n", "import numpy as np\n", "\n", "def vectorize(vocab, sequence):\n", " \"\"\"\n", " Args:\n", " vocab (Vocabulary)\n", " sequence (list): list of words or tags\n", " \"\"\"\n", " indices = [vocab.lookup_token(token) for token in sequence]\n", "\n", " return torch.tensor(indices, dtype=torch.long)" ], "metadata": { "id": "fqrUB46vaVGF" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(train_sentences[0])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vQMTRqC_gCAZ", "outputId": "98556011-58ad-4e76-f18c-dad8530f4745" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Pierre', 'Vinken', ',', '61', 'years', 'old', ',', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'Nov.', '29', '.']\n" ] } ] }, { "cell_type": "code", "source": [ "vectorize(word_vocab, train_sentences[0])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "C6DN2qF3gE5i", "outputId": "eafe79c6-d4cc-4061-add8-a0cce6a0aa7a" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([ 2, 3, 4, 5, 6, 7, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])" ] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "code", "source": [ "print(train_tag_sequences[0])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "75muyN_vgQ0E", "outputId": "d7497118-8e56-41bd-f8e1-f99a43d1a2ea" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['NOUN', 'NOUN', '.', 'NUM', 'NOUN', 'ADJ', '.', 'VERB', 'VERB', 'DET', 'NOUN', 'ADP', 'DET', 'ADJ', 'NOUN', 'NOUN', 'NUM', '.']\n" ] } ] }, { "cell_type": "code", "source": [ "vectorize(tag_vocab, train_tag_sequences[0])" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cgl3-wergTsz", "outputId": "3a07d952-bcd1-4ed6-ef0a-fc8349131871" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([1, 1, 2, 3, 1, 4, 2, 5, 5, 6, 1, 7, 6, 4, 1, 1, 3, 2])" ] }, "metadata": {}, "execution_count": 15 } ] }, { "cell_type": "markdown", "source": [ "Vectorize train/test data" ], "metadata": { "id": "fTRwKyaNYv46" } }, { "cell_type": "code", "source": [ "train_data = [vectorize(word_vocab, t) for t in train_sentences]\n", "test_data = [vectorize(word_vocab, t) for t in test_sentences]\n", "\n", "train_y = [vectorize(tag_vocab, t) for t in train_tag_sequences]\n", "test_y = [vectorize(tag_vocab, t) for t in test_tag_sequences]" ], "metadata": { "id": "ey8sO7ADYxmj" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Dataset class" ], "metadata": { "id": "QXU1ufU3JMxJ" } }, { "cell_type": "code", "source": [ "from torch.utils.data import Dataset, DataLoader\n", "\n", "class TextDataset(Dataset):\n", "\n", " def __init__(self, sequences, tag_sequences):\n", " \"\"\"\n", " Args:\n", " sequences (list): list of sentences. Each sentence is a list of words\n", " tag_sequences (list): list of tag sequences, each for one sentence\n", " \"\"\"\n", " self.sequences = sequences\n", " self.tag_sequences = tag_sequences\n", "\n", " def __len__(self):\n", " return len(self.sequences)\n", "\n", " def __getitem__(self, index):\n", " x = self.sequences[index]\n", " y = self.tag_sequences[index]\n", "\n", " return x, y" ], "metadata": { "id": "qfQ-PxQyJOau" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Create train_dataset and test_dataset" ], "metadata": { "id": "Z4Z2ksjoLLaH" } }, { "cell_type": "code", "source": [ "train_dataset = TextDataset(train_data, train_y)\n", "test_dataset = TextDataset(test_data, test_y)" ], "metadata": { "id": "14lrsCGpYNwW" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print( train_dataset[1] )" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IaVUUrZXYfVA", "outputId": "868c81c0-7bfa-4ab1-aa73-b5d2b1dbc62c" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "(tensor([19, 20, 21, 22, 18]), tensor([8, 9, 5, 4, 2]))\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Create DataLoader\n", "\n", "We need to define function for processing batches generated by DataLoader" ], "metadata": { "id": "5J4SlQHDoWiu" } }, { "cell_type": "code", "source": [ "from torch.nn.utils.rnn import pad_sequence\n", "\n", "def collate_batch(batch):\n", " \"\"\"Processing a batch generated by DataLoader\n", "\n", " Arguments:\n", " -----\n", " batch (torch.tensor): a tensor generated by DataLoader\n", " \"\"\"\n", " (x, y) = zip(*batch)\n", " x_lens = torch.tensor([len(x) for x in x])\n", " y_lens = torch.tensor([len(y) for y in y])\n", "\n", " x_pad = pad_sequence(x, batch_first=True, padding_value=0)\n", " y_pad = pad_sequence(y, batch_first=True, padding_value=0)\n", "\n", " return x_pad, y_pad, x_lens, y_lens" ], "metadata": { "id": "SCrjP0hroXlV" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## RNN Tagging Model" ], "metadata": { "id": "fkmP6Xhp9jiD" } }, { "cell_type": "code", "source": [ "import torch.nn as nn\n", "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", "\n", "\n", "class LSTMTagger(nn.Module):\n", "\n", " def __init__(self, vocab_size, embedding_dim, hidden_dim, tagset_size,\n", " num_layers=1, batch_first=True, padding_idx=0):\n", "\n", " super(LSTMTagger, self).__init__()\n", " self.hidden_dim = hidden_dim\n", "\n", " self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim,\n", " padding_idx=padding_idx)\n", "\n", " # The LSTM takes word embeddings as inputs, and outputs hidden states\n", " # with dimensionality hidden_dim.\n", " self.lstm = nn.LSTM(embedding_dim, hidden_dim,\n", " num_layers=num_layers, bidirectional=True, batch_first=batch_first)\n", " self.fc = nn.Linear(in_features=2*hidden_dim, out_features=tagset_size)\n", "\n", " ## Comment out to disable weight initialization\n", " torch.nn.init.xavier_uniform_(self.emb.weight)\n", " torch.nn.init.xavier_uniform_(self.fc.weight)\n", "\n", " def forward(self, x_in, x_lens):\n", " x_embed = self.emb(x_in)\n", " x_packed = pack_padded_sequence(x_embed, x_lens, batch_first=True, enforce_sorted=False)\n", " output_packed, _ = self.lstm(x_packed)\n", " output_padded, output_lengths = pad_packed_sequence(output_packed, batch_first=True)\n", " tag_space = self.fc(output_padded)\n", " tag_scores = F.log_softmax(tag_space, dim=1)\n", " return tag_scores" ], "metadata": { "id": "FGt-X6ei9nJq" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Create an LSTM Tagger Model" ], "metadata": { "id": "d5y8qIUEAxrK" } }, { "cell_type": "code", "source": [ "vocab_size = len(word_vocab)\n", "embedding_dim = 300\n", "hidden_dim = 128\n", "num_layers = 2\n", "tagset_size = len(tag_vocab)\n", "batch_first = True\n", "\n", "model = LSTMTagger(vocab_size=vocab_size,\n", " embedding_dim=embedding_dim,\n", " hidden_dim=hidden_dim,\n", " num_layers=num_layers,\n", " tagset_size=tagset_size,\n", " batch_first=batch_first)" ], "metadata": { "id": "ZjYlVuHgA1dR" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(model)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JwHuxQoXB0Aw", "outputId": "31d6351b-2d14-41b5-b550-ffd1c970c0dc" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "LSTMTagger(\n", " (emb): Embedding(11051, 300, padding_idx=0)\n", " (lstm): LSTM(300, 128, num_layers=2, batch_first=True, bidirectional=True)\n", " (fc): Linear(in_features=256, out_features=13, bias=True)\n", ")\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Training Loop" ], "metadata": { "id": "BRtuvVjrB41j" } }, { "cell_type": "code", "source": [ "from tqdm.notebook import trange, tqdm\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "learning_rate = 1e-3\n", "batch_size = 32\n", "epochs = 100\n", "\n", "criterion = torch.nn.CrossEntropyLoss(ignore_index=0)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "model.to(device)\n", "\n", "def train():\n", " train_dataloader = DataLoader(\n", " train_dataset,\n", " collate_fn=collate_batch,\n", " batch_size=batch_size,\n", " )\n", " model.train()\n", " train_iterator = trange(int(epochs), desc=\"Epoch\")\n", "\n", " for _ in train_iterator:\n", " for x_pad, y_pad, x_lens, y_lens in train_dataloader:\n", " x_pad = x_pad.to(device)\n", " y_pad = y_pad.to(device)\n", "\n", " optimizer.zero_grad()\n", " pred = model(x_pad, x_lens)\n", "\n", " pred = pred.view(-1, pred.shape[-1])\n", " y_pad = y_pad.view(-1)\n", "\n", " loss = criterion(pred, y_pad)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "train()\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "5b111d01ace6464a84e9cdb62f15347a", "b8f665ed9f8348cb9109977af001da42", "49b62860994c4242bbe5b95e99ae4ef6", "f3d1f1b01e1349f59be5be007c188ba5", "279562ff2314475198b1f546b1fd9ab7", "a823d460d6ce4e05aa0d2ef70a0924bc", "3e0a7036cd77412297fdd3756b617f69", "f997fbb2f7d24296bd12031f3c17d21a", "44d2fa68af5946b2ae133f442dcd30c6", "9287f66ee21c4d748b2b8b854cbee65d", "c4e2863ce2cc4b559ca4c4ad4bd614f7" ] }, "id": "z8WZa3b0B72x", "outputId": "fbadcac2-d614-4803-a77f-5f815ad0ca3f" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Epoch: 0%| | 0/100 [00:00