{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyMeB5e+dwDyZ4C5YI2Lgeyh"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","gpuClass":"standard","widgets":{"application/vnd.jupyter.widget-state+json":{"b038b4756b0c40c48ad78056281eec15":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_c2d17f202ab142f3b2f8cd2101751bd6","IPY_MODEL_a8cee5d0785c421eb7a8971b689032a3","IPY_MODEL_7811bfb507e74d6392abfff39f0c9663"],"layout":"IPY_MODEL_92894ccb2a9042abb6016cdd1e2591ae"}},"c2d17f202ab142f3b2f8cd2101751bd6":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_0f790adec961470f8f601c26ae792d73","placeholder":"​","style":"IPY_MODEL_3fcebe3181af474999de8f7b082e8df1","value":"Epoch: 100%"}},"a8cee5d0785c421eb7a8971b689032a3":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_cddf0a308c4c49efa698eb2bd7c25ad6","max":100,"min":0,"orientation":"horizontal","style":"IPY_MODEL_296707bec4a2425f9927233d07f8f1b4","value":100}},"7811bfb507e74d6392abfff39f0c9663":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_b4e2f0626d254240af4ca97237e2dc83","placeholder":"​","style":"IPY_MODEL_13c8b4ecdb264d3a83545ffc820199ab","value":" 100/100 [02:37<00:00, 1.90s/it]"}},"92894ccb2a9042abb6016cdd1e2591ae":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"0f790adec961470f8f601c26ae792d73":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"3fcebe3181af474999de8f7b082e8df1":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"cddf0a308c4c49efa698eb2bd7c25ad6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"296707bec4a2425f9927233d07f8f1b4":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"b4e2f0626d254240af4ca97237e2dc83":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"13c8b4ecdb264d3a83545ffc820199ab":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"bf87806d146d45ed90ed84b3aa42a7fb":{"model_module":"@jupyter-widgets/controls","model_name":"HBoxModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_3961b13915714bb085d816a013fbc193","IPY_MODEL_751ba7294b6a42c8a9df37b878f4863e","IPY_MODEL_a8d71e20189845a081fdf6bd32907423"],"layout":"IPY_MODEL_9456bba57b0c43928aec80be6311e58a"}},"3961b13915714bb085d816a013fbc193":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_ddefdee25688416288dfff4403e80e23","placeholder":"​","style":"IPY_MODEL_895bdef0ed974438be79e1ead24aa939","value":"Evaluating: 100%"}},"751ba7294b6a42c8a9df37b878f4863e":{"model_module":"@jupyter-widgets/controls","model_name":"FloatProgressModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_d292156d0b094db6a5cba225906dba9c","max":25,"min":0,"orientation":"horizontal","style":"IPY_MODEL_72f62598d4dd4f099ecf0b80925bc52a","value":25}},"a8d71e20189845a081fdf6bd32907423":{"model_module":"@jupyter-widgets/controls","model_name":"HTMLModel","model_module_version":"1.5.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_65d44cec783247dfb50177decd446e1b","placeholder":"​","style":"IPY_MODEL_5065856304db4bde8b90d329e38bcc61","value":" 25/25 [00:00<00:00, 106.02it/s]"}},"9456bba57b0c43928aec80be6311e58a":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"ddefdee25688416288dfff4403e80e23":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"895bdef0ed974438be79e1ead24aa939":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"d292156d0b094db6a5cba225906dba9c":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"72f62598d4dd4f099ecf0b80925bc52a":{"model_module":"@jupyter-widgets/controls","model_name":"ProgressStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"65d44cec783247dfb50177decd446e1b":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"5065856304db4bde8b90d329e38bcc61":{"model_module":"@jupyter-widgets/controls","model_name":"DescriptionStyleModel","model_module_version":"1.5.0","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}}}}},"cells":[{"cell_type":"markdown","source":["# RNN for POS Tagging\n","\n","What included in the notebook:\n","\n","- Implementation of RNN model for POS Tagging"],"metadata":{"id":"jM909rAz86iC"}},{"cell_type":"code","source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","\n","torch.manual_seed(1)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"T2BHCHqO4hEC","executionInfo":{"status":"ok","timestamp":1678518895533,"user_tz":-420,"elapsed":612,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"c4e028bf-e2b9-40f5-a374-8d5e4ee24f04"},"execution_count":59,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""]},"metadata":{},"execution_count":59}]},{"cell_type":"markdown","source":["## Dataset\n","\n","We will use the Treebak data obtained from nltk.\n"],"metadata":{"id":"0D3at1Q5ACM3"}},{"cell_type":"code","source":["%%capture\n","\n","import nltk\n","from nltk.corpus import treebank\n","\n","nltk.download('universal_tagset')\n","nltk.download('treebank')"],"metadata":{"id":"AvmzTdKNAGca","executionInfo":{"status":"ok","timestamp":1678518896072,"user_tz":-420,"elapsed":7,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":60,"outputs":[]},{"cell_type":"markdown","source":["Load data set"],"metadata":{"id":"yexlfZRQAJh9"}},{"cell_type":"code","source":["tagged_sentences = treebank.tagged_sents(tagset='universal')\n","tagged_sentences[0]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ejPjn4-pAQvg","executionInfo":{"status":"ok","timestamp":1678518896073,"user_tz":-420,"elapsed":6,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"46397ca7-7230-4b41-f55b-787f80dee25b"},"execution_count":61,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[('Pierre', 'NOUN'),\n"," ('Vinken', 'NOUN'),\n"," (',', '.'),\n"," ('61', 'NUM'),\n"," ('years', 'NOUN'),\n"," ('old', 'ADJ'),\n"," (',', '.'),\n"," ('will', 'VERB'),\n"," ('join', 'VERB'),\n"," ('the', 'DET'),\n"," ('board', 'NOUN'),\n"," ('as', 'ADP'),\n"," ('a', 'DET'),\n"," ('nonexecutive', 'ADJ'),\n"," ('director', 'NOUN'),\n"," ('Nov.', 'NOUN'),\n"," ('29', 'NUM'),\n"," ('.', '.')]"]},"metadata":{},"execution_count":61}]},{"cell_type":"markdown","source":["## Create train/test/split"],"metadata":{"id":"RC--NOf7Az-2"}},{"cell_type":"code","source":["from sklearn.model_selection import train_test_split\n","\n","train_tagged_sentences, test_tagged_sentences = train_test_split(tagged_sentences, test_size=0.2, random_state=42)"],"metadata":{"id":"SiqKtYo5EFd3","executionInfo":{"status":"ok","timestamp":1678518898590,"user_tz":-420,"elapsed":2521,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":62,"outputs":[]},{"cell_type":"markdown","source":["## Seperate sentences and tag sequences"],"metadata":{"id":"Z23hZnfXEIGm"}},{"cell_type":"code","source":["def make_x_y(tagged_sentences):\n"," \"\"\"Seperate sentences and tag sequences from tagged sentences\n","\n"," Arguments\n"," ----------\n"," tagged_sentences\n","\n"," Returns\n"," ----------\n"," sentences (list): list of sentences. Each sentence is a list of words\n"," tag_sequences\n"," \"\"\"\n"," sentences = []\n"," tag_sequences = []\n"," for s in tagged_sentences:\n"," words, tags = zip(*s)\n"," sentences.append(list(words))\n"," tag_sequences.append(list(tags))\n"," return sentences, tag_sequences"],"metadata":{"id":"mW6DzBhEAjIS","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":34,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":63,"outputs":[]},{"cell_type":"code","source":["train_sentences, train_tag_sequences = make_x_y(train_tagged_sentences)\n","test_sentences, test_tag_sequences = make_x_y(test_tagged_sentences)"],"metadata":{"id":"xGhHIkYJ6vOG","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":33,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":64,"outputs":[]},{"cell_type":"markdown","source":["## Steps in building RNN model for POS Tagging\n","\n","- Create Vocabulary, Vectorizer, Dataset\n","- Implement model class\n","- Training loop\n","- Evaluation on the test data"],"metadata":{"id":"SnIOM22QAX3H"}},{"cell_type":"markdown","source":["## Create Vocabulary\n","\n","We modified the Vocabulary class in the previous lecture.\n","\n","We need to convert tags into integer indeces, so we will create two vocabularies, one for words and one for tags."],"metadata":{"id":"YKC5n_msAxLL"}},{"cell_type":"code","source":["from collections import defaultdict\n","\n","class Vocabulary:\n"," def __init__(self, token_to_idx=None, use_unk=True):\n"," \"\"\"\n"," Args:\n"," token_to_idx (dict): a pre-existing map of tokens to indices\n"," \"\"\"\n"," if token_to_idx is None:\n"," token_to_idx = {}\n"," self._token_to_idx = token_to_idx\n","\n"," self._idx_to_token = {idx: token \n"," for token, idx in self._token_to_idx.items()}\n"," \n"," self.pad_index = 0\n"," \n"," if use_unk:\n"," self.unk_index = 1\n"," else:\n"," self.unk_index = -1\n","\n"," def lookup_token(self, token):\n"," \"\"\"Retrieve the index associated with the token \n"," or the UNK index if token isn't present.\n"," \n"," Args:\n"," token (str): the token to look up \n"," Returns:\n"," index (int): the index corresponding to the token\n"," Notes:\n"," `unk_index` needs to be >=0 (having been added into the Vocabulary) \n"," for the UNK functionality \n"," \"\"\"\n"," if self.unk_index >= 0:\n"," return self._token_to_idx.get(token, self.unk_index)\n"," else:\n"," return self._token_to_idx[token]\n"," \n"," def lookup_index(self, index):\n"," \"\"\"Return the token associated with the index\n"," \n"," Args: \n"," index (int): the index to look up\n"," Returns:\n"," token (str): the token corresponding to the index\n"," Raises:\n"," KeyError: if the index is not in the Vocabulary\n"," \"\"\"\n"," if index not in self._idx_to_token:\n"," raise KeyError(\"the index (%d) is not in the Vocabulary\" % index)\n"," return self._idx_to_token[index]\n"," \n"," def add_token(self, token):\n"," \"\"\"Update mapping dicts based on the token.\n","\n"," Args:\n"," token (str): the item to add into the Vocabulary\n"," Returns:\n"," index (int): the integer corresponding to the token\n"," \"\"\"\n"," if token in self._token_to_idx:\n"," index = self._token_to_idx[token]\n"," else:\n"," index = len(self._token_to_idx)\n"," self._token_to_idx[token] = index\n"," self._idx_to_token[index] = token\n"," return index\n","\n"," @classmethod\n"," def build_vocab(cls, sequences, use_unk=True):\n"," \"\"\"Build vocabulary from a list of sequences\n"," A sequence may be a sequence of words or a sequence of tags.\n","\n"," Arguments:\n"," ----------\n"," sequences (list): list of sequences, each sentence list of words\n"," or list of tags\n"," \n"," Return:\n"," ----------\n"," vocab (Vocabulary): a Vocabulary object\n"," \"\"\"\n"," if use_unk:\n"," token_to_idx = {\"\": 0, \"\": 1}\n"," else:\n"," token_to_idx = {\"\": 0}\n","\n"," vocab = cls(token_to_idx, use_unk=use_unk)\n"," for s in sequences:\n"," for word in s:\n"," vocab.add_token(word)\n"," return vocab\n","\n"," def __str__(self):\n"," return \"\" % len(self)\n","\n"," def __len__(self):\n"," return len(self._token_to_idx)"],"metadata":{"id":"11ojPoT_8he1","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":31,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":65,"outputs":[]},{"cell_type":"code","source":["# Word vocabulary\n","word_vocab = Vocabulary.build_vocab(train_sentences)\n","print(word_vocab)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YO2HGYCKANK9","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":31,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"b00ff2e9-6cba-4090-d290-d54fc974a680"},"execution_count":66,"outputs":[{"output_type":"stream","name":"stdout","text":["\n"]}]},{"cell_type":"code","source":["# Tag vocabulary\n","tag_vocab = Vocabulary.build_vocab(train_tag_sequences, use_unk=False)\n","print(tag_vocab._token_to_idx)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"HaNIuMBwW0dh","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":26,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"f5a20a14-7fd5-4e4c-e8b6-dd61a8680224"},"execution_count":67,"outputs":[{"output_type":"stream","name":"stdout","text":["{'': 0, 'NOUN': 1, '.': 2, 'NUM': 3, 'ADJ': 4, 'VERB': 5, 'DET': 6, 'ADP': 7, 'CONJ': 8, 'PRON': 9, 'X': 10, 'ADV': 11, 'PRT': 12}\n"]}]},{"cell_type":"markdown","source":["## Data Vectorizer"],"metadata":{"id":"63wTCWWtaTHb"}},{"cell_type":"code","source":["import torch\n","import numpy as np\n","\n","def vectorize(vocab, sequence):\n"," \"\"\"\n"," Args:\n"," vocab (Vocabulary)\n"," sequence (list): list of words or tags\n"," \"\"\"\n"," indices = [vocab.lookup_token(token) for token in sequence]\n"," \n"," return torch.tensor(indices, dtype=torch.long)"],"metadata":{"id":"fqrUB46vaVGF","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":24,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":68,"outputs":[]},{"cell_type":"code","source":["print(train_sentences[0])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vQMTRqC_gCAZ","executionInfo":{"status":"ok","timestamp":1678518898591,"user_tz":-420,"elapsed":23,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"0daa905c-15b8-4cb0-c5ec-e982799c3cea"},"execution_count":69,"outputs":[{"output_type":"stream","name":"stdout","text":["['Pierre', 'Vinken', ',', '61', 'years', 'old', ',', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'Nov.', '29', '.']\n"]}]},{"cell_type":"code","source":["vectorize(word_vocab, train_sentences[0])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"C6DN2qF3gE5i","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":22,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"653ad0aa-4b62-4846-b5eb-6f76504f5e01"},"execution_count":70,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([ 2, 3, 4, 5, 6, 7, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])"]},"metadata":{},"execution_count":70}]},{"cell_type":"code","source":["print(train_tag_sequences[0])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"75muyN_vgQ0E","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":20,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"fd5c8526-882d-471b-b149-a2ba75468275"},"execution_count":71,"outputs":[{"output_type":"stream","name":"stdout","text":["['NOUN', 'NOUN', '.', 'NUM', 'NOUN', 'ADJ', '.', 'VERB', 'VERB', 'DET', 'NOUN', 'ADP', 'DET', 'ADJ', 'NOUN', 'NOUN', 'NUM', '.']\n"]}]},{"cell_type":"code","source":["vectorize(tag_vocab, train_tag_sequences[0])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"cgl3-wergTsz","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":18,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"04585b5a-7294-43d9-ae15-f59986cdd5b7"},"execution_count":72,"outputs":[{"output_type":"execute_result","data":{"text/plain":["tensor([1, 1, 2, 3, 1, 4, 2, 5, 5, 6, 1, 7, 6, 4, 1, 1, 3, 2])"]},"metadata":{},"execution_count":72}]},{"cell_type":"markdown","source":["Vectorize train/test data"],"metadata":{"id":"fTRwKyaNYv46"}},{"cell_type":"code","source":["train_data = [vectorize(word_vocab, t) for t in train_sentences]\n","test_data = [vectorize(word_vocab, t) for t in test_sentences]\n","\n","train_y = [vectorize(tag_vocab, t) for t in train_tag_sequences]\n","test_y = [vectorize(tag_vocab, t) for t in test_tag_sequences]"],"metadata":{"id":"ey8sO7ADYxmj","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":17,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":73,"outputs":[]},{"cell_type":"markdown","source":["## Dataset class"],"metadata":{"id":"QXU1ufU3JMxJ"}},{"cell_type":"code","source":["from torch.utils.data import Dataset, DataLoader\n","\n","class TextDataset(Dataset):\n","\n"," def __init__(self, sequences, tag_sequences):\n"," \"\"\"\n"," Args:\n"," sequences (list): list of sentences. Each sentence is a list of words\n"," tag_sequences (list): list of tag sequences, each for one sentence\n"," \"\"\"\n"," self.sequences = sequences\n"," self.tag_sequences = tag_sequences\n"," \n"," def __len__(self):\n"," return len(self.sequences)\n"," \n"," def __getitem__(self, index):\n"," x = self.sequences[index]\n"," y = self.tag_sequences[index]\n","\n"," return x, y"],"metadata":{"id":"qfQ-PxQyJOau","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":15,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":74,"outputs":[]},{"cell_type":"markdown","source":["Create train_dataset and test_dataset"],"metadata":{"id":"Z4Z2ksjoLLaH"}},{"cell_type":"code","source":["train_dataset = TextDataset(train_data, train_y)\n","test_dataset = TextDataset(test_data, test_y)"],"metadata":{"id":"14lrsCGpYNwW","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":15,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":75,"outputs":[]},{"cell_type":"code","source":["print( train_dataset[1] )"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"IaVUUrZXYfVA","executionInfo":{"status":"ok","timestamp":1678518898592,"user_tz":-420,"elapsed":15,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"616fd4eb-93d8-4ab6-d051-6a1efba4e593"},"execution_count":76,"outputs":[{"output_type":"stream","name":"stdout","text":["(tensor([19, 20, 21, 22, 18]), tensor([8, 9, 5, 4, 2]))\n"]}]},{"cell_type":"markdown","source":["## Create DataLoader\n","\n","We need to define function for processing batches generated by DataLoader"],"metadata":{"id":"5J4SlQHDoWiu"}},{"cell_type":"code","source":["from torch.nn.utils.rnn import pad_sequence\n","\n","def collate_batch(batch):\n"," \"\"\"Processing a batch generated by DataLoader\n","\n"," Arguments:\n"," -----\n"," batch (torch.tensor): a tensor generated by DataLoader\n"," \"\"\"\n"," (x, y) = zip(*batch)\n"," x_lens = torch.tensor([len(x) for x in x])\n"," y_lens = torch.tensor([len(y) for y in y])\n"," \n"," x_pad = pad_sequence(x, batch_first=True, padding_value=0)\n"," y_pad = pad_sequence(y, batch_first=True, padding_value=0)\n","\n"," return x_pad, y_pad, x_lens, y_lens"],"metadata":{"id":"SCrjP0hroXlV","executionInfo":{"status":"ok","timestamp":1678518898593,"user_tz":-420,"elapsed":15,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":77,"outputs":[]},{"cell_type":"markdown","source":["## RNN Tagging Model"],"metadata":{"id":"fkmP6Xhp9jiD"}},{"cell_type":"code","source":["import torch.nn as nn\n","from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n","\n","\n","class LSTMTagger(nn.Module):\n","\n"," def __init__(self, vocab_size, embedding_dim, hidden_dim, tagset_size, \n"," num_layers=1, batch_first=True, padding_idx=0):\n"," \n"," super(LSTMTagger, self).__init__()\n"," self.hidden_dim = hidden_dim\n","\n"," self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, \n"," padding_idx=padding_idx)\n","\n"," # The LSTM takes word embeddings as inputs, and outputs hidden states\n"," # with dimensionality hidden_dim.\n"," self.lstm = nn.LSTM(embedding_dim, hidden_dim,\n"," num_layers=num_layers, bidirectional=True, batch_first=batch_first)\n"," self.fc = nn.Linear(in_features=2*hidden_dim, out_features=tagset_size)\n","\n"," ## Comment out to disable weight initialization\n"," torch.nn.init.xavier_uniform_(self.emb.weight)\n"," torch.nn.init.xavier_uniform_(self.fc.weight)\n","\n"," def forward(self, x_in, x_lens):\n"," x_embed = self.emb(x_in)\n"," x_packed = pack_padded_sequence(x_embed, x_lens, batch_first=True, enforce_sorted=False)\n"," output_packed, _ = self.lstm(x_packed)\n"," output_padded, output_lengths = pad_packed_sequence(output_packed, batch_first=True)\n"," tag_space = self.fc(output_padded)\n"," tag_scores = F.log_softmax(tag_space, dim=1)\n"," return tag_scores"],"metadata":{"id":"FGt-X6ei9nJq","executionInfo":{"status":"ok","timestamp":1678518898593,"user_tz":-420,"elapsed":15,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":78,"outputs":[]},{"cell_type":"markdown","source":["## Create an LSTM Tagger Model"],"metadata":{"id":"d5y8qIUEAxrK"}},{"cell_type":"code","source":["vocab_size = len(word_vocab)\n","embedding_dim = 300\n","hidden_dim = 128\n","num_layers = 2\n","tagset_size = len(tag_vocab)\n","batch_first = True\n","\n","model = LSTMTagger(vocab_size=vocab_size, \n"," embedding_dim=embedding_dim, \n"," hidden_dim=hidden_dim,\n"," num_layers=num_layers,\n"," tagset_size=tagset_size, \n"," batch_first=batch_first)"],"metadata":{"id":"ZjYlVuHgA1dR","executionInfo":{"status":"ok","timestamp":1678518898593,"user_tz":-420,"elapsed":14,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}}},"execution_count":79,"outputs":[]},{"cell_type":"code","source":["print(model)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"JwHuxQoXB0Aw","executionInfo":{"status":"ok","timestamp":1678518898593,"user_tz":-420,"elapsed":13,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"fff661d2-42c8-4c32-9aa9-8d5561940558"},"execution_count":80,"outputs":[{"output_type":"stream","name":"stdout","text":["LSTMTagger(\n"," (emb): Embedding(11051, 300, padding_idx=0)\n"," (lstm): LSTM(300, 128, num_layers=2, batch_first=True, bidirectional=True)\n"," (fc): Linear(in_features=256, out_features=13, bias=True)\n",")\n"]}]},{"cell_type":"markdown","source":["## Training Loop"],"metadata":{"id":"BRtuvVjrB41j"}},{"cell_type":"code","source":["from tqdm.notebook import trange, tqdm\n","\n","device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') \n","\n","learning_rate = 1e-3\n","batch_size = 32\n","epochs = 100\n","\n","criterion = torch.nn.CrossEntropyLoss(ignore_index=0)\n","optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n","model.to(device)\n","\n","def train():\n"," train_dataloader = DataLoader(\n"," train_dataset,\n"," collate_fn=collate_batch,\n"," batch_size=batch_size,\n"," )\n"," model.train()\n"," train_iterator = trange(int(epochs), desc=\"Epoch\")\n","\n"," for _ in train_iterator:\n"," for x_pad, y_pad, x_lens, y_lens in train_dataloader:\n"," x_pad = x_pad.to(device)\n"," y_pad = y_pad.to(device)\n","\n"," optimizer.zero_grad()\n"," pred = model(x_pad, x_lens)\n"," \n"," pred = pred.view(-1, pred.shape[-1])\n"," y_pad = y_pad.view(-1)\n","\n"," loss = criterion(pred, y_pad)\n"," loss.backward()\n"," optimizer.step()\n","\n","train()\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":49,"referenced_widgets":["b038b4756b0c40c48ad78056281eec15","c2d17f202ab142f3b2f8cd2101751bd6","a8cee5d0785c421eb7a8971b689032a3","7811bfb507e74d6392abfff39f0c9663","92894ccb2a9042abb6016cdd1e2591ae","0f790adec961470f8f601c26ae792d73","3fcebe3181af474999de8f7b082e8df1","cddf0a308c4c49efa698eb2bd7c25ad6","296707bec4a2425f9927233d07f8f1b4","b4e2f0626d254240af4ca97237e2dc83","13c8b4ecdb264d3a83545ffc820199ab"]},"id":"z8WZa3b0B72x","executionInfo":{"status":"ok","timestamp":1678519055651,"user_tz":-420,"elapsed":157070,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"1bf036f3-e625-4faf-d28d-0dc4b6162e57"},"execution_count":81,"outputs":[{"output_type":"display_data","data":{"text/plain":["Epoch: 0%| | 0/100 [00:00