From efe322bac7cb4e5c151b1bb28641d6d66a5d065d Mon Sep 17 00:00:00 2001 From: Humphrey Yang <39026988+HumphreyYang@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:42:51 +1100 Subject: [PATCH] FIX: Update Numba Lecture to Address Deprecation of `@jit` (#296) * update a section on type inference. * update lecture to avoid literal box warning * check the type of the function * Update lectures/numba.md Co-authored-by: mmcky * reduce redundancy * further simplifies descriptions * fix typos --------- Co-authored-by: mmcky --- lectures/numba.md | 178 ++++++++++++++++++++++++++++++---------------- 1 file changed, 118 insertions(+), 60 deletions(-) diff --git a/lectures/numba.md b/lectures/numba.md index d9793ef9..6cd40434 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -3,8 +3,10 @@ jupytext: text_representation: extension: .md format_name: myst + format_version: 0.13 + jupytext_version: 1.14.4 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 --- @@ -26,10 +28,9 @@ kernelspec: In addition to what's in Anaconda, this lecture will need the following libraries: -```{code-cell} ipython ---- -tags: [hide-output] ---- +```{code-cell} ipython3 +:tags: [hide-output] + !pip install quantecon ``` @@ -38,7 +39,7 @@ versions are a {doc}`common source of errors `. Let's start with some imports: -```{code-cell} ipython +```{code-cell} ipython3 %matplotlib inline import numpy as np import quantecon as qe @@ -98,13 +99,13 @@ $$ In what follows we set -```{code-cell} python3 +```{code-cell} ipython3 α = 4.0 ``` Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis -```{code-cell} python3 +```{code-cell} ipython3 def qm(x0, n): x = np.empty(n+1) x[0] = x0 @@ -122,10 +123,10 @@ plt.show() To speed the function `qm` up using Numba, our first step is -```{code-cell} python3 -from numba import jit +```{code-cell} ipython3 +from numba import njit -qm_numba = jit(qm) +qm_numba = njit(qm) ``` The function `qm_numba` is a version of `qm` that is "targeted" for @@ -135,7 +136,7 @@ We will explain what this means momentarily. Let's time and compare identical function calls across these two versions, starting with the original function `qm`: -```{code-cell} python3 +```{code-cell} ipython3 n = 10_000_000 qe.tic() @@ -145,7 +146,7 @@ time1 = qe.toc() Now let's try qm_numba -```{code-cell} python3 +```{code-cell} ipython3 qe.tic() qm_numba(0.1, int(n)) time2 = qe.toc() @@ -156,13 +157,14 @@ This is already a massive speed gain. In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory: (qm_numba_result)= -```{code-cell} python3 + +```{code-cell} ipython3 qe.tic() qm_numba(0.1, int(n)) time3 = qe.toc() ``` -```{code-cell} python3 +```{code-cell} ipython3 time1 / time3 # Calculate speed gain ``` @@ -194,12 +196,12 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2 The compiled code is then cached and recycled as required. -## Decorators and "nopython" Mode +## Decorator Notation In the code above we created a JIT compiled version of `qm` via the call -```{code-cell} python3 -qm_numba = jit(qm) +```{code-cell} ipython3 +qm_numba = njit(qm) ``` In practice this would typically be done using an alternative *decorator* syntax. @@ -208,14 +210,12 @@ In practice this would typically be done using an alternative *decorator* syntax Let's see how this is done. -### Decorator Notation - -To target a function for JIT compilation we can put `@jit` before the function definition. +To target a function for JIT compilation we can put `@njit` before the function definition. Here's what this looks like for `qm` -```{code-cell} python3 -@jit +```{code-cell} ipython3 +@njit def qm(x0, n): x = np.empty(n+1) x[0] = x0 @@ -224,15 +224,21 @@ def qm(x0, n): return x ``` -This is equivalent to `qm = jit(qm)`. +This is equivalent to `qm = njit(qm)`. The following now uses the jitted version: -```{code-cell} python3 -qm(0.1, 10) +```{code-cell} ipython3 +%%time + +qm(0.1, 100_000) ``` -### Type Inference and "nopython" Mode +Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html). + +In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization. + +## Type Inference Clearly type inference is a key part of JIT compilation. @@ -246,29 +252,83 @@ This allows it to generate native machine code, without having to call the Pytho In such a setting, Numba will be on par with machine code from low-level languages. -When Numba cannot infer all type information, some Python objects are given generic object status and execution falls back to the Python runtime. +When Numba cannot infer all type information, it will raise an error. -When this happens, Numba provides only minor speed gains or none at all. +For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap` -We generally prefer to force an error when this occurs, so we know effective -compilation is failing. +```{code-cell} ipython3 +@njit +def bootstrap(data, statistics, n): + bootstrap_stat = np.empty(n) + n = len(data) + for i in range(n_resamples): + resample = np.random.choice(data, size=n, replace=True) + bootstrap_stat[i] = statistics(resample) + return bootstrap_stat -This is done by using either `@jit(nopython=True)` or, equivalently, `@njit` instead of `@jit`. +def mean(data): + return np.mean(data) -For example, +data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2]) +n_resamples = 10 -```{code-cell} python3 -from numba import njit +print('Type of function:', type(mean)) + +#Error +try: + bootstrap(data, mean, n_resamples) +except Exception as e: + print(e) +``` +But Numba recognizes JIT-compiled functions + +```{code-cell} ipython3 @njit -def qm(x0, n): - x = np.empty(n+1) - x[0] = x0 - for t in range(n): - x[t+1] = 4 * x[t] * (1 - x[t]) - return x +def mean(data): + return np.mean(data) + +print('Type of function:', type(mean)) + +%time bootstrap(data, mean, n_resamples) +``` + +We can check the signature of the JIT-compiled function + +```{code-cell} ipython3 +bootstrap.signatures +``` + +The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer. + +Now let's see what happens when we change the inputs. + +Running it again with a larger integer for `n` and a different set of data does not change the signature of the function. + +```{code-cell} ipython3 +data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2]) +%time bootstrap(data, mean, 100) +bootstrap.signatures ``` +As expected, the second run is much faster. + +Let's try to change the data again and use an integer array as data + +```{code-cell} ipython3 +data = np.array([1, 2, 3, 4, 5], dtype=np.int64) +%time bootstrap(data, mean, 100) +bootstrap.signatures +``` + +Note that a second signature is added. + +It also takes longer to run, suggesting that Numba recompiles this function as the type changes. + +Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks. + +You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html). + ## Compiling Classes As mentioned above, at present Numba can only compile a subset of Python. @@ -285,7 +345,7 @@ created in {doc}`this lecture `. To compile this class we use the `@jitclass` decorator: -```{code-cell} python3 +```{code-cell} ipython3 from numba import float64 from numba.experimental import jitclass ``` @@ -294,11 +354,11 @@ Notice that we also imported something called `float64`. This is a data type representing standard floating point numbers. -We are importing it here because Numba needs a bit of extra help with types when it trys to deal with classes. +We are importing it here because Numba needs a bit of extra help with types when it tries to deal with classes. Here's our code: -```{code-cell} python3 +```{code-cell} ipython3 solow_data = [ ('n', float64), ('s', float64), @@ -361,7 +421,7 @@ After that, targeting the class for JIT compilation only requires adding When we call the methods in the class, the methods are compiled just like functions. -```{code-cell} python3 +```{code-cell} ipython3 s1 = Solow() s2 = Solow(k=8.0) @@ -444,7 +504,7 @@ For larger ones, or for routines using external libraries, it can easily fail. Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code. -This will give you much better performance than blanketing your Python programs with `@jit` statements. +This will give you much better performance than blanketing your Python programs with `@njit` statements. ### A Gotcha: Global Variables @@ -452,17 +512,17 @@ Here's another thing to be careful about when using Numba. Consider the following example -```{code-cell} python3 +```{code-cell} ipython3 a = 1 -@jit +@njit def add_a(x): return a + x print(add_a(10)) ``` -```{code-cell} python3 +```{code-cell} ipython3 a = 2 print(add_a(10)) @@ -492,7 +552,7 @@ Compare speed with and without Numba when the sample size is large. Here is one solution: -```{code-cell} python3 +```{code-cell} ipython3 from random import uniform @njit @@ -581,13 +641,13 @@ We let - 0 represent "low" - 1 represent "high" -```{code-cell} python3 +```{code-cell} ipython3 p, q = 0.1, 0.2 # Prob of leaving low and high state respectively ``` Here's a pure Python version of the function -```{code-cell} python3 +```{code-cell} ipython3 def compute_series(n): x = np.empty(n, dtype=np.int_) x[0] = 1 # Start in state 1 @@ -604,7 +664,7 @@ def compute_series(n): Let's run this code and check that the fraction of time spent in the low state is about 0.666 -```{code-cell} python3 +```{code-cell} ipython3 n = 1_000_000 x = compute_series(n) print(np.mean(x == 0)) # Fraction of time x is in state 0 @@ -614,7 +674,7 @@ This is (approximately) the right output. Now let's time it: -```{code-cell} python3 +```{code-cell} ipython3 qe.tic() compute_series(n) qe.toc() @@ -622,22 +682,20 @@ qe.toc() Next let's implement a Numba version, which is easy -```{code-cell} python3 -from numba import jit - -compute_series_numba = jit(compute_series) +```{code-cell} ipython3 +compute_series_numba = njit(compute_series) ``` Let's check we still get the right numbers -```{code-cell} python3 +```{code-cell} ipython3 x = compute_series_numba(n) print(np.mean(x == 0)) ``` Let's see the time -```{code-cell} python3 +```{code-cell} ipython3 qe.tic() compute_series_numba(n) qe.toc()