Skip to content

Commit

Permalink
Simplify function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
gary-peng committed Nov 26, 2023
1 parent f6f3dda commit f991a9f
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 1 deletion.
2 changes: 2 additions & 0 deletions evadb/functions/My_SimpleUDF.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def My_SimpleUDF(cls, x:int)->int:
return x + 5
104 changes: 104 additions & 0 deletions evadb/functions/simple_udf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pandas as pd
import importlib
import pickle
from pathlib import Path
import typing

from evadb.catalog.catalog_type import NdArrayType
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward, setup
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.configuration.constants import EvaDB_ROOT_DIR

class SimpleUDF(AbstractFunction):
@setup(cacheable=False, function_type="SimpleUDF", batchable=False)
def setup(self):
in_labels = []
in_types = []
for label in self.types:
if label == "return": continue
in_labels.append(label)
in_types.append(self.convert_python_types(self.types[label]))
out_types = [self.convert_python_types(self.types['return'])]

self.forward.tags["input"] = [PandasDataframe(
columns=in_labels,
column_types=in_types,
column_shapes=[(1) * len(in_labels)]
)]

self.forward.tags["output"] = [PandasDataframe(
columns=["output"],
column_types=out_types,
column_shapes=[(1) * len(out_types)],
)]

@property
def name(self) -> str:
return "SimpleUDF"

@forward(None, None)
def forward(self, df: pd.DataFrame) -> pd.DataFrame:
def _forward(row: pd.Series) -> np.ndarray:
temp = self.udf
return temp(row)

ret = pd.DataFrame()
ret["output"] = df.apply(_forward, axis=1)
return ret

def set_udf(self, classname:str, filepath: str):
if f"{EvaDB_ROOT_DIR}/simple_udfs/" in filepath:
f = open(f"{EvaDB_ROOT_DIR}/simple_udfs/Func_SimpleUDF", 'rb')
self.udf = pickle.load(f)
else:
try:
abs_path = Path(filepath).resolve()
spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ImportError as e:
# ImportError in the case when we are able to find the file but not able to load the module
err_msg = f"ImportError : Couldn't load function from {filepath} : {str(e)}. Not able to load the code provided in the file {abs_path}. Please ensure that the file contains the implementation code for the function."
raise ImportError(err_msg)
except FileNotFoundError as e:
# FileNotFoundError in the case when we are not able to find the file at all at the path.
err_msg = f"FileNotFoundError : Couldn't load function from {filepath} : {str(e)}. This might be because the function implementation file does not exist. Please ensure the file exists at {abs_path}"
raise FileNotFoundError(err_msg)
except Exception as e:
# Default exception, we don't know what exactly went wrong so we just output the error message
err_msg = f"Couldn't load function from {filepath} : {str(e)}."
raise RuntimeError(err_msg)

# Try to load the specified class by name
if classname and hasattr(module, classname):
self.udf = getattr(module, classname)

self.types = typing.get_type_hints(self.udf)

def convert_python_types(self, type):
if type == bool:
return NdArrayType.BOOL
elif type == int:
return NdArrayType.INT32
elif type == float:
return NdArrayType.FLOAT32
elif type == str:
return NdArrayType.STR
else:
return NdArrayType.ANYTYPE
27 changes: 26 additions & 1 deletion evadb/interfaces/relational/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import multiprocessing

import pandas
import pickle

from evadb.configuration.constants import EvaDB_DATABASE_DIR
from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR
from evadb.database import EvaDBDatabase, init_evadb_instance
from evadb.expression.tuple_value_expression import TupleValueExpression
from evadb.functions.function_bootstrap_queries import init_builtin_functions
Expand Down Expand Up @@ -413,6 +414,30 @@ def create_function(
function_name, if_not_exists, impl_path, type, **kwargs
)
return EvaDBQuery(self._evadb, stmt)

