diff --git a/evadb/functions/pytesseract_function.py b/evadb/functions/pytesseract_function.py new file mode 100644 index 0000000000..cc7d08db8c --- /dev/null +++ b/evadb/functions/pytesseract_function.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast + +import numpy as np +import pandas as pd + +from evadb.catalog.catalog_type import NdArrayType +from evadb.functions.abstract.abstract_function import AbstractFunction +from evadb.functions.decorators.decorators import forward, setup +from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe +from evadb.utils.generic_utils import try_to_import_pytesseract + + +class PyTesseractOCRFunction(AbstractFunction): + @property + def name(self) -> str: + return "PyTesseractOCRFunction" + + @setup(cacheable=False, function_type="FeatureExtraction", batchable=False) + def setup( + self, convert_to_grayscale: str, remove_noise: str, tesseract_path: str = None + ) -> None: # type: ignore + try_to_import_pytesseract() + + # set the tesseract engine + pytesseract.pytesseract.tesseract_cmd = tesseract_path + + self.grayscale_flag = convert_to_grayscale + self.remove_noise = remove_noise + + @forward( + input_signatures=[ + PandasDataframe( + columns=["data"], + column_types=[NdArrayType.FLOAT64], + column_shapes=[(None, 3)], + ), + ], + output_signatures=[ + PandasDataframe( + columns=["text"], + column_types=[NdArrayType.STR], + column_shapes=[(None,)], + ) + ], + ) + def forward(self, frames: pd.DataFrame) -> pd.DataFrame: + img_data = np.asarray(frames["data"][0]) + + if ast.literal_eval(self.grayscale_flag): + img_data = cv2.cvtColor(img_data, cv2.COLOR_RGB2GRAY) + + if ast.literal_eval(self.remove_noise): + img_data = cv2.medianBlur(img_data, 5) + + # apply the OCR + text = pytesseract.image_to_string(img_data) + + new_df = {"text": [text]} + + return pd.DataFrame(new_df) diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index d9af319103..f056bc426e 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -502,6 +502,26 @@ def try_to_import_norfair(): ) +def try_to_import_pytesseract(): + try_to_import_cv2() + + try: + import pytesseract + except ImportError: + raise ValueError( + """Could not import pytesseract python package. + Please install it with pip install pytesseract""" + ) + +def is_pytessseract_available(): + try: + try_to_import_pytesseract() + return True + except ValueError: + return False + + + ############################## ## DOCUMENT ############################## diff --git a/setup.py b/setup.py index 3334fa8361..35bf3147d7 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,7 @@ def read(path, encoding="utf-8"): "boto3", # AWS "norfair>=2.2.0", # OBJECT TRACKING "kornia", # SIFT FEATURES + "pytesseract", ] ray_libs = [ diff --git a/test/integration_tests/long/functions/test_pytesseract.py b/test/integration_tests/long/functions/test_pytesseract.py new file mode 100644 index 0000000000..6b8ed6dc28 --- /dev/null +++ b/test/integration_tests/long/functions/test_pytesseract.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from test.util import get_evadb_for_testing +from test.markers import pytesseract_skip_marker + +from evadb.server.command_handler import execute_query_fetch_all + + +class PytesseractTest(unittest.TestCase): + def setUp(self) -> None: + self.evadb = get_evadb_for_testing() + self.evadb.catalog().reset() + + load_image_query = """LOAD IMAGE 'data/ocr/Example.jpg' INTO MyImage;""" + + execute_query_fetch_all(self.evadb, load_image_query) + + def tearDown(self) -> None: + execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyImage;") + + @pytesseract_skip_marker + def test_pytesseract_function(self): + function_name = "PyTesseractOCRFunction" + execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") + + create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name} + IMPL 'evadb/functions/pytesseract_function.py'; + """ + execute_query_fetch_all(self.evadb, create_function_query) + + ocr_query = f"SELECT {function_name}(data) FROM MyImage;" + output_batch = execute_query_fetch_all(self.evadb, ocr_query) + self.assertEqual(1, len(output_batch)) diff --git a/test/markers.py b/test/markers.py index 6fdd2ad3c7..abd1927964 100644 --- a/test/markers.py +++ b/test/markers.py @@ -25,6 +25,7 @@ is_ludwig_available, is_milvus_available, is_pinecone_available, + is_pytessseract_available, is_qdrant_available, is_replicate_available, is_sklearn_available, @@ -112,3 +113,7 @@ stable_diffusion_skip_marker = pytest.mark.skipif( is_replicate_available() is False, reason="requires replicate" ) + +pytesseract_skip_marker = pytest.mark.skipif( + is_pytessseract_available() is False, reason="requires pytesseract" +) \ No newline at end of file