Faster Python Loops with Numba Jit

by John | December 29, 2023

 

Python is notoriously slow in comparison to compiled languages when it comes to looping over large arrays. This is due to the fact that when python loops over an array there is a lot of overheading interpreting the types and calling the __getitem__ method on each iteration of the loop. Whilst we can avoid most of this in many cases by using vectorization with numpy and pandas, sometimes there is no choice but to have a very large loop. This can cause significant performance issues, and in many cases even lead to the problem becoming infeasible due to time constraints. In this article we will show how to use Numba to achieve significant speedups versus plain Python code. 

 

Jit Compilation


A high level overview of just-in-time (JIT) compilation is that when the program enters run-time, think of this as when you click the run icon in your code editor. What happens with JIT compilation is the functions that are wrapped in the @jit decorator will be compiled in to faster machine code and then used in the program. This means we get the benefits of compiled language speeds and Python's flexibiity. 

 

 

Jit Example

 

As can be seen from the code snippet below, we are running a numerically intensive program with a loop and a nested loop. When we see something like the code below, it is a pretty good bet that Numba will speed things up. 

 

import numpy as np
import matplotlib.pyplot as plt
import time 


def mandelbrot_py(c, max_iter):
    z = 0
    n = 0
    while abs(z) <= 2 and n < max_iter:
        z = z * z + c
        n += 1
    return n

def mandelbrot_set_py(width, height, x_min, x_max, y_min, y_max, max_iter):
    x = np.linspace(x_min, x_max, width)
    y = np.linspace(y_min, y_max, height)
    mandelbrot_data = np.zeros((width, height))

    for i in range(width):
        for j in range(height):
            real = x[i]
            imag = y[j]
            c = complex(real, imag)
            mandelbrot_data[i, j] = mandelbrot_py(c, max_iter)

    return mandelbrot_data 



width, height = 800, 800
x_min, x_max = -2.0, 1.0
y_min, y_max = -1.5, 1.5
max_iter = 256



%timeit mandelbrot_data = mandelbrot_set_py(width, height, x_min, x_max, y_min, y_max, max_iter)

 

16.2 s ± 662 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

 

On my computer it takes approximately 16 seconds to create the set, and since we have gone to all that effort, we might as well take a look at a pretty picture generated by it. 

 

def plot_mandelbrot(mandelbrot_data):
    plt.imshow(mandelbrot_data.T, cmap='hot', extent=(-2, 1, -1, 1))
    plt.xlabel('Real')
    plt.ylabel('Imaginary')
    plt.title('Mandelbrot Set')
    plt.show()


plot_mandelbrot(mandelbrot_data)

 

 

Mandel set image through Python

 

 

Now that we are happy the code is working as expected, we may want to try to speed things up a bit. Let's experiment with the jit decorator. Note that if you haven't yet got numba installed you can do so with a simple 'pip install numba'. 

Note the nopython keyword has been set to false for this example, essentially this means, numba will try to compile the function, if unsuccessful it will fallback to using pure python. Later in this article we will give some examples of errors raised by numba when nopython=True. Clearly if we want to be absolutely sure the function will always run, we should use no python=False. 

from numba import jit

@jit(nopython=False)
def mandelbrot(c, max_iter):
    z = 0
    n = 0
    while abs(z) <= 2 and n < max_iter:
        z = z * z + c
        n += 1
    return n

@jit(nopython=False)
def mandelbrot_set(width, height, x_min, x_max, y_min, y_max, max_iter):
    x = np.linspace(x_min, x_max, width)
    y = np.linspace(y_min, y_max, height)
    mandelbrot_data = np.zeros((width, height))

    for i in range(width):
        for j in range(height):
            real = x[i]
            imag = y[j]
            c = complex(real, imag)
            mandelbrot_data[i, j] = mandelbrot(c, max_iter)

    return mandelbrot_data 

 

 

Ok now let's use the same parameters and test whether or not we get a speedup. 

 

width, height = 800, 800
x_min, x_max = -2.0, 1.0
y_min, y_max = -1.5, 1.5
max_iter = 256

%timeit mandelbrot_data = mandelbrot_set(width, height, x_min, x_max, y_min, y_max, max_iter)

 


488 ms ± 29.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

 

And quite incredibly with just the simple decorator we get over a 30x improvement in the speed of our program!

 

 

However, even in this comparison we are being slightly infair to the numba version as we don't allow it to compile first prior to doing the speed test. For a proper comparison I will restart my kernel and run it again in a slightly different way. So after restarting the kernel we create a new cell and paste the following in:

 

from numba import jit
import numpy as np
import matplotlib.pyplot as plt
import time 

@jit(nopython=False)
def mandelbrot(c, max_iter):
    z = 0
    n = 0
    while abs(z) <= 2 and n < max_iter:
        z = z * z + c
        n += 1
    return n

