-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
434 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,46 @@ | ||
# fractal-jax | ||
# Fractal Jax | ||
Generate figures of the Julia and Mandelbrot sets with Jax. | ||
|
||
## Install | ||
This package requires Jax - see the [official JAX documentation](https://github.com/google/jax#installation). | ||
``` | ||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
cd mandelbrot-jax | ||
pip install . | ||
``` | ||
|
||
## How to use | ||
|
||
```python | ||
from fractal_jax import FractalJax | ||
|
||
# specify number of iterations, divergence threshold and backend | ||
m = FractalJax(iterations=50, divergence_threshold=2, backend="gpu") | ||
``` | ||
```python | ||
import matplotlib.pyplot as plt | ||
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | ||
ax.imshow( | ||
m.generate_mandelbrot(x_range=[-2, 1], y_range=[-1.5, 1.5], pixel_res=300) | ||
); | ||
``` | ||
![Figure 1](figs/mandelbrot-1.png) | ||
|
||
You can also adjust the region which you care about and the pixel resolution: | ||
```python | ||
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | ||
ax.imshow( | ||
m.generate_mandelbrot(x_range=[-1, -0.9], y_range=[-.3, -.2], pixel_res=30000)) | ||
); | ||
``` | ||
![Figure 2](figs/mandelbrot-2.png) | ||
|
||
This library also allows you to generate figures of Julia sets: | ||
```python | ||
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | ||
ax.imshow(m.generate_julia(-0.5792518264067199 + 0.5448363340450433j, [-1.5, 1.5], [-1.5, 1.5], 300)); | ||
``` | ||
![Figure 2](figs/julia-1.png) | ||
|
||
## Credits | ||
This implementation is based on the analysis made by [jpivarski](https://gist.github.com/jpivarski) in [mandelbrot-on-all-accelerators.ipynb](https://gist.github.com/jpivarski/da343abd8024834ee8c5aaba691aafc7) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
__version_info__ = ('0', '1', '0') | ||
__version__ = '.'.join(__version_info__) | ||
|
||
|
||
from fractal_jax.generator import FractalJax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from typing import Tuple | ||
|
||
import jax | ||
from jax._src.typing import Array | ||
import numpy as np | ||
|
||
|
||
class FractalJax: | ||
"""A class for generating images of Mandelbrot and Julia sets with JAX. | ||
""" | ||
def __init__( | ||
self, | ||
iterations: int, | ||
divergence_threshold: float, | ||
backend: str | ||
): | ||
""" | ||
Attributes | ||
---------- | ||
iterations : int | ||
Number of iteration for computing `z = z^2 + c` | ||
divergence_threshold : int | ||
If z > `divergence_threshold`, we assume divergence to inf | ||
backend : str | ||
Whether to use CPU or GPU for jit | ||
""" | ||
self.iterations = iterations | ||
self.divergence_thershold = divergence_threshold | ||
self.backend = backend | ||
|
||
self._jit_mandelbrot = jax.jit(self._run_mandelbrot_kernel, backend=backend) | ||
self._jit_julia = jax.jit(self._run_julia_kernel, backend=backend) | ||
|
||
def _run_mandelbrot_kernel(self, c: Array, fractal: Array) -> Array: | ||
"""Run z = z^2 + c. | ||
In the Mandelbrot case, c is the point of interest, i.e. the pixel. | ||
""" | ||
z = c | ||
for i in range(self.iterations): | ||
z = z ** 2 + c | ||
diverged = jax.numpy.absolute(z) > self.divergence_thershold | ||
diverging_now = diverged & (fractal == self.iterations) | ||
fractal = jax.numpy.where(diverging_now, i, fractal) | ||
return fractal | ||
|
||
def _run_julia_kernel(self, c: complex, z: Array, fractal: Array) -> Array: | ||
"""Run z = z^2 + c. | ||
In the Julia case, c is a constant. | ||
z_0 is the point of interest, i.e. the pixel. | ||
""" | ||
for i in range(self.iterations): | ||
z = z ** 2 + c | ||
diverged = jax.numpy.absolute(z) > self.divergence_thershold | ||
diverging_now = diverged & (fractal == self.iterations) | ||
fractal = jax.numpy.where(diverging_now, i / self.iterations, fractal) | ||
return fractal | ||
|
||
def generate_mandelbrot(self, x_range: Tuple[int], y_range: Tuple[int], pixel_res: int) -> np.ndarray: | ||
"""Generate the image of a Mandelbrot set. | ||
Parameters | ||
---------- | ||
x_range : Tuple[int] | ||
Min and max on the x-axis in the complex plane | ||
y_range : Tuple[int] | ||
Min and max on the y-axis in the complex plane | ||
pixel_res : int | ||
Pixel resolution for box x- and y-axis | ||
Returns | ||
------- | ||
numpy.ndarray | ||
Image of the generated Mandelbrot set | ||
""" | ||
height = int((y_range[1] - y_range[0]) * pixel_res) | ||
width = int((x_range[1] - x_range[0]) * pixel_res) | ||
y, x = jax.numpy.ogrid[ | ||
y_range[1]:y_range[0]:height * 1j, | ||
x_range[0]:x_range[1]:width * 1j | ||
] | ||
c = x + y * 1j | ||
fractal = jax.numpy.full(c.shape, self.iterations, dtype=jax.numpy.int32) | ||
return np.asarray(self._jit_mandelbrot(c, fractal).block_until_ready()) | ||
|
||
def generate_julia(self, c: complex, x_range: Tuple[int], y_range: Tuple[int], pixel_res: int) -> np.ndarray: | ||
"""Generate the image of a Julia set. | ||
Parameters | ||
---------- | ||
c : complex | ||
The c constant which defines the Julia set | ||
x_range : Tuple[int] | ||
Min and max on the x-axis in the complex plane | ||
y_range : Tuple[int] | ||
Min and max on the y-axis in the complex plane | ||
pixel_res : int | ||
Pixel resolution for box x- and y-axis | ||
Returns | ||
------- | ||
numpy.ndarray | ||
Image of the generated Mandelbrot set | ||
""" | ||
height = int((y_range[1] - y_range[0]) * pixel_res) | ||
width = int((x_range[1] - x_range[0]) * pixel_res) | ||
y, x = jax.numpy.ogrid[ | ||
y_range[1]:y_range[0]:height * 1j, | ||
x_range[0]:x_range[1]:width * 1j | ||
] | ||
z = x + y * 1j | ||
fractal = jax.numpy.full(z.shape, self.iterations, dtype=jax.numpy.int32) | ||
return np.asarray(self._jit_julia(c, z, fractal).block_until_ready()) | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from setuptools import setup, find_packages | ||
|
||
|
||
def description(): | ||
description = ( | ||
"Generate fractals with JAX." | ||
) | ||
return description | ||
|
||
|
||
install_requires = [ | ||
"matplotlib" | ||
] | ||
|
||
|
||
setup( | ||
name='fractal_jax', | ||
version='0.1.0', | ||
description=description(), | ||
author="Filippo Grazioli", | ||
author_email="[email protected]", | ||
install_requires=install_requires, | ||
packages=find_packages(), | ||
) |