Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
flpgrz committed Dec 15, 2022
1 parent 2e7010e commit edb91a5
Show file tree
Hide file tree
Showing 9 changed files with 434 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea/
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,46 @@
# fractal-jax
# Fractal Jax
Generate figures of the Julia and Mandelbrot sets with Jax.

## Install
This package requires Jax - see the [official JAX documentation](https://github.com/google/jax#installation).
```
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
cd mandelbrot-jax
pip install .
```

## How to use

```python
from fractal_jax import FractalJax

# specify number of iterations, divergence threshold and backend
m = FractalJax(iterations=50, divergence_threshold=2, backend="gpu")
```
```python
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(
m.generate_mandelbrot(x_range=[-2, 1], y_range=[-1.5, 1.5], pixel_res=300)
);
```
![Figure 1](figs/mandelbrot-1.png)

You can also adjust the region which you care about and the pixel resolution:
```python
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(
m.generate_mandelbrot(x_range=[-1, -0.9], y_range=[-.3, -.2], pixel_res=30000))
);
```
![Figure 2](figs/mandelbrot-2.png)

This library also allows you to generate figures of Julia sets:
```python
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(m.generate_julia(-0.5792518264067199 + 0.5448363340450433j, [-1.5, 1.5], [-1.5, 1.5], 300));
```
![Figure 2](figs/julia-1.png)

## Credits
This implementation is based on the analysis made by [jpivarski](https://gist.github.com/jpivarski) in [mandelbrot-on-all-accelerators.ipynb](https://gist.github.com/jpivarski/da343abd8024834ee8c5aaba691aafc7)
Binary file added figs/julia-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/mandelbrot-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/mandelbrot-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions fractal_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__version_info__ = ('0', '1', '0')
__version__ = '.'.join(__version_info__)


from fractal_jax.generator import FractalJax
113 changes: 113 additions & 0 deletions fractal_jax/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Tuple

import jax
from jax._src.typing import Array
import numpy as np


class FractalJax:
"""A class for generating images of Mandelbrot and Julia sets with JAX.
"""
def __init__(
self,
iterations: int,
divergence_threshold: float,
backend: str
):
"""
Attributes
----------
iterations : int
Number of iteration for computing `z = z^2 + c`
divergence_threshold : int
If z > `divergence_threshold`, we assume divergence to inf
backend : str
Whether to use CPU or GPU for jit
"""
self.iterations = iterations
self.divergence_thershold = divergence_threshold
self.backend = backend

self._jit_mandelbrot = jax.jit(self._run_mandelbrot_kernel, backend=backend)
self._jit_julia = jax.jit(self._run_julia_kernel, backend=backend)

def _run_mandelbrot_kernel(self, c: Array, fractal: Array) -> Array:
"""Run z = z^2 + c.
In the Mandelbrot case, c is the point of interest, i.e. the pixel.
"""
z = c
for i in range(self.iterations):
z = z ** 2 + c
diverged = jax.numpy.absolute(z) > self.divergence_thershold
diverging_now = diverged & (fractal == self.iterations)
fractal = jax.numpy.where(diverging_now, i, fractal)
return fractal

def _run_julia_kernel(self, c: complex, z: Array, fractal: Array) -> Array:
"""Run z = z^2 + c.
In the Julia case, c is a constant.
z_0 is the point of interest, i.e. the pixel.
"""
for i in range(self.iterations):
z = z ** 2 + c
diverged = jax.numpy.absolute(z) > self.divergence_thershold
diverging_now = diverged & (fractal == self.iterations)
fractal = jax.numpy.where(diverging_now, i / self.iterations, fractal)
return fractal

def generate_mandelbrot(self, x_range: Tuple[int], y_range: Tuple[int], pixel_res: int) -> np.ndarray:
"""Generate the image of a Mandelbrot set.
Parameters
----------
x_range : Tuple[int]
Min and max on the x-axis in the complex plane
y_range : Tuple[int]
Min and max on the y-axis in the complex plane
pixel_res : int
Pixel resolution for box x- and y-axis
Returns
-------
numpy.ndarray
Image of the generated Mandelbrot set
"""
height = int((y_range[1] - y_range[0]) * pixel_res)
width = int((x_range[1] - x_range[0]) * pixel_res)
y, x = jax.numpy.ogrid[
y_range[1]:y_range[0]:height * 1j,
x_range[0]:x_range[1]:width * 1j
]
c = x + y * 1j
fractal = jax.numpy.full(c.shape, self.iterations, dtype=jax.numpy.int32)
return np.asarray(self._jit_mandelbrot(c, fractal).block_until_ready())

def generate_julia(self, c: complex, x_range: Tuple[int], y_range: Tuple[int], pixel_res: int) -> np.ndarray:
"""Generate the image of a Julia set.
Parameters
----------
c : complex
The c constant which defines the Julia set
x_range : Tuple[int]
Min and max on the x-axis in the complex plane
y_range : Tuple[int]
Min and max on the y-axis in the complex plane
pixel_res : int
Pixel resolution for box x- and y-axis
Returns
-------
numpy.ndarray
Image of the generated Mandelbrot set
"""
height = int((y_range[1] - y_range[0]) * pixel_res)
width = int((x_range[1] - x_range[0]) * pixel_res)
y, x = jax.numpy.ogrid[
y_range[1]:y_range[0]:height * 1j,
x_range[0]:x_range[1]:width * 1j
]
z = x + y * 1j
fractal = jax.numpy.full(z.shape, self.iterations, dtype=jax.numpy.int32)
return np.asarray(self._jit_julia(c, z, fractal).block_until_ready())

246 changes: 246 additions & 0 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from setuptools import setup, find_packages


def description():
description = (
"Generate fractals with JAX."
)
return description


install_requires = [
"matplotlib"
]


setup(
name='fractal_jax',
version='0.1.0',
description=description(),
author="Filippo Grazioli",
author_email="[email protected]",
install_requires=install_requires,
packages=find_packages(),
)

0 comments on commit edb91a5

Please sign in to comment.