Skip to content

Commit

Permalink
Add stable diffusion integration (georgia-tech-db#1240)
Browse files Browse the repository at this point in the history
Reopen the georgia-tech-db#1111.

---------

Co-authored-by: sudoboi <[email protected]>
Co-authored-by: Abhijith S Raj <[email protected]>
  • Loading branch information
3 people authored and a0x8o committed Nov 22, 2023
1 parent 23dc39f commit f9b931f
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ parts:
title: YOLO
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
- file: source/reference/ai/stablediffusion
title: Stable Diffusion
Expand All @@ -411,10 +412,16 @@ parts:
<<<<<<< HEAD
>>>>>>> 6d6a14c8 (Bump v0.3.4+ dev)
=======
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
- file: source/reference/ai/custom
title: Custom Model
<<<<<<< HEAD
=======
=======
- file: source/reference/ai/stablediffusion
title: Stable Diffusion
>>>>>>> bf022329 (Add stable diffusion integration (#1240))

- file: source/reference/ai/custom-ai-function
title: Bring Your Own AI Function
Expand Down
1 change: 1 addition & 0 deletions evadb/evadb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ third_party:
OPENAI_KEY: ""
PINECONE_API_KEY: ""
PINECONE_ENV: ""
REPLICATE_API_TOKEN: ""
29 changes: 29 additions & 0 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
<<<<<<< HEAD
=======
from evadb.configuration.configuration_manager import ConfigurationManager
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
Expand All @@ -33,8 +37,13 @@ class DallEFunction(AbstractFunction):
def name(self) -> str:
return "DallE"

<<<<<<< HEAD
def setup(self, openai_api_key="") -> None:
self.openai_api_key = openai_api_key
=======
def setup(self) -> None:
pass
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))

@forward(
input_signatures=[
Expand All @@ -56,6 +65,7 @@ def setup(self, openai_api_key="") -> None:
)
def forward(self, text_df):
try_to_import_openai()
<<<<<<< HEAD
from openai import OpenAI

api_key = self.openai_api_key
Expand All @@ -66,15 +76,34 @@ def forward(self, text_df):
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"

client = OpenAI(api_key=api_key)
=======
import openai

# Register API key, try configuration manager first
openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY")
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
<<<<<<< HEAD
response = client.images.generate(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response.data[0].url)
=======
response = openai.Image.create(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response["data"][0]["url"])
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
image = Image.open(BytesIO(image_response.content))

# Convert the image to an array format suitable for the DataFrame
Expand Down
11 changes: 11 additions & 0 deletions evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@
"""

<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
stablediffusion_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL '{}/functions/stable_diffusion.py';
""".format(
Expand All @@ -243,6 +246,7 @@
EvaDB_INSTALLATION_DIR
)

<<<<<<< HEAD
Upper_function_query = """CREATE FUNCTION IF NOT EXISTS UPPER
INPUT (input ANYTYPE)
OUTPUT (output NDARRAY STR(ANYDIM))
Expand All @@ -269,6 +273,8 @@

=======
>>>>>>> 2dacff69 (feat: sync master staging (#1050))
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))

def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
"""Load the built-in functions into the system during system bootstrapping.
Expand Down Expand Up @@ -337,6 +343,7 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
# Mvit_function_query,
Sift_function_query,
Yolo_function_query,
<<<<<<< HEAD
<<<<<<< HEAD
stablediffusion_function_query,
dalle_function_query,
Expand All @@ -345,6 +352,10 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
Concat_function_query,
=======
>>>>>>> 2dacff69 (feat: sync master staging (#1050))
=======
stablediffusion_function_query,
dalle_function_query,
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
]

# if mode is 'debug', add debug functions
Expand Down
24 changes: 24 additions & 0 deletions evadb/functions/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
<<<<<<< HEAD
=======
from evadb.configuration.configuration_manager import ConfigurationManager
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
Expand All @@ -33,8 +37,15 @@ class StableDiffusion(AbstractFunction):
def name(self) -> str:
return "StableDiffusion"

<<<<<<< HEAD
def setup(self, replicate_api_token="") -> None:
self.replicate_api_token = replicate_api_token
=======
def setup(
self,
) -> None:
pass
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))

@forward(
input_signatures=[
Expand All @@ -61,13 +72,26 @@ def forward(self, text_df):
try_to_import_replicate()
import replicate

<<<<<<< HEAD
replicate_api_key = self.replicate_api_token
# If not found, try OS Environment Variable
if replicate_api_key is None:
replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "")
assert (
len(replicate_api_key) != 0
), "Please set your Replicate API key using SET REPLICATE_API_TOKEN = '' or set the environment variable (REPLICATE_API_TOKEN)"
=======
# Register API key, try configuration manager first
replicate_api_key = ConfigurationManager().get_value(
"third_party", "REPLICATE_API_TOKEN"
)
# If not found, try OS Environment Variable
if len(replicate_api_key) == 0:
replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "")
assert (
len(replicate_api_key) != 0
), "Please set your Replicate API key in evadb.yml file (third_party, replicate_api_token) or environment variable (REPLICATE_API_TOKEN)"
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

model_id = (
Expand Down
6 changes: 6 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,9 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool:

return string_1.lower() == string_2.lower()
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))


