Louis Abraham's Home Page

Faster Python with NumPy broadcasting and Numba

19 May 2019

One of the most basic concepts in programming is the loop. Yet, Python is bad at making fast loops.

Loops are especially important in Machine Learning where most algorithms are supposed to be used on big datasets.

Slow code is one of the biggest flaws of Python and fixing this can boost the speed of your algorithms and, even better, your productivity.

Today we are going to see two optimization techniques that can be leveraged to make loops in numeric code really fast:

  1. Array Broadcasting in Numpy
  2. JIT (just-in-time) compilation with Numba

What is Array Broadcasting?

It is a simple technique that you already use every day when you write

import numpy as np

a = np.arange(100)
b = a * 2

Here, NumPy understood that when you write a * 2, you actually want to multiply every element of a by 2.

Array broadcasting allows more complex behaviors, see this example:

>>> a = np.arange(6).reshape(2, 3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> b = np.array([1,2])
>>> b
array([1, 2])
>>> a * b
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (2,3) (2,) 
>>> b = b.reshape(2, 1)
>>> b
array([[1],
       [2]])
>>> a * b
array([[ 0,  1,  2],
       [ 6,  8, 10]])

What the hell just happened?

On the first attempt, we tried the multiplication between two arrays with different numbers of dimensions (2 and 1).

The second time, we gave b a second dimension. When a dimension has only 1 component, it can be automatically extended. Here, just before doing the multiplication, it is as if we wrote:

>>> b = b.repeat(3, axis=1)
>>> b
array([[1, 1, 1],
       [2, 2, 2]])

Broadcasting can also extend multiple components as we are going to see in the examples.

Note that it is possible to activate the broadcasting without calling b.reshape:

>>> a = np.arange(6).reshape(2, 3)
>>> b = np.array([1,2])
>>> b[:, np.newaxis].shape
(2, 1)
>>> b[:, np.newaxis]
array([[1],
       [2]])
>>> a * b[:, np.newaxis]
array([[ 0,  1,  2],
       [ 6,  8, 10]])

You can remember that np.newaxis inserts a $1$ (a new axis) in the shape of an array.

np.newaxis is generally prefered to reshape because you don’t need to write the array sizes and it is more efficient as the array is not copied in memory.

What is Numba?

Numba is just a compiler that takes a subset of the Python language and compiles it to a native function.

Let’s take the simplest example: a function that adds two objects.

from numba import njit, prange

@njit
def f(a, b):
    return a + b

As you can see, Numba applies a decorator to f.

Readers already familiar with Numba will be surprised I did not use jit decorator. njit is equivalent to using jit(nopython=True). The nopython mode ensures that the code is really compiled without using the interpreter anymore. It means that you have to restrict yourself a bit more but if you don’t, you get an error instead of a slow code (which is much easier to correct).

I recommend to always use njit by default and use jit only if you get errors and know you are going to get a slower code.

Now, basic Python data types are going to be converted to Numba data types. All numeric types and arrays of NumPy are supported so you shouldn’t even notice it.

However, you can get surprises like this:

>>> f(2**63-1, 2**63-1)
-2

The problem is that Python’s int is automatically converted to a int64 hence you get the common pitfalls of bounded ints.

Let’s take a look at NumPy datatypes and check the signatures that Numba used:

>>> f(np.array([1., 2., 3.]), np.array([4., 5., 6.]))
array([5., 7., 9.])
>>> f.signatures
[(int64, int64), (array(float64, 1d, C), array(float64, 1d, C))]

The function f has been called and successfully compiled with two different data types: first with two int64, then with a 1-dimensional array of float64 (the C stands for C-style array order but you can ignore it).

But adding two integers or arrays is not very impressive. What makes Numba shine are really loops like in the example.

Note: don’t reimplement linear algebra computations (like np.dot for matrices) in Numba, the Numpy implementation is very optimized and can be called in Numba.

Pros and cons of each method

Array Broadcasting’s pros:

Array Broadcasting’s cons:

Numba’s pros:

Numba’s cons:

A concrete example: Closest centroid

Suppose we have $n + k$ vectors of dimension $d$. The $n$ vectors are data points an the $k$ are centroids. We want to assign to every data point the closest centroid. This example is actually the assignment step of Lloyd’s algorithm for K-means clustering.

Let’s implement this in Python with loops:

def lloyd_simple(data, centroids):
    k, _ = centroids.shape
    return np.array(
        [
            min(range(k), key=lambda i: np.sum((point - centroids[i]) ** 2))
            for point in data
        ]
    )

What if we could remplace the min by some np.argmin?

def distance_to_centroids(point, centroids):
    return np.sum((point[np.newaxis, :] - centroids) ** 2, axis=1)


def lloyd_broadcast_1(data, centroids):
    return np.array(
        [
            np.argmin(distance_to_centroids(point, centroids))
            for point in data
        ]
    )

point[np.newaxis, :] tells that we will need to create multiple instances of the point along the first axis to be able to compare them to the centroids. point[np.newaxis, :] - centroids creates an array of shape (k, d).

Now, what if we can replace even the for loop with a broadcast? We are first going to compute the distance matrix for all data points and all centroids, then do an argmin over the dimension of the centroids.

def all_distances_to_centroids(data, centroids):
    return np.sum((data[:, np.newaxis, :] - centroids[np.newaxis, :, :]) ** 2, axis=2)


def lloyd_broadcast_2(data, centroids):
    return np.argmin(all_distances_to_centroids(data, centroids), axis=1)

We just saw it was possible to simultaneously broadcast over 2 dimensions!

Note that this function uses much more memory as it build an array of shape (n, k, p).

Finally, the Numba implementation:

@njit
def lloyd_numba(data, centroids):
    n, d = data.shape
    k, _ = centroids.shape
    answer = np.empty(n, dtype=np.int64)
    for i in range(n):
        min_dist_i = np.finfo(np.float64).max
        for j in range(k):
            dist_i_j = 0
            for u in range(d):
                dist_i_j += (data[i, u] - centroids[j, u]) ** 2
            if dist_i_j < min_dist_i:
                min_dist_i = dist_i_j
                answer[i] = j
    return answer

and the version parallelized over the data points:

@njit(parallel=True)
def lloyd_numba_parallel(data, centroids):
    n, d = data.shape
    k, _ = centroids.shape
    answer = np.empty(n, dtype=np.int64)
    for i in prange(n):
        min_dist_i = np.finfo(np.float64).max
        for j in range(k):
            dist_i_j = 0
            for u in range(d):
                dist_i_j += (data[i, u] - centroids[j, u]) ** 2
            if dist_i_j < min_dist_i:
                min_dist_i = dist_i_j
                answer[i] = j
    return answer

We changed only two things:

Benchmark

We generate some random data with:

n_small = 20000
n_big = 200000
k = 50
d = 100

data_small = np.random.uniform(size=(n_small, d))
data_big = np.random.uniform(size=(n_big, d))
centroids = np.random.uniform(size=(k, d))

Here are the results of the benchmark. I did them on a 2-threaded colab instance which you can find here with all the code from this article:

Function Compilation Small Big
lloyd_simple 0 s 5.9 s 59 s
lloyd_broadcast_1 0 s 370 ms 3.6 s
lloyd_broadcast_2 0 s 725 ms 7 s
lloyd_numba 58 ms 137 ms 1.4 s
lloyd_numba_parallel 57 ms 84 ms 921 ms

We can see a few things:

Conclusion

Broadcasting is a powerful technique that makes your code orders of magnitude faster. While it allows greater speeds, it comes at the expense of a greater memory usage which you should be careful about — for example you should probably not broadcast over the size of your dataset like we did. It is also the most mathematical and concise way to describe your operations.

Numba is definitely the most flexible and fastest way to implement loops and almost any numeric function in Python. Although it makes the code a bit more complex, it is much easier to deploy than a C extension or Cython. Furthermore (maybe the greatest advantage) it allows to parallelize functions almost instantly and augment the speed even more with multiple processors.