diff --git a/evadb/functions/dalle.py b/evadb/functions/dalle.py index 0f850667fc..37c75b77e2 100644 --- a/evadb/functions/dalle.py +++ b/evadb/functions/dalle.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from evadb.utils.generic_utils import try_to_import_openai import os -from evadb.configuration.configuration_manager import ConfigurationManager import pandas as pd from evadb.catalog.catalog_type import NdArrayType +from evadb.configuration.configuration_manager import ConfigurationManager from evadb.functions.abstract.abstract_function import AbstractFunction from evadb.functions.decorators.decorators import forward from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe +from evadb.utils.generic_utils import try_to_import_openai + class DallEFunction(AbstractFunction): @property @@ -35,11 +36,11 @@ def setup(self) -> None: @forward( input_signatures=[ PandasDataframe( - columns = ["prompt"], + columns=["prompt"], column_types=[ NdArrayType.STR, ], - column_shapes=[(None,)] + column_shapes=[(None,)], ) ], output_signatures=[ @@ -65,18 +66,14 @@ def forward(self, text_df): len(openai.api_key) != 0 ), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)" - def generate_image(text_df : PandasDataframe): + def generate_image(text_df: PandasDataframe): results = [] queries = text_df[text_df.columns[0]] for query in queries: - response = openai.Image.create( - prompt=query, - n=1, - size="1024x1024" - ) - results.append(response['data'][0]['url']) + response = openai.Image.create(prompt=query, n=1, size="1024x1024") + results.append(response["data"][0]["url"]) return results - df = pd.DataFrame({"response":generate_image(text_df=text_df)}) + df = pd.DataFrame({"response": generate_image(text_df=text_df)}) - return df \ No newline at end of file + return df diff --git a/evadb/functions/function_bootstrap_queries.py b/evadb/functions/function_bootstrap_queries.py index cfa57b7fcc..e22c941a1f 100644 --- a/evadb/functions/function_bootstrap_queries.py +++ b/evadb/functions/function_bootstrap_queries.py @@ -214,6 +214,7 @@ EvaDB_INSTALLATION_DIR ) + def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None: """Load the built-in functions into the system during system bootstrapping. @@ -280,4 +281,4 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None: db, query, do_not_print_exceptions=False, do_not_raise_exceptions=True ) except Exception: - pass \ No newline at end of file + pass diff --git a/evadb/functions/stable_diffusion.py b/evadb/functions/stable_diffusion.py index a5ac2c7e00..85262d81dd 100644 --- a/evadb/functions/stable_diffusion.py +++ b/evadb/functions/stable_diffusion.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from evadb.utils.generic_utils import try_to_import_replicate import os import pandas as pd @@ -22,13 +21,13 @@ from evadb.functions.abstract.abstract_function import AbstractFunction from evadb.functions.decorators.decorators import forward from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe - - +from evadb.utils.generic_utils import try_to_import_replicate _VALID_STABLE_DIFFUSION_MODEL = [ "sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", ] + class StableDiffusion(AbstractFunction): @property def name(self) -> str: @@ -36,19 +35,21 @@ def name(self) -> str: def setup( self, - model="sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33" + model="sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", ) -> None: - assert model in _VALID_STABLE_DIFFUSION_MODEL, f"Unsupported Stable Diffusion {model}" + assert ( + model in _VALID_STABLE_DIFFUSION_MODEL + ), f"Unsupported Stable Diffusion {model}" self.model = model @forward( input_signatures=[ PandasDataframe( - columns = ["prompt"], + columns=["prompt"], column_types=[ NdArrayType.STR, ], - column_shapes=[(None,)] + column_shapes=[(None,)], ) ], output_signatures=[ @@ -65,22 +66,23 @@ def forward(self, text_df): try_to_import_replicate() import replicate - if os.environ.get('REPLICATE_API_TOKEN') is None: - replicate_api_key = 'r8_Q75IAgbaHFvYVfLSMGmjQPcW5uZZoXz0jGalu' # token for testing - os.environ['REPLICATE_API_TOKEN'] = replicate_api_key + if os.environ.get("REPLICATE_API_TOKEN") is None: + replicate_api_key = ( + "r8_Q75IAgbaHFvYVfLSMGmjQPcW5uZZoXz0jGalu" # token for testing + ) + os.environ["REPLICATE_API_TOKEN"] = replicate_api_key # @retry(tries=5, delay=20) - def generate_image(text_df : PandasDataframe): + def generate_image(text_df: PandasDataframe): results = [] queries = text_df[text_df.columns[0]] for query in queries: output = replicate.run( - "stability-ai/" + self.model, - input={"prompt": query} + "stability-ai/" + self.model, input={"prompt": query} ) results.append(output[0]) return results - df = pd.DataFrame({"response":generate_image(text_df=text_df)}) + df = pd.DataFrame({"response": generate_image(text_df=text_df)}) - return df \ No newline at end of file + return df diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index a45e67b0cb..be18ceafd3 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -597,18 +597,20 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool: return string_1.lower() == string_2.lower() + def try_to_import_replicate(): try: - import replicate # noqa: F401 + import replicate # noqa: F401 except ImportError: raise ValueError( """Could not import replicate python package. Please install it with `pip install replicate`.""" ) - + + def is_replicate_available(): try: try_to_import_replicate() return True except ValueError: - return False \ No newline at end of file + return False diff --git a/test/integration_tests/long/functions/test_dalle.py b/test/integration_tests/long/functions/test_dalle.py index d3774bc361..270e3cac15 100644 --- a/test/integration_tests/long/functions/test_dalle.py +++ b/test/integration_tests/long/functions/test_dalle.py @@ -15,11 +15,12 @@ import unittest -from unittest.mock import patch from test.util import get_evadb_for_testing +from unittest.mock import patch from evadb.server.command_handler import execute_query_fetch_all + class DallEFunctionTest(unittest.TestCase): def setUp(self) -> None: self.evadb = get_evadb_for_testing() @@ -29,19 +30,19 @@ def setUp(self) -> None: """ execute_query_fetch_all(self.evadb, create_table_query) - test_prompts = ['a surreal painting of a cat'] + test_prompts = ["a surreal painting of a cat"] for prompt in test_prompts: insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')""" execute_query_fetch_all(self.evadb, insert_query) - + def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;") - - @patch.dict('os.environ', {'OPENAI_KEY': 'mocked_openai_key'}) - @patch('openai.Image.create', return_value={'data': [{'url': 'mocked_url'}]}) + + @patch.dict("os.environ", {"OPENAI_KEY": "mocked_openai_key"}) + @patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]}) def test_dalle_image_generation(self, mock_openai_create): - function_name = 'DallE' + function_name = "DallE" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") @@ -52,10 +53,8 @@ def test_dalle_image_generation(self, mock_openai_create): gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;" output_batch = execute_query_fetch_all(self.evadb, gpt_query) - + self.assertEqual(output_batch.columns, ["dalle.response"]) mock_openai_create.assert_called_once_with( - prompt='a surreal painting of a cat', - n=1, - size="1024x1024" - ) \ No newline at end of file + prompt="a surreal painting of a cat", n=1, size="1024x1024" + ) diff --git a/test/integration_tests/long/functions/test_selfdiffusion.py b/test/integration_tests/long/functions/test_selfdiffusion.py index ce141ed211..cb0df2d78a 100644 --- a/test/integration_tests/long/functions/test_selfdiffusion.py +++ b/test/integration_tests/long/functions/test_selfdiffusion.py @@ -14,12 +14,13 @@ # limitations under the License. import unittest -from unittest.mock import patch from test.markers import stable_diffusion_skip_marker from test.util import get_evadb_for_testing +from unittest.mock import patch from evadb.server.command_handler import execute_query_fetch_all + class StableDiffusionTest(unittest.TestCase): def setUp(self) -> None: self.evadb = get_evadb_for_testing() @@ -29,19 +30,19 @@ def setUp(self) -> None: """ execute_query_fetch_all(self.evadb, create_table_query) - test_prompts = ['pink cat riding a rocket to the moon'] + test_prompts = ["pink cat riding a rocket to the moon"] for prompt in test_prompts: insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')""" execute_query_fetch_all(self.evadb, insert_query) - + def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;") - + @stable_diffusion_skip_marker - @patch('replicate.run', return_value=[{'response': 'mocked response'}]) + @patch("replicate.run", return_value=[{"response": "mocked response"}]) def test_stable_diffusion_image_generation(self, mock_replicate_run): - function_name = 'StableDiffusion' + function_name = "StableDiffusion" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") @@ -52,9 +53,9 @@ def test_stable_diffusion_image_generation(self, mock_replicate_run): gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;" output_batch = execute_query_fetch_all(self.evadb, gpt_query) - + self.assertEqual(output_batch.columns, ["stablediffusion.response"]) mock_replicate_run.assert_called_once_with( "stability-ai/sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", - input={"prompt": 'pink cat riding a rocket to the moon'} - ) \ No newline at end of file + input={"prompt": "pink cat riding a rocket to the moon"}, + ) diff --git a/test/markers.py b/test/markers.py index cc3402ee46..7d98e55348 100644 --- a/test/markers.py +++ b/test/markers.py @@ -99,6 +99,5 @@ ) stable_diffusion_skip_marker = pytest.mark.skipif( - is_replicate_available() is False, - reason="requires replicate" -) \ No newline at end of file + is_replicate_available() is False, reason="requires replicate" +)