diff --git a/README.md b/README.md index db91850..7be00fa 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,17 @@ Used for https://github.com/glue-viz/glue-jupyter (currently requires latest developer version of bqplot) +## Usage + +### ImageGL + +See https://py.cafe/maartenbreddels/bqplot-image-gl-demo for a demo of the ImageGL widget. + +Preview image: +![preview image](https://py.cafe/preview/maartenbreddels/bqplot-image-gl-demo) + + + # Installation To install use pip: diff --git a/bqplot_image_gl/imagegl.py b/bqplot_image_gl/imagegl.py index 18132fb..f0f026a 100644 --- a/bqplot_image_gl/imagegl.py +++ b/bqplot_image_gl/imagegl.py @@ -1,3 +1,4 @@ +import os import ipywidgets as widgets import bqplot from traittypes import Array @@ -6,10 +7,15 @@ from bqplot.marks import shape from bqplot.traits import array_to_json, array_from_json from bqplot_image_gl._version import __version__ +from .serialize import image_data_serialization __all__ = ['ImageGL', 'Contour'] +# can be 'png', 'webp' or 'none' +DEFAULT_IMAGE_DATA_COMPRESSION = os.environ.get("BQPLOT_IMAGE_GL_IMAGE_DATA_COMPRESSION", "none") + + @widgets.register class ImageGL(bqplot.Mark): """An example widget.""" @@ -24,7 +30,8 @@ class ImageGL(bqplot.Mark): scaled=True, rtype='Color', atype='bqplot.ColorAxis', - **array_serialization) + **image_data_serialization) + compression = Unicode(DEFAULT_IMAGE_DATA_COMPRESSION, allow_none=True).tag(sync=True) interpolation = Unicode('nearest', allow_none=True).tag(sync=True) opacity = Float(1.0).tag(sync=True) x = Array(default_value=(0, 1)).tag(sync=True, scaled=True, diff --git a/bqplot_image_gl/serialize.py b/bqplot_image_gl/serialize.py new file mode 100644 index 0000000..e93f5c5 --- /dev/null +++ b/bqplot_image_gl/serialize.py @@ -0,0 +1,94 @@ +from PIL import Image +import numpy as np +import io + +from bqplot.traits import array_serialization + + +def array_to_image_or_array(array, widget): + if widget.compression in ["png", "webp"]: + return array_to_image(array, widget.compression) + else: + return array_serialization["to_json"](array, widget) + + +def not_implemented(image): + # the widget never sends the image data back to the kernel + raise NotImplementedError("deserializing is not implemented yet") + + +def array_to_image(array, image_format): + # convert the array to a png image with intensity values only + # array = np.array(array) + min, max = None, None + use_colormap = False + if array.ndim == 2: + use_colormap = True + min = np.nanmin(array) + max = np.nanmax(array) + + array = (array - min) / (max - min) + array_bytes = (array * 255).astype(np.uint8) + intensity_image = Image.fromarray(array_bytes, mode="L") + + # create a mask image with 0 for NaN values and 255 for valid values + isnan = ~np.isnan(array) + mask = (isnan * 255).astype(np.uint8) + mask_image = Image.fromarray(mask, mode="L") + + # merge the intensity and mask image into a single image + image = Image.merge("LA", (intensity_image, mask_image)) + else: + # if floats, convert to uint8 + if array.dtype.kind == "f": + array_bytes = (array * 255).astype(np.uint8) + elif array.dtype == np.uint8: + array_bytes = array + else: + raise ValueError( + "Only float arrays or uint8 arrays are supported, your array has dtype" + "{array.dtype}" + ) + if array.shape[2] == 3: + image = Image.fromarray(array_bytes, mode="RGB") + elif array.shape[2] == 4: + image = Image.fromarray(array_bytes, mode="RGBA") + else: + raise ValueError( + "Only 2D arrays or 3D arrays with 3 or 4 channels are supported, " + f"your array has shape {array.shape}" + ) + + # and serialize it to a PNG + png_data = io.BytesIO() + image.save(png_data, format=image_format, lossless=True) + png_bytes = png_data.getvalue() + original_byte_length = array.nbytes + uint8_byte_length = array_bytes.nbytes + compressed_byte_length = len(png_bytes) + return { + "type": "image", + "format": image_format, + "use_colormap": use_colormap, + "min": min, + "max": max, + "data": png_bytes, + # this metadata is only useful/needed for debugging + "shape": array.shape, + "info": { + "original_byte_length": original_byte_length, + "uint8_byte_length": uint8_byte_length, + "compressed_byte_length": compressed_byte_length, + "compression_ratio": original_byte_length / compressed_byte_length, + "MB": { + "original": original_byte_length / 1024 / 1024, + "uint8": uint8_byte_length / 1024 / 1024, + "compressed": compressed_byte_length / 1024 / 1024, + }, + }, + } + + +image_data_serialization = dict( + to_json=array_to_image_or_array, from_json=not_implemented +) diff --git a/js/lib/contour.js b/js/lib/contour.js index ef5458c..b0f2f0b 100644 --- a/js/lib/contour.js +++ b/js/lib/contour.js @@ -34,19 +34,40 @@ class ContourModel extends bqplot.MarkModel { this.update_data(); } - update_data() { + async update_data() { const image_widget = this.get('image'); const level = this.get('level') // we support a single level or multiple this.thresholds = Array.isArray(level) ? level : [level]; if(image_widget) { const image = image_widget.get('image') - this.width = image.shape[1]; - this.height = image.shape[0]; + let data = null; + if(image.image) { + const imageNode = image.image; + this.width = imageNode.width; + this.height = imageNode.height; + // conver the image to a typed array using canvas + const canvas = document.createElement('canvas'); + canvas.width = this.width + canvas.height = this.height + const ctx = canvas.getContext('2d'); + ctx.drawImage(imageNode, 0, 0); + const imageData = ctx.getImageData(0, 0, imageNode.width, imageNode.height); + const {min, max} = image; + // use the r channel as the data, and scale to the range + data = new Float32Array(imageData.data.length / 4); + for(var i = 0; i < data.length; i++) { + data[i] = (imageData.data[i*4] / 255) * (max - min) + min; + } + } else { + this.width = image.shape[1]; + this.height = image.shape[0]; + data = image.data; + } this.contours = this.thresholds.map((threshold) => d3contour .contours() .size([this.width, this.height]) - .contour(image.data, [threshold]) + .contour(data, [threshold]) ) } else { this.width = 1; // precomputed contour_lines will have to be in normalized diff --git a/js/lib/imagegl.js b/js/lib/imagegl.js index 5118cb3..3e1a01e 100644 --- a/js/lib/imagegl.js +++ b/js/lib/imagegl.js @@ -38,9 +38,24 @@ class ImageGLModel extends bqplot.MarkModel { super.initialize(attributes, options); this.on_some_change(['x', 'y'], this.update_data, this); this.on_some_change(["preserve_domain"], this.update_domains, this); + this.listenTo(this, "change:image", () => { + const previous = this.previous("image"); + if(previous.image && previous.image.src) { + URL.revokeObjectURL(previous.image.src); + } + }, this); + this.update_data(); } + close(comm_closed) { + const image = this.get("image"); + if(image.image && image.image.src) { + URL.revokeObjectURL(previous.image.src); + } + return super.close(comm_closed); + } + update_data() { this.mark_data = { x: this.get("x"), y: this.get("y") @@ -79,9 +94,24 @@ ImageGLModel.serializers = Object.assign({}, bqplot.MarkModel.serializers, { x: serialize.array_or_json, y: serialize.array_or_json, image: { - deserialize: (obj, manager) => { - let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape}; - return jupyter_dataserializers.JSONToArray(state); + deserialize: async (obj, manager) => { + if(obj.type == "image") { + // the data is encoded in an image with LA format + // luminance for the intensity, alpha for the mask + let image = new Image(); + const blob = new Blob([obj.data], {type: `image/${obj.format}`}); + const url = URL.createObjectURL(blob); + image.src = url; + await new Promise((resolve, reject) => { + image.onload = resolve; + image.onerror = reject; + } ); + return {image, min: obj.min, max: obj.max, use_colormap: obj.use_colormap}; + } else { + // otherwise just a 'normal' ndarray + let state = {buffer: obj.value, dtype: obj.dtype, shape: obj.shape}; + return jupyter_dataserializers.JSONToArray(state); + } }, serialize: (ar) => { const {buffer, dtype, shape} = jupyter_dataserializers.arrayToJSON(ar); @@ -114,6 +144,10 @@ class ImageGLView extends bqplot.Mark { // basically the corners of the image image_domain_x : { type: "2f", value: [0.0, 1.0] }, image_domain_y : { type: "2f", value: [0.0, 1.0] }, + // in the case we use an image for the values, the image is normalized, and we need to scale + // it back to a particular image range + // This needs to be set to [0, 1] for array data (which is not normalized) + range_image : { type: "2f", value: [0.0, 1.0] }, // extra opacity value opacity: {type: 'f', value: 1.0} }, @@ -280,39 +314,56 @@ class ImageGLView extends bqplot.Mark { update_image(skip_render) { var image = this.model.get("image"); var type = null; - var data = image.data; - if(data instanceof Uint8Array) { - type = THREE.UnsignedByteType; - } else if(data instanceof Float64Array) { - console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).'); - data = Float32Array.from(data); - type = THREE.FloatType; - } else if(data instanceof Float32Array) { - type = THREE.FloatType; - } else { - console.error('only types uint8 and float32 are supported'); - return; - } - if(this.scales.image.model.get('scheme') && image.shape.length == 2) { - if(this.texture) + if(image.image) { + // the data is encoded in an image with LA format + if(this.texture) { this.texture.dispose(); - this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type); + } + this.texture = new THREE.Texture(image.image); this.texture.needsUpdate = true; + this.texture.flipY = false; this.image_material.uniforms.image.value = this.texture; - this.image_material.defines.USE_COLORMAP = true; + this.image_material.defines.USE_COLORMAP = image.use_colormap; this.image_material.needsUpdate = true; - } else if(image.shape.length == 3) { - this.image_material.defines.USE_COLORMAP = false; - if(this.texture) - this.texture.dispose(); - if(image.shape[2] == 3) - this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type); - if(image.shape[2] == 4) - this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type); - this.texture.needsUpdate = true; - this.image_material.uniforms.image.value = this.texture; + this.image_material.uniforms.range_image.value = [image.min, image.max]; } else { - console.error('image data not understood'); + // we are not dealing with an image, but with an array + // which is not normalized, so we can reset the range_image + this.image_material.uniforms.range_image.value = [0, 1]; + var data = image.data; + if(data instanceof Uint8Array) { + type = THREE.UnsignedByteType; + } else if(data instanceof Float64Array) { + console.warn('ImageGLView.data is a Float64Array which WebGL does not support, will convert to a Float32Array (consider sending float32 data for better performance).'); + data = Float32Array.from(data); + type = THREE.FloatType; + } else if(data instanceof Float32Array) { + type = THREE.FloatType; + } else { + console.error('only types uint8 and float32 are supported'); + return; + } + if(this.scales.image.model.get('scheme') && image.shape.length == 2) { + if(this.texture) + this.texture.dispose(); + this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.LuminanceFormat, type); + this.texture.needsUpdate = true; + this.image_material.uniforms.image.value = this.texture; + this.image_material.defines.USE_COLORMAP = true; + this.image_material.needsUpdate = true; + } else if(image.shape.length == 3) { + this.image_material.defines.USE_COLORMAP = false; + if(this.texture) + this.texture.dispose(); + if(image.shape[2] == 3) + this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBFormat, type); + if(image.shape[2] == 4) + this.texture = new THREE.DataTexture(data, image.shape[1], image.shape[0], THREE.RGBAFormat, type); + this.texture.needsUpdate = true; + this.image_material.uniforms.image.value = this.texture; + } else { + console.error('image data not understood'); + } } this.texture.magFilter = interpolations[this.model.get('interpolation')]; this.texture.minFilter = interpolations[this.model.get('interpolation')]; diff --git a/js/shaders/image-fragment.glsl b/js/shaders/image-fragment.glsl index fd98496..26b0558 100644 --- a/js/shaders/image-fragment.glsl +++ b/js/shaders/image-fragment.glsl @@ -16,6 +16,9 @@ uniform vec2 domain_y; uniform vec2 image_domain_x; uniform vec2 image_domain_y; +uniform vec2 range_image; + + bool isnan(float val) { return (val < 0.0 || 0.0 < val || val == 0.0) ? false : true; @@ -32,7 +35,10 @@ void main(void) { float y_normalized = scale_transform_linear(y_domain_value, vec2(0., 1.), image_domain_y); vec2 tex_uv = vec2(x_normalized, y_normalized); #ifdef USE_COLORMAP - float raw_value = texture2D(image, tex_uv).r; + // r (or g or b) is used for the value, alpha for the mask (is 0 if a nan is found) + vec2 pixel_value = texture2D(image, tex_uv).ra; + float raw_value = pixel_value[0] * (range_image[1] - range_image[0]) + range_image[0]; + float opacity_image = pixel_value[1]; float value = (raw_value - color_min) / (color_max - color_min); vec4 color; if(isnan(value)) // nan's are interpreted as missing values, and 'not shown' @@ -41,8 +47,9 @@ void main(void) { color = texture2D(colormap, vec2(value, 0.5)); #else vec4 color = texture2D(image, tex_uv); + float opacity_image = 1.0; #endif // since we're working with pre multiplied colors (regarding blending) // we also need to multiply rgb by opacity - gl_FragColor = color * opacity; + gl_FragColor = color * opacity * opacity_image; } diff --git a/setup.py b/setup.py index 0a1369e..99b0bed 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,8 @@ include_package_data=True, install_requires=[ 'ipywidgets>=7.0.0', - 'bqplot>=0.12' + 'bqplot>=0.12', + 'pillow', ], packages=find_packages(), zip_safe=False, diff --git a/tests/ui/contour_test.py b/tests/ui/contour_test.py index 82d3453..a0b6345 100644 --- a/tests/ui/contour_test.py +++ b/tests/ui/contour_test.py @@ -1,3 +1,5 @@ +from pathlib import Path +import pytest import ipywidgets as widgets import playwright.sync_api from IPython.display import display @@ -6,7 +8,8 @@ from bqplot_image_gl import ImageGL, Contour -def test_widget_image(solara_test, page_session: playwright.sync_api.Page, assert_solara_snapshot): +@pytest.mark.parametrize("compression", ["png", "none"]) +def test_widget_image(solara_test, page_session: playwright.sync_api.Page, assert_solara_snapshot, compression, request): scale_x = LinearScale(min=0, max=1) scale_y = LinearScale(min=0, max=1) @@ -23,7 +26,7 @@ def test_widget_image(solara_test, page_session: playwright.sync_api.Page, asser X, Y = np.meshgrid(x, y) data = 5. * np.sin(2 * np.pi * (X + Y**2)) - image = ImageGL(image=data, scales=scales_image) + image = ImageGL(image=data, scales=scales_image, compression=compression) contour = Contour(image=image, level=[2, 4], scales=scales_image) @@ -34,5 +37,8 @@ def test_widget_image(solara_test, page_session: playwright.sync_api.Page, asser svg = page_session.locator(".bqplot") svg.wait_for() - # page_session.wait_for_timeout(1000) + page_session.wait_for_timeout(100) + # although the contour is almost the same, due to precision issues, the image is slightly different + # therefore unlike the image_test, we use a different testname/image name based on the fixture value + # for compression assert_solara_snapshot(svg.screenshot()) diff --git a/tests/ui/image_test.py b/tests/ui/image_test.py index fbd79d0..fb6ea16 100644 --- a/tests/ui/image_test.py +++ b/tests/ui/image_test.py @@ -1,12 +1,16 @@ +from pathlib import Path +import pytest import ipywidgets as widgets import playwright.sync_api from IPython.display import display +import numpy as np -def test_widget_image(ipywidgets_runner, page_session: playwright.sync_api.Page, assert_solara_snapshot): +@pytest.mark.parametrize("compression", ["png", "none"]) +def test_widget_image(ipywidgets_runner, page_session: playwright.sync_api.Page, assert_solara_snapshot, compression, request): - def kernel_code(): + def kernel_code(compression=compression): import numpy as np from bqplot import Figure, LinearScale, Axis, ColorScale from bqplot_image_gl import ImageGL @@ -21,15 +25,61 @@ def kernel_code(): scales_image = {"x": scale_x, "y": scale_y, "image": ColorScale(min=0, max=2)} data = np.array([[0., 1.], [2., 3.]]) - image = ImageGL(image=data, scales=scales_image) + image = ImageGL(image=data, scales=scales_image, compression=compression) figure.marks = (image,) display(figure) - ipywidgets_runner(kernel_code) + ipywidgets_runner(kernel_code, locals=dict(compression=compression)) svg = page_session.locator(".bqplot") svg.wait_for() # make sure the image is rendered page_session.wait_for_timeout(100) - assert_solara_snapshot(svg.screenshot()) + # we don't want the compression fixure in the testname, because all screenshots should be the same + testname = f"{str(Path(request.node.name))}".replace("[", "-").replace("]", "").replace(" ", "-").replace(",", "-") + testname = testname.replace(f"-{compression}", "") + assert_solara_snapshot(svg.screenshot(), testname=testname) + + + +@pytest.mark.parametrize("compression", ["png", "none"]) +@pytest.mark.parametrize("dtype", [np.uint8, np.float32]) +def test_widget_image_rgba(solara_test, page_session: playwright.sync_api.Page, assert_solara_snapshot, compression, request, dtype): + + import numpy as np + from bqplot import Figure, LinearScale, Axis, ColorScale + from bqplot_image_gl import ImageGL + scale_x = LinearScale(min=0, max=1) + scale_y = LinearScale(min=0, max=1) + scales = {"x": scale_x, "y": scale_y} + axis_x = Axis(scale=scale_x, label="x") + axis_y = Axis(scale=scale_y, label="y", orientation="vertical") + + figure = Figure(scales=scales, axes=[axis_x, axis_y]) + + scales_image = {"x": scale_x, "y": scale_y, "image": ColorScale(min=0, max=2)} + + # four pixels, red, green, blue, and transparent + red = [255, 0, 0, 255] + green = [0, 255, 0, 255] + blue = [0, 0, 255, 255] + transparent = [0, 0, 0, 0] + data = np.array([[red, green], [blue, transparent]], dtype=dtype) + if dtype == np.float32: + data = data / 255. + image = ImageGL(image=data, scales=scales_image, compression=compression) + + figure.marks = (image,) + + display(figure) + + svg = page_session.locator(".bqplot") + svg.wait_for() + # make sure the image is rendered + page_session.wait_for_timeout(100) + # we don't want the compression or dtype fixure in the testname, because all screenshots should be the same + testname = f"{str(Path(request.node.name))}".replace("[", "-").replace("]", "").replace(" ", "-").replace(",", "-") + testname = testname.replace(f"-{compression}", "") + testname = testname.replace(f"-{dtype.__name__}", "") + assert_solara_snapshot(svg.screenshot(), testname=testname) diff --git a/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-linux-reference.png b/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-linux-reference.png deleted file mode 100644 index 5adf92d..0000000 Binary files a/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-linux-reference.png and /dev/null differ diff --git a/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-none-linux-reference.png b/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-none-linux-reference.png new file mode 100644 index 0000000..3f8c13b Binary files /dev/null and b/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-none-linux-reference.png differ diff --git a/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-png-linux-reference.png b/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-png-linux-reference.png new file mode 100644 index 0000000..3f8c13b Binary files /dev/null and b/tests/ui/snapshots/tests/ui/contour_test.py/test_widget_image-chromium-png-linux-reference.png differ diff --git a/tests/ui/snapshots/tests/ui/image_test.py/test_widget_image_rgba-chromium-darwin-reference.png b/tests/ui/snapshots/tests/ui/image_test.py/test_widget_image_rgba-chromium-darwin-reference.png new file mode 100644 index 0000000..8a87119 Binary files /dev/null and b/tests/ui/snapshots/tests/ui/image_test.py/test_widget_image_rgba-chromium-darwin-reference.png differ diff --git a/tests/ui/snapshots/tests/ui/image_test.py/test_widget_image_rgba-chromium-linux-reference.png b/tests/ui/snapshots/tests/ui/image_test.py/test_widget_image_rgba-chromium-linux-reference.png new file mode 100644 index 0000000..1357077 Binary files /dev/null and b/tests/ui/snapshots/tests/ui/image_test.py/test_widget_image_rgba-chromium-linux-reference.png differ