Using Numba in Dash Callbacks without Errors

by John | December 30, 2023


Join the discussion

Share this post with your friends!


 

First we will create a minimal example to show how Numba jit can interfere with dash Callbacks in a strange way. If you are currently using Numba in Dash callbacks and receiving the "Callback failed: the server did not respond" error in your browser, perhaps this article will shed on light on why this is happening and also how to fix it. For those that would like to follow along with this article you can create an empty directory and create the following files:

 

DashNumba/
│
├── __init__.py        
│
├── app_funcs.py       
│
└── app.py            

 

 

In the app_funcs.py file paste the following function in to act as the Numba example. Note this example is extremely contrived, but this is the best I could think of to show the example. 

 

import numba as nb
import numpy as np


@nb.jit(nopython=True)
def transform_data(x, y, z):
    output = np.empty(x.shape[0])

    for i in range(x.shape[0]):
        if x[i] < y[i]:
            val = (y[i] ** 2 + z[i] ** 2) ** 0.5
        else:
            val = (x[i] ** 2 + z[i] ** 2) ** 0.5

        output[i] = val
    return output

 

And in the app.py file 

 

import dash
from dash import html, dcc, callback, Input, Output, State
import pandas as pd
import numpy as np
import plotly.express as px
from app_funcs import transform_data

app = dash.Dash(__name__)

# Initial DataFrame
df = pd.DataFrame(
    {
        "c1": np.random.normal(size=1_000_000),
        "c2": np.random.normal(size=1_000_000),
        "c3": np.random.normal(size=1_000_000),
    }
)

app.layout = html.Div(
    [
        dcc.Store(id="data", data=df.to_json(orient="split")),
        html.H1("My app with numba"),
        html.Div(
            [
                html.Label("Enter number of rows to remove:"),
                dcc.Input(id="row-input", type="number", min=0, max=len(df), step=1),
                html.Button("Update Data", id="update-button"),
            ]
        ),
        dcc.Graph(id="myplot"),
    ]
)


@app.callback(
    Output("myplot", "figure"),
    Input("update-button", "n_clicks"),
    State("row-input", "value"),
    State("data", "data"),
)
def update_graph(n_clicks, n_rows, json_data):
    if n_clicks is None:
        raise dash.exceptions.PreventUpdate

    # Load DataFrame from JSON
    dff = pd.read_json(json_data, orient="split")

    # Remove the last n rows
    if n_rows is not None:
        dff = dff.iloc[:-n_rows]

    # Transform data
    x = dff.c1.to_numpy()
    y = dff.c2.to_numpy()
    z = dff.c3.to_numpy()
    result = transform_data(x, y, z)

    # Create the figure
    x_axis = list(range(len(result)))
    fig = px.line(
        x=x_axis, y=result, labels={"x": "x", "y": "res"}, title="Numba Example"
    )

    return fig


if __name__ == "__main__":
    app.run_server(debug=True)

 

 

When we run this app, the first time we will get the server did not respond error, subsequent calls will result in the desired behavior of plotting our contrived data manipulation example. So what is going on here? Well, by trial and error mostly I have narrowed it down to the fact that when the transform_data function is called the first time, it needs to be compiled by Numba. For some reason the Dash callback treats this as some sort of closure and the server doesn't respond. This issue doesn't seem specfic to Numba, according to the following forum post

 

"One thing that I’ve seen in local development is where the callbacks write some data to the file system, and that triggers the server reloader. Since the server would reload mid-callback, the requests would fail."

 

It makes sense that Numba compiling our function, resulting in some sort of write to the file system of the binary file, will mean that our Dash callback will fail the first time, but also will work the next time since the function is already compiled. 

 

How to fix it? 

 

Well there are a few options to fix it. Let's check out the most simple one first. Modify the app_funcs.py file as follows:

 

import numba as nb
import numpy as np


# @nb.jit(nopython=True)
# def transform_data(x, y, z):
#     output = np.empty(x.shape[0])

#     for i in range(x.shape[0]):
#         if x[i] < y[i]:
#             val = (y[i] ** 2 + z[i] ** 2) ** 0.5
#         else:
#             val = (x[i] ** 2 + z[i] ** 2) ** 0.5

#         output[i] = val
#     return output


@nb.njit((nb.float64[:], nb.float64[:], nb.float64[:]))
def transform_data(x, y, z):
    output = np.empty(x.shape[0])

    for i in range(x.shape[0]):
        if x[i] < y[i]:
            val = (y[i] ** 2 + z[i] ** 2) ** 0.5
        else:
            val = (x[i] ** 2 + z[i] ** 2) ** 0.5

        output[i] = val
    return output

 

What we are doing here, is instructing Numba to use eager compilation essentially this means that the function gets compiled when the app.py loads the script.

 

Using this will mean that the callback doesn't fail when it gets called the first time resulting in that strange error. However, the downside of this method is that now our app load time is also pretty slow, since the function is compiling. In this example it isn't such a big deal since it is only 1 function , but this will get pretty slow pretty quick if we are compiling multiple functions in this way. We can Fix this problem by simply adding the following keyword cache=True to the njit decorator. 

 

@nb.njit((nb.float64[:], nb.float64[:], nb.float64[:]), cache=True)
def transform_data(x, y, z):
    output = np.empty(x.shape[0])

    for i in range(x.shape[0]):
        if x[i] < y[i]:
            val = (y[i] ** 2 + z[i] ** 2) ** 0.5
        else:
            val = (x[i] ** 2 + z[i] ** 2) ** 0.5

        output[i] = val
    return output

 

Now in your __pycache__ folder numba will have stored two new binary files ending in .nbc and .nbi , this means that the functions are now cached and won't have to recompile each time. 

 

 

This is a nice easy solution that should cover the majority of problems that come from using Numba in Dash. But what about those cases when we simply can't wait for it to compile the first time and then load quicker the first time? Well for that we have one last final option aside from using Cython that is! 

 

 

Ahead-of-Time compilation.

 

If all else fails we can always use AOT compilation, It should be noted that this is much more error prone and harder to trace errors etc. 

 

from numba.pycc import CC
import numba as nb
import numpy as np

cc = CC("app_funcs")


@cc.export("transform_data", "f8[:], f8[:], f8[:]")
def transform_data(x, y, z):
    output = np.empty(x.shape[0])

    for i in range(x.shape[0]):
        if x[i] < y[i]:
            val = (y[i] ** 2 + z[i] ** 2) ** 0.5
        else:
            val = (x[i] ** 2 + z[i] ** 2) ** 0.5

        output[i] = val
    return output


if __name__ == "__main__":
    cc.compile()

 

After running the script above, you should be able to see something that looks similar to :

 

app_funcs.cpython-311-x86_64-linux-gnu.so

 

In your file directory, we can import the shared file as normal without making any changes to our code. As mentioned, surely this must be a last resort, as no doubt there is some sort of way to just use the cache method as built in to the Numba decorator, but perhaps if there is still some sort of issue using AOT compilation is always here as a last resort.