def try_to_import_replicate():
Expand All @@ -771,5 +774,8 @@ def is_replicate_available():
return True
except ValueError:
return False
<<<<<<< HEAD
=======
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
13 changes: 13 additions & 0 deletions script/test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ long_integration_test() {
notebook_test() {
<<<<<<< HEAD
<<<<<<< HEAD
=======
<<<<<<< HEAD
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb" --ignore="tutorials/18-stable-diffusion.ipynb" --ignore="tutorials/19-employee-classification-prediction.ipynb"
Expand All @@ -105,6 +111,7 @@ notebook_test() {
=======
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
<<<<<<< HEAD
=======
=======
>>>>>>> 6d6a14c8 (Bump v0.3.4+ dev)
Expand All @@ -119,6 +126,12 @@ notebook_test() {
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
>>>>>>> 6d6a14c8 (Bump v0.3.4+ dev)
=======
>>>>>>> eva-master
=======
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb" --ignore="tutorials/18-stable-diffusion.ipynb"
>>>>>>> bf022329 (Add stable diffusion integration (#1240))
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
code=$?
print_error_code $code "NOTEBOOK TEST"
}
Expand Down
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def read(path, encoding="utf-8"):
"neuralforecast" # MODEL TRAIN AND FINE TUNING
]

imagegen_libs = [
"replicate"
]

### NEEDED FOR DEVELOPER TESTING ONLY

dev_libs = [
Expand Down Expand Up @@ -220,7 +224,14 @@ def read(path, encoding="utf-8"):
"xgboost": xgboost_libs,
"forecasting": forecasting_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
<<<<<<< HEAD
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs
=======
<<<<<<< HEAD
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
<<<<<<< HEAD
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
=======
"forecasting": forecasting_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
Expand Down Expand Up @@ -252,7 +263,14 @@ def read(path, encoding="utf-8"):
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
<<<<<<< HEAD
>>>>>>> 6d6a14c8 (Bump v0.3.4+ dev)
=======
>>>>>>> eva-master
=======
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs,
>>>>>>> bf022329 (Add stable diffusion integration (#1240))
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
}

setup(
Expand Down
7 changes: 7 additions & 0 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_ludwig_available,
is_pinecone_available,
is_qdrant_available,
is_replicate_available,
is_sklearn_available,
>>>>>>> 40a10ce1 (Bump v0.3.4+ dev)
)
Expand Down Expand Up @@ -139,9 +140,15 @@
reason="Run only if forecasting packages available",
)
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))

stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False, reason="requires replicate"
)
<<<<<<< HEAD
=======
>>>>>>> 2dacff69 (feat: sync master staging (#1050))
=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
25 changes: 25 additions & 0 deletions test/unit_tests/test_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,22 @@
import unittest
from io import BytesIO
from test.util import get_evadb_for_testing
<<<<<<< HEAD
from typing import List, Optional
from unittest.mock import MagicMock, patch

from PIL import Image as PILImage
from pydantic import AnyUrl, BaseModel
=======
from unittest.mock import MagicMock, patch

from PIL import Image
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))

from evadb.server.command_handler import execute_query_fetch_all


<<<<<<< HEAD
class Image(BaseModel):
b64_json: Optional[str] # Replace with the actual type if different
revised_prompt: Optional[str] # Replace with the actual type if different
Expand All @@ -36,6 +43,8 @@ class ImagesResponse(BaseModel):
data: List[Image]


=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
class DallEFunctionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
Expand All @@ -54,12 +63,21 @@ def setUp(self) -> None:
def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")

<<<<<<< HEAD
@patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"})
@patch("requests.get")
@patch("openai.OpenAI")
def test_dalle_image_generation(self, mock_openai, mock_requests_get):
# Generate a 1x1 white pixel PNG image in memory
img = PILImage.new("RGB", (1, 1), color="white")
=======
@patch.dict("os.environ", {"OPENAI_KEY": "mocked_openai_key"})
@patch("requests.get")
@patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]})
def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
# Generate a 1x1 white pixel PNG image in memory
img = Image.new("RGB", (1, 1), color="white")
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
img_byte_array = BytesIO()
img.save(img_byte_array, format="PNG")
mock_image_content = img_byte_array.getvalue()
Expand All @@ -68,6 +86,7 @@ def test_dalle_image_generation(self, mock_openai, mock_requests_get):
mock_response.content = mock_image_content
mock_requests_get.return_value = mock_response

<<<<<<< HEAD
# Set up the mock for OpenAI instance
mock_openai_instance = mock_openai.return_value
mock_openai_instance.images.generate.return_value = ImagesResponse(
Expand All @@ -80,6 +99,8 @@ def test_dalle_image_generation(self, mock_openai, mock_requests_get):
]
)

=======
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
function_name = "DallE"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")
Expand All @@ -92,6 +113,10 @@ def test_dalle_image_generation(self, mock_openai, mock_requests_get):
gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
execute_query_fetch_all(self.evadb, gpt_query)

<<<<<<< HEAD
mock_openai_instance.images.generate.assert_called_once_with(
=======
mock_openai_create.assert_called_once_with(
>>>>>>> 2b924b76 (Add stable diffusion integration (#1240))
prompt="a surreal painting of a cat", n=1, size="1024x1024"
)

0 comments on commit f9b931f

Please sign in to comment.