Skip to content

Commit

Permalink
Merge pull request #106 from Kohulan/development
Browse files Browse the repository at this point in the history
feat: new input format as numpy array representing the image
  • Loading branch information
Kohulan authored Sep 10, 2024
2 parents b09ac80 + 7034d6f commit 5377ea2
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 106 deletions.
79 changes: 61 additions & 18 deletions DECIMER/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from PIL import ImageEnhance
from pillow_heif import register_heif_opener
from typing import Union

import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder
import DECIMER.Transformer_decoder as Transformer_decoder
Expand Down Expand Up @@ -95,26 +96,68 @@ def HEIF_to_pillow(image_path: str):
return heif_file


def remove_transparent(image_path: str):
def remove_transparent(image: Union[str, np.ndarray]) -> Image.Image:
"""
Removes the transparent layer from a PNG image with an alpha channel
Args: image_path (str): path of input image
Returns: PIL.Image
Removes the transparent layer from a PNG image with an alpha channel.
Args:
image (Union[str, np.ndarray]): Path of the input image or a numpy array representing the image.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
try:
png = Image.open(image_path).convert("RGBA")
except Exception as e:
if type(e).__name__ == "UnidentifiedImageError":
png = HEIF_to_pillow(image_path)
else:
print(e)
raise Exception
def process_image(png: Image.Image) -> Image.Image:
"""
Helper function to remove transparency from a single image.
Args:
png (PIL.Image.Image): The input PIL image with transparency.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
background = Image.new("RGBA", png.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, png)
return alpha_composite

background = Image.new("RGBA", png.size, (255, 255, 255))
def handle_image_path(image_path: str) -> Image.Image:
"""
Helper function to handle image paths.
Args:
image_path (str): The path to the input image.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
try:
png = Image.open(image_path).convert("RGBA")
except Exception as e:
if type(e).__name__ == "UnidentifiedImageError":
png = HEIF_to_pillow(image_path)
else:
print(e)
raise Exception
return process_image(png)

def handle_numpy_array(array: np.ndarray) -> Image.Image:
"""
Helper function to handle a numpy array.
Args:
array (np.ndarray): The numpy array representing the image.
Returns:
PIL.Image.Image: The image with transparency removed.
"""
png = Image.fromarray(array).convert("RGBA")
return process_image(png)

alpha_composite = Image.alpha_composite(background, png)
# Check if input is a numpy array
if isinstance(image, np.ndarray):
return handle_numpy_array(array=image)

return alpha_composite
return handle_image_path(image_path=image)


def get_bnw_image(image):
Expand Down Expand Up @@ -185,12 +228,12 @@ def increase_brightness(image):
return image


def decode_image(image_path: str):
def decode_image(image_path: Union[str, np.ndarray]):
"""Loads an image and preprocesses the input image in several steps to get
the image ready for DECIMER input.
Args:
image_path (str): path of input image
image_path (Union[str, np.ndarray]): path of input image or numpy array representing the image.
Returns:
Processed image
Expand Down Expand Up @@ -237,7 +280,7 @@ def initialize_encoder_config(
backbone_fn (method): Calls Efficient-Net V2 as backbone for encoder
image_shape (int): Shape of the input image
do_permute (bool, optional): . Defaults to False.
pretrained_weights (keras weights, optional): Use pretrainined efficient net weights or not. Defaults to None.
pretrained_weights (keras weights, optional): Use pretrained efficient net weights or not. Defaults to None.
"""
self.encoder_config = dict(
image_embedding_dim=image_embedding_dim,
Expand Down
13 changes: 7 additions & 6 deletions DECIMER/decimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from typing import Tuple

import numpy as np
import pystow
import tensorflow as tf

Expand Down Expand Up @@ -122,19 +123,19 @@ def detokenize_output_add_confidence(


def predict_SMILES(
image_path: str, confidence: bool = False, hand_drawn: bool = False
image_input: [str, np.ndarray], confidence: bool = False, hand_drawn: bool = False
) -> str:
"""Predicts SMILES representation of a molecule depicted in the given image.
Args:
image_path (str): Path of chemical structure depiction image
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn
image_input (str or np.ndarray): Path of chemical structure depiction image or a numpy array representing the image.
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction.
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn.
Returns:
str: SMILES representation of the molecule in the input image, optionally with confidence values
str: SMILES representation of the molecule in the input image, optionally with confidence values.
"""
chemical_structure = config.decode_image(image_path)
chemical_structure = config.decode_image(image_input)

model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))
Expand Down
Loading

0 comments on commit 5377ea2

Please sign in to comment.