def create_simple_function(
self,
function_name: str,
function: callable,
if_not_exists: bool = True,
) -> "EvaDBQuery":
"""
Create a function in the database by passing in a function instance.
Args:
function_name (str): Name of the function to be created.
if_not_exists (bool): If True, do not raise an error if the function already exist. If False, raise an error.
function (callable): The function instance
Returns:
EvaDBQuery: The EvaDBQuery object representing the function created.
"""
impl_path = f"{EvaDB_ROOT_DIR}/simple_udfs/{function_name}"
f = open(impl_path, 'ab')
pickle.dump(function, f)
f.close()

return self.create_function(function_name, if_not_exists, impl_path)

def create_table(
self, table_name: str, if_not_exists: bool = True, columns: str = None, **kwargs
Expand Down
13 changes: 13 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from evadb.configuration.constants import EvaDB_INSTALLATION_DIR
from evadb.utils.logging_manager import logger
from evadb.configuration.constants import EvaDB_ROOT_DIR


def validate_kwargs(
Expand Down Expand Up @@ -79,6 +80,14 @@ def load_function_class_from_file(filepath, classname=None):
FileNotFoundError: If the file cannot be found.
RuntimeError: Any othe type of runtime error.
"""
simple_udf_filepath = None
simple_udf_classname = None
if classname and "_SimpleUDF" in classname:
simple_udf_classname = classname
classname = "SimpleUDF"
simple_udf_filepath = filepath
filepath = f"{EvaDB_ROOT_DIR}/evadb/functions/simple_udf.py"

try:
abs_path = Path(filepath).resolve()
spec = importlib.util.spec_from_file_location(abs_path.stem, abs_path)
Expand All @@ -99,6 +108,10 @@ def load_function_class_from_file(filepath, classname=None):

# Try to load the specified class by name
if classname and hasattr(module, classname):
if classname == "SimpleUDF":
cls = getattr(module, classname)
cls.set_udf(cls, simple_udf_classname, simple_udf_filepath)
return cls
return getattr(module, classname)

# If class name not specified, check if there is only one class in the file
Expand Down
73 changes: 73 additions & 0 deletions test/integration_tests/long/test_simple_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test.util import suffix_pytest_xdist_worker_id_to_dir

import pytest
import pandas as pd

from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR
from evadb.interfaces.relational.db import connect
from evadb.server.command_handler import execute_query_fetch_all

def Func_SimpleUDF(cls, x:int)->int:
return x + 10

@pytest.mark.notparallel
class SimpleFunctionTests(unittest.TestCase):
def setUp(self):
self.db_dir = suffix_pytest_xdist_worker_id_to_dir(EvaDB_DATABASE_DIR)
self.conn = connect(self.db_dir)
self.evadb = self.conn._evadb
self.evadb.catalog().reset()

def tearDown(self):
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS test_table;")
execute_query_fetch_all(self.evadb, "DROP FUNCTION IF EXISTS My_SimpleUDF;")
execute_query_fetch_all(self.evadb, "DROP FUNCTION IF EXISTS Func_SimpleUDF;")

def test_from_file(self):
cursor = self.conn.cursor()

execute_query_fetch_all(self.evadb, "CREATE TABLE IF NOT EXISTS test_table (val INTEGER);")
cursor.insert("test_table", "(val)", "(1)").df()

cursor.create_function(
"My_SimpleUDF",
True,
f"{EvaDB_ROOT_DIR}/evadb/functions/My_SimpleUDF.py",
).df()

result = cursor.query("SELECT My_SimpleUDF(val) FROM test_table;").df()
expected = pd.DataFrame({'output': [6]})

self.assertTrue(expected.equals(result))

def test_from_function(self):
cursor = self.conn.cursor()

execute_query_fetch_all(self.evadb, "CREATE TABLE IF NOT EXISTS test_table (val INTEGER);")
cursor.insert("test_table", "(val)", "(1)").df()

cursor.create_simple_function(
"Func_SimpleUDF",
Func_SimpleUDF,
True,
).df()

result = cursor.query("SELECT Func_SimpleUDF(val) FROM test_table;").df()
expected = pd.DataFrame({'output': [11]})

self.assertTrue(expected.equals(result))

0 comments on commit f991a9f

Please sign in to comment.