Skip to content

Commit

Permalink
Formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
sudoboi committed Sep 29, 2023
1 parent c637a2d commit 68f0d23
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 56 deletions.
23 changes: 10 additions & 13 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from evadb.utils.generic_utils import try_to_import_openai
import os
from evadb.configuration.configuration_manager import ConfigurationManager

import pandas as pd

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_openai


class DallEFunction(AbstractFunction):
@property
Expand All @@ -35,11 +36,11 @@ def setup(self) -> None:
@forward(
input_signatures=[
PandasDataframe(
columns = ["prompt"],
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)]
column_shapes=[(None,)],
)
],
output_signatures=[
Expand All @@ -65,18 +66,14 @@ def forward(self, text_df):
len(openai.api_key) != 0
), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"

def generate_image(text_df : PandasDataframe):
def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(
prompt=query,
n=1,
size="1024x1024"
)
results.append(response['data'][0]['url'])
response = openai.Image.create(prompt=query, n=1, size="1024x1024")
results.append(response["data"][0]["url"])
return results

df = pd.DataFrame({"response":generate_image(text_df=text_df)})
df = pd.DataFrame({"response": generate_image(text_df=text_df)})

return df
return df
3 changes: 2 additions & 1 deletion evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
EvaDB_INSTALLATION_DIR
)


def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
"""Load the built-in functions into the system during system bootstrapping.
Expand Down Expand Up @@ -280,4 +281,4 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
db, query, do_not_print_exceptions=False, do_not_raise_exceptions=True
)
except Exception:
pass
pass
32 changes: 17 additions & 15 deletions evadb/functions/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from evadb.utils.generic_utils import try_to_import_replicate
import os

import pandas as pd
Expand All @@ -22,33 +21,35 @@
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe


from evadb.utils.generic_utils import try_to_import_replicate

_VALID_STABLE_DIFFUSION_MODEL = [
"sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33",
]


class StableDiffusion(AbstractFunction):
@property
def name(self) -> str:
return "StableDiffusion"

def setup(
self,
model="sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33"
model="sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33",
) -> None:
assert model in _VALID_STABLE_DIFFUSION_MODEL, f"Unsupported Stable Diffusion {model}"
assert (
model in _VALID_STABLE_DIFFUSION_MODEL
), f"Unsupported Stable Diffusion {model}"
self.model = model

@forward(
input_signatures=[
PandasDataframe(
columns = ["prompt"],
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)]
column_shapes=[(None,)],
)
],
output_signatures=[
Expand All @@ -65,22 +66,23 @@ def forward(self, text_df):
try_to_import_replicate()
import replicate

if os.environ.get('REPLICATE_API_TOKEN') is None:
replicate_api_key = 'r8_Q75IAgbaHFvYVfLSMGmjQPcW5uZZoXz0jGalu' # token for testing
os.environ['REPLICATE_API_TOKEN'] = replicate_api_key
if os.environ.get("REPLICATE_API_TOKEN") is None:
replicate_api_key = (
"r8_Q75IAgbaHFvYVfLSMGmjQPcW5uZZoXz0jGalu" # token for testing
)
os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

# @retry(tries=5, delay=20)
def generate_image(text_df : PandasDataframe):
def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
output = replicate.run(
"stability-ai/" + self.model,
input={"prompt": query}
"stability-ai/" + self.model, input={"prompt": query}
)
results.append(output[0])
return results

df = pd.DataFrame({"response":generate_image(text_df=text_df)})
df = pd.DataFrame({"response": generate_image(text_df=text_df)})

return df
return df
8 changes: 5 additions & 3 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,18 +597,20 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool:

return string_1.lower() == string_2.lower()


def try_to_import_replicate():
try:
import replicate # noqa: F401
import replicate # noqa: F401
except ImportError:
raise ValueError(
"""Could not import replicate python package.
Please install it with `pip install replicate`."""
)



def is_replicate_available():
try:
try_to_import_replicate()
return True
except ValueError:
return False
return False
23 changes: 11 additions & 12 deletions test/integration_tests/long/functions/test_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@


import unittest
from unittest.mock import patch
from test.util import get_evadb_for_testing
from unittest.mock import patch

from evadb.server.command_handler import execute_query_fetch_all


class DallEFunctionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
Expand All @@ -29,19 +30,19 @@ def setUp(self) -> None:
"""
execute_query_fetch_all(self.evadb, create_table_query)

test_prompts = ['a surreal painting of a cat']
test_prompts = ["a surreal painting of a cat"]

for prompt in test_prompts:
insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')"""
execute_query_fetch_all(self.evadb, insert_query)

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")
@patch.dict('os.environ', {'OPENAI_KEY': 'mocked_openai_key'})
@patch('openai.Image.create', return_value={'data': [{'url': 'mocked_url'}]})

@patch.dict("os.environ", {"OPENAI_KEY": "mocked_openai_key"})
@patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]})
def test_dalle_image_generation(self, mock_openai_create):
function_name = 'DallE'
function_name = "DallE"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

Expand All @@ -52,10 +53,8 @@ def test_dalle_image_generation(self, mock_openai_create):

gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)

self.assertEqual(output_batch.columns, ["dalle.response"])
mock_openai_create.assert_called_once_with(
prompt='a surreal painting of a cat',
n=1,
size="1024x1024"
)
prompt="a surreal painting of a cat", n=1, size="1024x1024"
)
19 changes: 10 additions & 9 deletions test/integration_tests/long/functions/test_selfdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.

import unittest
from unittest.mock import patch
from test.markers import stable_diffusion_skip_marker
from test.util import get_evadb_for_testing
from unittest.mock import patch

from evadb.server.command_handler import execute_query_fetch_all


class StableDiffusionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
Expand All @@ -29,19 +30,19 @@ def setUp(self) -> None:
"""
execute_query_fetch_all(self.evadb, create_table_query)

test_prompts = ['pink cat riding a rocket to the moon']
test_prompts = ["pink cat riding a rocket to the moon"]

for prompt in test_prompts:
insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')"""
execute_query_fetch_all(self.evadb, insert_query)

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")

@stable_diffusion_skip_marker
@patch('replicate.run', return_value=[{'response': 'mocked response'}])
@patch("replicate.run", return_value=[{"response": "mocked response"}])
def test_stable_diffusion_image_generation(self, mock_replicate_run):
function_name = 'StableDiffusion'
function_name = "StableDiffusion"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

Expand All @@ -52,9 +53,9 @@ def test_stable_diffusion_image_generation(self, mock_replicate_run):

gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)

self.assertEqual(output_batch.columns, ["stablediffusion.response"])
mock_replicate_run.assert_called_once_with(
"stability-ai/sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33",
input={"prompt": 'pink cat riding a rocket to the moon'}
)
input={"prompt": "pink cat riding a rocket to the moon"},
)
5 changes: 2 additions & 3 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,5 @@
)

stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False,
reason="requires replicate"
)
is_replicate_available() is False, reason="requires replicate"
)

0 comments on commit 68f0d23

Please sign in to comment.