Problems Using Scipy Stats Functions in Numba

by John | December 31, 2023

 

This post is directed towards those who would like to use statistical functions within Numba njit functions. We will also use this as an opportunity to demonstrate unit testing numba functions. 

 

First let's take a quick example of what the problem is, we have created some sort of script using numba for a simulation or something like that. 

 

from scipy.stats import norm 
import numpy as np
import numba as nb 



@nb.njit
def my_numba_func(arr):
    
    cdfs = np.empty_like(arr)
    
    for i in range(arr.shape[0]):
        
        ### some transformation 
        x = arr[i] 
        cdfs[i] = norm.cdf(x)
    return cdfs
    
    


arr = np.random.normal(size=1000)

res = my_numba_func(arr) 


'''

TypingError: Cannot determine Numba type of <class 'scipy.stats._continuous_distns.norm_gen'>

'''

 

And we find that scipy stats isn't supported. At first this seemed quite strange, because I, at least have always thought of Scipy and Numpy as essentially the same thing. So what gives? Why can't we use the normal distribution functions? We might even check the Scipy source code, and see that they are more or less only using numpy which is so heavily supported. 

 

The issue here is that although Scipy is using Numpy , and Numba supports Numpy , Numba it is unable to determine the type of the Python wrapper around the numpy code. Here is quick example for those that are interested. 

 

def helper(x):
    return x**2 


@nb.njit 
def main(x):
    return x + helper(x)

main(5)

'''
TypingError: Cannot determine Numba type of <class 'function'>
'''



 

So here in this comparative example, helper is to norm.cdf what main is to my_numba_func we made at the top of the page. 

 

@nb.njit
def helper(x):
    return x**2 

 

Everything would have worked, no problem if the code was structured as above. 

 

So what are our Options in this Scenario?

 

  • Close untitled99.py and go for walk 
  • Copy and paste the Scipy source code in to editor and wrap the functions in a njit decorator. Not advisable. 
  • We write the functions with chatgpt ourselves, wrap in njit decorator and then test against the values generated by scipy to make sure they are correct. 

 

We decide on doing option 3 as we like using numba too much to give up at first hurdle. 

 

import math 
import numpy 

@nb.njit()
def norm_pdf(x):
    return (1.0 / math.sqrt(2 * math.pi)) * np.exp(-0.5 * x**2)

@nb.njit()
def norm_cdf(x):
    return 0.5 * (1.0 + math.erf(x / math.sqrt(2.0))) 



vals = np.linspace(-4, 4, 1000)

pdf_us = np.array([norm_pdf(v) for v in vals])

cdf_us = np.array([norm_cdf(v) for v in vals])

print(f'PDF values are same = {np.allclose(pdf_us, norm.pdf(vals))}')
print(f'CDF values are the same {np.allclose(cdf_us, norm.cdf(vals))}')

'''

PDF values are same = True
CDF values are the same True

'''

 

 

 

 

Making Unit Tests for Numba function

 

When we make functions in Numba it is a really good idea, to have some sort of automated testing put in place. Because we might modify the code, and then it doesn't work and can be hard to find out where the problem is, this is especially important in Numba, because the error tracebacks aren't as descriptive as in normal python or numpy. 

 

We have a directory that looks as follows:

 

Project/
│
├── __init__.py
│
├── main.py
│
├── norm.py
│
└── test.py

 

In test.py we have 

 

from scipy.stats import norm
import unittest
import numpy as np
from norm import norm_cdf, norm_pdf

from numpy.testing import assert_almost_equal


class TestMyNumbaFuncs(unittest.TestCase):
    def test_norm_pdf(self):
        vals = np.linspace(-4, 4, 1000)
        pdf_us = np.array([norm_pdf(v) for v in vals])
        pdf_scipy = norm.pdf(vals)
        assert_almost_equal(pdf_us, pdf_scipy, decimal=5)

    def test_norm_cdf(self):
        vals = np.linspace(-4, 4, 1000)
        cdf_us = np.array([norm_cdf(v) for v in vals])
        cdf_scipy = norm.cdf(vals)

        assert_almost_equal(cdf_us, cdf_scipy, decimal=5)


if __name__ == "__main__":
    unittest.main()

 

Now this is very nice, we can feel free to modify the functions, or change the decorator arguments etc, once we have made a change we should always just go to the test.py file and run it, if everything is ok we get a nice output message that looks like 

 

..
----------------------------------------------------------------------
Ran 2 tests in 1.625s

OK

 

 

 

 

 

 


Join the discussion

Share this post with your friends!