@jit(nopython=False)
def mandelbrot_set(width, height, x_min, x_max, y_min, y_max, max_iter):
    x = np.linspace(x_min, x_max, width)
    y = np.linspace(y_min, y_max, height)
    mandelbrot_data = np.zeros((width, height))

    for i in range(width):
        for j in range(height):
            real = x[i]
            imag = y[j]
            c = complex(real, imag)
            mandelbrot_data[i, j] = mandelbrot(c, max_iter)

    return mandelbrot_data 




width, height = 800, 800
x_min, x_max = -2.0, 1.0
y_min, y_max = -1.5, 1.5
max_iter = 256


t = time.time() 
x= mandelbrot_set(width, height, x_min, x_max, y_min, y_max, max_iter)
print(f"on the first call it took {time.time()-t} seconds")

 

on the first call it took 2.6250035762786865 seconds

 

And then we do speed-test again:

 

%timeit mandelbrot_set(width, height, x_min, x_max, y_min, y_max, max_iter)

 

393 ms ± 40.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

 

 

The reason for this difference is the compilation, so what is happening is numba compiles the function the first time it is called, and subsequent calls it is using a cached version of the function. If we restart the kernel , or load this function in to a new notebook, then it will have to be compiled again. Also if we modify the function, and load it again then this compile time overhead will be present. However, in the example above even with the initial compile time overhead we still get a very noticeable speed-up!

 

 

 

Njit vs Jit and Errors

Njit is very similar in performance terms (at least with the example above) so we won't show the speed tests for using the njit decorator. Rather, we will show the difference in how strict the two methods are. Recall, we previously had the nopython keyword set to False, when using njit this is true by default, meaning if there is some reason that numba can't compile the function, it will raise an error. 

 

Let's take an example with the following function, for which we know that numba can't compile using the reversed built-in method. 

 

from numba import jit, njit

@jit(nopython=False)
def multiply_list_by(x, y):
    new_list =[]
    
    for val in x:
        new_val = val * y 
        new_list.append(new_val)
    
    return reversed(new_list)


x = [1,2,3,4]
y = 4

res = multiply_list_by(x, y)


print(f"Function compiled no nopython ={True if multiply_list_by.nopython_signatures else False}")

 

Function compiled no nopython =False

 

So basically this means, that numba tried to compile the function in to faster machine code, but it realized that the reversed keyword is not supported and so falls back to using python or 'object mode' as the numba documentation describes it. Let's see what happens when we set nopython=True. 

 

@jit(nopython=True)
def multiply_list_by(x, y):
    new_list =[]
    
    for val in x:
        new_val = val * y 
        new_list.append(new_val)
    
    return reversed(new_list)


x = [1,2,3,4]
y = 4

res = multiply_list_by(x, y)

 

TypingError: Cannot determine Numba type of <class 'type'>

 

This is an error you should get used to if you are new to numba, it essentially means you have messed up somewhere with types that numba can't compile, in this case it is the reversed keyword. 

 

My personal view is it is better to have nopython=True always, because if you are getting the fallback to pure python, maybe you won't realize until it is too late and your code is unexpectantly taking ages, and then it will be a real nightmare sifting through code trying to find the part that is causing the problem. It appears from the warnings that are raised in numba that it will be set to True be default so if you are reading this it is best to check which numba version you are using. I am using 0.58.1 currently. 

 

Broadly speaking, njit and jit from what I have seen produce largely similar results on my computer at least. The only difference being that njit has nopython=True by default and raises the Typing error if something is wrong. 

 

 

Numba TypingError

 

When numba can't compile something it can be very confusing to diagnose exactly what has went wrong, and where the problem lies. So to start, we will give some general advice for using numba before some more examples. 

 

  • Ensure that as far as possible you write your code such that one function does one thing. This is good coding practice in general, but for those of us who generally prefer the one function monolith, we will get heavily penalized to the point of insanity trying to debug numba functions. 
  • Stick to using numpy methods as much as possible, numba plays very nicely with the vast majority of numpy methods. See I believe to be an exhaustive list here. It also seems to play nice with most of Python's math library from what I have seen so far. 
  • Assume that anything that is not on the list above is not supported by numba until proven otherwise. 
  • Write unit tests , this is another general good practice, but becomes more important with numba.

 

 

 

Common Error Example 1

 

Lets say we have some array called nums and we want to apply some sort of mapping to create a new array, let's say for example we want to take each element of an array and add the square of the value to it. We want to do this by calling a helper function within the main function as shown below. 

 

from numba import jit 
import numpy as np 


def helper_function(x):
    return x + x*2 

