diff --git a/evadb/binder/function_expression_binder.py b/evadb/binder/function_expression_binder.py index bbc0f6cc5..19b80630b 100644 --- a/evadb/binder/function_expression_binder.py +++ b/evadb/binder/function_expression_binder.py @@ -30,10 +30,8 @@ from evadb.expression.tuple_value_expression import TupleValueExpression from evadb.parser.types import FunctionType from evadb.third_party.huggingface.binder import assign_hf_function -from evadb.utils.generic_utils import ( - load_function_class_from_file, - string_comparison_case_insensitive, -) +from evadb.utils.generic_utils import string_comparison_case_insensitive +from evadb.utils.load_function_class_from_file import load_function_class_from_file from evadb.utils.logging_manager import logger diff --git a/evadb/catalog/catalog_type.py b/evadb/catalog/catalog_type.py index 5da568779..7dd271921 100644 --- a/evadb/catalog/catalog_type.py +++ b/evadb/catalog/catalog_type.py @@ -73,6 +73,23 @@ class NdArrayType(EvaDBEnum): DATETIME # noqa: F821 ANYTYPE # noqa: F821 + @classmethod + def from_python_type(cls, t): + from decimal import Decimal + + if t == int: + return cls.INT64 + elif t == str: + return cls.STR + elif t == bool: + return cls.BOOL + elif t == float: + return cls.FLOAT64 + elif t == Decimal: + return cls.DECIMAL + else: + return cls.ANYTYPE + @classmethod def to_numpy_type(cls, t): from decimal import Decimal diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index f68679418..f33445a2c 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -45,7 +45,6 @@ from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries from evadb.utils.errors import FunctionIODefinitionError from evadb.utils.generic_utils import ( - load_function_class_from_file, string_comparison_case_insensitive, try_to_import_flaml_automl, try_to_import_ludwig, @@ -54,6 +53,7 @@ try_to_import_torch, try_to_import_ultralytics, ) +from evadb.utils.load_function_class_from_file import load_function_class_from_file from evadb.utils.logging_manager import logger diff --git a/evadb/functions/helpers/udf.py b/evadb/functions/helpers/udf.py new file mode 100644 index 000000000..106744147 --- /dev/null +++ b/evadb/functions/helpers/udf.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pandas as pd + +from evadb.catalog.catalog_type import NdArrayType +from evadb.functions.abstract.abstract_function import AbstractFunction +from evadb.functions.decorators.decorators import forward, setup +from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe + + +class UserDefinedFunction(AbstractFunction): + """ + Arguments: + + Input Signatures: + id (int) + + Output Signatures: + output (int) + """ + + @property + def name(self) -> str: + return self._func.__name__ + + @setup(cacheable=True, batchable=True) + def setup(self) -> None: + import inspect + + sig = inspect.signature(self._func) + params = sig.parameters + # assert that all params have a type annotation + for param in params.values(): + assert ( + param.annotation != inspect.Parameter.empty + ), f"Parameter {param.name} has no type annotation" + self._inputs = list(params.values()) + # get the return type annotation + self._output = sig.return_annotation + # assert that the return type annotation is not empty + assert ( + self._output != inspect.Parameter.empty + ), "Return type annotation is empty" + + input_io_arg = PandasDataframe( + columns=[x.name for x in self._inputs], + column_types=[ + NdArrayType.from_python_type(x.annotation) for x in self._inputs + ], + column_shapes=[(1,) for x in self._inputs], + ) + + output_io_arg = PandasDataframe( + columns=[self.name.lower()], + column_types=[NdArrayType.from_python_type(self._output)], + column_shapes=[(1,)], + ) + + # set the input and output tags (similar to @forward decorator) + self.forward.tags["input"] = [input_io_arg] + self.forward.tags["output"] = [output_io_arg] + + @forward( + input_signatures=[], + output_signatures=[], + ) + def forward(self, in_df: pd.DataFrame): + out_df = pd.DataFrame() + # apply the function to each row + out_df[self.name.lower()] = in_df.apply(self._func, axis=1) + + return out_df + + +def generate_udf(func): + class_body = { + "_func": staticmethod(func), + } + return type(func.__name__, (UserDefinedFunction,), class_body) diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index 426719f87..f2fd26708 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -14,7 +14,6 @@ # limitations under the License. import hashlib import importlib -import inspect import os import pickle import shutil @@ -63,57 +62,6 @@ def str_to_class(class_path: str): return getattr(module, class_name) -def load_function_class_from_file(filepath, classname=None): - """ - Load a class from a Python file. If the classname is not specified, the function will check if there is only one class in the file and load that. If there are multiple classes, it will raise an error. - - Args: - filepath (str): The path to the Python file. - classname (str, optional): The name of the class to load. If not specified, the function will try to load a class with the same name as the file. Defaults to None. - - Returns: - The class instance. - - Raises: - ImportError: If the module cannot be loaded. - FileNotFoundError: If the file cannot be found. - RuntimeError: Any othe type of runtime error. - """ - try: - abs_path = Path(filepath).resolve() - spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - except ImportError as e: - # ImportError in the case when we are able to find the file but not able to load the module - err_msg = f"ImportError : Couldn't load function from {filepath} : {str(e)}. Not able to load the code provided in the file {abs_path}. Please ensure that the file contains the implementation code for the function." - raise ImportError(err_msg) - except FileNotFoundError as e: - # FileNotFoundError in the case when we are not able to find the file at all at the path. - err_msg = f"FileNotFoundError : Couldn't load function from {filepath} : {str(e)}. This might be because the function implementation file does not exist. Please ensure the file exists at {abs_path}" - raise FileNotFoundError(err_msg) - except Exception as e: - # Default exception, we don't know what exactly went wrong so we just output the error message - err_msg = f"Couldn't load function from {filepath} : {str(e)}." - raise RuntimeError(err_msg) - - # Try to load the specified class by name - if classname and hasattr(module, classname): - return getattr(module, classname) - - # If class name not specified, check if there is only one class in the file - classes = [ - obj - for _, obj in inspect.getmembers(module, inspect.isclass) - if obj.__module__ == module.__name__ - ] - if len(classes) != 1: - raise ImportError( - f"{filepath} contains {len(classes)} classes, please specify the correct class to load by naming the function with the same name in the CREATE query." - ) - return classes[0] - - def is_gpu_available() -> bool: """ Checks if the system has GPUS available to execute tasks diff --git a/evadb/utils/load_function_class_from_file.py b/evadb/utils/load_function_class_from_file.py new file mode 100644 index 000000000..49a95e9f5 --- /dev/null +++ b/evadb/utils/load_function_class_from_file.py @@ -0,0 +1,81 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +from pathlib import Path + +from evadb.functions.helpers.udf import generate_udf + + +def load_function_class_from_file(filepath, classname=None): + """ + Load a class from a Python file. If the classname is not specified, the function will check if there is only one class in the file and load that. If there are multiple classes, it will raise an error. + + Args: + filepath (str): The path to the Python file. + classname (str, optional): The name of the class to load. If not specified, the function will try to load a class with the same name as the file. Defaults to None. + + Returns: + The class instance. + + Raises: + ImportError: If the module cannot be loaded. + FileNotFoundError: If the file cannot be found. + RuntimeError: Any othe type of runtime error. + """ + try: + abs_path = Path(filepath).resolve() + spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except ImportError as e: + # ImportError in the case when we are able to find the file but not able to load the module + err_msg = f"ImportError : Couldn't load function from {filepath} : {str(e)}. Not able to load the code provided in the file {abs_path}. Please ensure that the file contains the implementation code for the function." + raise ImportError(err_msg) + except FileNotFoundError as e: + # FileNotFoundError in the case when we are not able to find the file at all at the path. + err_msg = f"FileNotFoundError : Couldn't load function from {filepath} : {str(e)}. This might be because the function implementation file does not exist. Please ensure the file exists at {abs_path}" + raise FileNotFoundError(err_msg) + except Exception as e: + # Default exception, we don't know what exactly went wrong so we just output the error message + err_msg = f"Couldn't load function from {filepath} : {str(e)}." + raise RuntimeError(err_msg) + + # Try to load the specified class by name + if classname and hasattr(module, classname): + obj = getattr(module, classname) + if not inspect.isclass(obj): + return generate_udf(obj) + return obj + + # If class name not specified, check if there is only one class in the file + classes = [ + obj + for _, obj in inspect.getmembers(module, inspect.isclass) + if obj.__module__ == module.__name__ + ] + if len(classes) != 1: + functions = [ + obj + for _, obj in inspect.getmembers(module, inspect.isfunction) + if obj.__module__ == module.__name__ + ] + if len(functions) == 1: + return generate_udf(functions[0]) + raise ImportError( + f"{filepath} contains {len(classes)} classes, please specify the correct class to load by naming the function with the same name in the CREATE query." + ) + + return classes[0] diff --git a/test/integration_tests/short/test_generic_utils.py b/test/integration_tests/short/test_generic_utils.py index b44d171ec..4305146a1 100644 --- a/test/integration_tests/short/test_generic_utils.py +++ b/test/integration_tests/short/test_generic_utils.py @@ -22,10 +22,10 @@ from evadb.utils.generic_utils import ( generate_file_path, is_gpu_available, - load_function_class_from_file, str_to_class, validate_kwargs, ) +from evadb.utils.load_function_class_from_file import load_function_class_from_file class ModulePathTest(unittest.TestCase): diff --git a/test/integration_tests/short/test_simple_udf.py b/test/integration_tests/short/test_simple_udf.py new file mode 100644 index 000000000..629f2088e --- /dev/null +++ b/test/integration_tests/short/test_simple_udf.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest +from test.util import get_evadb_for_testing + +import pandas as pd + +from evadb.server.command_handler import execute_query_fetch_all + + +class SimpleUDFTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.evadb = get_evadb_for_testing() + cls.evadb.catalog().reset() + + def write_udf_mod5(self, f): + f.write("def mod5(id:int)->int:\n") + f.write("\treturn id % 5\n") + + def write_udf_isEven(self, f): + f.write("def isEven(id:int)->bool:\n") + f.write("\treturn id % 2 == 0\n") + + def setUp(self): + fd, self.temp_path = tempfile.mkstemp(suffix=".py") + # Create a python file with two functions + with os.fdopen(fd, "w") as f: + self.write_udf_mod5(f) + self.write_udf_isEven(f) + # Create a table with 10 rows + execute_query_fetch_all( + self.evadb, "CREATE TABLE IF NOT EXISTS test (id INTEGER);" + ) + for i in range(10): + execute_query_fetch_all(self.evadb, f"INSERT INTO test (id) VALUES ({i});") + + def tearDown(self): + # Delete the python file + os.remove(self.temp_path) + # Delete the table + execute_query_fetch_all(self.evadb, "DROP TABLE test;") + + def test_first_udf(self): + # Create the UDF + execute_query_fetch_all( + self.evadb, f"CREATE FUNCTION mod5 IMPL '{self.temp_path}';" + ) + # Query the UDF + result = execute_query_fetch_all(self.evadb, "SELECT mod5(id) FROM test;") + expected = pd.DataFrame({"mod5.mod5": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]}) + # Check the result + self.assertTrue(result.frames.equals(expected)) + # Delete the UDF + execute_query_fetch_all(self.evadb, "DROP FUNCTION mod5;") + + def test_second_udf(self): + # Create the UDF + execute_query_fetch_all( + self.evadb, f"CREATE FUNCTION isEven IMPL '{self.temp_path}';" + ) + # Query the UDF + result = execute_query_fetch_all(self.evadb, "SELECT isEven(id) FROM test;") + expected = pd.DataFrame({"iseven.iseven": [i % 2 == 0 for i in range(10)]}) + # Check the result + self.assertEqual(result.frames.equals(expected), True) + # Delete the UDF + execute_query_fetch_all(self.evadb, "DROP FUNCTION isEven;") + + def test_udf_name_missing(self): + # Create the UDF + with self.assertRaises(Exception): + execute_query_fetch_all( + self.evadb, f"CREATE FUNCTION temp IMPL '{self.temp_path}';" + ) + + def test_udf_single_function(self): + # rewrite the file to have only one function + with open(self.temp_path, "w") as f: + self.write_udf_mod5(f) + # Create the UDF + execute_query_fetch_all( + self.evadb, f"CREATE FUNCTION mod_five IMPL '{self.temp_path}';" + ) + # Query the UDF + result = execute_query_fetch_all(self.evadb, "SELECT mod_five(id) FROM test;") + expected = pd.DataFrame({"mod_five.mod5": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]}) + # Check the result + self.assertTrue(result.frames.equals(expected)) + # Delete the UDF + execute_query_fetch_all(self.evadb, "DROP FUNCTION mod_five;") diff --git a/test/unit_tests/functions/test_abstract_udf.py b/test/unit_tests/functions/test_abstract_udf.py index 306486b4e..d12a5d5e8 100644 --- a/test/unit_tests/functions/test_abstract_udf.py +++ b/test/unit_tests/functions/test_abstract_udf.py @@ -21,6 +21,7 @@ import evadb from evadb.functions.abstract.abstract_function import AbstractFunction from evadb.functions.abstract.hf_abstract_function import AbstractHFFunction +from evadb.functions.helpers.udf import UserDefinedFunction, generate_udf from evadb.functions.yolo_object_detector import Yolo @@ -32,6 +33,16 @@ def test_function_abstract_functions(self): # skip yolo and HF to avoid downloading model if issubclass(derived_function_class, (Yolo, AbstractHFFunction)): continue + # if class is UserDefinedFunction + if issubclass(derived_function_class, UserDefinedFunction): + + def temp_fun(x: int) -> int: + return x + + dummy_object = generate_udf(temp_fun)() + self.assertTrue(str(dummy_object.name) is not None) + continue + if isabstract(derived_function_class) is False: class_type = derived_function_class # Check class init signature