Skip to content

Commit

Permalink
FIX: Update Numba Lecture to Address Deprecation of @jit (#296)
Browse files Browse the repository at this point in the history
* update a section on type inference.

* update lecture to avoid literal box warning

* check the type of the function

* Update lectures/numba.md

Co-authored-by: mmcky <[email protected]>

* reduce redundancy

* further simplifies descriptions

* fix typos

---------

Co-authored-by: mmcky <[email protected]>
  • Loading branch information
HumphreyYang and mmcky authored Dec 14, 2023
1 parent 995c490 commit efe322b
Showing 1 changed file with 118 additions and 60 deletions.
178 changes: 118 additions & 60 deletions lectures/numba.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.4
kernelspec:
display_name: Python 3
display_name: Python 3 (ipykernel)
language: python
name: python3
---
Expand All @@ -26,10 +28,9 @@ kernelspec:

In addition to what's in Anaconda, this lecture will need the following libraries:

```{code-cell} ipython
---
tags: [hide-output]
---
```{code-cell} ipython3
:tags: [hide-output]
!pip install quantecon
```

Expand All @@ -38,7 +39,7 @@ versions are a {doc}`common source of errors <troubleshooting>`.

Let's start with some imports:

```{code-cell} ipython
```{code-cell} ipython3
%matplotlib inline
import numpy as np
import quantecon as qe
Expand Down Expand Up @@ -98,13 +99,13 @@ $$

In what follows we set

```{code-cell} python3
```{code-cell} ipython3
α = 4.0
```

Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis

```{code-cell} python3
```{code-cell} ipython3
def qm(x0, n):
x = np.empty(n+1)
x[0] = x0
Expand All @@ -122,10 +123,10 @@ plt.show()

To speed the function `qm` up using Numba, our first step is

```{code-cell} python3
from numba import jit
```{code-cell} ipython3
from numba import njit
qm_numba = jit(qm)
qm_numba = njit(qm)
```

The function `qm_numba` is a version of `qm` that is "targeted" for
Expand All @@ -135,7 +136,7 @@ We will explain what this means momentarily.

Let's time and compare identical function calls across these two versions, starting with the original function `qm`:

```{code-cell} python3
```{code-cell} ipython3
n = 10_000_000
qe.tic()
Expand All @@ -145,7 +146,7 @@ time1 = qe.toc()

Now let's try qm_numba

```{code-cell} python3
```{code-cell} ipython3
qe.tic()
qm_numba(0.1, int(n))
time2 = qe.toc()
Expand All @@ -156,13 +157,14 @@ This is already a massive speed gain.
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:

(qm_numba_result)=
```{code-cell} python3

```{code-cell} ipython3
qe.tic()
qm_numba(0.1, int(n))
time3 = qe.toc()
```

```{code-cell} python3
```{code-cell} ipython3
time1 / time3 # Calculate speed gain
```

Expand Down Expand Up @@ -194,12 +196,12 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2

The compiled code is then cached and recycled as required.

## Decorators and "nopython" Mode
## Decorator Notation

In the code above we created a JIT compiled version of `qm` via the call

```{code-cell} python3
qm_numba = jit(qm)
```{code-cell} ipython3
qm_numba = njit(qm)
```

In practice this would typically be done using an alternative *decorator* syntax.
Expand All @@ -208,14 +210,12 @@ In practice this would typically be done using an alternative *decorator* syntax

Let's see how this is done.

### Decorator Notation

To target a function for JIT compilation we can put `@jit` before the function definition.
To target a function for JIT compilation we can put `@njit` before the function definition.

Here's what this looks like for `qm`

```{code-cell} python3
@jit
```{code-cell} ipython3
@njit
def qm(x0, n):
x = np.empty(n+1)
x[0] = x0
Expand All @@ -224,15 +224,21 @@ def qm(x0, n):
return x
```

This is equivalent to `qm = jit(qm)`.
This is equivalent to `qm = njit(qm)`.

The following now uses the jitted version:

```{code-cell} python3
qm(0.1, 10)
```{code-cell} ipython3
%%time
qm(0.1, 100_000)
```

### Type Inference and "nopython" Mode
Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html).

In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization.

## Type Inference

Clearly type inference is a key part of JIT compilation.

Expand All @@ -246,29 +252,83 @@ This allows it to generate native machine code, without having to call the Pytho

In such a setting, Numba will be on par with machine code from low-level languages.

When Numba cannot infer all type information, some Python objects are given generic object status and execution falls back to the Python runtime.
When Numba cannot infer all type information, it will raise an error.

When this happens, Numba provides only minor speed gains or none at all.
For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap`

We generally prefer to force an error when this occurs, so we know effective
compilation is failing.
```{code-cell} ipython3
@njit
def bootstrap(data, statistics, n):
bootstrap_stat = np.empty(n)
n = len(data)
for i in range(n_resamples):
resample = np.random.choice(data, size=n, replace=True)
bootstrap_stat[i] = statistics(resample)
return bootstrap_stat
This is done by using either `@jit(nopython=True)` or, equivalently, `@njit` instead of `@jit`.
def mean(data):
return np.mean(data)
For example,
data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2])
n_resamples = 10
```{code-cell} python3
from numba import njit
print('Type of function:', type(mean))
#Error
try:
bootstrap(data, mean, n_resamples)
except Exception as e:
print(e)
```