@jit(nopython=True)
def main_function(nums):
    
    new_array = np.empty(nums.shape[0])
    
    for n in nums:
        
        new_array[n] = helper_function(n)
    
    return new_array


nums = np.arange(0, 10 ,1)


res = main_function(nums)

 

TypingError: Cannot determine Numba type of <class 'function'>

 

The reason for this error is that numba can't determine the type of our helper function and therefore raises an error. Therefore anything that we custom build outside a numba function must also necessarily be a numba function , or numba supported function. So easily fixed, if we know what to look for:

 

from numba import jit 
import numpy as np 

@jit(nopython=True)
def helper_function(x):
    return x + x*2 

@jit(nopython=True)
def main_function(nums):
    
    new_array = np.empty(nums.shape[0])
    
    for n in nums:
        
        new_array[n] = helper_function(n)
    
    return new_array


nums = np.arange(0, 10 ,1)


res = main_function(nums)

 

 

 

 

Common Error Example 2

 

The next example, is much more subtle and is likely to be particularly confusing to those who are not so familiar with strongly typed languages. Take the example below it appears numba is converting our results to integers rather than the desired floating point numbers. 

 

@jit(nopython=True)
def func(nums):
    
    new_array = np.zeros_like(nums)
    
    for n in nums:
        
        new_array[n] = n - np.sqrt(n)
    
    return new_array


nums = np.arange(0, 10 ,1)

print(func(nums))

'''

[0 0 0 1 2 2 3 4 5 6]
'''


'''

Expected

0.0
0.0
0.5857864376269049
1.2679491924311228
2.0
2.76393202250021
3.550510257216822
4.354248688935409
5.17157287525381
6.0
'''

 

This error is one of the worst types to get in that it actually doesn't raise an error when you run it, but the results are very wrong indeed. It can be difficult to debug something like this in a larger program. 

 

The issue is actually quite subtle, and is related to types. First let's check what the nums array dtype is. 

 

print(nums.dtype)

'''

int32

'''

 

So we are passing the numba function an array of integers and telling it to assign memory to a new array of the same type, however,  when we try to populate each element of this new array we are passing in floating point numbers. So the new_array is expecting integers, but we pass floating point numbers. For some reason this doesn't raise an error (I think it really should, although I am sure there is some good reason why it doesn't.) and just auto converts our floating point results in to integers. 

 

And here dear Pythonista we finally understand those looks from C++ developers, half disgust half pity, we get when we tell them we program in Python. 

 

This is easily fixed if we know what to look for. We simply declare the type of data that the new array should expect. 

 

@jit(nopython=True)
def func(nums):
    
    new_array = np.zeros_like(nums, dtype='float64')
    
    for n in nums:
        
        new_array[n] = n - np.sqrt(n)
    
    return new_array


nums = np.arange(0, 10 ,1)

print(func(nums))

'''
[0.         0.         0.58578644 1.26794919 2.         2.76393202
 3.55051026 4.35424869 5.17157288 6.        ]
'''

 

It may also be advisable to simply never use the zeros_like function in numpy as typing errors are so easy to create using this method. Perhaps it is better to use np.empty(nums.shape[0]) rather than the zeros_like function, although the zeros_like is so very useful when dealing with arrays that are not one dimensional, I personally will continue using it and just ensure to write unit & integration tests for everything that uses numba. 

 

 

 

@jit(nopython=True)
def func_mat(mat1, mat2):
    return np.trace(mat1) + np.trace(mat2)

np.random.seed(0)
m1 = np.random.normal(size=100).reshape((10,10))
m2 = np.random.normal(size=100).reshape((10,10))

print(func_mat(m1, m2))

'''
2.888964406767875
'''

 

Since we have tested this in others parts of code also not using numba we are pretty sure the results of this are correct. However, we later try the following:

 

m3 = np.arange(0, 100).reshape((10, 10))
m4 = np.arange(0, 1000, 10).reshape((10,10))

print(func_mat(m3, m4))

 

Which seems to result in 

 

TypingError: No implementation of function Function(<function trace at 0x000001605F3865C0>) found for signature:
 
trace(array(int32, 2d, C))
 
There are 2 candidate implementations:
      - Of which 2 did not match due to:
      Overload in function 'matrix_trace_impl': File: numba\np\linalg.py: Line 2642.
        With argument(s): '(array(int32, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: np.trace() only supported on float and complex arrays.

 

 

Although we are quite lucky here and the error is quite desriptive. When seeing this type of error it is usually a good idea to just try play around with the types until either getting a result or giving up! In this particular case it seems to work just to convert to float. 

 

m3 = np.arange(0, 100).reshape((10, 10)).astype(np.float64)
m4 = np.arange(0, 1000, 10).reshape((10,10)).astype(np.float64)

print(func_mat(m3, m4))

 

 

 


Join the discussion

Share this post with your friends!