{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"markdown","metadata":{"id":"DsYAMZwocQp4"},"source":["# Logistic Regression for Text Categorization\n","\n","In this document, we will do experiments using Logistic Regression algorithm for text classification task. We will use the framework sklearn for experiments.\n","\n","\n","## Binary classification\n","\n","We download the data set as the first step.\n"]},{"cell_type":"code","metadata":{"id":"B69LTgEec0Si"},"source":["%%capture\n","!rm -f titles-en-train.labeled\n","!rm -f titles-en-test.labeled\n","\n","!wget https://raw.githubusercontent.com/neubig/nlptutorial/master/data/titles-en-train.labeled\n","!wget https://raw.githubusercontent.com/neubig/nlptutorial/master/data/titles-en-test.labeled"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Each sample is written in a line. There are two labels {1, -1} in the data.\n","\n","```\n","1\tFUJIWARA no Chikamori ( year of birth and death unknown ) was a samurai and poet who lived at the end of the Heian period .\n","-1\tYomi is the world of the dead .\n","```"],"metadata":{"id":"sk1LNxnEdtAU"}},{"cell_type":"markdown","source":["We will need to predict whether the title of an article is about an individual or not.\n","\n","- Label 1: The title is about an individual\n","- Lable -1: The title is not about an individual"],"metadata":{"id":"B9m_aSfj9QOm"}},{"cell_type":"markdown","metadata":{"id":"BE_vvHs6dfCW"},"source":["### Load data\n","\n","We will load data into a list of sentences with their labels."]},{"cell_type":"code","metadata":{"id":"1EaRT5e1dn0a"},"source":["def load_data(file_path):\n"," data = []\n"," with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:\n"," for line in f:\n"," line = line.strip()\n"," if line == '':\n"," continue\n"," lb, text = line.split('\\t')\n"," data.append((text,int(lb)))\n","\n"," return data"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"KslXet-rdrjU"},"source":["Loading data from files"]},{"cell_type":"code","metadata":{"id":"-b3uhb7udv6W"},"source":["train_data = load_data('./titles-en-train.labeled')\n","test_data = load_data('./titles-en-test.labeled')\n","\n","train_docs, train_labels = zip(*train_data)\n","test_docs, test_labels = zip(*test_data)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"FmA_faq5d6oF"},"source":["### Using scikit-learn for feature extraction\n","\n","We can use scikit-learn for [feature extraction](http://scikit-learn.org/stable/modules/feature_extraction.html). We use the bag-of-word representation for feature extraction. In scikit-learn, we can use `CountVectorizer` or `TfidfTransformer`.\n","\n","### Feature extraction with CountVectorizer\n","\n"]},{"cell_type":"code","metadata":{"id":"bvfGX24PeLms","outputId":"34134b03-548d-4d99-ad5e-21b38ef14364","executionInfo":{"status":"ok","timestamp":1673679980769,"user_tz":-420,"elapsed":366,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["from sklearn.feature_extraction.text import CountVectorizer\n","\n","vectorizer = CountVectorizer(\n"," binary=True, # Use binary features\n"," max_features=10000\n"," )\n","vectorizer"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["CountVectorizer(binary=True, max_features=10000)"]},"metadata":{},"execution_count":6}]},{"cell_type":"markdown","metadata":{"id":"U2ymNE8PePXa"},"source":["Now, we fit the vectorizer object on the training data."]},{"cell_type":"code","metadata":{"id":"4vX0-2s7eSiC","outputId":"7c7129f9-b2df-416d-91ac-81bac3c00200","executionInfo":{"status":"ok","timestamp":1673679996379,"user_tz":-420,"elapsed":374,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["X_train = vectorizer.fit_transform(train_docs)\n","X_train.shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(11288, 10000)"]},"metadata":{},"execution_count":7}]},{"cell_type":"markdown","metadata":{"id":"t7ouScuXeUEE"},"source":["We we try the vectorizer to get BoW of a sentence."]},{"cell_type":"code","metadata":{"id":"9UxomvFDeWHk","outputId":"0e82fbbe-0ddc-455d-8c49-551a8b6f663f","executionInfo":{"status":"ok","timestamp":1673680025437,"user_tz":-420,"elapsed":490,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["analyze = vectorizer.build_analyzer()\n","analyze(\"This is a text document to analyze.\")"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['this', 'is', 'text', 'document', 'to', 'analyze']"]},"metadata":{},"execution_count":8}]},{"cell_type":"markdown","metadata":{"id":"Bji1eO_AeXg5"},"source":["### Text categorization with logistic regression\n","\n","Now let's try text categorization with [logistic regression implementation](http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) in scikit-learn. See the document [here](http://scikit-learn.org/stable/modules/linear_model.html#logistic-regression) for more details."]},{"cell_type":"code","metadata":{"id":"UOe_mZSvecIM","outputId":"a77ffcf4-4fbe-441a-cf0b-d24b46fd240a","executionInfo":{"status":"ok","timestamp":1673680065082,"user_tz":-420,"elapsed":374,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["from sklearn.linear_model import LogisticRegression\n","\n","clf = LogisticRegression(max_iter=500)\n","clf"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LogisticRegression(max_iter=500)"]},"metadata":{},"execution_count":9}]},{"cell_type":"markdown","metadata":{"id":"83t_d-QXey2R"},"source":["Now, we fit the model on the training data."]},{"cell_type":"code","metadata":{"id":"nxT21G5Ue1Ug","outputId":"91ab994a-2dc3-4d9f-b2c2-2e131a38b202","executionInfo":{"status":"ok","timestamp":1673680070024,"user_tz":-420,"elapsed":358,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["clf.fit(X_train, train_labels)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LogisticRegression(max_iter=500)"]},"metadata":{},"execution_count":10}]},{"cell_type":"markdown","metadata":{"id":"Z9HVDkGMe2xD"},"source":["### Evaluation on test set\n","\n","Now let's evaluate the model on the test data."]},{"cell_type":"code","metadata":{"id":"LmnGcUuwe6JB"},"source":["X_test = vectorizer.transform(test_docs)\n","test_preds = clf.predict(X_test)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"gId4ZRsWe99l","outputId":"cc116005-ff84-45cc-9ffd-112e81f08db7","executionInfo":{"status":"ok","timestamp":1673680136664,"user_tz":-420,"elapsed":408,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["from sklearn import metrics\n","\n","accuracy = metrics.accuracy_score(test_labels, test_preds)\n","print(\"# Test accuracy: {}\".format(accuracy))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["# Test accuracy: 0.9426142401700319\n"]}]},{"cell_type":"markdown","metadata":{"id":"bAKlKMXjfdrM"},"source":["See the classification report:"]},{"cell_type":"code","metadata":{"id":"6ySjIWuuff8G","outputId":"b78b909f-de9d-4a5d-de56-9a5f72a68aed","executionInfo":{"status":"ok","timestamp":1673602801331,"user_tz":-420,"elapsed":305,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["print( metrics.classification_report(test_labels, test_preds) )"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":[" precision recall f1-score support\n","\n"," -1 0.93 0.96 0.95 1477\n"," 1 0.95 0.93 0.94 1346\n","\n"," accuracy 0.94 2823\n"," macro avg 0.94 0.94 0.94 2823\n","weighted avg 0.94 0.94 0.94 2823\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"1k0P77WufEej"},"source":["We can predict the label for an input review."]},{"cell_type":"code","metadata":{"id":"mrZf2O4cfG2m","outputId":"31b08273-dbba-4ddb-f36d-507048398d8f","executionInfo":{"status":"ok","timestamp":1673602841936,"user_tz":-420,"elapsed":363,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["example = \"FUJIWARA no Chikamori ( year of birth and death unknown ) was a samurai and poet who lived at the end of the Heian period .\"\n","test_x = vectorizer.transform([example])\n","print(\"Predicted class: {}\".format(clf.predict(test_x)))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Predicted class: [1]\n"]}]},{"cell_type":"markdown","metadata":{"id":"WbtQpM1xfIr_"},"source":["We can get prediction probabilities."]},{"cell_type":"code","metadata":{"id":"oNLGkjiAfNk0","outputId":"78d72f40-b3f0-4423-e658-6df431d8708e","executionInfo":{"status":"ok","timestamp":1673602846613,"user_tz":-420,"elapsed":302,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["clf.predict_proba(test_x)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([[7.78066924e-04, 9.99221933e-01]])"]},"metadata":{},"execution_count":17}]},{"cell_type":"markdown","metadata":{"id":"oNxASF-zfSUR"},"source":["The first value is the probability that the instance belongs to the class \"-1\" and the second value is the probability that the instance belongs to the class \"+1\". Let's try another sample."]},{"cell_type":"code","metadata":{"id":"IH288aYPfU-S","outputId":"fbd80559-2b7c-486c-9b97-ae8b4e22970f","executionInfo":{"status":"ok","timestamp":1673602905644,"user_tz":-420,"elapsed":339,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["example2 = \"Yomi is the world of the dead .\"\n","test_x2 = vectorizer.transform([example2])\n","clf.predict_proba(test_x2)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([[0.822939, 0.177061]])"]},"metadata":{},"execution_count":18}]},{"cell_type":"markdown","metadata":{"id":"sq-DqtDpfWom"},"source":["We can combine probability values with a threshold $t$ to customize our prediction. For instance, we can decide that the prediction is \"-1\" if the probability is greater than 0.6 instead of 0.5."]},{"cell_type":"markdown","metadata":{"id":"sokkOpbnftrg"},"source":["### Get top features with the highest weights\n","\n","In this section, we would like to see top features with the highest weights.\n","\n","First, we get all features in vectorizer and target_names.\n","\n"]},{"cell_type":"code","metadata":{"id":"mGTadetvgAZP","outputId":"9c9bf36d-0802-4694-c71b-cc2d370b56c3","executionInfo":{"status":"ok","timestamp":1673603046229,"user_tz":-420,"elapsed":2,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["feature_names = vectorizer.get_feature_names_out()\n","target_names = [\"+1\", \"-1\"]\n","print(len(clf.coef_), clf.coef_)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["1 [[-0.05390017 0.02947881 -0.00048846 ... 0.0102783 0.0021012\n"," 0.00209807]]\n"]}]},{"cell_type":"code","metadata":{"id":"2-KOQ_9zfwfB","outputId":"be271f82-6a1a-4ac5-e0fd-9e1125df662e","executionInfo":{"status":"ok","timestamp":1673603047996,"user_tz":-420,"elapsed":8,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["import numpy as np\n","\n","topN = 50\n","print(\"top {} keywords:\".format(topN))\n","top10 = np.argsort(clf.coef_[0])[-topN:]\n","top_features = [ feature_names[i] for i in top10 ]\n","print(\" \".join(top_features))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["top 50 keywords:\n","was statesman march nihonshoki real kugyo april he shinsengumi may tanka fiction uesugi drama crown november member october august imperial december kutsuki literature september lord performer emperor warlord chapters lived july actors detached miyake poems novel throne emperors noble unknown waka poetry commander director myth professional monk tale priests priest\n"]}]},{"cell_type":"markdown","metadata":{"id":"kqqzlMfsfzDL"},"source":["### Try with tf-idf term weighting\n","\n","Now, we use tf-idf term weighting for feature extraction"]},{"cell_type":"code","metadata":{"id":"nLbICHDbgSSk","outputId":"94cd4df8-b717-43ec-fc04-6448a2721b89","executionInfo":{"status":"ok","timestamp":1673603058199,"user_tz":-420,"elapsed":907,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["from sklearn.feature_extraction.text import TfidfVectorizer\n","\n","vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5,\n"," stop_words='english')\n","X_train = vectorizer.fit_transform(train_docs)\n","\n","clf = LogisticRegression(solver='lbfgs')\n","\n","clf.fit(X_train, train_labels)\n","\n","X_test = vectorizer.transform(test_docs)\n","test_preds = clf.predict(X_test)\n","\n","accuracy = metrics.accuracy_score(test_labels, test_preds)\n","print(\"# Test accuracy: {}\".format(accuracy))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["# Test accuracy: 0.9344668792065179\n"]}]},{"cell_type":"markdown","metadata":{"id":"lxXVMxqFgUfE"},"source":["## Multiclass Text Classification\n","\n","In this section, we will do multiclass text classification with 20 newsgroup dataset. It will be automatically downloaded, then cached."]},{"cell_type":"code","metadata":{"id":"xCxhx1LNg0Xe"},"source":["from sklearn.datasets import fetch_20newsgroups\n","\n","remove = ('headers', 'footers', 'quotes')\n","\n","data_train = fetch_20newsgroups(subset='train',\n"," shuffle=True, random_state=42,\n"," remove=remove)\n","\n","data_test = fetch_20newsgroups(subset='test',\n"," shuffle=True, random_state=42,\n"," remove=remove)\n","\n","y_train, y_test = data_train.target, data_test.target"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lOPgKopEhQEs","outputId":"6a4bd09d-995a-47ee-a4df-9b89bbc15e7f","executionInfo":{"status":"ok","timestamp":1673603120484,"user_tz":-420,"elapsed":301,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["def size_mb(docs):\n"," return sum(len(s.encode('utf-8')) for s in docs) / 1e6\n","\n","data_train_size_mb = size_mb(data_train.data)\n","data_test_size_mb = size_mb(data_test.data)\n","\n","print(\"%d documents - %0.3fMB (training set)\" % (\n"," len(data_train.data), data_train_size_mb))\n","print(\"%d documents - %0.3fMB (test set)\" % (\n"," len(data_test.data), data_test_size_mb))\n","print()"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["11314 documents - 13.782MB (training set)\n","7532 documents - 8.262MB (test set)\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"BgwSQlm7qwib"},"source":["### Feature Extraction\n","\n","We will use TF-IDF features."]},{"cell_type":"code","metadata":{"id":"jgWEx7JWq1un"},"source":["from sklearn.feature_extraction.text import TfidfVectorizer\n","\n","vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5,\n"," stop_words='english')\n","X_train = vectorizer.fit_transform(data_train.data)\n","X_test = vectorizer.transform(data_test.data)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"LRjrAQ71rWJx"},"source":["Let's try Logistic Regression with 'ovr' (one-vs-rest) strategy."]},{"cell_type":"code","metadata":{"id":"bv6tpNj1rje0","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1673603174196,"user_tz":-420,"elapsed":44926,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"outputId":"3a63e51f-198f-4f7a-ca9c-68cefcff7f5e"},"source":["clf = LogisticRegression(solver='lbfgs', multi_class='ovr')\n","clf.fit(X_train, y_train)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LogisticRegression(multi_class='ovr')"]},"metadata":{},"execution_count":29}]},{"cell_type":"markdown","metadata":{"id":"cDzbmy-8r0ji"},"source":["Let's evaluate the results on the test set."]},{"cell_type":"code","metadata":{"id":"4UI5rhndr4T2","outputId":"77fcd275-fcbd-4dcf-902f-2c69c7b4b4a9","executionInfo":{"status":"ok","timestamp":1673603174197,"user_tz":-420,"elapsed":10,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["from sklearn import metrics\n","\n","y_preds = clf.predict(X_test)\n","accuracy = metrics.accuracy_score(y_test, y_preds)\n","print(\"# Test accuracy: {}\".format(accuracy))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["# Test accuracy: 0.6949017525225704\n"]}]},{"cell_type":"markdown","metadata":{"id":"IZCAcFN8sET2"},"source":["Let's try multinomial Logistic Regression."]},{"cell_type":"code","metadata":{"id":"-yqXTtsXsMIp","outputId":"941d16e2-09ff-4e59-b96e-841a28728fb8","executionInfo":{"status":"ok","timestamp":1673603231212,"user_tz":-420,"elapsed":57022,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["clf = LogisticRegression(solver='lbfgs', multi_class='multinomial')\n","clf.fit(X_train, y_train)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["LogisticRegression(multi_class='multinomial')"]},"metadata":{},"execution_count":31}]},{"cell_type":"markdown","metadata":{"id":"Vko1xyrOsRUF"},"source":["We will test multinomial Logistic Regression on the test data."]},{"cell_type":"code","metadata":{"id":"ubhdUJ6WsdAi","outputId":"ddae2fb7-4838-4d7e-a290-29532bc2ccaa","executionInfo":{"status":"ok","timestamp":1673603231214,"user_tz":-420,"elapsed":15,"user":{"displayName":"Minh Pham","userId":"01293297774691882951"}},"colab":{"base_uri":"https://localhost:8080/"}},"source":["y_preds = clf.predict(X_test)\n","accuracy = metrics.accuracy_score(y_test, y_preds)\n","print(\"# Test accuracy: {}\".format(accuracy))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["# Test accuracy: 0.6946362187997875\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"beYOVmhEhKLY"},"execution_count":null,"outputs":[]}]}