But Numba recognizes JIT-compiled functions

```{code-cell} ipython3
@njit
def qm(x0, n):
x = np.empty(n+1)
x[0] = x0
for t in range(n):
x[t+1] = 4 * x[t] * (1 - x[t])
return x
def mean(data):
return np.mean(data)
print('Type of function:', type(mean))
%time bootstrap(data, mean, n_resamples)
```

We can check the signature of the JIT-compiled function

```{code-cell} ipython3
bootstrap.signatures
```

The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer.

Now let's see what happens when we change the inputs.

Running it again with a larger integer for `n` and a different set of data does not change the signature of the function.

```{code-cell} ipython3
data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2])
%time bootstrap(data, mean, 100)
bootstrap.signatures
```

As expected, the second run is much faster.

Let's try to change the data again and use an integer array as data

```{code-cell} ipython3
data = np.array([1, 2, 3, 4, 5], dtype=np.int64)
%time bootstrap(data, mean, 100)
bootstrap.signatures
```

Note that a second signature is added.

It also takes longer to run, suggesting that Numba recompiles this function as the type changes.

Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks.

You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html).

## Compiling Classes

As mentioned above, at present Numba can only compile a subset of Python.
Expand All @@ -285,7 +345,7 @@ created in {doc}`this lecture <python_oop>`.

To compile this class we use the `@jitclass` decorator:

```{code-cell} python3
```{code-cell} ipython3
from numba import float64
from numba.experimental import jitclass
```
Expand All @@ -294,11 +354,11 @@ Notice that we also imported something called `float64`.

This is a data type representing standard floating point numbers.

We are importing it here because Numba needs a bit of extra help with types when it trys to deal with classes.
We are importing it here because Numba needs a bit of extra help with types when it tries to deal with classes.

Here's our code:

```{code-cell} python3
```{code-cell} ipython3
solow_data = [
('n', float64),
('s', float64),
Expand Down Expand Up @@ -361,7 +421,7 @@ After that, targeting the class for JIT compilation only requires adding

When we call the methods in the class, the methods are compiled just like functions.

```{code-cell} python3
```{code-cell} ipython3
s1 = Solow()
s2 = Solow(k=8.0)
Expand Down Expand Up @@ -444,25 +504,25 @@ For larger ones, or for routines using external libraries, it can easily fail.

Hence, it's prudent when using Numba to focus on speeding up small, time-critical snippets of code.

This will give you much better performance than blanketing your Python programs with `@jit` statements.
This will give you much better performance than blanketing your Python programs with `@njit` statements.

### A Gotcha: Global Variables

Here's another thing to be careful about when using Numba.

Consider the following example

```{code-cell} python3
```{code-cell} ipython3
a = 1
@jit
@njit
def add_a(x):
return a + x
print(add_a(10))
```

```{code-cell} python3
```{code-cell} ipython3
a = 2
print(add_a(10))
Expand Down Expand Up @@ -492,7 +552,7 @@ Compare speed with and without Numba when the sample size is large.

Here is one solution:

```{code-cell} python3
```{code-cell} ipython3
from random import uniform
@njit
Expand Down Expand Up @@ -581,13 +641,13 @@ We let
- 0 represent "low"
- 1 represent "high"

```{code-cell} python3
```{code-cell} ipython3
p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
```

Here's a pure Python version of the function

```{code-cell} python3
```{code-cell} ipython3
def compute_series(n):
x = np.empty(n, dtype=np.int_)
x[0] = 1 # Start in state 1
Expand All @@ -604,7 +664,7 @@ def compute_series(n):
Let's run this code and check that the fraction of time spent in the low
state is about 0.666

```{code-cell} python3
```{code-cell} ipython3
n = 1_000_000
x = compute_series(n)
print(np.mean(x == 0)) # Fraction of time x is in state 0
Expand All @@ -614,30 +674,28 @@ This is (approximately) the right output.

Now let's time it:

```{code-cell} python3
```{code-cell} ipython3
qe.tic()
compute_series(n)
qe.toc()
```

Next let's implement a Numba version, which is easy

```{code-cell} python3
from numba import jit
compute_series_numba = jit(compute_series)
```{code-cell} ipython3
compute_series_numba = njit(compute_series)
```

Let's check we still get the right numbers

```{code-cell} python3
```{code-cell} ipython3
x = compute_series_numba(n)
print(np.mean(x == 0))
```

Let's see the time

```{code-cell} python3
```{code-cell} ipython3
qe.tic()
compute_series_numba(n)
qe.toc()
Expand Down

0 comments on commit efe322b

Please sign in to comment.