{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The Hungarian Maximum Likelihood Trick" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*by Louis Abraham*\n", "\n", "You can download the notebook from https://louisabraham.github.io/notebooks/hungarian_trick.ipynb" ] }, { "cell_type": "markdown", "metadata": { "toc": "true" }, "source": [ "# Table of Contents\n", "

1  Introduction
2  Data generator
3  Simple model: SVM
4  Multiple digits classification
5  The Hungarian Trick
6  Results
7  Discussion
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "This notebook is about a cool (possibly novel) idea I had in a real world example.\n", "\n", "I wanted to recognize the digits in images like this (because why not?):\n", "\n", "\n", "\n", "I knew all the digits would be different, and wanted to use this information to boost my predictions.\n", "\n", "However, I will only present images from MNIST in this notebook, to avoid being distracted by image processing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data generator" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:26.044610Z", "start_time": "2018-06-02T09:53:11.642899Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset shape: (70000, 784)\n" ] } ], "source": [ "import numpy as np\n", "from sklearn.datasets import fetch_mldata\n", "import skimage.transform\n", "\n", "np.random.seed(1337)\n", "np.set_printoptions(suppress=True)\n", "\n", "mnist = fetch_mldata('MNIST original')\n", "\n", "DIM = 16\n", "X, y = mnist.data, mnist.target\n", "\n", "def preprocess(x):\n", " im = x.reshape(28, 28)\n", " im = skimage.transform.resize(im, (DIM, DIM), anti_aliasing=True, mode='reflect')\n", " return im.flatten()\n", "\n", "def sample():\n", " i = np.random.randint(len(X))\n", " return preprocess(X[i]), y[i]\n", "\n", "TRAIN = 5000\n", "TEST = 1000\n", "X_train, y_train = map(np.array, zip(*(sample() for _ in range(TRAIN))))\n", "X_test, y_test = map(np.array, zip(*(sample() for _ in range(TEST))))\n", "\n", "print('Dataset shape:', X.shape)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:29.006510Z", "start_time": "2018-06-02T09:53:26.048172Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "plt.rcParams[\"figure.figsize\"] = (12, 7)\n", "\n", "for i in range(4):\n", " plt.subplot(141+i)\n", " plt.imshow(X_train[i].reshape(DIM, DIM), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple model: SVM\n", "\n", "I don't really care about the final accuracy, but use SVMs for the sake of simplicity.\n", "\n", "If you don't know about them, the [Wikipedia article](https://en.wikipedia.org/wiki/Support_vector_machine) is well written." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:46.503170Z", "start_time": "2018-06-02T09:53:29.010093Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 16s\n", "Classification report for classifier SVC(C=5, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape='ovr', degree=3, gamma=0.05, kernel='rbf',\n", " max_iter=-1, probability=True, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False):\n", " precision recall f1-score support\n", "\n", " 0.0 0.99 1.00 0.99 98\n", " 1.0 0.98 0.96 0.97 126\n", " 2.0 0.95 0.97 0.96 107\n", " 3.0 0.95 0.94 0.95 123\n", " 4.0 0.92 0.99 0.95 96\n", " 5.0 0.95 0.95 0.95 80\n", " 6.0 0.97 0.98 0.97 89\n", " 7.0 0.96 0.96 0.96 100\n", " 8.0 0.95 0.95 0.95 81\n", " 9.0 0.98 0.91 0.94 100\n", "\n", "avg / total 0.96 0.96 0.96 1000\n", "\n", "\n", "Confusion matrix:\n", "[[ 98 0 0 0 0 0 0 0 0 0]\n", " [ 0 121 3 0 2 0 0 0 0 0]\n", " [ 0 1 104 0 1 0 0 0 1 0]\n", " [ 0 0 2 116 0 2 0 1 2 0]\n", " [ 0 0 0 0 95 0 1 0 0 0]\n", " [ 0 0 0 1 0 76 2 0 1 0]\n", " [ 1 0 0 0 1 0 87 0 0 0]\n", " [ 0 1 0 0 1 0 0 96 0 2]\n", " [ 0 0 0 3 0 1 0 0 77 0]\n", " [ 0 0 0 2 3 1 0 3 0 91]]\n" ] } ], "source": [ "from time import time\n", "\n", "from sklearn import svm, metrics\n", "\n", "# Took the parameters from\n", "# https://github.com/ksopyla/svm_mnist_digit_classification\n", "classifier = svm.SVC(C=5, gamma=0.05, probability=True)\n", "\n", "t = time()\n", "classifier.fit(X_train, y_train)\n", "print('Training time: %is' % (time() - t))\n", "\n", "predicted = classifier.predict(X_test)\n", "\n", "print(\"Classification report for classifier %s:\\n%s\\n\"\n", " % (classifier, metrics.classification_report(y_test, predicted)))\n", "print(\"Confusion matrix:\\n%s\" % metrics.confusion_matrix(y_test, predicted))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiple digits classification\n", "\n", "As I explained above, I want to test on batches with different digits and exploit the information.\n", "\n", "Here is what the labels look like:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:46.579398Z", "start_time": "2018-06-02T09:53:46.506675Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[5, 1, 6, 9, 2, 8, 7],\n", " [0, 3, 2, 5, 4, 8, 6],\n", " [5, 9, 2, 3, 4, 1, 7],\n", " [5, 8, 1, 6, 7, 3, 2],\n", " [5, 6, 2, 7, 8, 4, 3],\n", " [4, 8, 0, 6, 3, 9, 7],\n", " [3, 7, 4, 1, 8, 0, 6],\n", " [8, 6, 5, 2, 1, 4, 0],\n", " [6, 9, 0, 5, 2, 8, 1],\n", " [3, 8, 1, 0, 5, 2, 4]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "\n", "N_DIGITS = 7\n", "TEST2 = 500\n", "\n", "reverse_labels = {digit: np.flatnonzero(mnist.target==digit) for digit in range(10)}\n", "\n", "testing_indexes = np.array([[random.choice(reverse_labels[i])\n", " for i in random.sample(range(10), N_DIGITS)]\n", " for _ in range(TEST2)])\n", "\n", "y[testing_indexes[:10]].astype(int)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We say the model is correct on a batch if its predictions are correct for all `N_DIGITS` images of the batch." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:53.924758Z", "start_time": "2018-06-02T09:53:46.590923Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for single digit prediction: 95.7%\n", "Accuracy for single digit power N_DIGITS: 73.4%\n", "Accuracy for multiple digits prediction: 73.6%\n" ] } ], "source": [ "testing_indexes_flat = testing_indexes.flatten()\n", "\n", "X_test_2_flat = np.array(list(map(preprocess, X[testing_indexes_flat])))\n", "X_test_2 = X_test_2_flat.reshape(TEST2, N_DIGITS, DIM * DIM)\n", "y_test_2 = y[testing_indexes]\n", "\n", "predicted = classifier.predict(X_test_2_flat).reshape(TEST2, N_DIGITS)\n", "accuracy = (predicted == y_test_2).mean()\n", "accuracy_multi = (predicted == y_test_2).all(axis=1).mean()\n", "\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy))\n", "print('Accuracy for single digit power N_DIGITS: %.01f%%' % (100 * accuracy ** N_DIGITS))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, the accuracy on the new task is necessarily less.\n", "\n", "But what if we could exploit the fact that the labels are different in a batch?" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2018-06-01T14:21:32.346361Z", "start_time": "2018-06-01T14:21:32.339234Z" } }, "source": [ "# The Hungarian Trick\n", "\n", "Look at what we get for some predictions:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:54.661979Z", "start_time": "2018-06-02T09:53:53.929434Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "l = [24754, 18623, 5923, 12665, 0, 42051]\n", "pred = classifier.predict(list(map(preprocess, X[l])))\n", "prob = classifier.predict_proba(list(map(preprocess, X[l])))\n", "\n", "plt.rcParams[\"figure.figsize\"] = (12, 7)\n", "for i, index in enumerate(l):\n", " plt.subplot(1, len(l), i+1)\n", " plt.title('Truth: %i\\nPredicted: %i' % (y[index], pred[i]))\n", " plt.axis('off')\n", " plt.imshow(preprocess(X[index]).reshape(DIM, DIM), cmap='gray')\n", "plt.show()\n", "\n", "plt.title('Colormap of the predictions probabilities')\n", "plt.imshow(prob.T)\n", "plt.xlabel('image')\n", "plt.ylabel('label')\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, although the overall result is good, in some cases it can make errors.\n", "\n", "On the example above (how convenient!), the classifier sees a 4 instead of a 7 in the last image,\n", "but its second best guess is correct.\n", "Since there is already a 4 in the batch, and since the SVM is really sure about it,\n", "there _should_ be a way to exploit this information.\n", "\n", "The idea I had is to study the _joint distribution_ of the digits in the same patch,\n", "and instead of maximizing each probability, to maximize their product or, equivalently,\n", "the log-likelihood.\n", "Thus, in the above example, the optimal choice would \"reserve\" the label 4 for the first\n", "image because it is more confident than for the last one.\n", "\n", "Note that SVMs do not output probabilities, [the `predict_proba` function is an estimator](http://scikit-learn.org/stable/modules/svm.html#scores-probabilities).\n", "\n", "Now, what if we tested all the possibilities of disjoint labels?\n", "In our case, testing $\\frac{10!}{(10-\\texttt{N_DIGITS})!}$ possibilities is fast, but it wouldn't be possible with more categories, like letters instead of digits for example.\n", "\n", "It turns out that this problem can be formulated as an [assignment problem](https://en.wikipedia.org/wiki/Assignment_problem): we have two categories of objects and we want to find the matching with the smallest cost.\n", "A matching is a set of disjoint pairs ; and the cost is the sum of the weights of the chosen edges. Those two images describe an instance of the problem (without the weights) and a solution:\n", "\n", "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "\n", "\n", "
\n", "\n", "In our problem, the two categories are images and labels, and the weight matrix looks like the image shown above.\n", "\n", "Because the assignment problem is usually formulated with a sum, instead of searching $$\\underset{assignment}{\\operatorname{argmax}} \\prod_i \\mathbb{P}(true_i = assignment_i)$$\n", "we look for $$\\underset{assignment}{\\operatorname{argmin}} - \\sum_i \\log \\mathbb{P}(true_i = assignment_i)$$\n", "\n", "So the cost matrix is $m_{ij} = - \\log \\mathbb{P}(true_i = possible\\_label_j)$.\n", "\n", "The advantage of this formulation is that there are algorithms running in polynomial time (cubic in general).\n", "I can now admit that I used \"Hungarian\" as a buzz word: the [Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm) is a polynomial algorithm for the assignment problem, and it is implemented in the popular library [SciPy](https://www.scipy.org/).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Results" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:58.143734Z", "start_time": "2018-06-02T09:53:54.665422Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Without Hungarian trick:\n", "Accuracy for single digit prediction: 95.7%\n", "Accuracy for multiple digits prediction: 73.6%\n", "\n", "With Hungarian trick:\n", "Accuracy for single digit prediction: 97.6%\n", "Accuracy for multiple digits prediction: 85.2%\n" ] } ], "source": [ "from scipy.optimize import linear_sum_assignment\n", "\n", "def predict(patch):\n", " cost_matrix = -classifier.predict_log_proba(patch)\n", " row_ind, col_ind = linear_sum_assignment(cost_matrix)\n", " # not used here\n", " cost = cost_matrix[row_ind, col_ind].sum()\n", " confidence = np.exp(-cost)\n", " return classifier.classes_[col_ind]\n", "\n", "predicted_trick = np.array(list(map(predict, X_test_2)))\n", "\n", "accuracy_trick = (predicted_trick == y_test_2).mean()\n", "accuracy_multi_trick = (predicted_trick == y_test_2).all(axis=1).mean()\n", "\n", "\n", "print('Without Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi))\n", "\n", "print('\\nWith Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy_trick))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi_trick))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try with a bigger value of `N_DIGITS`!" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:54:14.055013Z", "start_time": "2018-06-02T09:53:58.151344Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Without Hungarian trick:\n", "Accuracy for single digit prediction: 95.9%\n", "Accuracy for multiple digits prediction: 66.4%\n", "\n", "With Hungarian trick:\n", "Accuracy for single digit prediction: 99.5%\n", "Accuracy for multiple digits prediction: 97.8%\n" ] } ], "source": [ "N_DIGITS = 10\n", "\n", "testing_indexes = np.array([[random.choice(reverse_labels[i])\n", " for i in random.sample(range(10), N_DIGITS)]\n", " for _ in range(TEST2)])\n", "testing_indexes_flat = testing_indexes.flatten()\n", "\n", "X_test_2_flat = np.array(list(map(preprocess, X[testing_indexes_flat])))\n", "X_test_2 = X_test_2_flat.reshape(TEST2, N_DIGITS, DIM * DIM)\n", "y_test_2 = y[testing_indexes]\n", "\n", "predicted = classifier.predict(X_test_2_flat).reshape(TEST2, N_DIGITS)\n", "accuracy = (predicted == y_test_2).mean()\n", "accuracy_multi = (predicted == y_test_2).all(axis=1).mean()\n", "\n", "predicted_trick = np.array(list(map(predict, X_test_2)))\n", "\n", "accuracy_trick = (predicted_trick == y_test_2).mean()\n", "accuracy_multi_trick = (predicted_trick == y_test_2).all(axis=1).mean()\n", "\n", "print('Without Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi))\n", "\n", "print('\\nWith Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy_trick))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi_trick))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Discussion\n", "\n", "With just 5 lines of code, we achieved a huge accuracy improvement!\n", "\n", "This method applies not just for digits (or even images) or for SVMs.\n", "You can use it for any _batched_ classification task where you know the samples\n", "in a batch have different labels, and for any model that gives probabilities.\n", "\n", "Moreover, the algorithm gives a confidence score (see in `predict`) that can be\n", "used to select a parameter. In my application, I discovered that changing the padding\n", "made a huge difference, but couldn't select it because I didn't have validation samples\n", "from the testing distribution, so my code does a grid search _at testing time_ and\n", "maximizes this confidence.\n", "\n", "In fact, OCR softwares use the same kind of tricks: instead of recognizing each character\n", "separately, they ask the user for the language and use their knowledge of the dictionnary\n", "to boost their predictions. See [here](https://github.com/tesseract-ocr/tesseract/wiki/NeuralNetsInTesseract4.00) how Tesseract uses LSTM neural networks." ] } ], "metadata": { "date": "2018-06-01", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "title": "The Hungarian maximum likelihood trick", "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": false, "nav_menu": { "height": "30px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "skipTitle": true, "threshold": 4, "toc_cell": true, "toc_section_display": "none", "toc_window_display": true, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }