{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The Hungarian Maximum Likelihood Trick" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*by Louis Abraham*\n", "\n", "You can download the notebook from https://louisabraham.github.io/notebooks/hungarian_trick.ipynb" ] }, { "cell_type": "markdown", "metadata": { "toc": "true" }, "source": [ "# Table of Contents\n", "

1  Introduction
2  Data generator
3  Simple model: SVM
4  Multiple digits classification
5  The Hungarian Trick
6  Results
7  Discussion
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "This notebook is about a cool (possibly novel) idea I had in a real world example.\n", "\n", "I wanted to recognize the digits in images like this (because why not?):\n", "\n", "\n", "\n", "I knew all the digits would be different, and wanted to use this information to boost my predictions.\n", "\n", "However, I will only present images from MNIST in this notebook, to avoid being distracted by image processing." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data generator" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:26.044610Z", "start_time": "2018-06-02T09:53:11.642899Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset shape: (70000, 784)\n" ] } ], "source": [ "import numpy as np\n", "from sklearn.datasets import fetch_mldata\n", "import skimage.transform\n", "\n", "np.random.seed(1337)\n", "np.set_printoptions(suppress=True)\n", "\n", "mnist = fetch_mldata('MNIST original')\n", "\n", "DIM = 16\n", "X, y = mnist.data, mnist.target\n", "\n", "def preprocess(x):\n", " im = x.reshape(28, 28)\n", " im = skimage.transform.resize(im, (DIM, DIM), anti_aliasing=True, mode='reflect')\n", " return im.flatten()\n", "\n", "def sample():\n", " i = np.random.randint(len(X))\n", " return preprocess(X[i]), y[i]\n", "\n", "TRAIN = 5000\n", "TEST = 1000\n", "X_train, y_train = map(np.array, zip(*(sample() for _ in range(TRAIN))))\n", "X_test, y_test = map(np.array, zip(*(sample() for _ in range(TEST))))\n", "\n", "print('Dataset shape:', X.shape)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:29.006510Z", "start_time": "2018-06-02T09:53:26.048172Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAAC0CAYAAACAP3qfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF89JREFUeJzt3XuMnXWdx/HPp1yWUHBLBUqhpRa3agChaK0YzMpF2UIItdqFkpUlC6bFLHaJS7JltZaKmm42iHJRUlYcQK0sy9UsIJXdWEwQmJLSVuRSELZDa7tQ5GJtStvv/tFDdhjOPL/HOc+c8/zOvF/JkznnzGfm+XX4dPj23H6OCAEAAABoblSnFwAAAADUGQMzAAAAUICBGQAAACjAwAwAAAAUYGAGAAAACjAwAwAAAAUYmAEAAIACDMwAAABAAQZmAAAAoMCenV5AM7bZfhAtiwi383z0FhV5KSIOaucJ6S6q0O7fuRLdRTXKdLele5htz7D9lO11thc0+fyf2b6l8fmHbb+nlfMBVaG7qLEXij5Jd5EjeovcDXlgtr2HpGslnSbpSEnn2D5yQOwCSa9ExF9IulLSvwz1fEBV6C5yRXeRI3qLbtDKPczTJa2LiOciYrukn0iaOSAzU9KNjcv/IekU221/yAYYgO4iV3QXOaK3yF4rA/Nhktb3u97XuK1pJiJ2SHpV0rtbOCdQBbqLXNFd5IjeInutvOiv2b/8Bj75vkxmd9CeK2luC+sByqqsu/QWbUZ3kSPmBWSvlXuY+yRN7Hd9gqQNg2Vs7ynpzyVtafbNImJpREyLiGktrAkoo7Lu0lu0Gd1FjpgXkL1WBuZHJU2xPdn23pLmSLp7QOZuSec1Ls+W9F8RwVvAoNPoLnJFd5Ejeov8RcSQD0mnS3pa0rOSvty47WuSzmxc3kfSrZLWSXpE0hElv29wcLR6tLu7nf7zcnTN0Ut3OXI8mBc4cj3KdM11/AfcSH8j8lGj0nf8jxs3LpnZvHlzMrNz585Sa8pRmTcir9JI7y0qs7LdDzXTXVSh3b9zJbqLapTpLltjAwAAAAUYmAEAAIACDMwAAABAAQZmAAAAoAADMwAAAFCAgRkAAAAowMAMAAAAFGBgBgAAAArs2ekF4J3OP//8ZOY73/lOMnPLLbckM5///OeTmV27diUzADDS2dXs21HHDcVQrTIblJXpAV1pH+5hBgAAAAowMAMAAAAFGJgBAACAAgzMAAAAQAEGZgAAAKDAkAdm2xNt/7ft39j+te1/aJI50fartlc1jq+2tlygdXQXuaK7yBG9RTdo5W3ldkj6x4h4zPb+klbaXh4RTwzIPRgRZ7RwHqBqdBe5orvIEb1F9oZ8D3NEbIyIxxqXX5f0G0mHVbUwYLjQXeSK7iJH9BbdoJKNS2y/R9Jxkh5u8umP2X5c0gZJl0TEr6s4Z64mTpyYzCxcuDCZ6e3tTWbKbG4y0jclobvVKfNG/GPGjElmvvnNbyYzt956azLz8ssvJzOrVq1KZuqK7rZXmU1J7r///mTmW9/6VjJz7733llpTjujtbmeddVYyc/HFFyczl156aTLzi1/8IpkZ6bNAGS0PzLb3k3SbpIsj4rUBn35M0qSIeMP26ZLulDRlkO8zV9LcVtcDlFVFd+ktOoHuIkfMC8hZS++SYXsv7S7/jyLi9oGfj4jXIuKNxuV7JO1l+8Bm3ysilkbEtIiY1sqagDKq6i69RbvRXeSIeQG5a+VdMizp+5J+ExFNH2OyfUgjJ9vTG+dLP04KDCO6i1zRXeSI3qIbtPKUjBMknStpje23ngj4z5IOl6SIuE7SbElfsL1D0h8lzYmIaOGcQBXoLnJFd5EjeovsDXlgjohfSip8FUREXCPpmqGeAxgOdBe5orvIEb1FN2CnPwAAAKAAAzMAAABQgIEZAAAAKOA6Pqfedv0WVcLo0aOTmZ/+9KfJzPHHH5/MfPzjH09mymzK0M1vVh4R6Z0GKpRrb8uYOnVqMrN48eJk5thjj01mXn/99WRmzz3TL78os0nKhAkTkhlJ2rlzZ6lcRVa2++2yurm7VZk0aVIy89vf/jaZOemkk5KZMhtN1FG7f+dK+Xa3zEY4M2fOTGa++MUvJjPbtm1LZlasWJHM3HLLLcnMCy+8kMzUce4s013uYQYAAAAKMDADAAAABRiYAQAAgAIMzAAAAEABBmYAAACgAAMzAAAAUICBGQAAACjAwAwAAAAUYGAGAAAACqS3y0Jpxx13XDLziU98Ipm56KKLkpkyO/StXLkymTn77LOTmaeffjqZQb7K7L734IMPJjNldooqs0Plxo0bk5mjjjoqmbntttuSGaCsWbNmJTPbt29PZvh9Cqncbnd33nlnMnPXXXclM2V2qbz22muTmSVLliQz11xzTTIzf/78ZEaq346ALd/DbPt522tsr7Ld2+Tztn2V7XW2V9v+UKvnBFpFb5Eruotc0V3krKp7mE+KiJcG+dxpkqY0jo9K+l7jI9Bp9Ba5orvIFd1FltrxHOaZkm6K3X4laYzt8W04L9AKeotc0V3kiu6itqoYmEPS/bZX2p7b5POHSVrf73pf47a3sT3Xdm+zh2mAYUBvkSu6i1zRXWSriqdknBARG2wfLGm57Scjov+rf9zka97xTO6IWCppqSTZrtczvdGN6C1yRXeRK7qLbLV8D3NEbGh83CzpDknTB0T6JE3sd32CpA2tnhdoBb1FruguckV3kbOWBmbbo23v/9ZlSadKWjsgdrekv228+vV4Sa9GRPp9o4BhQm+RK7qLXNFd5K7Vp2SMk3SH7be+148j4j7bF0pSRFwn6R5Jp0taJ2mrpL9r8ZxAq+gtckV3kSu6i6y5bm8MLdXvOUmNv+BJV111VTJTZlOSMhtJrF69Opk555xzkplvf/vbycwxxxyTzGzatCmZabeIKPcfriJ1622ZN6uXpN7e9Otmymy2cOqppyYzW7duTWZOPvnkZKanpyeZ2WeffZKZQw45JJmRpJ07d5bKVWRlRExr5wnr1t12e9e73pXM3H777clMmd+VkydPTmb+8Ic/JDN11O7fuRLdLWPUqPSTCcp0t8wGVsuWLUtm5s2bl8xI7d24pEx32RobAAAAKMDADAAAABRgYAYAAAAKMDADAAAABRiYAQAAgAIMzAAAAEABBmYAAACgAAMzAAAAUKDVnf5GhDFjxpTKzZgxI5l56KGHkpknn3yy1PlSbr311mRm0aJFyczVV1+dzMyZM6fUmnbt2lUqh9aV3fxgy5YtycyUKVOSmfnz5yczb775ZjLzjW98I5nZe++9k5nLL788mWnzhiSoqbFjxyYzp5xySjLzwx/+MJkps3kPul+ZzUQ+8IEPJDOf+9znkpmPfOQjycyGDRuSmdmzZyczDzzwQDJTxw3zyuAeZgAAAKAAAzMAAABQgIEZAAAAKMDADAAAABQY8sBs+/22V/U7XrN98YDMibZf7Zf5autLBlpDd5Eruosc0Vt0gyG/S0ZEPCVpqiTZ3kPSi5LuaBJ9MCLOGOp5gKrRXeSK7iJH9BbdoKqnZJwi6dmIeKGi7we0C91FruguckRvkaWqBuY5kpYN8rmP2X7c9r22j6rofEBV6C5yRXeRI3qLLLW8cYntvSWdKenSJp9+TNKkiHjD9umS7pTUdAcE23MlzW11PcNhv/32K5WbPHlyMvPzn/88mSmzuUMZO3bsSGZuvvnmZOYrX/lKMnPooYeWWlNfX1+pXDtU0d069/all14qlTv55JOTmQsvvDCZGT9+fDJz9NFHJzNlNiV57rnnkpkrr7wymclVt3e33U477bRkZtOmTcnMvHnzkplcN22owkiYF8q69NJmP4K3K7Ox2Je+9KVk5qyzzkpmfv/73yczI7m7UjX3MJ8m6bGIeMdvk4h4LSLeaFy+R9Jetg9s9k0iYmlETIuIaRWsCSij5e7SW3QI3UWOmBeQrSoG5nM0yMMrtg+x7cbl6Y3zvVzBOYEq0F3kiu4iR/QW2WrpKRm295X0KUnz+t12oSRFxHWSZkv6gu0dkv4oaU6M9Pv0UQt0F7miu8gRvUXuWhqYI2KrpHcPuO26fpevkXRNK+cAhgPdRa7oLnJEb5E7dvoDAAAACjAwAwAAAAUYmAEAAIACDMwAAABAgZY3LhkJtm3bVipX1YYj7bRr165kpsxGEmPGjCl1vjptXILdXnzxxWRm4cKFyUzjHaEK3XTTTclMmb9Hn/3sZ5OZV155JZlB9yvTy9mzZyczy5cvT2a2bt1aak3AFVdckcxMmdJ035a3mT9/fjLT29ubzDzyyCPJzEh/0xLuYQYAAAAKMDADAAAABRiYAQAAgAIMzAAAAEABBmYAAACgAAMzAAAAUICBGQAAACjAwAwAAAAUYOOSEt54441Sue3btw/zSv40ZTYcmTlzZjJT5o3/gaOOOiqZOeOMM5KZ9evXJzPPPPNMqTUBY8eOTWY+/OEPJzOLFi2qYjmApHIbop1//vnJzPTp05OZq6++Opn53e9+l8zMmjUrmdmxY0cyk6tS9zDbvsH2Zttr+9021vZy2880Ph4wyNee18g8Y/u8qhYOpNBb5IruIld0F92q7FMyeiTNGHDbAkkPRMQUSQ80rr+N7bGSFkn6qKTpkhYN9hcFGAY9orfIU4/oLvLUI7qLLlRqYI6IFZK2DLh5pqQbG5dvlPTpJl/6V5KWR8SWiHhF0nK98y8SMCzoLXJFd5Eruotu1cqL/sZFxEZJanw8uEnmMEn9n5DY17gN6BR6i1zRXeSK7iJ7w/2iv2avFoumQXuupLnDuxygFHqLXNFd5IruotZauYd5k+3xktT4uLlJpk/SxH7XJ0ja0OybRcTSiJgWEdNaWBOQQm+RK7qLXNFdZK+VgfluSW+9ivU8SXc1yfxM0qm2D2g8ef/Uxm1Ap9Bb5IruIld0F9kr+7ZyyyQ9JOn9tvtsXyBpiaRP2X5G0qca12V7mu1/k6SI2CLpckmPNo6vNW4Dhh29Ra7oLnJFd9GtHNH0KUIdZbtWiyq7ccd9992XzJR5s/LZs2cnM6NGpf+tc8kllyQzixcvTmZWrFiRzHzyk59MZiRp165dpXJViIi27rhSt95Wqczfgeuvvz6ZOfvss5OZMl16+OGHk5mMrWz3Q825drdML3/wgx8kM2PGjElmPvOZzyQz7fz9Vkft/p0r5dvddjr44GavsXy7Z599Npk58sgjk5kyG0/VUZnusjU2AAAAUICBGQAAACjAwAwAAAAUYGAGAAAACjAwAwAAAAUYmAEAAIACDMwAAABAAQZmAAAAoMCenV5ADspu7jJv3rxkZs2aNcnMQw89lMwcdNBByczhhx+ezGzevDmZWbJkSTIz0t+wv9uV2Uzn3HPPTWYeeeSRZObRRx8ttSbggAMOSGbOPPPMZOaCCy5IZvgdh3YrszHP+973vmRmwYIFyUxPT08ys3HjxmSmm3EPMwAAAFCAgRkAAAAowMAMAAAAFGBgBgAAAAowMAMAAAAFkgOz7Rtsb7a9tt9t/2r7Sdurbd9he8wgX/u87TW2V9nurXLhQArdRa7oLnJEb9HNytzD3CNpxoDblks6OiKOkfS0pEsLvv6kiJgaEdOGtkRgyHpEd5GnHtFd5KdH9BZdKjkwR8QKSVsG3HZ/ROxoXP2VpAnDsDagJXQXuaK7yBG9RTer4jnM50u6d5DPhaT7ba+0PbeCcwFVorvIFd1FjugtstXSTn+2vyxph6QfDRI5ISI22D5Y0nLbTzb+Bdrse82VlPVfkueffz6ZKbMzVZndfcoos0NhmczOnTurWE6tVNXdbujt6NGjk5mvf/3rlZzrsssuS2bYUa0Y3f1/H/zgB5OZfffdN5l54oknqlgOCoyUeWHUqHL3Q06dOjWZWbx4cTIzffr0ZGbhwoXJzNKlS5OZkW7I9zDbPk/SGZL+JgaZuiJiQ+PjZkl3SBr0v2xELI2IaTx3CcOtyu7SW7QT3UWOmBfQDYY0MNueIemfJJ0ZEVsHyYy2vf9blyWdKmltsyzQLnQXuaK7yBG9Rbco87ZyyyQ9JOn9tvtsXyDpGkn7a/fDJqtsX9fIHmr7nsaXjpP0S9uPS3pE0n9GxH3D8qcAmqC7yBXdRY7oLbpZ8jnMEXFOk5u/P0h2g6TTG5efk3RsS6sDWkB3kSu6ixzRW3QzdvoDAAAACjAwAwAAAAUYmAEAAIACDMwAAABAAZfZuKLdbNdvUchORFSzA0xJufZ20qRJyUyZjR2++93vJjMLFixIZrpxo5w/0cp2v79srt3da6+9kpnx48cnM+vXr09m6vj/yrpp9+9cqX7dLbtxyRFHHJHMvPnmm8nMpk2bkplt27aVWtNIVqa73MMMAAAAFGBgBgAAAAowMAMAAAAFGJgBAACAAgzMAAAAQAEGZgAAAKAAAzMAAABQgIEZAAAAKMDGJehabFxSHTv9o6zj75JMsXEJssTGJchVJRuX2L7B9mbba/vddpntF22vahynD/K1M2w/ZXud7fQWX0CF6C5yRXeRI3qLrhYRhYekv5T0IUlr+912maRLEl+3h6RnJR0haW9Jj0s6MnW+xtcGB0cFR1u7W4M/77AdtpNHp9fYRUev6C5HngfzAkeWR5muJe9hjogVkrakck1Ml7QuIp6LiO2SfiJp5hC+DzAkdBe5orvIEb1FN2vlRX8X2V7deAjmgCafP0zS+n7X+xq3NWV7ru1e270trAkoo7Lu0lu0Gd1FjpgXkL2hDszfk/ReSVMlbZR0RZNMsydQx2DfMCKWRsS0dr/YBSNOpd2lt2gjuoscMS+gKwxpYI6ITRGxMyJ2Sbpeux9OGahP0sR+1ydI2jCU8wFVobvIFd1FjugtusWQBmbb4/tdnSVpbZPYo5Km2J5se29JcyTdPZTzAVWhu8gV3UWO6C26RolXoC7T7odR3tTufwVeIOlmSWskrdbuUo9vZA+VdE+/rz1d0tPa/erXL5d5FSKveuWo8Ghrd2vw5x22g3fJaOvRK7rLkefBvMCR5VGma3XduOR/Jb3Q76YDJb3UoeW0Isd157hm6Z3rnhQRB7VzAU16K+X588xxzVKe62625jp0N8efpZTnunNcs1SD37kS3e2wHNcsDbG7tRyYB7Ldm+OT+3Ncd45rluq77rquq0iOa5byXHdd11zXdaXkuO4c1yzVd911XVdKjuvOcc3S0NfdytvKAQAAAF2PgRkAAAAokMvAvLTTCxiiHNed45ql+q67rusqkuOapTzXXdc113VdKTmuO8c1S/Vdd13XlZLjunNcszTEdWfxHGYAAACgU3K5hxkAAADoiNoPzLZn2H7K9jrbCzq9njJsP297je1Vdd7r3vYNtjfbXtvvtrG2l9t+pvHxgE6ucaBB1nyZ7RcbP+9Vtk/v5Boba8qutxLdHU50d3jl0N0ceyvR3eGUQ2+lPLtbdW9rPTDb3kPStZJOk3SkpHNsH9nZVZV2UkRMrflbrvRImjHgtgWSHoiIKZIeaFyvkx69c82SdGXj5z01Iu5p85reJvPeSnR3uPSI7g63une3R/n1VqK7w63uvZXy7G6PKuxtrQdm7d5zfl1EPBcR2yX9RNLMDq+pa0TECklbBtw8U9KNjcs3Svp0WxeVMMia64beDjO6O2zo7jDKsbcS3UWe3a26t3UfmA+TtL7f9b7GbXUXku63vdL23E4v5k80LiI2SlLj48EdXk9ZF9le3XgIptMPC+XaW4nudgLdrUau3c21txLdrUKuvZXy7e6Qelv3gdlNbsvhbT1OiIgPafdDQ39v+y87vaAu9z1J75U0VdJGSVd0djnZ9laiu+1Gd6tDd9uL7laD3rbXkHtb94G5T9LEftcnSNrQobWUFhEbGh83S7pDux8qysUm2+MlqfFxc4fXkxQRmyJiZ0TsknS9Ov/zzrK3Et1tN7pbnYy7m11vJbpblYx7K2XY3VZ6W/eB+VFJU2xPtr23pDmS7u7wmgrZHm17/7cuSzpV0trir6qVuyWd17h8nqS7OriWUt76C9swS53/eWfXW4nudgLdrUbm3c2utxLdrULmvZUy7G4rvd2z+uVUJyJ22L5I0s8k7SHphoj4dYeXlTJO0h22pd0/3x9HxH2dXVJztpdJOlHSgbb7JC2StETSv9u+QNL/SPrrzq3wnQZZ84m2p2r3w2/PS5rXsQUq295KdHdY0d1hlUV3c+ytRHeHURa9lfLsbtW9Zac/AAAAoEDdn5IBAAAAdBQDMwAAAFCAgRkAAAAowMAMAAAAFGBgBgAAAAowMAMAAAAFGJgBAACAAgzMAAAAQIH/A2fx89dxfvWwAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "plt.rcParams[\"figure.figsize\"] = (12, 7)\n", "\n", "for i in range(4):\n", " plt.subplot(141+i)\n", " plt.imshow(X_train[i].reshape(DIM, DIM), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple model: SVM\n", "\n", "I don't really care about the final accuracy, but use SVMs for the sake of simplicity.\n", "\n", "If you don't know about them, the [Wikipedia article](https://en.wikipedia.org/wiki/Support_vector_machine) is well written." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:46.503170Z", "start_time": "2018-06-02T09:53:29.010093Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training time: 16s\n", "Classification report for classifier SVC(C=5, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape='ovr', degree=3, gamma=0.05, kernel='rbf',\n", " max_iter=-1, probability=True, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False):\n", " precision recall f1-score support\n", "\n", " 0.0 0.99 1.00 0.99 98\n", " 1.0 0.98 0.96 0.97 126\n", " 2.0 0.95 0.97 0.96 107\n", " 3.0 0.95 0.94 0.95 123\n", " 4.0 0.92 0.99 0.95 96\n", " 5.0 0.95 0.95 0.95 80\n", " 6.0 0.97 0.98 0.97 89\n", " 7.0 0.96 0.96 0.96 100\n", " 8.0 0.95 0.95 0.95 81\n", " 9.0 0.98 0.91 0.94 100\n", "\n", "avg / total 0.96 0.96 0.96 1000\n", "\n", "\n", "Confusion matrix:\n", "[[ 98 0 0 0 0 0 0 0 0 0]\n", " [ 0 121 3 0 2 0 0 0 0 0]\n", " [ 0 1 104 0 1 0 0 0 1 0]\n", " [ 0 0 2 116 0 2 0 1 2 0]\n", " [ 0 0 0 0 95 0 1 0 0 0]\n", " [ 0 0 0 1 0 76 2 0 1 0]\n", " [ 1 0 0 0 1 0 87 0 0 0]\n", " [ 0 1 0 0 1 0 0 96 0 2]\n", " [ 0 0 0 3 0 1 0 0 77 0]\n", " [ 0 0 0 2 3 1 0 3 0 91]]\n" ] } ], "source": [ "from time import time\n", "\n", "from sklearn import svm, metrics\n", "\n", "# Took the parameters from\n", "# https://github.com/ksopyla/svm_mnist_digit_classification\n", "classifier = svm.SVC(C=5, gamma=0.05, probability=True)\n", "\n", "t = time()\n", "classifier.fit(X_train, y_train)\n", "print('Training time: %is' % (time() - t))\n", "\n", "predicted = classifier.predict(X_test)\n", "\n", "print(\"Classification report for classifier %s:\\n%s\\n\"\n", " % (classifier, metrics.classification_report(y_test, predicted)))\n", "print(\"Confusion matrix:\\n%s\" % metrics.confusion_matrix(y_test, predicted))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Multiple digits classification\n", "\n", "As I explained above, I want to test on batches with different digits and exploit the information.\n", "\n", "Here is what the labels look like:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:46.579398Z", "start_time": "2018-06-02T09:53:46.506675Z" } }, "outputs": [ { "data": { "text/plain": [ "array([[5, 1, 6, 9, 2, 8, 7],\n", " [0, 3, 2, 5, 4, 8, 6],\n", " [5, 9, 2, 3, 4, 1, 7],\n", " [5, 8, 1, 6, 7, 3, 2],\n", " [5, 6, 2, 7, 8, 4, 3],\n", " [4, 8, 0, 6, 3, 9, 7],\n", " [3, 7, 4, 1, 8, 0, 6],\n", " [8, 6, 5, 2, 1, 4, 0],\n", " [6, 9, 0, 5, 2, 8, 1],\n", " [3, 8, 1, 0, 5, 2, 4]])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "\n", "N_DIGITS = 7\n", "TEST2 = 500\n", "\n", "reverse_labels = {digit: np.flatnonzero(mnist.target==digit) for digit in range(10)}\n", "\n", "testing_indexes = np.array([[random.choice(reverse_labels[i])\n", " for i in random.sample(range(10), N_DIGITS)]\n", " for _ in range(TEST2)])\n", "\n", "y[testing_indexes[:10]].astype(int)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We say the model is correct on a batch if its predictions are correct for all `N_DIGITS` images of the batch." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:53.924758Z", "start_time": "2018-06-02T09:53:46.590923Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for single digit prediction: 95.7%\n", "Accuracy for single digit power N_DIGITS: 73.4%\n", "Accuracy for multiple digits prediction: 73.6%\n" ] } ], "source": [ "testing_indexes_flat = testing_indexes.flatten()\n", "\n", "X_test_2_flat = np.array(list(map(preprocess, X[testing_indexes_flat])))\n", "X_test_2 = X_test_2_flat.reshape(TEST2, N_DIGITS, DIM * DIM)\n", "y_test_2 = y[testing_indexes]\n", "\n", "predicted = classifier.predict(X_test_2_flat).reshape(TEST2, N_DIGITS)\n", "accuracy = (predicted == y_test_2).mean()\n", "accuracy_multi = (predicted == y_test_2).all(axis=1).mean()\n", "\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy))\n", "print('Accuracy for single digit power N_DIGITS: %.01f%%' % (100 * accuracy ** N_DIGITS))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, the accuracy on the new task is necessarily less.\n", "\n", "But what if we could exploit the fact that the labels are different in a batch?" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2018-06-01T14:21:32.346361Z", "start_time": "2018-06-01T14:21:32.339234Z" } }, "source": [ "# The Hungarian Trick\n", "\n", "Look at what we get for some predictions:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:54.661979Z", "start_time": "2018-06-02T09:53:53.929434Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsMAAACeCAYAAADXAwdCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGHpJREFUeJzt3Xu0lXWdx/HPFxBJ4AheDowXQDwpEzRmdBkRFEdTTMFmEMc0NUvF0SVaXnPhoJmDZpjjQsvGFMVKMm/TrMwraqZmiooRVKSgAiJ3QQUDfvPH8xCn0/fHOXufffbt936txfLsz3nOs397f/c+fvdz9vPdFkIQAAAAkKJOlV4AAAAAUCk0wwAAAEgWzTAAAACSRTMMAACAZNEMAwAAIFk0wwAAAEgWzXCBzKzJzJhHVweoZX2gjvWDWtYPalk/UqhlXTXDZrau2b/NZvZBs8snFrnPt8xsZImXKjPb3sz+ZGYLSr3velALtTSz883sdTN718wWmdkUM+tSqv3Xgxqp46Fm9kRex/ml2m+9qZFaXmJmc8xsrZm9ZmZfL9W+60mN1LKTmX3HzFaa2Qozm2xmVqr914saqeXDLdb5oZm9VKr9l0Jd/Y87hNBjy9d5k3laCOHR2PZm1iWEsLEca3NcImmJpH4Vuv6qViO1vF/SLSGENWa2i6R7JJ0l6YYyr6Nq1Ugd35N0i6Seks4v83XXjBqppSR9SdJsSftIetjM3ggh/KwC66haNVLL/5D0eUlDJHWW9KikPyt7riJXC7UMIRzeYg1PS/pFOdfQmro6MtwaM/uWmc0ws5+Y2VpJXzKzO83s8mbbHLblaK2Z/UTSbpIezF/NfL3Zdifnr56WmdklBa6jSdK/S/p2CW5WkqqhliGEP4cQ1jSLNktqaudNS0qV1PG5EMKdkl4v1e1KUZXU8uoQwkshhE0hhLmSfi7pwBLdxGRUQy0lnSLpOyGExSGENyVdJ+nL7b91aamSWjZfz96SDpA0vR03q+SSaoZz/yrpx5J2lDRjWxuGEL4oabGkI0MIPUII1zX79jBljc8Rkq4ws49KkpkdbGbLW1nDVEkXS1pf3E1AruK1NLOT8l8wyyQNlvSDYm9MwipeR5RM1dTSzDpJGi5pTsG3AlLlazlY0ivNLr+SZyhcpWvZ3CmSZuYvcKpGis3w0yGEn4cQNocQPmjHfi4PIawPIcxS9st2P0kKITwZQtgl9kNmNk7SxhDCz9tx3chUtJb5NtNDCD0lDZJ0s6R32rGOVFW8jiiZaqrllZI2SrqjHetIWcVqaWYmaQdJzf/ytkbZW5lQuKp4XuZ1PUnStHasoUOk2AyX5NVICOHtZhffl9Qjtu0WZtZD0mRJ55ZiDahcLZ19/EHSH5Qd9UdhqqaOaLeqqKWZnSvpeElHhxA+LMWaElSxWoYQQr5tQ7O4QdLaUqwpQVXxvJR0sKSdJd1bivWUUl2dQNdGLceDvKfsFegWfVvZvj0GKTth7tfZCyR1lbSjmb0t6dPV9meDGlDJWnq6SNq7g6+jHlVbHVG8itfSzM5QdiLkQSGExaXef0IqXcstRx5n5Zf3E295KVala7nFKZJ+FkJ4v4P2X7QUjwy39LKko8yst5n9g6QJLb6/VNLAEl5XP0mfyP+NV/benE/k/0X7lLOWMrPTzawx/3qwsveBP1aq/Ses3HXsZGbdJG2XXbRuZrZdqfafuHLX8hRJV0j6XAhhQan2C0llrqWyt7ecb2a7mdkekr6mKvzzeo0qdy1lZjtIOlZVWkOa4awwcyUtlPRLSXe1+P5/KXuj+GozO6+1nZnZSDNb7X0vhLAxhPD2ln+SVknalF/e1K5bAamMtcwdJOl3ZvaepP+T9L+SLitm4fgb01TeOv6LpA+U1W9g/vWDRawbf2+aylvLbyn7M+yLtnWmKW9dKo1pKm8tb5L0kLKjwbMlPSDph0WsG39vmspbS0kaq+xE818VvtyOZ9lbcwAAAID0cGQYAAAAyaIZBgAAQLJohgEAAJAsmmEAAAAkK+lm2MwGmFkwsy755Qfz0Twdfb2Xm9mdHX09qaCO9YNa1g9qWT+oZf2glr6qb4bNbIGZfZCPyFlqZrfln+RWciGEI0MIt7dxTYd1xBpaud5J+YO47NfdXqnX0cw+ZmYvmNmq/N+jZvaxclx3qVFL62pmP8uvM5jZyHJcb0eglvbPZvaIma00s2Vmdrdlc1drTuq1zK/vUDObZ2bvm9lMM+tfrusuJWr5N9dblr6n6pvh3OgQQg9Jn5T0aUkTW25gmVq5PQUzs72VDaxeUum1tEPKdVysrH47SdpF2UzblrMda0nKtZSkpyV9SdLbrW1YA1KuZW9JP5A0QFJ/ZR/3e1slF9ROydbSzHZR9jG/lyn7PfuCpBkVXVT7JFvLLcrZ99TUnRhCWKRsGP4QSTKzJ8zsKjP7tbLPyR5oZjua2Q/NbImZLTKzb5lZ53z7zmb2HTNbbmavSTqq+f7z/Z3W7PLpZjbXzNaa2e/N7JNmNl3Zp8j9PH/VdlG+7T+b2TOWDal+pfnRIjPby8yezPfziLJmqFBTlX3C2YdF/GxVSbGOIYTVIYQFIRvsbZI2SWoq5v6rJonW8sMQwvUhhKeV1bEuJFrLB0MId4cQ3s0/InaqpAOLugOrSIq1lPRvkubk9Vwv6XJJ+5nZoELvv2qSaC23KF/fE0Ko6n+SFkg6LP96T2WfRnNlfvkJSW9IGiypi7KPU71f0s2SuktqlPS8pPH59mdKmpfvZydJM5V9BneXZvs7Lf96nKRFyl6RmbLGpX/LNeWXd5e0QtLnlb3A+Fx+edf8+89Kuk7S9so+tWytpDub/fxsSSds4z4YJ+kB77pr5R91/Os2qyVtlLRZ0sRK14VaFl/LfLu3JI2sdE2oZftrmW97nqTnKl0Xall4LSX9t6Tvtch+J2lspWtDLau/76l40dv4oFinrIlYqOwjGj/SrIjfbLZtH0kbtnw/z74oaWb+9eOSzmz2vcO38aB4SNK5rT1Q88sXS5reYpuHJJ2i7NXURkndm33vx80fFK3c/h6S/iRpr3I9KKhj6evYYp/dJZ0l6ahK14VatruW9dAMU8vs5/5J0kpJIypdF2pZeC2VfVTz1S2yX0v6cqVrQy2rv+/potrwhRDCo5Hvvdns6/7KXiUtMbMtWadm2+zWYvuF27jOPSX9uY3r6y9pnJmNbpZtp+wV2G6SVoUQ3mtxvXu2cd9XKHvAvd7G7atZynX8qxDCe2b2fUnLzOwfQwjvFLqPKkAt60fytTSzJmV/ij43hPCrQn62yqRcy3WSGlpkDcqOSNailGtZ9r6nVprhbQnNvn5T2SukXUIIG51tl+hvi9FvG/t9U9LebbjOLdtODyGc3nJDy85m7W1m3Zs9MPo5+4g5VNIeZnZWfnlXST81s2tCCNe0cR+1oN7r2FInSTso+1NTLTbD25JaLetZ3dcy38ejyv4MPb2tP1eD6r2Wc5Qdldyyv+75uua08edrSb3Xsux9T02dQNeaEMISSQ9LmmJmDWbWycz2NrOD801+KmmCme1hZr0lXbKN3d0i6QIzG2qZJts6pmWppIHNtr1T0mgzOyJ/s3o3MxtpZnuEEBYqO6v1CstGMg2XNFptd6iyN85/Iv+3WNJ4STcWsI+aUo91NLPPmdn++X4blL2XapWkuW3dRy2qx1pKkpltb2bd8otd8/3bNn+oxtVjLc1sd2V/Rr4xhPD9tv5cravHWkq6T9IQMxubPzf/U9LsEMK8AvZRc+q0lmXve+qqGc6dLKmrpN8razZ+JmnL3Mj/UfaellckzVI2hsUVQrhb0lXK3ueyVtkb1HfKvz1Z0kTLzqC8IITwpqRjJF0qaZmyV0wXauv9e4Kkzyp7P9okSXc0vy4zm2NmJ0bWsSKE8PaWf8rOXl8VQljXtrujZtVVHSX1kvQTSWuU/RmqSdKokJ31XO/qrZaS9AdJHyg7sv9Q/nX/bWxfL+qtlqcp+x/8JMvOkl9nZvX+u3WLuqplCGGZpLH5Wlbl+zm+DfdDPai3Wpa977H8zckAAABAcurxyDAAAADQJjTDAAAASBbNMAAAAJJFMwwAAIBk0QwDAAAgWWX90A0zY3RFBYUQSjYHlVpWFrWsH6WqJXWsLJ6T9YNa1o+21pIjwwAAAEgWzTAAAACSRTMMAACAZNEMAwAAIFk0wwAAAEgWzTAAAACSRTMMAACAZNEMAwAAIFk0wwAAAEgWzTAAAACSRTMMAACAZHWp9AKwVefOnd1806ZNZV5J/TDzP5Y8dl/Htg/B/3j5WE7NqlenTv4xgFi+cePGjlwOilTocxXls/POO7v5N7/5zYL2c9lll7n5ypUrC14TsC0cGQYAAECyaIYBAACQLJphAAAAJItmGAAAAMmiGQYAAECyqnqaxN577+3ma9ascfPly5d35HJK5tBDD3Xz0aNHu/l5553XkcupKU1NTW5+zjnnuPm+++7r5ocffribF3qG+rJly9z8xBNPdPNHH33UzVF63bp1c/O7777bzUeNGuXmH//4x9183rx5xS0scb169XLzk08+2c1HjBjh5v369XPz2HNv/vz5bVgdCtG3b183nzFjhpvHarlu3To3X7hwoZtPmTLFzZnig2JxZBgAAADJohkGAABAsmiGAQAAkCyaYQAAACSLZhgAAADJquppErEpCvfee6+bz5w5syOXUzIjR4508/Xr15d3IVUsNtUh9ln1sTPI33jjDTefPXu2m3fu3NnNBw0a5OaNjY1uPmHCBDdnmkT5xCaJHHTQQW7+7LPPuvnatWtLtqZ61LNnTzc/9dRT3fz888938z59+rj5HXfc4eaxiTDjxo1z88mTJ7s5Wjd8+HA3v+aaa9x82LBhbh6b9jBp0iQ3v+KKK9w89nv9rrvucnNsNXDgQDdvaGhw85dffrkjl6PBgwe7+Uc/+lE3f/DBB918w4YN7VoHR4YBAACQLJphAAAAJItmGAAAAMmiGQYAAECyaIYBAACQrKqYJhGbHHDggQe6+f3339+RyymZ2GSCs88+281jZ1+nKITg5uPHj3fziRMnuvmKFSvcPDa5o0sX/ynxzDPPuPnQoUPdnKkR5bPHHnu4+S9/+Us3j/2+Ofnkk9180aJFxS0sEV/72tfc/PLLL3fz2HPv2muvdfMrr7zSzUeNGuXmPXr0cHO0rnfv3m5+2223uXm/fv3c/KabbnLzL3zhC24em1gQm/ASe86n6CMf+Yibx54f06ZNc/OLL77YzV955RU3j/U3O+20k5uPHTvWza+//no337hxo5s3NTW5+ZIlS9y8rTgyDAAAgGTRDAMAACBZNMMAAABIFs0wAAAAkkUzDAAAgGRVxTSJXr16ufmgQYPKvJLSamxsdPPY7X3++ec7cjl1IXYm+ptvvlnQfmJnwh533HFuvu+++7p57IzXGTNmFLQeFG/06NFuvuuuu7p5bJrLggULSrWkpDz++ONu/tprr7l5bDJL7P6PTYfo1q1b64uDq6Ghwc2nTp3q5jvvvHNB23/jG99w80svvdTN16xZ4+axCS9f+cpX3Py73/2um2/atMnNa0ls0tEvfvELNx8+fLibf/vb33bz3/72twVtH5sOEZsmEXvMxab7zJs3z81Xr17t5u3FkWEAAAAki2YYAAAAyaIZBgAAQLJohgEAAJAsmmEAAAAkqyqmSWy//fZu3rVrVzdfuHBhSa43dhZjLN9uu+3cfNiwYW5+4403uvlTTz3l5suXL3dzbNWvXz83HzdunJvHzjiPTY0YPHiwm2/YsMHNJ06c6ObLli1zcxRv6NChbj558mQ3f/LJJ9389ttvL9maID399NMF5UOGDHHzo48+2s1j03e6d+/u5vPnz3dzbHXqqae6+RFHHOHmI0aMcPPYxJAPP/ywoLxQsekWsf9314NLLrnEzQ8++GA3j90XF154oZtfdNFFBa0nNv1l+vTpbn7WWWe5+auvvurmsd8HH3zwQeuLKwJHhgEAAJAsmmEAAAAki2YYAAAAyaIZBgAAQLJohgEAAJCsqpgmEUIoKD/qqKPcfNGiRW6+3377ufk+++zj5j169HDzPn36uPns2bPdvLGx0c2nTZvm5n/5y1/cPEWxzzefOXOmmw8YMMDNC50YEhM7azo2MWTz5s0F7R9b9ezZ081vueUWN3/33Xfd/Nhjj3Xz9evXF7ewxMWeM/3793fzk046yc1PP/10N49Nftl1113d/J133nHzxx57zM1TFLvvLrjgAjePTVqZM2dOydZUiNhjLvYcjvUMtaRLF78tmzBhgpvHprbEpm7FpkA888wzbj537lw333333d38gQcecPPnn3/ezceOHevmixcvdvOOwpFhAAAAJItmGAAAAMmiGQYAAECyaIYBAACQLJphAAAAJKsqpkksXbrUzc844ww3P+SQQ9z8U5/6lJu///77bn7DDTe4+W9+8xs3j00IaGhocPMTTjjBzV9//XU3x1YrV65089gkkV69ern52rVrC7re2GPu7LPPdvPevXu7+bp16wq63hTFzhSfNGmSm8emwvzoRz9y81WrVhW3MLg+85nPuHlswsuUKVPc/IADDnDz4447zs2vvfZaN99hhx3c/KqrrnLzCy+80M1jj5MNGza4eS0ZOHCgm++4445uHruvO1rfvn3dfOjQoW5+/fXXu/mmTZtKtqZKid2Gyy67zM1vvfVWNy/VdKojjzzSzWNTI2bNmuXmxxxzjJvHpsKUG0eGAQAAkCyaYQAAACSLZhgAAADJohkGAABAsmiGAQAAkKyqmCYRc9tttxWUV0rsrOauXbu6+XPPPdeRy6lr8+bNK8l+YpMMYp/DXg+feV9t9txzTzcfP368m8emsMSmBKA4Y8aMcfPY793Y9I/vfe97bn7RRRe5+Zlnnunmxx9/vJu/9NJLbj5hwgQ3j53lHps2NHLkSDd/66233LwaxX7PLV682M2XLVvWodc7bNgwN7/jjjvcfOLEiW5+zz33FLewGhD7f83NN9/codc7ZMgQN49N63n11VfdPDbxacWKFcUtrEw4MgwAAIBk0QwDAAAgWTTDAAAASBbNMAAAAJJFMwwAAIBkVfU0iVoXO6O2c+fOZV5J9YqdKb7XXnu5+dVXX13Q/rt16+bmX/3qV918ypQpbr5+/Xo337x5c0HrSVHseXDiiSe6eY8ePdz8uuuuc/O33367uIUlrksX/9f/1KlT3Xz58uVuHnuuvvDCC26+evVqNx8+fLib//GPf3TzmPPOO8/NY9MwYo+32MSFejBgwAA3P/bYY938iSeecPPDDjvMzUeNGuXmI0aMcPPYc/uuu+5yc6b7FK+pqcnN77vvPjd/7bXX3PyYY45x82qfGhHDkWEAAAAki2YYAAAAyaIZBgAAQLJohgEAAJAsmmEAAAAky8p5VqaZ1eUpoLvttpubL1q0yM33339/N3/55ZdLtiZPCME/rb8IhdYyduZ67IztjRs3uvmLL75YyNVGaxP7HPbZs2e7+QUXXODmTz75ZEHrKZVK1rJQjY2Nbj537lw379TJf43+2c9+1s0LnTZQbUpVy0LrGLufY2fw77PPPm4e+70VOzv9sccec/N169a5ea2oxufkoEGD3PzZZ5918549e7p5bJJIr1693PyRRx5x8xtvvNHNH3roITev1NSIaqxloWITQ5566ik3j015GTNmjJsvWLCgmGWVXVtryZFhAAAAJItmGAAAAMmiGQYAAECyaIYBAACQLJphAAAAJItpEiXQt29fN49NSjjkkEPcvKMnE1TjGbJnnnmmm0+ePNnNY2c7r1y50s1nzZrl5rEz5qdPn+7mmzZtcvNKqcZaxvTp08fN58+f7+bvvPOOmzc1Nbl5pc44L5VKTZMoYv9uXuv3f6nU0nPywAMPdPNbb721oP2cc845bh6bJlErj5VaquXAgQPd/OGHH3bzWA2GDx/u5kuXLi1uYVWCaRIAAABAK2iGAQAAkCyaYQAAACSLZhgAAADJohkGAABAspgmUQKdOvmvKWKfDR6bMrF+/fpSLclVjWfIxs5Q32WXXdy8e/fubh77XPU1a9a4ea2c1RxTjbXcxv7dfPfdd3fz2PPpjTfeKNmaqkmtTJPAttXScxLbVo21bGhocPN7773XzRsbG918zJgxbr5gwYKi1lXtmCYBAAAAtIJmGAAAAMmiGQYAAECyaIYBAACQLJphAAAAJItpEgmpxjNkURxqWT+YJlEfeE7Wj1qqZWxaT0ytT1IqFNMkAAAAgFbQDAMAACBZNMMAAABIFs0wAAAAkkUzDAAAgGSVdZoEAAAAUE04MgwAAIBk0QwDAAAgWTTDAAAASBbNMAAAAJJFMwwAAIBk0QwDAAAgWTTDAAAASBbNMAAAAJJFMwwAAIBk0QwDAAAgWTTDAAAASBbNMAAAAJJFMwwAAIBk0QwDAAAgWTTDAAAASBbNMAAAAJJFMwwAAIBk0QwDAAAgWTTDAAAASBbNMAAAAJJFMwwAAIBk0QwDAAAgWTTDAAAASNb/A5hkwXI7oM3UAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVkAAAG5CAYAAAAktkdZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xu0ZGV55/Hvj5ZrIxpFHaFbIIpMGDVCegAHMxLRiDdQF0a8xOiojCsx8TYxaryiyTKZxDjLoEnHGJN4QbxlekwLGgWjRrGRawAxLaDdoHIREBCE7vPMH3u3lqdP1TkNZ5/a+/D9rFWrz67atfdTu6qf85xnv/utVBWSpG7sNO0AJGk5M8lKUodMspLUIZOsJHXIJCtJHTLJSlKHBpNkkxyVZPO041hsSXZP8v+S3JjkYwt8zplJXtx1bIshyQeSvL39+VeTXHont/NXSd64uNEtnbvy+U2yf5JKco8xj78+yfvmWjfJZ5L81oRtD/q4DsGcb1qXkjwHeBXwn4GbgPOAP6qqLy91LD1xPPAA4L5VtWX2g0neAjykqp631IEttqr6EnDQfOsleQHw4qp69MhzX9phaINWVX884bEnbvvZ4zodS1rJJnkV8C7gj2kSy4OA9wDHdbzfFV1u/y7aD/jWXAm2b8ZVUvLYaIKqWpIbcC/gZuCZE9bZlSYJX9Xe3gXs2j52FLB5ZN1fAs4EbgAuAo4deewDwHuB9cAtwOPa+94DfKaN4yvAf2r3cT3wTeCQkW28Fvg2TbV9MfD0kcde0D7/3cCN7XOPnvC65owVeCtwO3BHG9OLZj3vmFmPn9/efybwtjaGm4DPAnuPPO8I4N/a/Z0PHDUhtiuA17Wv8Xrg74DdRo858AfA94F/bO9/Cs1fIDe0+3nEyPYOAc5p4/oocArw9jHv4Wrgk8A1wHXAX7bH6jZga/uabxh5T98+8tyXABuBHwLrgH1GHivgpcB/tK/pZCDtYw8Bvti+b9cCHx1zXPZvt3MizWfxe8CrRx5/C/Bx4IPAj4AXs4DPL/D6dr9XAM8d2d6TgXPbbW0C3rKDsXxw1rr3GPmsvHgHjuuk9/YPgCvb9/ZSJnzmvY18lpZsR03C2LLtzR+zzknA14D7A/dr3+S3jX5I2593bv+DvR7YBXhs+8YfNPLBuRE4kqZa362971rgV9rlLwCXA88HVgBvB84YieWZwD7t859Fk6wf2D72gva1vLKN5Vnt/u4zx2uaL9af/gcZc0y2e7z9j/Nt4KHA7u3yO9rH9qVJWE9qY398u3y/Mdu/Avh3moR3H5rEPZoUtwB/QpNAdgcOBa4GDm+P22+129i1fX3fGTkux9P8gtguybbPPR/4C2Bl+548euT4fnlWnB8Y2c5j2/fy0Ha/7wb+dWTdAj4N3Jvmr6VrgGPaxz4C/OHI5+LRY47L/u12PtLG9/B2O48beV/uAJ7Wbmt35v/8bgHe2cb8GJrP1EEjjz+83dYjgB8AT9uBWCYm2QUe10nv7UE0yX+fkf08eNoJbAi3pWwX3Be4tib/Wfxc4KSqurqqrqGp9H5zjvWOAPakSSy3V9UXaP5TPXtknf9bVV+pqpmquq2971NV9Y12+VPAbVX1D1W1labqOmTbk6vqY1V1Vfv8j9JURYeNbP9q4F1VdUf7+KU01cidifXO+Luq+lZV3QqcCjyyvf95wPqqWt/G/jngbJqkO85fVtWmqvoh8EezYpsB3lxVP2n39RLgr6vqrKraWlV/D/ykfZ1H0CTXbcfl48CGMfs8jOaX2O9X1S1VdVstvC//XOD9VXVOVf2EphJ/VJL9R9Z5R1XdUFXfBc7gZ8fnDpoWzT4L3Odb2/gupKnyR4/NV6vqn9rjfCsL+/y+sT2WXwT+GfgNgKo6s6oubLd1AU1CfcwOxLIYJr23W2mS7cFJdq6qK6rq24u8/2VpKZPsdcDe8/Su9qGphLb5TnvfXOttqqqZWevuO7K8aY7n/WDk51vnWN5z20KS5yc5L8kNSW4AHgbsPbL+lVU1OrvOXYn1zvj+yM8/Hol9P+CZ2+JuY3808MAJ2xo9VrNfxzUjv6S2bf/Vs7a/un3OPsx9XOayGvjOPL90x/m5z0lV3Uzz+Ro9puOOz2uAAF9PclGS/zHPviYdm9mfsfk+v9dX1S1zPZ7k8CRnJLkmyY007Y7Rz9t8sSyGse9tVW0EXkFTNV+d5JQki73/ZWkpk+xXaXpCT5uwzlU0b/Q2D2rvm2u91Ul2mrXulSPLd3p6sST7AX8DvIzmrP+9af6kzshq+yYZXb4rsU6yo69jE03v9N4jt5VV9Y4Jz1k9K7bR1zF7/5toRoOMbn+PqvoITa9wruMyLs4HjfmlO99r/rnPSZKVNH8pzXtMq+r7VfWSqtoH+J/Ae5I8ZMJTduTYzPf5/YU21rke/zBNb3l1Vd0L+Ct+/vM2XywLMd9xnfTeUlUfrmZkwn7ttv5kB/d/t7RkSbaqbgTeBJyc5GlJ9kiyc5InJvnTdrWPAG9Icr8ke7frf3COzZ1F0896TbuNo4Cn0pxkWQwraT5E1wAkeSFNJTvq/sDvtft/Js2JhfUdxPoDYP9ZSXqSDwJPTfKEJCuS7NaO0Vw14Tm/k2RVkvvQ9I4/OmHdvwFe2lZeSbIyyZOT3JPmF+kWmuNyjyTP4OdbLKO+TpOU39FuY7ckR4685lVJdhnz3A8DL0zyyCS70oxWOauqrpgQNwBJnjlyLK6neZ+3TnjKG9vP6n8BXsjkY7OQz+9bk+yS5FdpTjJtGxt9T+CHVXVbksOA59zFWOYy33Ed+94mOSjJY9vjfRvNX36TjptaSzqEq6reSTNG9g00CWwTTbX4T+0qb6fpH14AXEhzlvrtc2znduBY4Ik0J0DeAzy/qr65SHFeDPw5TdL4Ac2Jhq/MWu0s4MB2/38EHF9V13UQ67b/hNclOWcBsW+iGRL3en52jH+fye/1h2lGKFzW3rY75iPbP5umd/eXNElqI80JlW2v9Rnt8vU0JwQ/OWY7W2l+2TwE+C7NmfdntQ9/gWYUxveTXDvHcz8PvBH4BE2ifjBwwoTXN+q/AmcluZmmcnx5VV0+Yf0vtq/x88CfVdVnJ6w73+f3+zTH5SrgQ8BLRz4Hvw2clOQmmuR86l2MZS7zHdex7y1NP/YdNJ/h79MUGa/fwf3fLW0b1qIdMNeg7qFKcgXNa/mXacfSJ+1JtMuBne9k31gCBnRZrSQNkUlWkjpku0CSOmQlK0kd6tWkFrtk19qNlfOvKN1JD33Ej6cdwpy+dWH/Pve31S3cXrfNHqu7pJ7wayvruh8uzkixb1zwk9Or6phF2dgO6FWS3Y2VHJ6jpx2GlrHTTz9v2iHM6Zj9xg0nnp6v3XHatEPguh9u5eunj7ueZceseOB/zL6Cbkn0KslK0qgCZpiZd70+sycrSR2ykpXUY8XWGnYla5KV1FtNu2DYw0xtF0hSh6xkJfXa0E98mWQl9VZRbB34Vam2CySpQ1ayknpt6Ce+TLKSeqv52ophJ1nbBZLUIStZSb1mu0CSOlLg6AJJ0nidJtkkxyS5NMnGJK/tcl+SlqeZRbpNS2ftgiQrgJOBx9N83fOGJOvar9uWpHkV5eiCCQ4DNlbVZVV1O3AKcFyH+5Ok3unyxNe+wKaR5c3A4bNXSnIicCLAbuzRYTiSBqdg67AL2U6T7FzfDbTd4aqqtcBagL1yn4EfTkmLqZnqcNi6bBdsBlaPLK8Crupwf5LUO11WshuAA5McAFwJnAA8p8P9SVp2wtY5/ygejs6SbFVtSfIy4HRgBfD+qrqoq/1JWn4KmBl4E7HTK76qaj2wvst9SFKfeVmtpF6zXSBJHWmmOhx2knXuAknqkJWspF6bqWFXsiZZSb1lu0CSNJGVrKTeKsLWgdeCJllJvWZPVpI6Yk9WkjSRlaykHgtba9i1oElWUm8188kOO8kOO3pJ6jkr2QFbf+U50w5hTk/a99BphzDWE/Z55LRDGOP2aQewverHHINDP/FlkpXUW1XD78kOO3pJ6jkrWUm9NmO7QJK60VyMMOw/uIcdvST1nJWspB4b/okvk6yk3vJiBEnSRFayknptq1MdSlI3lsOk3cOOXpJ6zkpWUq/NOLpAkrrhxQiSpImsZCX1VhFHF0hSl7wYQZI0lpWspN6qwrkLJKk7Gfx8ssP+FSFJPWclK6m3CtsFktQpL0aQJI1lJSupt4ow48UIktQd2wWSpLGsZCX1VuFUh5LUobDVixEkSeNYyUrqLdsFktQx2wWSpLGsZCX1VlVsF0hSl4Y+Qcywo5ekRZTkmCSXJtmY5LVzPP6gJGckOTfJBUmeNN82rWQl9VbBkk3anWQFcDLweGAzsCHJuqq6eGS1NwCnVtV7kxwMrAf2n7Rdk6ykHstStgsOAzZW1WUASU4BjgNGk2wBe7U/3wu4ar6N9i/JpofDNaqmHcGcnrTvodMOYU6f2vz1aYcw1tNXHTbtEDQ9eyc5e2R5bVWtHVneF9g0srwZOHzWNt4CfDbJ7wIrgcfNt9P+JVlJajUXIyxa4XVtVa2Z8PhcO5pdYT0b+EBV/XmSRwH/mORhVTUzbqMmWUm9toRTHW4GVo8sr2L7dsCLgGMAquqrSXYD9gauHrdRRxdIUmMDcGCSA5LsApwArJu1zneBowGS/BKwG3DNpI1ayUrqraX8ZoSq2pLkZcDpwArg/VV1UZKTgLOrah3wauBvkrySppXwgqrJJ21MspJ6bWYJ/+CuqvU0w7JG73vTyM8XA0fuyDZtF0hSh6xkJfVWFWz1ixQlqTtD/7Za2wWS1CErWUm91YwuGHYtaJKV1GtD/2YEk6yk3lrky2qnYth1uCT1nJWspB6zJytJnVqqSbu70tmviCSr269puCTJRUle3tW+JKmvuqxktwCvrqpzktwT+EaSz836KgdJGssrviaoqu8B32t/vinJJTQzj5tkJS2YPdkFSLI/cAhw1hyPnQicCLAbeyxFOJK0ZDpPskn2BD4BvKKqfjT78fY7dtYC7JX79PPLtCRNxVLOJ9uVTpNskp1pEuyHquqTXe5L0vLk6IIxkgT4W+CSqnpnV/uRpD7rspI9EvhN4MIk57X3vb6deVyS5rUcLqvtcnTBl5n7K3YlacGGPrpg2NFLUs95Wa2k/ipHF0hSZwpHF0iSJrCSldRrtgskqSPLYQiX7QJJ6pCVrKReG3ola5KV1FvLYYIY2wWS1CErWUm9NvRxsiZZSf1Vw+/J2i6QpA71r5Kt/n05wulXnTf/SlPwhH0eOe0Q5vT0Bz1q2iFMsHXaAcwtPazWevBfcTmMk+1fkpWkEUNPsrYLJKlDVrKSems5jJM1yUrqtRp4krVdIEkdspKV1GtejCBJHSkvRpAkTWIlK6nXhn7iyyQrqceGP4TLdoEkdchKVlKv2S6QpI4shwlibBdIUoesZCX1V/Vy9tMdYpKV1GtDv+LLdoEkdchKVlJvFY4ukKQOeTGCJGkCK1lJveboAknq0NB7srYLJKlDVrKSeqtq+JWsSVZSrzm6QJI0lpWspF5zdIEkdcierCR1pMjgk6w9WUnqkJWspF4beEvWJCupx5bBOFnbBZLUIStZSf028H6BlaykXqvKotwWIskxSS5NsjHJa8es8xtJLk5yUZIPz7dNK1lJApKsAE4GHg9sBjYkWVdVF4+scyDwOuDIqro+yf3n266VrKReq1qc2wIcBmysqsuq6nbgFOC4Weu8BDi5qq5vYqur59toryrZfR5+C29ad860w9jOE/Y5dNohDMvM1mlHMDi3nrb/tEPYzszv7DLtEBb7O772TnL2yPLaqlo7srwvsGlkeTNw+KxtPBQgyVeAFcBbquq0STvtVZKVpA5dW1VrJjw+VzafXQPfAzgQOApYBXwpycOq6oZxGzXJSuqvApZunOxmYPXI8irgqjnW+VpV3QFcnuRSmqS7YdxG7clK6rUl7MluAA5MckCSXYATgHWz1vkn4NcAkuxN0z64bNJGTbKSBFTVFuBlwOnAJcCpVXVRkpOSHNuudjpwXZKLgTOA36+q6yZt13aBpH5bwosRqmo9sH7WfW8a+bmAV7W3BTHJSuoxpzqUJE1gJSup3wY+d4FJVlJ/OdWhJGkSK1lJ/Wa7QJK6ZLtAkjSGlaykfrNdIEkdGniS7bxdkGRFknOTfLrrfUlS3yxFJftymskW9lqCfUlaTpZ2qsNOdFrJJlkFPBl4X5f7kbR8LeFUh53oul3wLuA1wEzH+5GkXuosySZ5CnB1VX1jnvVOTHJ2krNvuM5cLGmWWqTblHRZyR4JHJvkCppvfXxskg/OXqmq1lbVmqpac+/7OmxX0iyVxblNSWdZrapeV1Wrqmp/mq9x+EJVPa+r/UlSHzlOVlKvZeDjZJckyVbVmcCZS7EvScvIlPupi8EmqCR1yHaBpB6b7kmrxWCSldRvtgskSeNYyUrqt4FXshOTbJJnTHq8qj65uOFI0izLOckCT53wWAEmWUmaYGKSraoXLlUgkrSdu8tUh0kekORvk3ymXT44yYu6DU2Smiu+FuM2LQsdXfAB4HRgn3b5W8ArughIkpaThSbZvavqVNp5YatqC7C1s6gkaZu7yVSHtyS5L22oSY4AbuwsKklaJhY6TvZVwDrgwUm+AtwPOL6zqCRpmVhQkq2qc5I8BjgICHBpVd2x2MFcdeFKTvrFQxd7s1LvrXzG1dMOYTsrfrxl2iEAd5OpDpPsBvw28GialsGXkvxVVd3WZXCSNPQhXAttF/wDcBPw7nb52cA/As/sIihJWi4WmmQPqqpfHlk+I8n5XQQkST91N5q0+9x2RAEASQ4HvtJNSJI0YuBDuOabIOZCmvB2Bp6f5Lvt8n7Axd2HJ+nubrmf+HrKkkQhScvUfBPEfGd0Ocn9gd06jUiSRg28kl3oBDHHJvkP4HLgi8AVwGc6jEuSGgPvyS70xNfbgCOAb1XVAcDReOJLkua10CR7R1VdB+yUZKeqOgN4ZIdxSdKiTXM4zZNnCx0ne0OSPYF/BT6U5GqgH9fcSVreBn7F10Ir2eOAW4FXAqcB32byV9NIklj4BDG3jCz+fUexSNL2Bj66YL6LEW5i7pcYoKpqr06ikqTWsr4YoaruuVSBSNJytNATX5I0Hcu5kpWkqZry8KvFsNDRBZKkO8FKVlK/DbySNclK6reBJ1nbBZLUIStZSb3miS9J0lgmWUnqkO0CSf028HaBSVZSf3kxgiRpEitZSf028ErWJCup3waeZG0XSFKHrGQl9VYY/okvk6ykfht4krVdIEkdspKV1F+Ok5WkjtUi3RYgyTFJLk2yMclrJ6x3fJJKsma+bZpkJQlIsgI4GXgicDDw7CQHz7HePYHfA85ayHZNspL6bekq2cOAjVV1WVXdDpwCHDfHem8D/hS4bSEbNclK6rXU4tyAvZOcPXI7cdau9gU2jSxvbu/7WSzJIcDqqvr0QuP3xJfUAzO3LqgoWlI1MzPtEBbbtVU1qYeaOe77aQ2cZCfgL4AX7MhOrWQl9dvStQs2A6tHllcBV40s3xN4GHBmkiuAI4B18538spKV1F87MDJgEWwADkxyAHAlcALwnJ+GUnUjsPe25SRnAv+rqs6etFErWUkCqmoL8DLgdOAS4NSquijJSUmOvbPbtZKV1GtLeTFCVa0H1s+6701j1j1qIds0yUrqN6/4kiSNYyUrqdeGPneBSVZSvw08ydoukKQOWclK6q+lHSfbCZOspN4Kc1/rOiS2CySpQ1aykvrNdoEkdWfoQ7hsF0hShzpNsknuneTjSb6Z5JIkj+pyf5KWoSX8jq8udN0u+D/AaVV1fJJdgD063p+k5Wbg7YLOkmySvYD/TjuLePudObd3tT9J6qMu2wW/CFwD/F2Sc5O8L8nK2SslOXHbd+7cwU86DEfS4CzS93tN8+RZl0n2HsChwHur6hDgFmC77zGvqrVVtaaq1uzMrh2GI2mQBt6T7TLJbgY2V9W27yb/OE3SlaQFs5Ido6q+D2xKclB719HAxV3tT5L6qOvRBb8LfKgdWXAZ8MKO9ydpuXF0wXhVdR4w8etyJWkSr/iSJI3l3AWS+sv5ZCWpYwNPsrYLJKlDVrKSeisM/8SXSVZSvw08ydoukKQOWclK6rXUsEtZk6yk/loGQ7hsF0hSh6xkJfWaowskqUsDT7K2CySpQ72qZLPTTuy0x3bfUDN1M7fcMu0QtMzV4Q+bdgjbO/9L044AsF0gSd0aeJK1XSBJHbKSldRfU/5+rsVgkpXUbwNPsrYLJKlDVrKSesupDiWpawOfIMZ2gSR1yEpWUq/ZLpCkrjjVoSRpEitZSb2WmWlHcNeYZCX1m+0CSdI4VrKSes3RBZLUlcKLESRJ41nJSuo12wWS1KWBJ1nbBZLUIStZSb3lVIeS1KUqRxdIksazkpXUa7YLJKlLA0+ytgskqUNWspJ6zXaBJHWlgJlhZ1nbBZLUIStZSf027ELWJCup34bek7VdIEkdspKV1G8Dv6zWJCup12wXSNIykeSYJJcm2ZjktXM8/qokFye5IMnnk+w33zZNspL6qxbxNo8kK4CTgScCBwPPTnLwrNXOBdZU1SOAjwN/Ot92TbKSequZT7YW5bYAhwEbq+qyqrodOAU4bnSFqjqjqn7cLn4NWDXfRnvVk62ZGWZuuWXaYUhLbqcNF007hO1tuXXaETRmFm1Leyc5e2R5bVWtHVneF9g0srwZOHzC9l4EfGa+nfYqyUpSh66tqjUTHs8c981ZAid5HrAGeMx8OzXJSuq1Bf6pvxg2A6tHllcBV20XT/I44A+Bx1TVT+bbqD1ZSf21hCe+gA3AgUkOSLILcAKwbnSFJIcAfw0cW1VXL2SjJllJAqpqC/Ay4HTgEuDUqrooyUlJjm1X+9/AnsDHkpyXZN2Yzf2U7QJJPba0X6RYVeuB9bPue9PIz4/b0W2aZCX1mld8SZLGspKV1G9OECNJHSnI4l2MMBW2CySpQ1aykvrNdoEkdWjYOdZ2gSR1yUpWUq8t4dwFnTDJSuq3gSfZTtsFSV6Z5KIk/57kI0l263J/ktQ3nSXZJPsCv0fzVQ0PA1bQzGojSQtTNJN2L8ZtSrpuF9wD2D3JHcAezDE3oySNExb81TG91VklW1VXAn8GfBf4HnBjVX129npJTkxydpKz72De+W8laVC6bBf8As2XkB0A7AOsbL+y4edU1dqqWlNVa3Zm167CkTRUVYtzm5IuT3w9Dri8qq6pqjuATwL/rcP9SVqOTLJjfRc4IskeSQIcTTPbuCTdbXR24quqzkryceAcYAtwLrB28rMkacS20QUD1unogqp6M/DmLvchaXlzdIEkaSwvq5XUbwOvZE2yknpsuiMDFoPtAknqkJWspP4qBl/JmmQl9dvAh3DZLpCkDlnJSuq1oY+TNclK6reBJ1nbBZLUIStZSf1VwMywK1mTrKQe82IESdIEvapks2InVuy517TD2M7WH/1o2iFomdvpvveZdgjbybU9SQ8Dr2R7chQlaYyBJ1nbBZLUIStZSf3l6AJJ6lJBDXvyAtsFktQhK1lJ/TbwE18mWUn9tQx6srYLJKlDVrKS+s12gSR1aOBJ1naBJHXISlZSjw1/Fi6TrKT+KmDGixEkSWNYyUrqN9sFktQhk6wkdaW84kuSNJ6VrKT+KqiBT3VokpXUb7YLJEnjWMlK6jdHF0hSR6q84kuSNJ6VrKR+s10gSd0p2wWSpHGsZCX1mPPJSlJ3/LZaSdIkVrKS+s25CySpGwWU7QJJ0jhWspL6q8p2gSR1yXaBJGksK1lJ/TbwdkGqR1dTJLkG+M4ibW5v4NpF2tZiMq4dY1w7ZjHj2q+q7rdI27pTkpxG85oWw7VVdcwibWvBepVkF1OSs6tqzbTjmM24doxx7Zi+xnV3Zk9WkjpkkpWkDi3nJLt22gGMYVw7xrh2TF/juttatj1ZSeqD5VzJStLUmWQlqUPLLskmOSbJpUk2JnnttOPZJsn7k1yd5N+nHcs2SVYnOSPJJUkuSvLyaccEkGS3JF9Pcn4b11unHdOoJCuSnJvk09OOZZskVyS5MMl5Sc6edjz6mWXVk02yAvgW8HhgM7ABeHZVXTzVwIAk/x24GfiHqnrYtOMBSPJA4IFVdU6SewLfAJ427eOVJMDKqro5yc7Al4GXV9XXphnXNkleBawB9qqqp0w7HmiSLLCmqvp4gcTd2nKrZA8DNlbVZVV1O3AKcNyUYwKgqv4V+OG04xhVVd+rqnPan28CLgH2nW5UUI2b28Wd21svqoEkq4AnA++bdiwahuWWZPcFNo0sb6YHSWMIkuwPHAKcNd1IGu2f5OcBVwOfq6pexAW8C3gN0LcL6gv4bJJvJDlx2sHoZ5Zbks0c9/WiAuqzJHsCnwBeUVU/mnY8AFW1taoeCawCDksy9RZLkqcAV1fVN6YdyxyOrKpDgScCv9O2p9QDyy3JbgZWjyyvAq6aUiyD0PY8PwF8qKo+Oe14ZquqG4AzgSWf2GMORwLHtv3PU4DHJvngdENqVNVV7b9XA5+iaZ2pB5Zbkt0AHJjkgCS7ACcA66YcU2+1J5j+Frikqt457Xi2SXK/JPduf94deBzwzelGBVX1uqpaVVX703y2vlBVz5tyWCRZ2Z64JMlK4NeB3oxiubtbVkm2qrYALwNOpzmJc2pVXTTdqBpJPgJ8FTgoyeYkL5p2TDSV2W/SVGTntbcnTTso4IHAGUkuoPnF+bmq6s1wqR56APDlJOcDXwf+uapOm3JMai2rIVyS1DfLqpKVpL4xyUpSh0yyktQhk6wkdcgkK0kdMslqTkn+bdoxSMuBQ7gkqUNWsppTkpvbf49K8sUkpyb5VpJ3JHluO9/rhUke3K731CRntfOs/kuSB7T33y/J55Kck+Svk3wnyd7tY89rt3Ne+9iK6b1iqRsmWS3ELwMvBx5Oc4XYQ6vqMJrp/n63XefLwBFVdQjNdf2vae9/M83lp4fSXFP/IIAkvwQ8i2Zik0cCW4HnLs3LkZbOPaYdgAZhQ1V9DyDJt4HPtvdfCPxa+/Mq4KPtROC7AJe39z8aeDpAVZ2W5Pr2/qOBXwE2NFMosDvNtIbSsmKS1UK8OAKaAAAAx0lEQVT8ZOTnmZHlGX72GXo38M6qWpfkKOAt7f1zTT+57f6/r6rXLW6oUr/YLtBiuRdwZfvzb43c/2XgNwCS/DrwC+39nweOT3L/9rH7JNlviWKVloxJVovlLcDHknwJGP2eqbcCv57kHJoJpb8H3NR+j9gbaGbzvwD4HM3sW9Ky4hAudSrJrsDWqtqS5FHAe9sTXdLdgj1Zde1BwKlJdgJuB14y5XikJWUlK0kdsicrSR0yyUpSh0yyktQhk6wkdcgkK0kd+v9PUlMlIhBjGwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "l = [24754, 18623, 5923, 12665, 0, 42051]\n", "pred = classifier.predict(list(map(preprocess, X[l])))\n", "prob = classifier.predict_proba(list(map(preprocess, X[l])))\n", "\n", "plt.rcParams[\"figure.figsize\"] = (12, 7)\n", "for i, index in enumerate(l):\n", " plt.subplot(1, len(l), i+1)\n", " plt.title('Truth: %i\\nPredicted: %i' % (y[index], pred[i]))\n", " plt.axis('off')\n", " plt.imshow(preprocess(X[index]).reshape(DIM, DIM), cmap='gray')\n", "plt.show()\n", "\n", "plt.title('Colormap of the predictions probabilities')\n", "plt.imshow(prob.T)\n", "plt.xlabel('image')\n", "plt.ylabel('label')\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, although the overall result is good, in some cases it can make errors.\n", "\n", "On the example above (how convenient!), the classifier sees a 4 instead of a 7 in the last image,\n", "but its second best guess is correct.\n", "Since there is already a 4 in the batch, and since the SVM is really sure about it,\n", "there _should_ be a way to exploit this information.\n", "\n", "The idea I had is to study the _joint distribution_ of the digits in the same patch,\n", "and instead of maximizing each probability, to maximize their product or, equivalently,\n", "the log-likelihood.\n", "Thus, in the above example, the optimal choice would \"reserve\" the label 4 for the first\n", "image because it is more confident than for the last one.\n", "\n", "Note that SVMs do not output probabilities, [the `predict_proba` function is an estimator](http://scikit-learn.org/stable/modules/svm.html#scores-probabilities).\n", "\n", "Now, what if we tested all the possibilities of disjoint labels?\n", "In our case, testing $\\frac{10!}{(10-\\texttt{N_DIGITS})!}$ possibilities is fast, but it wouldn't be possible with more categories, like letters instead of digits for example.\n", "\n", "It turns out that this problem can be formulated as an [assignment problem](https://en.wikipedia.org/wiki/Assignment_problem): we have two categories of objects and we want to find the matching with the smallest cost.\n", "A matching is a set of disjoint pairs ; and the cost is the sum of the weights of the chosen edges. Those two images describe an instance of the problem (without the weights) and a solution:\n", "\n", "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "\n", "\n", "
\n", "\n", "In our problem, the two categories are images and labels, and the weight matrix looks like the image shown above.\n", "\n", "Because the assignment problem is usually formulated with a sum, instead of searching $$\\underset{assignment}{\\operatorname{argmax}} \\prod_i \\mathbb{P}(true_i = assignment_i)$$\n", "we look for $$\\underset{assignment}{\\operatorname{argmin}} - \\sum_i \\log \\mathbb{P}(true_i = assignment_i)$$\n", "\n", "So the cost matrix is $m_{ij} = - \\log \\mathbb{P}(true_i = possible\\_label_j)$.\n", "\n", "The advantage of this formulation is that there are algorithms running in polynomial time (cubic in general).\n", "I can now admit that I used \"Hungarian\" as a buzz word: the [Hungarian algorithm](https://en.wikipedia.org/wiki/Hungarian_algorithm) is a polynomial algorithm for the assignment problem, and it is implemented in the popular library [SciPy](https://www.scipy.org/).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Results" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:53:58.143734Z", "start_time": "2018-06-02T09:53:54.665422Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Without Hungarian trick:\n", "Accuracy for single digit prediction: 95.7%\n", "Accuracy for multiple digits prediction: 73.6%\n", "\n", "With Hungarian trick:\n", "Accuracy for single digit prediction: 97.6%\n", "Accuracy for multiple digits prediction: 85.2%\n" ] } ], "source": [ "from scipy.optimize import linear_sum_assignment\n", "\n", "def predict(patch):\n", " cost_matrix = -classifier.predict_log_proba(patch)\n", " row_ind, col_ind = linear_sum_assignment(cost_matrix)\n", " # not used here\n", " cost = cost_matrix[row_ind, col_ind].sum()\n", " confidence = np.exp(-cost)\n", " return classifier.classes_[col_ind]\n", "\n", "predicted_trick = np.array(list(map(predict, X_test_2)))\n", "\n", "accuracy_trick = (predicted_trick == y_test_2).mean()\n", "accuracy_multi_trick = (predicted_trick == y_test_2).all(axis=1).mean()\n", "\n", "\n", "print('Without Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi))\n", "\n", "print('\\nWith Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy_trick))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi_trick))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try with a bigger value of `N_DIGITS`!" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2018-06-02T09:54:14.055013Z", "start_time": "2018-06-02T09:53:58.151344Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Without Hungarian trick:\n", "Accuracy for single digit prediction: 95.9%\n", "Accuracy for multiple digits prediction: 66.4%\n", "\n", "With Hungarian trick:\n", "Accuracy for single digit prediction: 99.5%\n", "Accuracy for multiple digits prediction: 97.8%\n" ] } ], "source": [ "N_DIGITS = 10\n", "\n", "testing_indexes = np.array([[random.choice(reverse_labels[i])\n", " for i in random.sample(range(10), N_DIGITS)]\n", " for _ in range(TEST2)])\n", "testing_indexes_flat = testing_indexes.flatten()\n", "\n", "X_test_2_flat = np.array(list(map(preprocess, X[testing_indexes_flat])))\n", "X_test_2 = X_test_2_flat.reshape(TEST2, N_DIGITS, DIM * DIM)\n", "y_test_2 = y[testing_indexes]\n", "\n", "predicted = classifier.predict(X_test_2_flat).reshape(TEST2, N_DIGITS)\n", "accuracy = (predicted == y_test_2).mean()\n", "accuracy_multi = (predicted == y_test_2).all(axis=1).mean()\n", "\n", "predicted_trick = np.array(list(map(predict, X_test_2)))\n", "\n", "accuracy_trick = (predicted_trick == y_test_2).mean()\n", "accuracy_multi_trick = (predicted_trick == y_test_2).all(axis=1).mean()\n", "\n", "print('Without Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi))\n", "\n", "print('\\nWith Hungarian trick:')\n", "print('Accuracy for single digit prediction: %.01f%%' % (100 * accuracy_trick))\n", "print('Accuracy for multiple digits prediction: %.01f%%' % (100 * accuracy_multi_trick))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Discussion\n", "\n", "With just 5 lines of code, we achieved a huge accuracy improvement!\n", "\n", "This method applies not just for digits (or even images) or for SVMs.\n", "You can use it for any _batched_ classification task where you know the samples\n", "in a batch have different labels, and for any model that gives probabilities.\n", "\n", "Moreover, the algorithm gives a confidence score (see in `predict`) that can be\n", "used to select a parameter. In my application, I discovered that changing the padding\n", "made a huge difference, but couldn't select it because I didn't have validation samples\n", "from the testing distribution, so my code does a grid search _at testing time_ and\n", "maximizes this confidence.\n", "\n", "In fact, OCR softwares use the same kind of tricks: instead of recognizing each character\n", "separately, they ask the user for the language and use their knowledge of the dictionnary\n", "to boost their predictions. See [here](https://github.com/tesseract-ocr/tesseract/wiki/NeuralNetsInTesseract4.00) how Tesseract uses LSTM neural networks." ] } ], "metadata": { "date": "2018-06-01", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "title": "The Hungarian maximum likelihood trick", "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": false, "nav_menu": { "height": "30px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "skipTitle": true, "threshold": 4, "toc_cell": true, "toc_section_display": "none", "toc_window_display": true, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }